From df09d0830a1680a1028a5df0864d916d6b66d302 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Tue, 18 Jun 2024 15:03:31 +0200 Subject: [PATCH 001/115] feat(sqlalchemy): Replace peewee with sqlalchemy --- .github/workflows/integration-test.yml | 4 +- backend/alembic.ini | 114 +++++++ backend/apps/ollama/main.py | 7 +- backend/apps/openai/main.py | 4 +- backend/apps/socket/main.py | 4 +- backend/apps/webui/internal/db.py | 64 ++-- backend/apps/webui/internal/wrappers.py | 72 ----- backend/apps/webui/main.py | 6 +- backend/apps/webui/models/auths.py | 77 ++--- backend/apps/webui/models/chats.py | 303 ++++++++---------- backend/apps/webui/models/documents.py | 104 +++--- backend/apps/webui/models/files.py | 62 ++-- backend/apps/webui/models/functions.py | 74 +++-- backend/apps/webui/models/memories.py | 87 +++-- backend/apps/webui/models/models.py | 79 +++-- backend/apps/webui/models/prompts.py | 86 +++-- backend/apps/webui/models/tags.py | 198 +++++++----- backend/apps/webui/models/tools.py | 78 +++-- backend/apps/webui/models/users.py | 181 ++++++----- backend/apps/webui/routers/auths.py | 66 ++-- backend/apps/webui/routers/chats.py | 146 +++++---- backend/apps/webui/routers/documents.py | 40 ++- backend/apps/webui/routers/files.py | 25 +- backend/apps/webui/routers/functions.py | 27 +- backend/apps/webui/routers/memories.py | 29 +- backend/apps/webui/routers/models.py | 35 +- backend/apps/webui/routers/prompts.py | 32 +- backend/apps/webui/routers/tools.py | 35 +- backend/apps/webui/routers/users.py | 69 ++-- backend/apps/webui/routers/utils.py | 8 +- backend/main.py | 56 +++- backend/migrations/README | 4 + backend/migrations/env.py | 93 ++++++ backend/migrations/script.py.mako | 27 ++ .../migrations/versions/22b5ab2667b8_init.py | 188 +++++++++++ backend/requirements.txt | 13 +- backend/test/__init__.py | 0 backend/test/apps/webui/routers/test_auths.py | 209 ++++++++++++ backend/test/apps/webui/routers/test_chats.py | 239 ++++++++++++++ .../test/apps/webui/routers/test_documents.py | 106 ++++++ .../test/apps/webui/routers/test_models.py | 60 ++++ .../test/apps/webui/routers/test_prompts.py | 82 +++++ backend/test/apps/webui/routers/test_users.py | 170 ++++++++++ .../test/util/abstract_integration_test.py | 155 +++++++++ backend/test/util/mock_user.py | 45 +++ backend/utils/utils.py | 15 +- src/lib/apis/models/index.ts | 5 +- 47 files changed, 2580 insertions(+), 1003 deletions(-) create mode 100644 backend/alembic.ini delete mode 100644 backend/apps/webui/internal/wrappers.py create mode 100644 backend/migrations/README create mode 100644 backend/migrations/env.py create mode 100644 backend/migrations/script.py.mako create mode 100644 backend/migrations/versions/22b5ab2667b8_init.py create mode 100644 backend/test/__init__.py create mode 100644 backend/test/apps/webui/routers/test_auths.py create mode 100644 backend/test/apps/webui/routers/test_chats.py create mode 100644 backend/test/apps/webui/routers/test_documents.py create mode 100644 backend/test/apps/webui/routers/test_models.py create mode 100644 backend/test/apps/webui/routers/test_prompts.py create mode 100644 backend/test/apps/webui/routers/test_users.py create mode 100644 backend/test/util/abstract_integration_test.py create mode 100644 backend/test/util/mock_user.py diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 85810c2ed..c8e7c1672 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -171,7 +171,7 @@ jobs: fi # Check that service will reconnect to postgres when connection will be closed - status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health) + status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db) if [[ "$status_code" -ne 200 ]] ; then echo "Server has failed before postgres reconnect check" exit 1 @@ -183,7 +183,7 @@ jobs: cur = conn.cursor(); \ cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')" - status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health) + status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db) if [[ "$status_code" -ne 200 ]] ; then echo "Server has not reconnected to postgres after connection was closed: returned status $status_code" exit 1 diff --git a/backend/alembic.ini b/backend/alembic.ini new file mode 100644 index 000000000..72f2b762b --- /dev/null +++ b/backend/alembic.ini @@ -0,0 +1,114 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = REPLACE_WITH_DATABASE_URL + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 455dc89a5..85bb4c0df 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -31,6 +31,7 @@ from typing import Optional, List, Union from starlette.background import BackgroundTask +from apps.webui.internal.db import get_db from apps.webui.models.models import Models from apps.webui.models.users import Users from constants import ERROR_MESSAGES @@ -711,6 +712,7 @@ async def generate_chat_completion( form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db=Depends(get_db), ): log.debug( @@ -724,7 +726,7 @@ async def generate_chat_completion( } model_id = form_data.model - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(db, model_id) if model_info: if model_info.base_model_id: @@ -883,6 +885,7 @@ async def generate_openai_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db=Depends(get_db), ): form_data = OpenAIChatCompletionForm(**form_data) @@ -891,7 +894,7 @@ async def generate_openai_chat_completion( } model_id = form_data.model - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(db, model_id) if model_info: if model_info.base_model_id: diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 302dd8d98..bc40bc661 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -11,6 +11,7 @@ import logging from pydantic import BaseModel from starlette.background import BackgroundTask +from apps.webui.internal.db import get_db from apps.webui.models.models import Models from apps.webui.models.users import Users from constants import ERROR_MESSAGES @@ -353,12 +354,13 @@ async def generate_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), + db=Depends(get_db), ): idx = 0 payload = {**form_data} model_id = form_data.get("model") - model_info = Models.get_model_by_id(model_id) + model_info = Models.get_model_by_id(db, model_id) if model_info: if model_info.base_model_id: diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index 123ff31cd..bbbbccd79 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -24,7 +24,9 @@ async def connect(sid, environ, auth): data = decode_token(auth["token"]) if data is not None and "id" in data: - user = Users.get_user_by_id(data["id"]) + from apps.webui.internal.db import SessionLocal + + user = Users.get_user_by_id(SessionLocal(), data["id"]) if user: SESSION_POOL[sid] = user.id diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 80c30d652..5acf83d5c 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -1,18 +1,34 @@ import os import logging import json +from typing import Optional, Any +from typing_extensions import Self -from peewee import * -from peewee_migrate import Router +from sqlalchemy import create_engine, types, Dialect +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.sql.type_api import _T -from apps.webui.internal.wrappers import register_connection from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) -class JSONField(TextField): +class JSONField(types.TypeDecorator): + impl = types.Text + cache_ok = True + + def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any: + return json.dumps(value) + + def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any: + if value is not None: + return json.loads(value) + + def copy(self, **kw: Any) -> Self: + return JSONField(self.impl.length) + def db_value(self, value): return json.dumps(value) @@ -29,26 +45,24 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): else: pass +SQLALCHEMY_DATABASE_URL = DATABASE_URL +if "sqlite" in SQLALCHEMY_DATABASE_URL: + engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} + ) +else: + engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +Base = declarative_base() -# The `register_connection` function encapsulates the logic for setting up -# the database connection based on the connection string, while `connect` -# is a Peewee-specific method to manage the connection state and avoid errors -# when a connection is already open. -try: - DB = register_connection(DATABASE_URL) - log.info(f"Connected to a {DB.__class__.__name__} database.") -except Exception as e: - log.error(f"Failed to initialize the database connection: {e}") - raise -router = Router( - DB, - migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", - logger=log, -) -router.run() -try: - DB.connect(reuse_if_open=True) -except OperationalError as e: - log.info(f"Failed to connect to database again due to: {e}") - pass +def get_db(): + db = SessionLocal() + try: + yield db + db.commit() + except Exception as e: + db.rollback() + raise e + finally: + db.close() diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py deleted file mode 100644 index 2b5551ce2..000000000 --- a/backend/apps/webui/internal/wrappers.py +++ /dev/null @@ -1,72 +0,0 @@ -from contextvars import ContextVar -from peewee import * -from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError - -import logging -from playhouse.db_url import connect, parse -from playhouse.shortcuts import ReconnectMixin - -from config import SRC_LOG_LEVELS - -log = logging.getLogger(__name__) -log.setLevel(SRC_LOG_LEVELS["DB"]) - -db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} -db_state = ContextVar("db_state", default=db_state_default.copy()) - - -class PeeweeConnectionState(object): - def __init__(self, **kwargs): - super().__setattr__("_state", db_state) - super().__init__(**kwargs) - - def __setattr__(self, name, value): - self._state.get()[name] = value - - def __getattr__(self, name): - value = self._state.get()[name] - return value - - -class CustomReconnectMixin(ReconnectMixin): - reconnect_errors = ( - # psycopg2 - (OperationalError, "termin"), - (InterfaceError, "closed"), - # peewee - (PeeWeeInterfaceError, "closed"), - ) - - -class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): - pass - - -def register_connection(db_url): - db = connect(db_url) - if isinstance(db, PostgresqlDatabase): - # Enable autoconnect for SQLite databases, managed by Peewee - db.autoconnect = True - db.reuse_if_open = True - log.info("Connected to PostgreSQL database") - - # Get the connection details - connection = parse(db_url) - - # Use our custom database class that supports reconnection - db = ReconnectingPostgresqlDatabase( - connection["database"], - user=connection["user"], - password=connection["password"], - host=connection["host"], - port=connection["port"], - ) - db.connect(reuse_if_open=True) - elif isinstance(db, SqliteDatabase): - # Enable autoconnect for SQLite databases, managed by Peewee - db.autoconnect = True - db.reuse_if_open = True - log.info("Connected to SQLite database") - else: - raise ValueError("Unsupported database connection") - return db diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 28b1b4aac..8bef22c05 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -3,7 +3,7 @@ from fastapi.routing import APIRoute from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware - +from sqlalchemy.orm import Session from apps.webui.routers import ( auths, users, @@ -114,8 +114,8 @@ async def get_status(): } -async def get_pipe_models(): - pipes = Functions.get_functions_by_type("pipe", active_only=True) +async def get_pipe_models(db: Session): + pipes = Functions.get_functions_by_type(db, "pipe", active_only=True) pipe_models = [] for pipe in pipes: diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 9ea38abcb..5ff348dac 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -1,14 +1,14 @@ from pydantic import BaseModel -from typing import List, Union, Optional -import time +from typing import Optional import uuid import logging -from peewee import * +from sqlalchemy import String, Column, Boolean +from sqlalchemy.orm import Session from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import DB +from apps.webui.internal.db import Base from config import SRC_LOG_LEVELS @@ -20,14 +20,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### -class Auth(Model): - id = CharField(unique=True) - email = CharField() - password = TextField() - active = BooleanField() +class Auth(Base): + __tablename__ = "auth" - class Meta: - database = DB + id = Column(String, primary_key=True) + email = Column(String) + password = Column(String) + active = Column(Boolean) class AuthModel(BaseModel): @@ -94,12 +93,10 @@ class AddUserForm(SignupForm): class AuthsTable: - def __init__(self, db): - self.db = db - self.db.create_tables([Auth]) def insert_new_auth( self, + db: Session, email: str, password: str, name: str, @@ -114,24 +111,30 @@ class AuthsTable: auth = AuthModel( **{"id": id, "email": email, "password": password, "active": True} ) - result = Auth.create(**auth.model_dump()) + result = Auth(**auth.model_dump()) + db.add(result) user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub + db, id, name, email, profile_image_url, role, oauth_sub ) + db.commit() + db.refresh(result) + if result and user: return user else: return None - def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: + def authenticate_user( + self, db: Session, email: str, password: str + ) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") try: - auth = Auth.get(Auth.email == email, Auth.active == True) + auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) + user = Users.get_user_by_id(db, auth.id) return user else: return None @@ -140,55 +143,55 @@ class AuthsTable: except: return None - def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + def authenticate_user_by_api_key( + self, db: Session, api_key: str + ) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") # if no api_key, return None if not api_key: return None try: - user = Users.get_user_by_api_key(api_key) + user = Users.get_user_by_api_key(db, api_key) return user if user else None except: return False - def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: + def authenticate_user_by_trusted_header( + self, db: Session, email: str + ) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: - auth = Auth.get(Auth.email == email, Auth.active == True) + auth = db.query(Auth).filter(email=email, active=True).first() if auth: user = Users.get_user_by_id(auth.id) return user except: return None - def update_user_password_by_id(self, id: str, new_password: str) -> bool: + def update_user_password_by_id( + self, db: Session, id: str, new_password: str + ) -> bool: try: - query = Auth.update(password=new_password).where(Auth.id == id) - result = query.execute() - + result = db.query(Auth).filter_by(id=id).update({"password": new_password}) return True if result == 1 else False except: return False - def update_email_by_id(self, id: str, email: str) -> bool: + def update_email_by_id(self, db: Session, id: str, email: str) -> bool: try: - query = Auth.update(email=email).where(Auth.id == id) - result = query.execute() - + result = db.query(Auth).filter_by(id=id).update({"email": email}) return True if result == 1 else False except: return False - def delete_auth_by_id(self, id: str) -> bool: + def delete_auth_by_id(self, db: Session, id: str) -> bool: try: # Delete User - result = Users.delete_user_by_id(id) + result = Users.delete_user_by_id(db, id) if result: - # Delete Auth - query = Auth.delete().where(Auth.id == id) - query.execute() # Remove the rows, return number of rows removed. + db.query(Auth).filter_by(id=id).delete() return True else: @@ -197,4 +200,4 @@ class AuthsTable: return False -Auths = AuthsTable(DB) +Auths = AuthsTable() diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index a6f1ae923..dd92fd0a1 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -1,36 +1,39 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing import List, Union, Optional -from peewee import * -from playhouse.shortcuts import model_to_dict import json import uuid import time -from apps.webui.internal.db import DB +from sqlalchemy import Column, String, BigInteger, Boolean +from sqlalchemy.orm import Session + +from apps.webui.internal.db import Base + #################### # Chat DB Schema #################### -class Chat(Model): - id = CharField(unique=True) - user_id = CharField() - title = TextField() - chat = TextField() # Save Chat JSON as Text +class Chat(Base): + __tablename__ = "chat" - created_at = BigIntegerField() - updated_at = BigIntegerField() + id = Column(String, primary_key=True) + user_id = Column(String) + title = Column(String) + chat = Column(String) # Save Chat JSON as Text - share_id = CharField(null=True, unique=True) - archived = BooleanField(default=False) + created_at = Column(BigInteger) + updated_at = Column(BigInteger) - class Meta: - database = DB + share_id = Column(String, unique=True, nullable=True) + archived = Column(Boolean, default=False) class ChatModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str user_id: str title: str @@ -75,11 +78,10 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: - def __init__(self, db): - self.db = db - db.create_tables([Chat]) - def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: + def insert_new_chat( + self, db: Session, user_id: str, form_data: ChatForm + ) -> Optional[ChatModel]: id = str(uuid.uuid4()) chat = ChatModel( **{ @@ -94,29 +96,36 @@ class ChatTable: } ) - result = Chat.create(**chat.model_dump()) - return chat if result else None + result = Chat(**chat.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return ChatModel.model_validate(result) if result else None - def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: + def update_chat_by_id( + self, db: Session, id: str, chat: dict + ) -> Optional[ChatModel]: try: - query = Chat.update( - chat=json.dumps(chat), - title=chat["title"] if "title" in chat else "New Chat", - updated_at=int(time.time()), - ).where(Chat.id == id) - query.execute() + db.query(Chat).filter_by(id=id).update( + { + "chat": json.dumps(chat), + "title": chat["title"] if "title" in chat else "New Chat", + "updated_at": int(time.time()), + } + ) - chat = Chat.get(Chat.id == id) - return ChatModel(**model_to_dict(chat)) + return self.get_chat_by_id(db, id) except: return None - def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + def insert_shared_chat_by_chat_id( + self, db: Session, chat_id: str + ) -> Optional[ChatModel]: # Get the existing chat to share - chat = Chat.get(Chat.id == chat_id) + chat = db.get(Chat, chat_id) # Check if the chat is already shared if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared") # Create a new chat with the same data, but with a new ID shared_chat = ChatModel( **{ @@ -128,228 +137,196 @@ class ChatTable: "updated_at": int(time.time()), } ) - shared_result = Chat.create(**shared_chat.model_dump()) + shared_result = Chat(**shared_chat.model_dump()) + db.add(shared_result) + db.commit() + db.refresh(shared_result) # Update the original chat with the share_id result = ( - Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute() + db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id}) ) return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: + def update_shared_chat_by_chat_id( + self, db: Session, chat_id: str + ) -> Optional[ChatModel]: try: print("update_shared_chat_by_id") - chat = Chat.get(Chat.id == chat_id) + chat = db.get(Chat, chat_id) print(chat) - query = Chat.update( - title=chat.title, - chat=chat.chat, - ).where(Chat.id == chat.share_id) + db.query(Chat).filter_by(id=chat.share_id).update( + {"title": chat.title, "chat": chat.chat} + ) - query.execute() - - chat = Chat.get(Chat.id == chat.share_id) - return ChatModel(**model_to_dict(chat)) + return self.get_chat_by_id(db, chat.share_id) except: return None - def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: + def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool: try: - query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}") - query.execute() # Remove the rows, return number of rows removed. - + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() return True except: return False def update_chat_share_id_by_id( - self, id: str, share_id: Optional[str] + self, db: Session, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - query = Chat.update( - share_id=share_id, - ).where(Chat.id == id) - query.execute() + db.query(Chat).filter_by(id=id).update({"share_id": share_id}) - chat = Chat.get(Chat.id == id) - return ChatModel(**model_to_dict(chat)) + return self.get_chat_by_id(db, id) except: return None - def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: + def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]: try: - chat = self.get_chat_by_id(id) - query = Chat.update( - archived=(not chat.archived), - ).where(Chat.id == id) + chat = self.get_chat_by_id(db, id) + db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) - query.execute() - - chat = Chat.get(Chat.id == id) - return ChatModel(**model_to_dict(chat)) + return self.get_chat_by_id(db, id) except: return None - def archive_all_chats_by_user_id(self, user_id: str) -> bool: + def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool: try: - chats = self.get_chats_by_user_id(user_id) - for chat in chats: - query = Chat.update( - archived=True, - ).where(Chat.id == chat.id) - - query.execute() + db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) return True except: return False def get_archived_chat_list_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 + self, db: Session, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.archived == True) - .where(Chat.user_id == user_id) + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc()) - # .limit(limit) - # .offset(skip) - ] + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, + db: Session, user_id: str, include_archived: bool = False, skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - if include_archived: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.user_id == user_id) - .order_by(Chat.updated_at.desc()) - # .limit(limit) - # .offset(skip) - ] - else: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.archived == False) - .where(Chat.user_id == user_id) - .order_by(Chat.updated_at.desc()) - # .limit(limit) - # .offset(skip) - ] + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( - self, chat_ids: List[str], skip: int = 0, limit: int = 50 + self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.archived == False) - .where(Chat.id.in_(chat_ids)) + all_chats = ( + db.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) .order_by(Chat.updated_at.desc()) - ] + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_by_id(self, id: str) -> Optional[ChatModel]: + def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]: try: - chat = Chat.get(Chat.id == id) - return ChatModel(**model_to_dict(chat)) + chat = db.get(Chat, id) + return ChatModel.model_validate(chat) except: return None - def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: + def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]: try: - chat = Chat.get(Chat.share_id == id) + chat = db.query(Chat).filter_by(share_id=id).first() if chat: - chat = Chat.get(Chat.id == id) - return ChatModel(**model_to_dict(chat)) + return self.get_chat_by_id(db, id) else: return None + except Exception as e: + return None + + def get_chat_by_id_and_user_id( + self, db: Session, id: str, user_id: str + ) -> Optional[ChatModel]: + try: + chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None - def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: - try: - chat = Chat.get(Chat.id == id, Chat.user_id == user_id) - return ChatModel(**model_to_dict(chat)) - except: - return None - - def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select().order_by(Chat.updated_at.desc()) + def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]: + all_chats = ( + db.query(Chat) # .limit(limit).offset(skip) - ] - - def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.user_id == user_id) .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - ] + ) + return [ChatModel.model_validate(chat) for chat in all_chats] - def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - return [ - ChatModel(**model_to_dict(chat)) - for chat in Chat.select() - .where(Chat.archived == True) - .where(Chat.user_id == user_id) + def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]: + all_chats = ( + db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] + + def get_archived_chats_by_user_id( + self, db: Session, user_id: str + ) -> List[ChatModel]: + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc()) - ] + ) + return [ChatModel.model_validate(chat) for chat in all_chats] - def delete_chat_by_id(self, id: str) -> bool: + def delete_chat_by_id(self, db: Session, id: str) -> bool: try: - query = Chat.delete().where((Chat.id == id)) - query.execute() # Remove the rows, return number of rows removed. + db.query(Chat).filter_by(id=id).delete() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(db, id) except: return False - def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: + def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool: try: - query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) - query.execute() # Remove the rows, return number of rows removed. + db.query(Chat).filter_by(id=id, user_id=user_id).delete() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(db, id) except: return False - def delete_chats_by_user_id(self, user_id: str) -> bool: + def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool: try: - self.delete_shared_chats_by_user_id(user_id) - - query = Chat.delete().where(Chat.user_id == user_id) - query.execute() # Remove the rows, return number of rows removed. + self.delete_shared_chats_by_user_id(db, user_id) + db.query(Chat).filter_by(user_id=user_id).delete() return True except: return False - def delete_shared_chats_by_user_id(self, user_id: str) -> bool: + def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool: try: - shared_chat_ids = [ - f"shared-{chat.id}" - for chat in Chat.select().where(Chat.user_id == user_id) - ] + chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - query = Chat.delete().where(Chat.user_id << shared_chat_ids) - query.execute() # Remove the rows, return number of rows removed. + db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() return True except: return False -Chats = ChatTable(DB) +Chats = ChatTable() diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 3b730535f..b272a5912 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -1,14 +1,12 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict -from typing import List, Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import List, Optional import time import logging -from utils.utils import decode_token -from utils.misc import get_gravatar_url +from sqlalchemy import String, Column, BigInteger +from sqlalchemy.orm import Session -from apps.webui.internal.db import DB +from apps.webui.internal.db import Base import json @@ -22,20 +20,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### -class Document(Model): - collection_name = CharField(unique=True) - name = CharField(unique=True) - title = TextField() - filename = TextField() - content = TextField(null=True) - user_id = CharField() - timestamp = BigIntegerField() +class Document(Base): + __tablename__ = "document" - class Meta: - database = DB + collection_name = Column(String, primary_key=True) + name = Column(String, unique=True) + title = Column(String) + filename = Column(String) + content = Column(String, nullable=True) + user_id = Column(String) + timestamp = Column(BigInteger) class DocumentModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + collection_name: str name: str title: str @@ -72,12 +71,9 @@ class DocumentForm(DocumentUpdateForm): class DocumentsTable: - def __init__(self, db): - self.db = db - self.db.create_tables([Document]) def insert_new_doc( - self, user_id: str, form_data: DocumentForm + self, db: Session, user_id: str, form_data: DocumentForm ) -> Optional[DocumentModel]: document = DocumentModel( **{ @@ -88,73 +84,69 @@ class DocumentsTable: ) try: - result = Document.create(**document.model_dump()) + result = Document(**document.model_dump()) + db.add(result) + db.commit() + db.refresh(result) if result: - return document + return DocumentModel.model_validate(result) else: return None except: return None - def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: + def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]: try: - document = Document.get(Document.name == name) - return DocumentModel(**model_to_dict(document)) + document = db.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None - def get_docs(self) -> List[DocumentModel]: - return [ - DocumentModel(**model_to_dict(doc)) - for doc in Document.select() - # .limit(limit).offset(skip) - ] + def get_docs(self, db: Session) -> List[DocumentModel]: + return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()] def update_doc_by_name( - self, name: str, form_data: DocumentUpdateForm + self, db: Session, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - query = Document.update( - title=form_data.title, - name=form_data.name, - timestamp=int(time.time()), - ).where(Document.name == name) - query.execute() - - doc = Document.get(Document.name == form_data.name) - return DocumentModel(**model_to_dict(doc)) + db.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + return self.get_doc_by_name(db, form_data.name) except Exception as e: log.exception(e) return None def update_doc_content_by_name( - self, name: str, updated: dict + self, db: Session, name: str, updated: dict ) -> Optional[DocumentModel]: try: - doc = self.get_doc_by_name(name) + doc = self.get_doc_by_name(db, name) doc_content = json.loads(doc.content if doc.content else "{}") doc_content = {**doc_content, **updated} - query = Document.update( - content=json.dumps(doc_content), - timestamp=int(time.time()), - ).where(Document.name == name) - query.execute() + db.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) - doc = Document.get(Document.name == name) - return DocumentModel(**model_to_dict(doc)) + return self.get_doc_by_name(db, name) except Exception as e: log.exception(e) return None - def delete_doc_by_name(self, name: str) -> bool: + def delete_doc_by_name(self, db: Session, name: str) -> bool: try: - query = Document.delete().where((Document.name == name)) - query.execute() # Remove the rows, return number of rows removed. - + db.query(Document).filter_by(name=name).delete() return True except: return False -Documents = DocumentsTable(DB) +Documents = DocumentsTable() diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index 6459ad725..dc9f6be39 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -1,10 +1,12 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict from typing import List, Union, Optional import time import logging -from apps.webui.internal.db import DB, JSONField + +from sqlalchemy import Column, String, BigInteger +from sqlalchemy.orm import Session + +from apps.webui.internal.db import JSONField, Base import json @@ -18,15 +20,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### -class File(Model): - id = CharField(unique=True) - user_id = CharField() - filename = TextField() - meta = JSONField() - created_at = BigIntegerField() +class File(Base): + __tablename__ = "file" - class Meta: - database = DB + id = Column(String, primary_key=True) + user_id = Column(String) + filename = Column(String) + meta = Column(JSONField) + created_at = Column(BigInteger) class FileModel(BaseModel): @@ -36,6 +37,7 @@ class FileModel(BaseModel): meta: dict created_at: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) #################### # Forms @@ -57,11 +59,8 @@ class FileForm(BaseModel): class FilesTable: - def __init__(self, db): - self.db = db - self.db.create_tables([File]) - def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: + def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]: file = FileModel( **{ **form_data.model_dump(), @@ -71,42 +70,41 @@ class FilesTable: ) try: - result = File.create(**file.model_dump()) + result = File(**file.model_dump()) + db.add(result) + db.commit() + db.refresh(result) if result: - return file + return FileModel.model_validate(result) else: return None except Exception as e: print(f"Error creating tool: {e}") return None - def get_file_by_id(self, id: str) -> Optional[FileModel]: + def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]: try: - file = File.get(File.id == id) - return FileModel(**model_to_dict(file)) + file = db.get(File, id) + return FileModel.model_validate(file) except: return None - def get_files(self) -> List[FileModel]: - return [FileModel(**model_to_dict(file)) for file in File.select()] + def get_files(self, db: Session) -> List[FileModel]: + return [FileModel.model_validate(file) for file in db.query(File).all()] - def delete_file_by_id(self, id: str) -> bool: + def delete_file_by_id(self, db: Session, id: str) -> bool: try: - query = File.delete().where((File.id == id)) - query.execute() # Remove the rows, return number of rows removed. - + db.query(File).filter_by(id=id).delete() return True except: return False - def delete_all_files(self) -> bool: + def delete_all_files(self, db: Session) -> bool: try: - query = File.delete() - query.execute() # Remove the rows, return number of rows removed. - + db.query(File).delete() return True except: return False -Files = FilesTable(DB) +Files = FilesTable() diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 261987981..88fa24a21 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -1,10 +1,12 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict from typing import List, Union, Optional import time import logging -from apps.webui.internal.db import DB, JSONField + +from sqlalchemy import Column, String, Text, BigInteger, Boolean +from sqlalchemy.orm import Session + +from apps.webui.internal.db import JSONField, Base from apps.webui.models.users import Users import json @@ -21,20 +23,19 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### -class Function(Model): - id = CharField(unique=True) - user_id = CharField() - name = TextField() - type = TextField() - content = TextField() - meta = JSONField() - valves = JSONField() - is_active = BooleanField(default=False) - updated_at = BigIntegerField() - created_at = BigIntegerField() +class Function(Base): + __tablename__ = "function" - class Meta: - database = DB + id = Column(String, primary_key=True) + user_id = Column(String) + name = Column(Text) + type = Column(Text) + content = Column(Text) + meta = Column(JSONField) + valves = Column(JSONField) + is_active = Column(Boolean) + updated_at = Column(BigInteger) + created_at = Column(BigInteger) class FunctionMeta(BaseModel): @@ -53,6 +54,8 @@ class FunctionModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) + #################### # Forms @@ -82,12 +85,9 @@ class FunctionValves(BaseModel): class FunctionsTable: - def __init__(self, db): - self.db = db - self.db.create_tables([Function]) def insert_new_function( - self, user_id: str, type: str, form_data: FunctionForm + self, db: Session, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: function = FunctionModel( **{ @@ -100,19 +100,22 @@ class FunctionsTable: ) try: - result = Function.create(**function.model_dump()) + result = Function(**function.model_dump()) + db.add(result) + db.commit() + db.refresh(result) if result: - return function + return FunctionModel.model_validate(result) else: return None except Exception as e: print(f"Error creating tool: {e}") return None - def get_function_by_id(self, id: str) -> Optional[FunctionModel]: + def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]: try: - function = Function.get(Function.id == id) - return FunctionModel(**model_to_dict(function)) + function = db.get(Function, id) + return FunctionModel.model_validate(function) except: return None @@ -211,14 +214,11 @@ class FunctionsTable: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: try: - query = Function.update( + db.query(Function).filter_by(id=id).update({ **updated, - updated_at=int(time.time()), - ).where(Function.id == id) - query.execute() - - function = Function.get(Function.id == id) - return FunctionModel(**model_to_dict(function)) + "updated_at": int(time.time()), + }) + return self.get_function_by_id(db, id) except: return None @@ -235,14 +235,12 @@ class FunctionsTable: except: return None - def delete_function_by_id(self, id: str) -> bool: + def delete_function_by_id(self, db: Session, id: str) -> bool: try: - query = Function.delete().where((Function.id == id)) - query.execute() # Remove the rows, return number of rows removed. - + db.query(Function).filter_by(id=id).delete() return True except: return False -Functions = FunctionsTable(DB) +Functions = FunctionsTable() diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index ef63674ab..f5f6d13fb 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -1,9 +1,10 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict from typing import List, Union, Optional -from apps.webui.internal.db import DB +from sqlalchemy import Column, String, BigInteger +from sqlalchemy.orm import Session + +from apps.webui.internal.db import Base from apps.webui.models.chats import Chats import time @@ -14,15 +15,14 @@ import uuid #################### -class Memory(Model): - id = CharField(unique=True) - user_id = CharField() - content = TextField() - updated_at = BigIntegerField() - created_at = BigIntegerField() +class Memory(Base): + __tablename__ = "memory" - class Meta: - database = DB + id = Column(String, primary_key=True) + user_id = Column(String) + content = Column(String) + updated_at = Column(BigInteger) + created_at = Column(BigInteger) class MemoryModel(BaseModel): @@ -32,6 +32,8 @@ class MemoryModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) + #################### # Forms @@ -39,12 +41,10 @@ class MemoryModel(BaseModel): class MemoriesTable: - def __init__(self, db): - self.db = db - self.db.create_tables([Memory]) def insert_new_memory( self, + db: Session, user_id: str, content: str, ) -> Optional[MemoryModel]: @@ -59,74 +59,73 @@ class MemoriesTable: "updated_at": int(time.time()), } ) - result = Memory.create(**memory.model_dump()) + result = Memory(**memory.dict()) + db.add(result) + db.commit() + db.refresh(result) if result: - return memory + return MemoryModel.model_validate(result) else: return None def update_memory_by_id( self, + db: Session, id: str, content: str, ) -> Optional[MemoryModel]: try: - memory = Memory.get(Memory.id == id) - memory.content = content - memory.updated_at = int(time.time()) - memory.save() - return MemoryModel(**model_to_dict(memory)) + db.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + return self.get_memory_by_id(db, id) except: return None - def get_memories(self) -> List[MemoryModel]: + def get_memories(self, db: Session) -> List[MemoryModel]: try: - memories = Memory.select() - return [MemoryModel(**model_to_dict(memory)) for memory in memories] + memories = db.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None - def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: + def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]: try: - memories = Memory.select().where(Memory.user_id == user_id) - return [MemoryModel(**model_to_dict(memory)) for memory in memories] + memories = db.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None - def get_memory_by_id(self, id) -> Optional[MemoryModel]: + def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]: try: - memory = Memory.get(Memory.id == id) - return MemoryModel(**model_to_dict(memory)) + memory = db.get(Memory, id) + return MemoryModel.model_validate(memory) except: return None - def delete_memory_by_id(self, id: str) -> bool: + def delete_memory_by_id(self, db: Session, id: str) -> bool: try: - query = Memory.delete().where(Memory.id == id) - query.execute() # Remove the rows, return number of rows removed. - + db.query(Memory).filter_by(id=id).delete() return True except: return False - def delete_memories_by_user_id(self, user_id: str) -> bool: + def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool: try: - query = Memory.delete().where(Memory.user_id == user_id) - query.execute() - + db.query(Memory).filter_by(user_id=user_id).delete() return True except: return False - def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: + def delete_memory_by_id_and_user_id( + self, db: Session, id: str, user_id: str + ) -> bool: try: - query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id) - query.execute() - + db.query(Memory).filter_by(id=id, user_id=user_id).delete() return True except: return False -Memories = MemoriesTable(DB) +Memories = MemoriesTable() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 851352398..137333409 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -2,13 +2,11 @@ import json import logging from typing import Optional -import peewee as pw -from peewee import * - -from playhouse.shortcuts import model_to_dict from pydantic import BaseModel, ConfigDict +from sqlalchemy import String, Column, BigInteger +from sqlalchemy.orm import Session -from apps.webui.internal.db import DB, JSONField +from apps.webui.internal.db import Base, JSONField from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -46,41 +44,42 @@ class ModelMeta(BaseModel): pass -class Model(pw.Model): - id = pw.TextField(unique=True) +class Model(Base): + __tablename__ = "model" + + id = Column(String, primary_key=True) """ The model's id as used in the API. If set to an existing model, it will override the model. """ - user_id = pw.TextField() + user_id = Column(String) - base_model_id = pw.TextField(null=True) + base_model_id = Column(String, nullable=True) """ An optional pointer to the actual model that should be used when proxying requests. """ - name = pw.TextField() + name = Column(String) """ The human-readable display name of the model. """ - params = JSONField() + params = Column(JSONField) """ Holds a JSON encoded blob of parameters, see `ModelParams`. """ - meta = JSONField() + meta = Column(JSONField) """ Holds a JSON encoded blob of metadata, see `ModelMeta`. """ - updated_at = BigIntegerField() - created_at = BigIntegerField() - - class Meta: - database = DB + updated_at = Column(BigInteger) + created_at = Column(BigInteger) class ModelModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str user_id: str base_model_id: Optional[str] = None @@ -115,15 +114,9 @@ class ModelForm(BaseModel): class ModelsTable: - def __init__( - self, - db: pw.SqliteDatabase | pw.PostgresqlDatabase, - ): - self.db = db - self.db.create_tables([Model]) def insert_new_model( - self, form_data: ModelForm, user_id: str + self, db: Session, form_data: ModelForm, user_id: str ) -> Optional[ModelModel]: model = ModelModel( **{ @@ -134,46 +127,50 @@ class ModelsTable: } ) try: - result = Model.create(**model.model_dump()) + result = Model(**model.dict()) + db.add(result) + db.commit() + db.refresh(result) if result: - return model + return ModelModel.model_validate(result) else: return None except Exception as e: print(e) return None - def get_all_models(self) -> List[ModelModel]: - return [ModelModel(**model_to_dict(model)) for model in Model.select()] + def get_all_models(self, db: Session) -> List[ModelModel]: + return [ModelModel.model_validate(model) for model in db.query(Model).all()] - def get_model_by_id(self, id: str) -> Optional[ModelModel]: + def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]: try: - model = Model.get(Model.id == id) - return ModelModel(**model_to_dict(model)) + model = db.get(Model, id) + return ModelModel.model_validate(model) except: return None - def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + def update_model_by_id( + self, db: Session, id: str, model: ModelForm + ) -> Optional[ModelModel]: try: # update only the fields that are present in the model - query = Model.update(**model.model_dump()).where(Model.id == id) - query.execute() - - model = Model.get(Model.id == id) - return ModelModel(**model_to_dict(model)) + model = db.query(Model).get(id) + model.update(**model.model_dump()) + db.commit() + db.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) return None - def delete_model_by_id(self, id: str) -> bool: + def delete_model_by_id(self, db: Session, id: str) -> bool: try: - query = Model.delete().where(Model.id == id) - query.execute() + db.query(Model).filter_by(id=id).delete() return True except: return False -Models = ModelsTable(DB) +Models = ModelsTable() diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index c4ac6be14..21c4de3e1 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -1,13 +1,11 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict -from typing import List, Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import List, Optional import time -from utils.utils import decode_token -from utils.misc import get_gravatar_url +from sqlalchemy import String, Column, BigInteger +from sqlalchemy.orm import Session -from apps.webui.internal.db import DB +from apps.webui.internal.db import Base import json @@ -16,15 +14,14 @@ import json #################### -class Prompt(Model): - command = CharField(unique=True) - user_id = CharField() - title = TextField() - content = TextField() - timestamp = BigIntegerField() +class Prompt(Base): + __tablename__ = "prompt" - class Meta: - database = DB + command = Column(String, primary_key=True) + user_id = Column(String) + title = Column(String) + content = Column(String) + timestamp = Column(BigInteger) class PromptModel(BaseModel): @@ -34,6 +31,8 @@ class PromptModel(BaseModel): content: str timestamp: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) + #################### # Forms @@ -48,12 +47,8 @@ class PromptForm(BaseModel): class PromptsTable: - def __init__(self, db): - self.db = db - self.db.create_tables([Prompt]) - def insert_new_prompt( - self, user_id: str, form_data: PromptForm + self, db: Session, user_id: str, form_data: PromptForm ) -> Optional[PromptModel]: prompt = PromptModel( **{ @@ -66,53 +61,48 @@ class PromptsTable: ) try: - result = Prompt.create(**prompt.model_dump()) + result = Prompt(**prompt.dict()) + db.add(result) + db.commit() + db.refresh(result) if result: - return prompt + return PromptModel.model_validate(result) else: return None - except: + except Exception as e: return None - def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: + def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]: try: - prompt = Prompt.get(Prompt.command == command) - return PromptModel(**model_to_dict(prompt)) + prompt = db.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) except: return None - def get_prompts(self) -> List[PromptModel]: - return [ - PromptModel(**model_to_dict(prompt)) - for prompt in Prompt.select() - # .limit(limit).offset(skip) - ] + def get_prompts(self, db: Session) -> List[PromptModel]: + return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()] def update_prompt_by_command( - self, command: str, form_data: PromptForm + self, db: Session, command: str, form_data: PromptForm ) -> Optional[PromptModel]: try: - query = Prompt.update( - title=form_data.title, - content=form_data.content, - timestamp=int(time.time()), - ).where(Prompt.command == command) - - query.execute() - - prompt = Prompt.get(Prompt.command == command) - return PromptModel(**model_to_dict(prompt)) + db.query(Prompt).filter_by(command=command).update( + { + "title": form_data.title, + "content": form_data.content, + "timestamp": int(time.time()), + } + ) + return self.get_prompt_by_command(db, command) except: return None - def delete_prompt_by_command(self, command: str) -> bool: + def delete_prompt_by_command(self, db: Session, command: str) -> bool: try: - query = Prompt.delete().where((Prompt.command == command)) - query.execute() # Remove the rows, return number of rows removed. - + db.query(Prompt).filter_by(command=command).delete() return True except: return False -Prompts = PromptsTable(DB) +Prompts = PromptsTable() diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 4c4fa82e6..419425662 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -1,14 +1,15 @@ -from pydantic import BaseModel -from typing import List, Union, Optional -from peewee import * -from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict +from typing import List, Optional import json import uuid import time import logging -from apps.webui.internal.db import DB +from sqlalchemy import String, Column, BigInteger +from sqlalchemy.orm import Session + +from apps.webui.internal.db import Base from config import SRC_LOG_LEVELS @@ -20,25 +21,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### -class Tag(Model): - id = CharField(unique=True) - name = CharField() - user_id = CharField() - data = TextField(null=True) +class Tag(Base): + __tablename__ = "tag" - class Meta: - database = DB + id = Column(String, primary_key=True) + name = Column(String) + user_id = Column(String) + data = Column(String, nullable=True) -class ChatIdTag(Model): - id = CharField(unique=True) - tag_name = CharField() - chat_id = CharField() - user_id = CharField() - timestamp = BigIntegerField() +class ChatIdTag(Base): + __tablename__ = "chatidtag" - class Meta: - database = DB + id = Column(String, primary_key=True) + tag_name = Column(String) + chat_id = Column(String) + user_id = Column(String) + timestamp = Column(BigInteger) class TagModel(BaseModel): @@ -47,6 +46,8 @@ class TagModel(BaseModel): user_id: str data: Optional[str] = None + model_config = ConfigDict(from_attributes=True) + class ChatIdTagModel(BaseModel): id: str @@ -55,6 +56,8 @@ class ChatIdTagModel(BaseModel): user_id: str timestamp: int + model_config = ConfigDict(from_attributes=True) + #################### # Forms @@ -75,37 +78,39 @@ class ChatTagsResponse(BaseModel): class TagTable: - def __init__(self, db): - self.db = db - db.create_tables([Tag, ChatIdTag]) - def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: + def insert_new_tag( + self, db: Session, name: str, user_id: str + ) -> Optional[TagModel]: id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: - result = Tag.create(**tag.model_dump()) + result = Tag(**tag.dict()) + db.add(result) + db.commit() + db.refresh(result) if result: - return tag + return TagModel.model_validate(result) else: return None except Exception as e: return None def get_tag_by_name_and_user_id( - self, name: str, user_id: str + self, db: Session, name: str, user_id: str ) -> Optional[TagModel]: try: - tag = Tag.get(Tag.name == name, Tag.user_id == user_id) - return TagModel(**model_to_dict(tag)) + tag = db.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None def add_tag_to_chat( - self, user_id: str, form_data: ChatIdTagForm + self, db: Session, user_id: str, form_data: ChatIdTagForm ) -> Optional[ChatIdTagModel]: - tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) + tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id) if tag == None: - tag = self.insert_new_tag(form_data.tag_name, user_id) + tag = self.insert_new_tag(db, form_data.tag_name, user_id) id = str(uuid.uuid4()) chatIdTag = ChatIdTagModel( @@ -118,120 +123,135 @@ class TagTable: } ) try: - result = ChatIdTag.create(**chatIdTag.model_dump()) + result = ChatIdTag(**chatIdTag.dict()) + db.add(result) + db.commit() + db.refresh(result) if result: - return chatIdTag + return ChatIdTagModel.model_validate(result) else: return None except: return None - def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: + def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]: tag_names = [ - ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name - for chat_id_tag in ChatIdTag.select() - .where(ChatIdTag.user_id == user_id) - .order_by(ChatIdTag.timestamp.desc()) + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) ] return [ - TagModel(**model_to_dict(tag)) - for tag in Tag.select() - .where(Tag.user_id == user_id) - .where(Tag.name.in_(tag_names)) + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) ] def get_tags_by_chat_id_and_user_id( - self, chat_id: str, user_id: str + self, db: Session, chat_id: str, user_id: str ) -> List[TagModel]: tag_names = [ - ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name - for chat_id_tag in ChatIdTag.select() - .where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)) - .order_by(ChatIdTag.timestamp.desc()) + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) ] return [ - TagModel(**model_to_dict(tag)) - for tag in Tag.select() - .where(Tag.user_id == user_id) - .where(Tag.name.in_(tag_names)) + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) ] def get_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> Optional[ChatIdTagModel]: + self, db: Session, tag_name: str, user_id: str + ) -> List[ChatIdTagModel]: return [ - ChatIdTagModel(**model_to_dict(chat_id_tag)) - for chat_id_tag in ChatIdTag.select() - .where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name)) - .order_by(ChatIdTag.timestamp.desc()) + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) ] def count_chat_ids_by_tag_name_and_user_id( - self, tag_name: str, user_id: str + self, db: Session, tag_name: str, user_id: str ) -> int: - return ( - ChatIdTag.select() - .where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)) - .count() - ) + return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count() - def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: + def delete_tag_by_tag_name_and_user_id( + self, db: Session, tag_name: str, user_id: str + ) -> bool: try: - query = ChatIdTag.delete().where( - (ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id) + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() ) - res = query.execute() # Remove the rows, return number of rows removed. log.debug(f"res: {res}") - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + db, tag_name, user_id + ) if tag_count == 0: # Remove tag item from Tag col as well - query = Tag.delete().where( - (Tag.name == tag_name) & (Tag.user_id == user_id) - ) - query.execute() # Remove the rows, return number of rows removed. - + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: log.error(f"delete_tag: {e}") return False def delete_tag_by_tag_name_and_chat_id_and_user_id( - self, tag_name: str, chat_id: str, user_id: str + self, db: Session, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - query = ChatIdTag.delete().where( - (ChatIdTag.tag_name == tag_name) - & (ChatIdTag.chat_id == chat_id) - & (ChatIdTag.user_id == user_id) + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() ) - res = query.execute() # Remove the rows, return number of rows removed. log.debug(f"res: {res}") - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + db, tag_name, user_id + ) if tag_count == 0: # Remove tag item from Tag col as well - query = Tag.delete().where( - (Tag.name == tag_name) & (Tag.user_id == user_id) - ) - query.execute() # Remove the rows, return number of rows removed. + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: log.error(f"delete_tag: {e}") return False - def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: - tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) + def delete_tags_by_chat_id_and_user_id( + self, db: Session, chat_id: str, user_id: str + ) -> bool: + tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id) for tag in tags: self.delete_tag_by_tag_name_and_chat_id_and_user_id( - tag.tag_name, chat_id, user_id + db, tag.tag_name, chat_id, user_id ) return True -Tags = TagTable(DB) +Tags = TagTable() diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 694081df9..b8df2e163 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -1,10 +1,11 @@ -from pydantic import BaseModel -from peewee import * -from playhouse.shortcuts import model_to_dict -from typing import List, Union, Optional +from pydantic import BaseModel, ConfigDict +from typing import List, Optional import time import logging -from apps.webui.internal.db import DB, JSONField +from sqlalchemy import String, Column, BigInteger +from sqlalchemy.orm import Session + +from apps.webui.internal.db import Base, JSONField from apps.webui.models.users import Users import json @@ -21,19 +22,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) #################### -class Tool(Model): - id = CharField(unique=True) - user_id = CharField() - name = TextField() - content = TextField() - specs = JSONField() - meta = JSONField() - valves = JSONField() - updated_at = BigIntegerField() - created_at = BigIntegerField() +class Tool(Base): + __tablename__ = "tool" - class Meta: - database = DB + id = Column(String, primary_key=True) + user_id = Column(String) + name = Column(String) + content = Column(String) + specs = Column(JSONField) + meta = Column(JSONField) + valves = Column(JSONField) + updated_at = Column(BigInteger) + created_at = Column(BigInteger) class ToolMeta(BaseModel): @@ -51,6 +51,8 @@ class ToolModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) + #################### # Forms @@ -78,12 +80,9 @@ class ToolValves(BaseModel): class ToolsTable: - def __init__(self, db): - self.db = db - self.db.create_tables([Tool]) def insert_new_tool( - self, user_id: str, form_data: ToolForm, specs: List[dict] + self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict] ) -> Optional[ToolModel]: tool = ToolModel( **{ @@ -96,24 +95,27 @@ class ToolsTable: ) try: - result = Tool.create(**tool.model_dump()) + result = Tool(**tool.dict()) + db.add(result) + db.commit() + db.refresh(result) if result: - return tool + return ToolModel.model_validate(result) else: return None except Exception as e: print(f"Error creating tool: {e}") return None - def get_tool_by_id(self, id: str) -> Optional[ToolModel]: + def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]: try: - tool = Tool.get(Tool.id == id) - return ToolModel(**model_to_dict(tool)) + tool = db.get(Tool, id) + return ToolModel.model_validate(tool) except: return None - def get_tools(self) -> List[ToolModel]: - return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] + def get_tools(self, db: Session) -> List[ToolModel]: + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: @@ -180,25 +182,19 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - query = Tool.update( - **updated, - updated_at=int(time.time()), - ).where(Tool.id == id) - query.execute() - - tool = Tool.get(Tool.id == id) - return ToolModel(**model_to_dict(tool)) + db.query(Tool).filter_by(id=id).update( + {**updated, "updated_at": int(time.time())} + ) + return self.get_tool_by_id(db, id) except: return None - def delete_tool_by_id(self, id: str) -> bool: + def delete_tool_by_id(self, db: Session, id: str) -> bool: try: - query = Tool.delete().where((Tool.id == id)) - query.execute() # Remove the rows, return number of rows removed. - + db.query(Tool).filter_by(id=id).delete() return True except: return False -Tools = ToolsTable(DB) +Tools = ToolsTable() diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index e3e1842b8..7202d2d71 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -1,11 +1,13 @@ -from pydantic import BaseModel, ConfigDict -from peewee import * -from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict, parse_obj_as from typing import List, Union, Optional import time + +from sqlalchemy import String, Column, BigInteger, Text +from sqlalchemy.orm import Session + from utils.misc import get_gravatar_url -from apps.webui.internal.db import DB, JSONField +from apps.webui.internal.db import Base, JSONField from apps.webui.models.chats import Chats #################### @@ -13,25 +15,24 @@ from apps.webui.models.chats import Chats #################### -class User(Model): - id = CharField(unique=True) - name = CharField() - email = CharField() - role = CharField() - profile_image_url = TextField() +class User(Base): + __tablename__ = "user" - last_active_at = BigIntegerField() - updated_at = BigIntegerField() - created_at = BigIntegerField() + id = Column(String, primary_key=True) + name = Column(String) + email = Column(String) + role = Column(String) + profile_image_url = Column(String) - api_key = CharField(null=True, unique=True) - settings = JSONField(null=True) - info = JSONField(null=True) + last_active_at = Column(BigInteger) + updated_at = Column(BigInteger) + created_at = Column(BigInteger) - oauth_sub = TextField(null=True, unique=True) + api_key = Column(String, nullable=True, unique=True) + settings = Column(JSONField, nullable=True) + info = Column(JSONField, nullable=True) - class Meta: - database = DB + oauth_sub = Column(Text, unique=True) class UserSettings(BaseModel): @@ -41,6 +42,8 @@ class UserSettings(BaseModel): class UserModel(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str name: str email: str @@ -76,12 +79,10 @@ class UserUpdateForm(BaseModel): class UsersTable: - def __init__(self, db): - self.db = db - self.db.create_tables([User]) def insert_new_user( self, + db: Session, id: str, name: str, email: str, @@ -102,30 +103,33 @@ class UsersTable: "oauth_sub": oauth_sub, } ) - result = User.create(**user.model_dump()) + result = User(**user.model_dump()) + db.add(result) + db.commit() + db.refresh(result) if result: return user else: return None - def get_user_by_id(self, id: str) -> Optional[UserModel]: + def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]: try: - user = User.get(User.id == id) - return UserModel(**model_to_dict(user)) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except Exception as e: + return None + + def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]: + try: + user = db.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) except: return None - def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]: try: - user = User.get(User.api_key == api_key) - return UserModel(**model_to_dict(user)) - except: - return None - - def get_user_by_email(self, email: str) -> Optional[UserModel]: - try: - user = User.get(User.email == email) - return UserModel(**model_to_dict(user)) + user = db.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) except: return None @@ -136,88 +140,94 @@ class UsersTable: except: return None - def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: - return [ - UserModel(**model_to_dict(user)) - for user in User.select() - # .limit(limit).offset(skip) - ] + def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]: + users = ( + db.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] - def get_num_users(self) -> Optional[int]: - return User.select().count() + def get_num_users(self, db: Session) -> Optional[int]: + return db.query(User).count() - def get_first_user(self) -> UserModel: + def get_first_user(self, db: Session) -> UserModel: try: - user = User.select().order_by(User.created_at).first() - return UserModel(**model_to_dict(user)) + user = db.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) except: return None - def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: + def update_user_role_by_id( + self, db: Session, id: str, role: str + ) -> Optional[UserModel]: try: - query = User.update(role=role).where(User.id == id) - query.execute() + db.query(User).filter_by(id=id).update({"role": role}) + db.commit() - user = User.get(User.id == id) - return UserModel(**model_to_dict(user)) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_profile_image_url_by_id( - self, id: str, profile_image_url: str + self, db: Session, id: str, profile_image_url: str ) -> Optional[UserModel]: try: - query = User.update(profile_image_url=profile_image_url).where( - User.id == id + db.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} ) - query.execute() + db.commit() - user = User.get(User.id == id) - return UserModel(**model_to_dict(user)) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None - def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: + def update_user_last_active_by_id( + self, db: Session, id: str + ) -> Optional[UserModel]: try: - query = User.update(last_active_at=int(time.time())).where(User.id == id) - query.execute() + db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())}) - user = User.get(User.id == id) - return UserModel(**model_to_dict(user)) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_oauth_sub_by_id( - self, id: str, oauth_sub: str + self, db: Session, id: str, oauth_sub: str ) -> Optional[UserModel]: try: - query = User.update(oauth_sub=oauth_sub).where(User.id == id) - query.execute() + db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - user = User.get(User.id == id) - return UserModel(**model_to_dict(user)) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None - def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: + def update_user_by_id( + self, db: Session, id: str, updated: dict + ) -> Optional[UserModel]: try: - query = User.update(**updated).where(User.id == id) - query.execute() + db.query(User).filter_by(id=id).update(updated) + db.commit() - user = User.get(User.id == id) - return UserModel(**model_to_dict(user)) - except: + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) + except Exception as e: return None - def delete_user_by_id(self, id: str) -> bool: + def delete_user_by_id(self, db: Session, id: str) -> bool: try: # Delete User Chats - result = Chats.delete_chats_by_user_id(id) + result = Chats.delete_chats_by_user_id(db, id) if result: # Delete User - query = User.delete().where(User.id == id) - query.execute() # Remove the rows, return number of rows removed. + db.query(User).filter_by(id=id).delete() + db.commit() return True else: @@ -225,21 +235,20 @@ class UsersTable: except: return False - def update_user_api_key_by_id(self, id: str, api_key: str) -> str: + def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str: try: - query = User.update(api_key=api_key).where(User.id == id) - result = query.execute() - + result = db.query(User).filter_by(id=id).update({"api_key": api_key}) + db.commit() return True if result == 1 else False except: return False - def get_user_api_key_by_id(self, id: str) -> Optional[str]: + def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]: try: - user = User.get(User.id == id) + user = db.query(User).filter_by(id=id).first() return user.api_key - except: + except Exception as e: return None -Users = UsersTable(DB) +Users = UsersTable() diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index 1be79d259..e83ee8cb9 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -10,6 +10,7 @@ import re import uuid import csv +from apps.webui.internal.db import get_db from apps.webui.models.auths import ( SigninForm, SignupForm, @@ -78,10 +79,13 @@ async def get_session_user( @router.post("/update/profile", response_model=UserResponse) async def update_profile( - form_data: UpdateProfileForm, session_user=Depends(get_current_user) + form_data: UpdateProfileForm, + session_user=Depends(get_current_user), + db=Depends(get_db), ): if session_user: user = Users.update_user_by_id( + db, session_user.id, {"profile_image_url": form_data.profile_image_url, "name": form_data.name}, ) @@ -100,16 +104,18 @@ async def update_profile( @router.post("/update/password", response_model=bool) async def update_password( - form_data: UpdatePasswordForm, session_user=Depends(get_current_user) + form_data: UpdatePasswordForm, + session_user=Depends(get_current_user), + db=Depends(get_db), ): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) if session_user: - user = Auths.authenticate_user(session_user.email, form_data.password) + user = Auths.authenticate_user(db, session_user.email, form_data.password) if user: hashed = get_password_hash(form_data.new_password) - return Auths.update_user_password_by_id(user.id, hashed) + return Auths.update_user_password_by_id(db, user.id, hashed) else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) else: @@ -122,7 +128,7 @@ async def update_password( @router.post("/signin", response_model=SigninResponse) -async def signin(request: Request, response: Response, form_data: SigninForm): +async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) @@ -133,32 +139,34 @@ async def signin(request: Request, response: Response, form_data: SigninForm): trusted_name = request.headers.get( WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email ) - if not Users.get_user_by_email(trusted_email.lower()): + if not Users.get_user_by_email(db, trusted_email.lower()): await signup( request, SignupForm( email=trusted_email, password=str(uuid.uuid4()), name=trusted_name ), + db, ) - user = Auths.authenticate_user_by_trusted_header(trusted_email) + user = Auths.authenticate_user_by_trusted_header(db, trusted_email) elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin" - if Users.get_user_by_email(admin_email.lower()): - user = Auths.authenticate_user(admin_email.lower(), admin_password) + if Users.get_user_by_email(db, admin_email.lower()): + user = Auths.authenticate_user(db, admin_email.lower(), admin_password) else: - if Users.get_num_users() != 0: + if Users.get_num_users(db) != 0: raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) await signup( request, SignupForm(email=admin_email, password=admin_password, name="User"), + db, ) - user = Auths.authenticate_user(admin_email.lower(), admin_password) + user = Auths.authenticate_user(db, admin_email.lower(), admin_password) else: - user = Auths.authenticate_user(form_data.email.lower(), form_data.password) + user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password) if user: token = create_token( @@ -192,7 +200,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): @router.post("/signup", response_model=SigninResponse) -async def signup(request: Request, response: Response, form_data: SignupForm): +async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)): if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED @@ -203,17 +211,18 @@ async def signup(request: Request, response: Response, form_data: SignupForm): status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if Users.get_user_by_email(db, form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: role = ( "admin" - if Users.get_num_users() == 0 + if Users.get_num_users(db) == 0 else request.app.state.config.DEFAULT_USER_ROLE ) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( + db, form_data.email.lower(), hashed, form_data.name, @@ -267,14 +276,16 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.post("/add", response_model=SigninResponse) -async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): +async def add_user( + form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db) +): if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower()): + if Users.get_user_by_email(db, form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -282,6 +293,7 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): print(form_data) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( + db, form_data.email.lower(), hashed, form_data.name, @@ -312,7 +324,9 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): @router.get("/admin/details") -async def get_admin_details(request: Request, user=Depends(get_current_user)): +async def get_admin_details( + request: Request, user=Depends(get_current_user), db=Depends(get_db) +): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None @@ -320,11 +334,11 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)): print(admin_email, admin_name) if admin_email: - admin = Users.get_user_by_email(admin_email) + admin = Users.get_user_by_email(db, admin_email) if admin: admin_name = admin.name else: - admin = Users.get_first_user() + admin = Users.get_first_user(db) if admin: admin_email = admin.email admin_name = admin.name @@ -397,9 +411,9 @@ async def update_admin_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def create_api_key_(user=Depends(get_current_user)): +async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)): api_key = create_api_key() - success = Users.update_user_api_key_by_id(user.id, api_key) + success = Users.update_user_api_key_by_id(db, user.id, api_key) if success: return { "api_key": api_key, @@ -410,15 +424,15 @@ async def create_api_key_(user=Depends(get_current_user)): # delete api key @router.delete("/api_key", response_model=bool) -async def delete_api_key(user=Depends(get_current_user)): - success = Users.update_user_api_key_by_id(user.id, None) +async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)): + success = Users.update_user_api_key_by_id(db, user.id, None) return success # get api key @router.get("/api_key", response_model=ApiKey) -async def get_api_key(user=Depends(get_current_user)): - api_key = Users.get_user_api_key_by_id(user.id) +async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)): + api_key = Users.get_user_api_key_by_id(db, user.id) if api_key: return { "api_key": api_key, diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 9d1cceaa1..1454d47bd 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -1,6 +1,8 @@ from fastapi import Depends, Request, HTTPException, status from datetime import datetime, timedelta from typing import List, Union, Optional + +from apps.webui.internal.db import get_db from utils.utils import get_current_user, get_admin_user from fastapi import APIRouter from pydantic import BaseModel @@ -43,9 +45,9 @@ router = APIRouter() @router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/list", response_model=List[ChatTitleIdResponse]) async def get_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50 + user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) ): - return Chats.get_chat_list_by_user_id(user.id, skip, limit) + return Chats.get_chat_list_by_user_id(db, user.id, skip, limit) ############################ @@ -54,7 +56,9 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) -async def delete_all_user_chats(request: Request, user=Depends(get_current_user)): +async def delete_all_user_chats( + request: Request, user=Depends(get_current_user), db=Depends(get_db) +): if ( user.role == "user" @@ -65,7 +69,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user) detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Chats.delete_chats_by_user_id(user.id) + result = Chats.delete_chats_by_user_id(db, user.id) return result @@ -76,10 +80,14 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user) @router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse]) async def get_user_chat_list_by_user_id( - user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 + user_id: str, + user=Depends(get_admin_user), + skip: int = 0, + limit: int = 50, + db=Depends(get_db), ): return Chats.get_chat_list_by_user_id( - user_id, include_archived=True, skip=skip, limit=limit + db, user_id, include_archived=True, skip=skip, limit=limit ) @@ -89,9 +97,11 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): +async def create_new_chat( + form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) +): try: - chat = Chats.insert_new_chat(user.id, form_data) + chat = Chats.insert_new_chat(db, user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) except Exception as e: log.exception(e) @@ -106,10 +116,10 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): @router.get("/all", response_model=List[ChatResponse]) -async def get_user_chats(user=Depends(get_current_user)): +async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_chats_by_user_id(user.id) + for chat in Chats.get_chats_by_user_id(db, user.id) ] @@ -119,10 +129,10 @@ async def get_user_chats(user=Depends(get_current_user)): @router.get("/all/archived", response_model=List[ChatResponse]) -async def get_user_chats(user=Depends(get_current_user)): +async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_archived_chats_by_user_id(user.id) + for chat in Chats.get_archived_chats_by_user_id(db, user.id) ] @@ -132,7 +142,7 @@ async def get_user_chats(user=Depends(get_current_user)): @router.get("/all/db", response_model=List[ChatResponse]) -async def get_all_user_chats_in_db(user=Depends(get_admin_user)): +async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)): if not ENABLE_ADMIN_EXPORT: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -140,7 +150,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): ) return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_chats() + for chat in Chats.get_chats(db) ] @@ -151,9 +161,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): @router.get("/archived", response_model=List[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50 + user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) ): - return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) + return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit) ############################ @@ -162,8 +172,8 @@ async def get_archived_session_user_chat_list( @router.post("/archive/all", response_model=bool) -async def archive_all_chats(user=Depends(get_current_user)): - return Chats.archive_all_chats_by_user_id(user.id) +async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)): + return Chats.archive_all_chats_by_user_id(db, user.id) ############################ @@ -172,16 +182,18 @@ async def archive_all_chats(user=Depends(get_current_user)): @router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): +async def get_shared_chat_by_id( + share_id: str, user=Depends(get_current_user), db=Depends(get_db) +): if user.role == "pending": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND ) if user.role == "user": - chat = Chats.get_chat_by_share_id(share_id) + chat = Chats.get_chat_by_share_id(db, share_id) elif user.role == "admin": - chat = Chats.get_chat_by_id(share_id) + chat = Chats.get_chat_by_id(db, share_id) if chat: return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -204,21 +216,23 @@ class TagNameForm(BaseModel): @router.post("/tags", response_model=List[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagNameForm, user=Depends(get_current_user) + form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db) ): print(form_data) chat_ids = [ chat_id_tag.chat_id for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( - form_data.name, user.id + db, form_data.name, user.id ) ] - chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) + chats = Chats.get_chat_list_by_chat_ids( + db, chat_ids, form_data.skip, form_data.limit + ) if len(chats) == 0: - Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) + Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id) return chats @@ -229,9 +243,9 @@ async def get_user_chat_list_by_tag_name( @router.get("/tags/all", response_model=List[TagModel]) -async def get_all_tags(user=Depends(get_current_user)): +async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)): try: - tags = Tags.get_tags_by_user_id(user.id) + tags = Tags.get_tags_by_user_id(db, user.id) return tags except Exception as e: log.exception(e) @@ -246,8 +260,8 @@ async def get_all_tags(user=Depends(get_current_user)): @router.get("/{id}", response_model=Optional[ChatResponse]) -async def get_chat_by_id(id: str, user=Depends(get_current_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): + chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) if chat: return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -264,13 +278,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): @router.post("/{id}", response_model=Optional[ChatResponse]) async def update_chat_by_id( - id: str, form_data: ChatForm, user=Depends(get_current_user) + id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) ): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) + chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) if chat: updated_chat = {**json.loads(chat.chat), **form_data.chat} - chat = Chats.update_chat_by_id(id, updated_chat) + chat = Chats.update_chat_by_id(db, id, updated_chat) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( @@ -285,10 +299,12 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) -async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)): +async def delete_chat_by_id( + request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db) +): if user.role == "admin": - result = Chats.delete_chat_by_id(id) + result = Chats.delete_chat_by_id(db, id) return result else: if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]: @@ -297,7 +313,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Chats.delete_chat_by_id_and_user_id(id, user.id) + result = Chats.delete_chat_by_id_and_user_id(db, id, user.id) return result @@ -307,8 +323,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ @router.get("/{id}/clone", response_model=Optional[ChatResponse]) -async def clone_chat_by_id(id: str, user=Depends(get_current_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): + chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) if chat: chat_body = json.loads(chat.chat) @@ -319,7 +335,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)): "title": f"Clone of {chat.title}", } - chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) + chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat})) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( @@ -333,10 +349,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)): @router.get("/{id}/archive", response_model=Optional[ChatResponse]) -async def archive_chat_by_id(id: str, user=Depends(get_current_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def archive_chat_by_id( + id: str, user=Depends(get_current_user), db=Depends(get_db) +): + chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) if chat: - chat = Chats.toggle_chat_archive_by_id(id) + chat = Chats.toggle_chat_archive_by_id(db, id) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( @@ -350,16 +368,16 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)): @router.post("/{id}/share", response_model=Optional[ChatResponse]) -async def share_chat_by_id(id: str, user=Depends(get_current_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): + chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) if chat: if chat.share_id: - shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) + shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id) return ChatResponse( **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} ) - shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) + shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id) if not shared_chat: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -382,14 +400,16 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)): @router.delete("/{id}/share", response_model=Optional[bool]) -async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): - chat = Chats.get_chat_by_id_and_user_id(id, user.id) +async def delete_shared_chat_by_id( + id: str, user=Depends(get_current_user), db=Depends(get_db) +): + chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) if chat: if not chat.share_id: return False - result = Chats.delete_shared_chat_by_chat_id(id) - update_result = Chats.update_chat_share_id_by_id(id, None) + result = Chats.delete_shared_chat_by_chat_id(db, id) + update_result = Chats.update_chat_share_id_by_id(db, id, None) return result and update_result != None else: @@ -405,8 +425,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): @router.get("/{id}/tags", response_model=List[TagModel]) -async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) +async def get_chat_tags_by_id( + id: str, user=Depends(get_current_user), db=Depends(get_db) +): + tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) if tags != None: return tags @@ -423,12 +445,15 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) async def add_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) + id: str, + form_data: ChatIdTagForm, + user=Depends(get_current_user), + db=Depends(get_db), ): - tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) + tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) if form_data.tag_name not in tags: - tag = Tags.add_tag_to_chat(user.id, form_data) + tag = Tags.add_tag_to_chat(db, user.id, form_data) if tag: return tag @@ -450,10 +475,13 @@ async def add_chat_tag_by_id( @router.delete("/{id}/tags", response_model=Optional[bool]) async def delete_chat_tag_by_id( - id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) + id: str, + form_data: ChatIdTagForm, + user=Depends(get_current_user), + db=Depends(get_db), ): result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( - form_data.tag_name, id, user.id + db, form_data.tag_name, id, user.id ) if result: @@ -470,8 +498,10 @@ async def delete_chat_tag_by_id( @router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): - result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) +async def delete_all_chat_tags_by_id( + id: str, user=Depends(get_current_user), db=Depends(get_db) +): + result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id) if result: return result diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index 311455390..b9a42352a 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -6,6 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json +from apps.webui.internal.db import get_db from apps.webui.models.documents import ( Documents, DocumentForm, @@ -25,7 +26,7 @@ router = APIRouter() @router.get("/", response_model=List[DocumentResponse]) -async def get_documents(user=Depends(get_current_user)): +async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): docs = [ DocumentResponse( **{ @@ -33,7 +34,7 @@ async def get_documents(user=Depends(get_current_user)): "content": json.loads(doc.content if doc.content else "{}"), } ) - for doc in Documents.get_docs() + for doc in Documents.get_docs(db) ] return docs @@ -44,10 +45,12 @@ async def get_documents(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[DocumentResponse]) -async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): - doc = Documents.get_doc_by_name(form_data.name) +async def create_new_doc( + form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db) +): + doc = Documents.get_doc_by_name(db, form_data.name) if doc == None: - doc = Documents.insert_new_doc(user.id, form_data) + doc = Documents.insert_new_doc(db, user.id, form_data) if doc: return DocumentResponse( @@ -74,8 +77,10 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): @router.get("/doc", response_model=Optional[DocumentResponse]) -async def get_doc_by_name(name: str, user=Depends(get_current_user)): - doc = Documents.get_doc_by_name(name) +async def get_doc_by_name( + name: str, user=Depends(get_current_user), db=Depends(get_db) +): + doc = Documents.get_doc_by_name(db, name) if doc: return DocumentResponse( @@ -106,8 +111,12 @@ class TagDocumentForm(BaseModel): @router.post("/doc/tags", response_model=Optional[DocumentResponse]) -async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): - doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) +async def tag_doc_by_name( + form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db) +): + doc = Documents.update_doc_content_by_name( + db, form_data.name, {"tags": form_data.tags} + ) if doc: return DocumentResponse( @@ -130,9 +139,12 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u @router.post("/doc/update", response_model=Optional[DocumentResponse]) async def update_doc_by_name( - name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) + name: str, + form_data: DocumentUpdateForm, + user=Depends(get_admin_user), + db=Depends(get_db), ): - doc = Documents.update_doc_by_name(name, form_data) + doc = Documents.update_doc_by_name(db, name, form_data) if doc: return DocumentResponse( **{ @@ -153,6 +165,8 @@ async def update_doc_by_name( @router.delete("/doc/delete", response_model=bool) -async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): - result = Documents.delete_doc_by_name(name) +async def delete_doc_by_name( + name: str, user=Depends(get_admin_user), db=Depends(get_db) +): + result = Documents.delete_doc_by_name(db, name) return result diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index 3b6d44aa5..2ed119ad0 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -20,6 +20,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from pydantic import BaseModel import json +from apps.webui.internal.db import get_db from apps.webui.models.files import ( Files, FileForm, @@ -53,6 +54,7 @@ router = APIRouter() def upload_file( file: UploadFile = File(...), user=Depends(get_verified_user), + db=Depends(get_db) ): log.info(f"file.content_type: {file.content_type}") try: @@ -70,6 +72,7 @@ def upload_file( f.close() file = Files.insert_new_file( + db, user.id, FileForm( **{ @@ -106,8 +109,8 @@ def upload_file( @router.get("/", response_model=List[FileModel]) -async def list_files(user=Depends(get_verified_user)): - files = Files.get_files() +async def list_files(user=Depends(get_verified_user), db=Depends(get_db)): + files = Files.get_files(db) return files @@ -117,8 +120,8 @@ async def list_files(user=Depends(get_verified_user)): @router.delete("/all") -async def delete_all_files(user=Depends(get_admin_user)): - result = Files.delete_all_files() +async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)): + result = Files.delete_all_files(db) if result: folder = f"{UPLOAD_DIR}" @@ -154,8 +157,8 @@ async def delete_all_files(user=Depends(get_admin_user)): @router.get("/{id}", response_model=Optional[FileModel]) -async def get_file_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): + file = Files.get_file_by_id(db, id) if file: return file @@ -172,8 +175,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/content", response_model=Optional[FileModel]) -async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): + file = Files.get_file_by_id(db, id) if file: file_path = Path(file.meta["path"]) @@ -223,11 +226,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.delete("/{id}") -async def delete_file_by_id(id: str, user=Depends(get_verified_user)): - file = Files.get_file_by_id(id) +async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): + file = Files.get_file_by_id(db, id) if file: - result = Files.delete_file_by_id(id) + result = Files.delete_file_by_id(db, id) if result: return {"message": "File deleted successfully"} else: diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index 4c89ca487..f15566702 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -6,6 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json +from apps.webui.internal.db import get_db from apps.webui.models.functions import ( Functions, FunctionForm, @@ -31,8 +32,8 @@ router = APIRouter() @router.get("/", response_model=List[FunctionResponse]) -async def get_functions(user=Depends(get_verified_user)): - return Functions.get_functions() +async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)): + return Functions.get_functions(db) ############################ @@ -41,8 +42,8 @@ async def get_functions(user=Depends(get_verified_user)): @router.get("/export", response_model=List[FunctionModel]) -async def get_functions(user=Depends(get_admin_user)): - return Functions.get_functions() +async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)): + return Functions.get_functions(db) ############################ @@ -52,7 +53,7 @@ async def get_functions(user=Depends(get_admin_user)): @router.post("/create", response_model=Optional[FunctionResponse]) async def create_new_function( - request: Request, form_data: FunctionForm, user=Depends(get_admin_user) + request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) ): if not form_data.id.isidentifier(): raise HTTPException( @@ -62,7 +63,7 @@ async def create_new_function( form_data.id = form_data.id.lower() - function = Functions.get_function_by_id(form_data.id) + function = Functions.get_function_by_id(db, form_data.id) if function == None: function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") try: @@ -77,7 +78,7 @@ async def create_new_function( FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function(user.id, function_type, form_data) + function = Functions.insert_new_function(db, user.id, function_type, form_data) function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) @@ -108,8 +109,8 @@ async def create_new_function( @router.get("/id/{id}", response_model=Optional[FunctionModel]) -async def get_function_by_id(id: str, user=Depends(get_admin_user)): - function = Functions.get_function_by_id(id) +async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): + function = Functions.get_function_by_id(db, id) if function: return function @@ -154,7 +155,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[FunctionModel]) async def update_function_by_id( - request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) + request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) ): function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") @@ -171,7 +172,7 @@ async def update_function_by_id( updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} print(updated) - function = Functions.update_function_by_id(id, updated) + function = Functions.update_function_by_id(db, id, updated) if function: return function @@ -195,9 +196,9 @@ async def update_function_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_function_by_id( - request: Request, id: str, user=Depends(get_admin_user) + request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) ): - result = Functions.delete_function_by_id(id) + result = Functions.delete_function_by_id(db, id) if result: FUNCTIONS = request.app.state.FUNCTIONS diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index e9ae96173..e7fafa37b 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -7,6 +7,7 @@ from fastapi import APIRouter from pydantic import BaseModel import logging +from apps.webui.internal.db import get_db from apps.webui.models.memories import Memories, MemoryModel from utils.utils import get_verified_user @@ -31,8 +32,8 @@ async def get_embeddings(request: Request): @router.get("/", response_model=List[MemoryModel]) -async def get_memories(user=Depends(get_verified_user)): - return Memories.get_memories_by_user_id(user.id) +async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)): + return Memories.get_memories_by_user_id(db, user.id) ############################ @@ -50,9 +51,12 @@ class MemoryUpdateModel(BaseModel): @router.post("/add", response_model=Optional[MemoryModel]) async def add_memory( - request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) + request: Request, + form_data: AddMemoryForm, + user=Depends(get_verified_user), + db=Depends(get_db), ): - memory = Memories.insert_new_memory(user.id, form_data.content) + memory = Memories.insert_new_memory(db, user.id, form_data.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") @@ -72,8 +76,9 @@ async def update_memory_by_id( request: Request, form_data: MemoryUpdateModel, user=Depends(get_verified_user), + db=Depends(get_db), ): - memory = Memories.update_memory_by_id(memory_id, form_data.content) + memory = Memories.update_memory_by_id(db, memory_id, form_data.content) if memory is None: raise HTTPException(status_code=404, detail="Memory not found") @@ -124,12 +129,12 @@ async def query_memory( ############################ @router.get("/reset", response_model=bool) async def reset_memory_from_vector_db( - request: Request, user=Depends(get_verified_user) + request: Request, user=Depends(get_verified_user), db=Depends(get_db) ): CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") - memories = Memories.get_memories_by_user_id(user.id) + memories = Memories.get_memories_by_user_id(db, user.id) for memory in memories: memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) collection.upsert( @@ -146,8 +151,8 @@ async def reset_memory_from_vector_db( @router.delete("/user", response_model=bool) -async def delete_memory_by_user_id(user=Depends(get_verified_user)): - result = Memories.delete_memories_by_user_id(user.id) +async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)): + result = Memories.delete_memories_by_user_id(db, user.id) if result: try: @@ -165,8 +170,10 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)): @router.delete("/{memory_id}", response_model=bool) -async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): - result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) +async def delete_memory_by_id( + memory_id: str, user=Depends(get_verified_user), db=Depends(get_db) +): + result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id) if result: collection = CHROMA_CLIENT.get_or_create_collection( diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py index acc1c6b47..f151e8864 100644 --- a/backend/apps/webui/routers/models.py +++ b/backend/apps/webui/routers/models.py @@ -5,6 +5,8 @@ from typing import List, Union, Optional from fastapi import APIRouter from pydantic import BaseModel import json + +from apps.webui.internal.db import get_db from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse from utils.utils import get_verified_user, get_admin_user @@ -18,8 +20,8 @@ router = APIRouter() @router.get("/", response_model=List[ModelResponse]) -async def get_models(user=Depends(get_verified_user)): - return Models.get_all_models() +async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): + return Models.get_all_models(db) ############################ @@ -29,7 +31,10 @@ async def get_models(user=Depends(get_verified_user)): @router.post("/add", response_model=Optional[ModelModel]) async def add_new_model( - request: Request, form_data: ModelForm, user=Depends(get_admin_user) + request: Request, + form_data: ModelForm, + user=Depends(get_admin_user), + db=Depends(get_db), ): if form_data.id in request.app.state.MODELS: raise HTTPException( @@ -37,7 +42,7 @@ async def add_new_model( detail=ERROR_MESSAGES.MODEL_ID_TAKEN, ) else: - model = Models.insert_new_model(form_data, user.id) + model = Models.insert_new_model(db, form_data, user.id) if model: return model @@ -53,9 +58,9 @@ async def add_new_model( ############################ -@router.get("/", response_model=Optional[ModelModel]) -async def get_model_by_id(id: str, user=Depends(get_verified_user)): - model = Models.get_model_by_id(id) +@router.get("/{id}", response_model=Optional[ModelModel]) +async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): + model = Models.get_model_by_id(db, id) if model: return model @@ -73,15 +78,19 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): @router.post("/update", response_model=Optional[ModelModel]) async def update_model_by_id( - request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user) + request: Request, + id: str, + form_data: ModelForm, + user=Depends(get_admin_user), + db=Depends(get_db), ): - model = Models.get_model_by_id(id) + model = Models.get_model_by_id(db, id) if model: - model = Models.update_model_by_id(id, form_data) + model = Models.update_model_by_id(db, id, form_data) return model else: if form_data.id in request.app.state.MODELS: - model = Models.insert_new_model(form_data, user.id) + model = Models.insert_new_model(db, form_data, user.id) if model: return model else: @@ -102,6 +111,6 @@ async def update_model_by_id( @router.delete("/delete", response_model=bool) -async def delete_model_by_id(id: str, user=Depends(get_admin_user)): - result = Models.delete_model_by_id(id) +async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): + result = Models.delete_model_by_id(db, id) return result diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py index 47d8c7012..c8f173a1e 100644 --- a/backend/apps/webui/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -6,6 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json +from apps.webui.internal.db import get_db from apps.webui.models.prompts import Prompts, PromptForm, PromptModel from utils.utils import get_current_user, get_admin_user @@ -19,8 +20,8 @@ router = APIRouter() @router.get("/", response_model=List[PromptModel]) -async def get_prompts(user=Depends(get_current_user)): - return Prompts.get_prompts() +async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)): + return Prompts.get_prompts(db) ############################ @@ -29,10 +30,12 @@ async def get_prompts(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): - prompt = Prompts.get_prompt_by_command(form_data.command) +async def create_new_prompt( + form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db) +): + prompt = Prompts.get_prompt_by_command(db, form_data.command) if prompt == None: - prompt = Prompts.insert_new_prompt(user.id, form_data) + prompt = Prompts.insert_new_prompt(db, user.id, form_data) if prompt: return prompt @@ -52,8 +55,10 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)) @router.get("/command/{command}", response_model=Optional[PromptModel]) -async def get_prompt_by_command(command: str, user=Depends(get_current_user)): - prompt = Prompts.get_prompt_by_command(f"/{command}") +async def get_prompt_by_command( + command: str, user=Depends(get_current_user), db=Depends(get_db) +): + prompt = Prompts.get_prompt_by_command(db, f"/{command}") if prompt: return prompt @@ -71,9 +76,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)): @router.post("/command/{command}/update", response_model=Optional[PromptModel]) async def update_prompt_by_command( - command: str, form_data: PromptForm, user=Depends(get_admin_user) + command: str, + form_data: PromptForm, + user=Depends(get_admin_user), + db=Depends(get_db), ): - prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) + prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data) if prompt: return prompt else: @@ -89,6 +97,8 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): - result = Prompts.delete_prompt_by_command(f"/{command}") +async def delete_prompt_by_command( + command: str, user=Depends(get_admin_user), db=Depends(get_db) +): + result = Prompts.delete_prompt_by_command(db, f"/{command}") return result diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index d20584c22..4eb6d1caf 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -6,7 +6,7 @@ from fastapi import APIRouter from pydantic import BaseModel import json - +from apps.webui.internal.db import get_db from apps.webui.models.users import Users from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.utils import load_toolkit_module_by_id @@ -34,7 +34,7 @@ router = APIRouter() @router.get("/", response_model=List[ToolResponse]) -async def get_toolkits(user=Depends(get_verified_user)): +async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)): toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits @@ -45,8 +45,8 @@ async def get_toolkits(user=Depends(get_verified_user)): @router.get("/export", response_model=List[ToolModel]) -async def get_toolkits(user=Depends(get_admin_user)): - toolkits = [toolkit for toolkit in Tools.get_tools()] +async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)): + toolkits = [toolkit for toolkit in Tools.get_tools(db)] return toolkits @@ -57,7 +57,10 @@ async def get_toolkits(user=Depends(get_admin_user)): @router.post("/create", response_model=Optional[ToolResponse]) async def create_new_toolkit( - request: Request, form_data: ToolForm, user=Depends(get_admin_user) + request: Request, + form_data: ToolForm, + user=Depends(get_admin_user), + db=Depends(get_db), ): if not form_data.id.isidentifier(): raise HTTPException( @@ -67,7 +70,7 @@ async def create_new_toolkit( form_data.id = form_data.id.lower() - toolkit = Tools.get_tool_by_id(form_data.id) + toolkit = Tools.get_tool_by_id(db, form_data.id) if toolkit == None: toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") try: @@ -81,7 +84,7 @@ async def create_new_toolkit( TOOLS[form_data.id] = toolkit_module specs = get_tools_specs(TOOLS[form_data.id]) - toolkit = Tools.insert_new_tool(user.id, form_data, specs) + toolkit = Tools.insert_new_tool(db, user.id, form_data, specs) tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) @@ -112,8 +115,8 @@ async def create_new_toolkit( @router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): - toolkit = Tools.get_tool_by_id(id) +async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): + toolkit = Tools.get_tool_by_id(db, id) if toolkit: return toolkit @@ -131,7 +134,11 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[ToolModel]) async def update_toolkit_by_id( - request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user) + request: Request, + id: str, + form_data: ToolForm, + user=Depends(get_admin_user), + db=Depends(get_db), ): toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") @@ -153,7 +160,7 @@ async def update_toolkit_by_id( } print(updated) - toolkit = Tools.update_tool_by_id(id, updated) + toolkit = Tools.update_tool_by_id(db, id, updated) if toolkit: return toolkit @@ -176,8 +183,10 @@ async def update_toolkit_by_id( @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): - result = Tools.delete_tool_by_id(id) +async def delete_toolkit_by_id( + request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) +): + result = Tools.delete_tool_by_id(db, id) if result: TOOLS = request.app.state.TOOLS diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index 270d72a23..46a418fc1 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -9,6 +9,7 @@ import time import uuid import logging +from apps.webui.internal.db import get_db from apps.webui.models.users import ( UserModel, UserUpdateForm, @@ -40,8 +41,10 @@ router = APIRouter() @router.get("/", response_model=List[UserModel]) -async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)): - return Users.get_users(skip, limit) +async def get_users( + skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db) +): + return Users.get_users(db, skip, limit) ############################ @@ -68,10 +71,12 @@ async def update_user_permissions( @router.post("/update/role", response_model=Optional[UserModel]) -async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): +async def update_user_role( + form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db) +): - if user.id != form_data.id and form_data.id != Users.get_first_user().id: - return Users.update_user_role_by_id(form_data.id, form_data.role) + if user.id != form_data.id and form_data.id != Users.get_first_user(db).id: + return Users.update_user_role_by_id(db, form_data.id, form_data.role) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -85,8 +90,10 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin @router.get("/user/settings", response_model=Optional[UserSettings]) -async def get_user_settings_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_settings_by_session_user( + user=Depends(get_verified_user), db=Depends(get_db) +): + user = Users.get_user_by_id(db, user.id) if user: return user.settings else: @@ -103,9 +110,9 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)): @router.post("/user/settings/update", response_model=UserSettings) async def update_user_settings_by_session_user( - form_data: UserSettings, user=Depends(get_verified_user) + form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db) ): - user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()}) + user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()}) if user: return user.settings else: @@ -121,8 +128,10 @@ async def update_user_settings_by_session_user( @router.get("/user/info", response_model=Optional[dict]) -async def get_user_info_by_session_user(user=Depends(get_verified_user)): - user = Users.get_user_by_id(user.id) +async def get_user_info_by_session_user( + user=Depends(get_verified_user), db=Depends(get_db) +): + user = Users.get_user_by_id(db, user.id) if user: return user.info else: @@ -138,15 +147,17 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)): @router.post("/user/info/update", response_model=Optional[dict]) -async def update_user_settings_by_session_user( - form_data: dict, user=Depends(get_verified_user) +async def update_user_info_by_session_user( + form_data: dict, user=Depends(get_verified_user), db=Depends(get_db) ): - user = Users.get_user_by_id(user.id) + user = Users.get_user_by_id(db, user.id) if user: if user.info is None: user.info = {} - user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) + user = Users.update_user_by_id( + db, user.id, {"info": {**user.info, **form_data}} + ) if user: return user.info else: @@ -172,13 +183,15 @@ class UserResponse(BaseModel): @router.get("/{user_id}", response_model=UserResponse) -async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): +async def get_user_by_id( + user_id: str, user=Depends(get_verified_user), db=Depends(get_db) +): # Check if user_id is a shared chat # If it is, get the user_id from the chat if user_id.startswith("shared-"): chat_id = user_id.replace("shared-", "") - chat = Chats.get_chat_by_id(chat_id) + chat = Chats.get_chat_by_id(db, chat_id) if chat: user_id = chat.user_id else: @@ -187,7 +200,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.USER_NOT_FOUND, ) - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(db, user_id) if user: return UserResponse(name=user.name, profile_image_url=user.profile_image_url) @@ -205,13 +218,16 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): @router.post("/{user_id}/update", response_model=Optional[UserModel]) async def update_user_by_id( - user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user) + user_id: str, + form_data: UserUpdateForm, + session_user=Depends(get_admin_user), + db=Depends(get_db), ): - user = Users.get_user_by_id(user_id) + user = Users.get_user_by_id(db, user_id) if user: if form_data.email.lower() != user.email: - email_user = Users.get_user_by_email(form_data.email.lower()) + email_user = Users.get_user_by_email(db, form_data.email.lower()) if email_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -221,10 +237,11 @@ async def update_user_by_id( if form_data.password: hashed = get_password_hash(form_data.password) log.debug(f"hashed: {hashed}") - Auths.update_user_password_by_id(user_id, hashed) + Auths.update_user_password_by_id(db, user_id, hashed) - Auths.update_email_by_id(user_id, form_data.email.lower()) + Auths.update_email_by_id(db, user_id, form_data.email.lower()) updated_user = Users.update_user_by_id( + db, user_id, { "name": form_data.name, @@ -253,9 +270,11 @@ async def update_user_by_id( @router.delete("/{user_id}", response_model=bool) -async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): +async def delete_user_by_id( + user_id: str, user=Depends(get_admin_user), db=Depends(get_db) +): if user.id != user_id: - result = Auths.delete_auth_by_id(user_id) + result = Auths.delete_auth_by_id(db, user_id) if result: return True diff --git a/backend/apps/webui/routers/utils.py b/backend/apps/webui/routers/utils.py index 8f6d663b4..780ed6b43 100644 --- a/backend/apps/webui/routers/utils.py +++ b/backend/apps/webui/routers/utils.py @@ -1,6 +1,5 @@ from fastapi import APIRouter, UploadFile, File, Response from fastapi import Depends, HTTPException, status -from peewee import SqliteDatabase from starlette.responses import StreamingResponse, FileResponse from pydantic import BaseModel @@ -10,7 +9,6 @@ import markdown import black -from apps.webui.internal.db import DB from utils.utils import get_admin_user from utils.misc import calculate_sha256, get_gravatar_url @@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - if not isinstance(DB, SqliteDatabase): + from apps.webui.internal.db import engine + + if engine.name != "sqlite": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DB_NOT_SQLITE, ) return FileResponse( - DB.database, + engine.url.database, media_type="application/octet-stream", filename="webui.db", ) diff --git a/backend/main.py b/backend/main.py index 52da33155..d80c6a729 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,6 @@ import base64 import uuid +import subprocess from contextlib import asynccontextmanager from authlib.integrations.starlette_client import OAuth @@ -27,6 +28,8 @@ from fastapi.responses import JSONResponse from fastapi import HTTPException from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy import text +from sqlalchemy.orm import Session from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -54,6 +57,7 @@ from apps.webui.main import ( get_pipe_models, generate_function_chat_completion, ) +from apps.webui.internal.db import get_db, SessionLocal from pydantic import BaseModel @@ -124,6 +128,8 @@ from config import ( WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, AppConfig, + BACKEND_DIR, + DATABASE_URL, ) from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from utils.webhook import post_webhook @@ -166,8 +172,19 @@ https://github.com/open-webui/open-webui ) +def run_migrations(): + from alembic.config import Config + from alembic import command + + alembic_cfg = Config(f"{BACKEND_DIR}/alembic.ini") + alembic_cfg.set_main_option("sqlalchemy.url", DATABASE_URL) + alembic_cfg.set_main_option("script_location", f"{BACKEND_DIR}/migrations") + command.upgrade(alembic_cfg, "head") + + @asynccontextmanager async def lifespan(app: FastAPI): + run_migrations() yield @@ -393,6 +410,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), + SessionLocal(), ) # Flag to skip RAG completions if file_handler is present in tools/functions skip_files = False @@ -736,6 +754,7 @@ class PipelineMiddleware(BaseHTTPMiddleware): user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), + SessionLocal(), ) try: @@ -781,7 +800,9 @@ app.add_middleware( @app.middleware("http") async def check_url(request: Request, call_next): if len(app.state.MODELS) == 0: - await get_all_models() + db = SessionLocal() + await get_all_models(db) + db.commit() else: pass @@ -815,12 +836,12 @@ app.mount("/api/v1", webui_app) webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION -async def get_all_models(): +async def get_all_models(db: Session): pipe_models = [] openai_models = [] ollama_models = [] - pipe_models = await get_pipe_models() + pipe_models = await get_pipe_models(db) if app.state.config.ENABLE_OPENAI_API: openai_models = await get_openai_models() @@ -842,7 +863,7 @@ async def get_all_models(): models = pipe_models + openai_models + ollama_models - custom_models = Models.get_all_models() + custom_models = Models.get_all_models(db) for custom_model in custom_models: if custom_model.base_model_id == None: for model in models: @@ -882,8 +903,8 @@ async def get_all_models(): @app.get("/api/models") -async def get_models(user=Depends(get_verified_user)): - models = await get_all_models() +async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): + models = await get_all_models(db) # Filter out filter pipelines models = [ @@ -1584,9 +1605,12 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use @app.get("/api/pipelines/{pipeline_id}/valves") async def get_pipeline_valves( - urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user) + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), + db=Depends(get_db), ): - models = await get_all_models() + models = await get_all_models(db) r = None try: @@ -1622,9 +1646,12 @@ async def get_pipeline_valves( @app.get("/api/pipelines/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( - urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user) + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), + db=Depends(get_db), ): - models = await get_all_models() + models = await get_all_models(db) r = None try: @@ -1663,8 +1690,9 @@ async def update_pipeline_valves( pipeline_id: str, form_data: dict, user=Depends(get_admin_user), + db=Depends(get_db), ): - models = await get_all_models() + models = await get_all_models(db) r = None try: @@ -2011,6 +2039,12 @@ async def healthcheck(): return {"status": True} +@app.get("/health/db") +async def healthcheck_with_db(db: Session = Depends(get_db)): + result = db.execute(text("SELECT 1;")).all() + return {"status": True} + + app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") diff --git a/backend/migrations/README b/backend/migrations/README new file mode 100644 index 000000000..f1d93dff9 --- /dev/null +++ b/backend/migrations/README @@ -0,0 +1,4 @@ +Generic single-database configuration. + +Create new migrations with +DATABASE_URL= alembic revision --autogenerate -m "a description" diff --git a/backend/migrations/env.py b/backend/migrations/env.py new file mode 100644 index 000000000..836893bbe --- /dev/null +++ b/backend/migrations/env.py @@ -0,0 +1,93 @@ +import os +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +from apps.webui.models.auths import Auth +from apps.webui.models.chats import Chat +from apps.webui.models.documents import Document +from apps.webui.models.memories import Memory +from apps.webui.models.models import Model +from apps.webui.models.prompts import Prompt +from apps.webui.models.tags import Tag, ChatIdTag +from apps.webui.models.tools import Tool +from apps.webui.models.users import User +from apps.webui.models.files import File +from apps.webui.models.functions import Function + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Auth.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + +database_url = os.getenv("DATABASE_URL", None) +if database_url: + config.set_main_option("sqlalchemy.url", database_url) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/migrations/script.py.mako b/backend/migrations/script.py.mako new file mode 100644 index 000000000..5f667ccfe --- /dev/null +++ b/backend/migrations/script.py.mako @@ -0,0 +1,27 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import apps.webui.internal.db +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/backend/migrations/versions/22b5ab2667b8_init.py b/backend/migrations/versions/22b5ab2667b8_init.py new file mode 100644 index 000000000..af10dc2cf --- /dev/null +++ b/backend/migrations/versions/22b5ab2667b8_init.py @@ -0,0 +1,188 @@ +"""init + +Revision ID: 22b5ab2667b8 +Revises: +Create Date: 2024-06-20 13:22:40.397002 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.engine.reflection import Inspector + +import apps.webui.internal.db + + +# revision identifiers, used by Alembic. +revision: str = "22b5ab2667b8" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + con = op.get_bind() + inspector = Inspector.from_engine(con) + tables = set(inspector.get_table_names()) + + # ### commands auto generated by Alembic - please adjust! ### + if not "auth" in tables: + op.create_table( + "auth", + sa.Column("id", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=True), + sa.Column("password", sa.String(), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + if not "chat" in tables: + op.create_table( + "chat", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("chat", sa.String(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("share_id", sa.String(), nullable=True), + sa.Column("archived", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("share_id"), + ) + + if not "chatidtag" in tables: + op.create_table( + "chatidtag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("tag_name", sa.String(), nullable=True), + sa.Column("chat_id", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + if not "document" in tables: + op.create_table( + "document", + sa.Column("collection_name", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("filename", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("collection_name"), + sa.UniqueConstraint("name"), + ) + + if not "memory" in tables: + op.create_table( + "memory", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + if not "model" in tables: + op.create_table( + "model", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("base_model_id", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + if not "prompt" in tables: + op.create_table( + "prompt", + sa.Column("command", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("command"), + ) + + if not "tag" in tables: + op.create_table( + "tag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("data", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + if not "tool" in tables: + op.create_table( + "tool", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + if not "user" in tables: + op.create_table( + "user", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("email", sa.String(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("profile_image_url", sa.String(), nullable=True), + sa.Column("last_active_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("api_key"), + ) + + if not "file" in tables: + op.create_table('file', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('filename', sa.String(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if not "function" in tables: + op.create_table('function', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # do nothing as we assume we had previous migrations from peewee-migrate + pass + # ### end Alembic commands ### diff --git a/backend/requirements.txt b/backend/requirements.txt index a36af5497..720809471 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,8 +12,10 @@ passlib[bcrypt]==1.7.4 requests==2.32.2 aiohttp==3.9.5 -peewee==3.17.5 -peewee-migrate==1.12.2 +sqlalchemy==2.0.30 +alembic==1.13.1 +# peewee==3.17.5 +# peewee-migrate==1.12.2 psycopg2-binary==2.9.9 PyMySQL==1.1.1 bcrypt==4.1.3 @@ -67,4 +69,9 @@ pytube==15.0.0 extract_msg pydub -duckduckgo-search~=6.1.5 \ No newline at end of file +duckduckgo-search~=6.1.5 + +## Tests +docker~=7.1.0 +pytest~=8.2.1 +pytest-docker~=3.1.1 diff --git a/backend/test/__init__.py b/backend/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/test/apps/webui/routers/test_auths.py b/backend/test/apps/webui/routers/test_auths.py new file mode 100644 index 000000000..3450f57c6 --- /dev/null +++ b/backend/test/apps/webui/routers/test_auths.py @@ -0,0 +1,209 @@ +import pytest + +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + + +class TestAuths(AbstractPostgresTest): + BASE_PATH = "/api/v1/auths" + + def setup_class(cls): + super().setup_class() + from apps.webui.models.users import Users + from apps.webui.models.auths import Auths + + cls.users = Users + cls.auths = Auths + + def test_get_session_user(self): + with mock_webui_user(): + response = self.fast_api_client.get(self.create_url("")) + assert response.status_code == 200 + assert response.json() == { + "id": "1", + "name": "John Doe", + "email": "john.doe@openwebui.com", + "role": "user", + "profile_image_url": "/user.png", + } + + def test_update_profile(self): + from utils.utils import get_password_hash + + user = self.auths.insert_new_auth( + self.db_session, + email="john.doe@openwebui.com", + password=get_password_hash("old_password"), + name="John Doe", + profile_image_url="/user.png", + role="user", + ) + + with mock_webui_user(id=user.id): + response = self.fast_api_client.post( + self.create_url("/update/profile"), + json={"name": "John Doe 2", "profile_image_url": "/user2.png"}, + ) + assert response.status_code == 200 + db_user = self.users.get_user_by_id(self.db_session, user.id) + assert db_user.name == "John Doe 2" + assert db_user.profile_image_url == "/user2.png" + + def test_update_password(self): + from utils.utils import get_password_hash + + user = self.auths.insert_new_auth( + self.db_session, + email="john.doe@openwebui.com", + password=get_password_hash("old_password"), + name="John Doe", + profile_image_url="/user.png", + role="user", + ) + + with mock_webui_user(id=user.id): + response = self.fast_api_client.post( + self.create_url("/update/password"), + json={"password": "old_password", "new_password": "new_password"}, + ) + assert response.status_code == 200 + + old_auth = self.auths.authenticate_user( + self.db_session, "john.doe@openwebui.com", "old_password" + ) + assert old_auth is None + new_auth = self.auths.authenticate_user( + self.db_session, "john.doe@openwebui.com", "new_password" + ) + assert new_auth is not None + + def test_signin(self): + from utils.utils import get_password_hash + + user = self.auths.insert_new_auth( + self.db_session, + email="john.doe@openwebui.com", + password=get_password_hash("password"), + name="John Doe", + profile_image_url="/user.png", + role="user", + ) + response = self.fast_api_client.post( + self.create_url("/signin"), + json={"email": "john.doe@openwebui.com", "password": "password"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == user.id + assert data["name"] == "John Doe" + assert data["email"] == "john.doe@openwebui.com" + assert data["role"] == "user" + assert data["profile_image_url"] == "/user.png" + assert data["token"] is not None and len(data["token"]) > 0 + assert data["token_type"] == "Bearer" + + def test_signup(self): + response = self.fast_api_client.post( + self.create_url("/signup"), + json={ + "name": "John Doe", + "email": "john.doe@openwebui.com", + "password": "password", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] is not None and len(data["id"]) > 0 + assert data["name"] == "John Doe" + assert data["email"] == "john.doe@openwebui.com" + assert data["role"] in ["admin", "user", "pending"] + assert data["profile_image_url"] == "/user.png" + assert data["token"] is not None and len(data["token"]) > 0 + assert data["token_type"] == "Bearer" + + def test_add_user(self): + with mock_webui_user(): + response = self.fast_api_client.post( + self.create_url("/add"), + json={ + "name": "John Doe 2", + "email": "john.doe2@openwebui.com", + "password": "password2", + "role": "admin", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] is not None and len(data["id"]) > 0 + assert data["name"] == "John Doe 2" + assert data["email"] == "john.doe2@openwebui.com" + assert data["role"] == "admin" + assert data["profile_image_url"] == "/user.png" + assert data["token"] is not None and len(data["token"]) > 0 + assert data["token_type"] == "Bearer" + + def test_get_admin_details(self): + self.auths.insert_new_auth( + self.db_session, + email="john.doe@openwebui.com", + password="password", + name="John Doe", + profile_image_url="/user.png", + role="admin", + ) + with mock_webui_user(): + response = self.fast_api_client.get(self.create_url("/admin/details")) + + assert response.status_code == 200 + assert response.json() == { + "name": "John Doe", + "email": "john.doe@openwebui.com", + } + + def test_create_api_key_(self): + user = self.auths.insert_new_auth( + self.db_session, + email="john.doe@openwebui.com", + password="password", + name="John Doe", + profile_image_url="/user.png", + role="admin", + ) + with mock_webui_user(id=user.id): + response = self.fast_api_client.post(self.create_url("/api_key")) + assert response.status_code == 200 + data = response.json() + assert data["api_key"] is not None + assert len(data["api_key"]) > 0 + + def test_delete_api_key(self): + user = self.auths.insert_new_auth( + self.db_session, + email="john.doe@openwebui.com", + password="password", + name="John Doe", + profile_image_url="/user.png", + role="admin", + ) + self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") + with mock_webui_user(id=user.id): + response = self.fast_api_client.delete(self.create_url("/api_key")) + assert response.status_code == 200 + assert response.json() == True + db_user = self.users.get_user_by_id(self.db_session, user.id) + assert db_user.api_key is None + + def test_get_api_key(self): + user = self.auths.insert_new_auth( + self.db_session, + email="john.doe@openwebui.com", + password="password", + name="John Doe", + profile_image_url="/user.png", + role="admin", + ) + self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") + with mock_webui_user(id=user.id): + response = self.fast_api_client.get(self.create_url("/api_key")) + assert response.status_code == 200 + assert response.json() == {"api_key": "abc"} diff --git a/backend/test/apps/webui/routers/test_chats.py b/backend/test/apps/webui/routers/test_chats.py new file mode 100644 index 000000000..2d1145c06 --- /dev/null +++ b/backend/test/apps/webui/routers/test_chats.py @@ -0,0 +1,239 @@ +import uuid + +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + + +class TestChats(AbstractPostgresTest): + + BASE_PATH = "/api/v1/chats" + + def setup_class(cls): + super().setup_class() + + def setup_method(self): + super().setup_method() + from apps.webui.models.chats import ChatForm + from apps.webui.models.chats import Chats + + self.chats = Chats + self.chats.insert_new_chat( + self.db_session, + "2", + ChatForm( + **{ + "chat": { + "name": "chat1", + "description": "chat1 description", + "tags": ["tag1", "tag2"], + "history": {"currentId": "1", "messages": []}, + } + } + ), + ) + + def test_get_session_user_chat_list(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + first_chat = response.json()[0] + assert first_chat["id"] is not None + assert first_chat["title"] == "New Chat" + assert first_chat["created_at"] is not None + assert first_chat["updated_at"] is not None + + def test_delete_all_user_chats(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.delete(self.create_url("/")) + assert response.status_code == 200 + assert len(self.chats.get_chats(self.db_session)) == 0 + + def test_get_user_chat_list_by_user_id(self): + with mock_webui_user(id="3"): + response = self.fast_api_client.get(self.create_url("/list/user/2")) + assert response.status_code == 200 + first_chat = response.json()[0] + assert first_chat["id"] is not None + assert first_chat["title"] == "New Chat" + assert first_chat["created_at"] is not None + assert first_chat["updated_at"] is not None + + def test_create_new_chat(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/new"), + json={ + "chat": { + "name": "chat2", + "description": "chat2 description", + "tags": ["tag1", "tag2"], + } + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["archived"] is False + assert data["chat"] == { + "name": "chat2", + "description": "chat2 description", + "tags": ["tag1", "tag2"], + } + assert data["user_id"] == "2" + assert data["id"] is not None + assert data["share_id"] is None + assert data["title"] == "New Chat" + assert data["updated_at"] is not None + assert data["created_at"] is not None + assert len(self.chats.get_chats(self.db_session)) == 2 + + def test_get_user_chats(self): + self.test_get_session_user_chat_list() + + def test_get_user_archived_chats(self): + self.chats.archive_all_chats_by_user_id(self.db_session, "2") + self.db_session.commit() + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/all/archived")) + assert response.status_code == 200 + first_chat = response.json()[0] + assert first_chat["id"] is not None + assert first_chat["title"] == "New Chat" + assert first_chat["created_at"] is not None + assert first_chat["updated_at"] is not None + + def test_get_all_user_chats_in_db(self): + with mock_webui_user(id="4"): + response = self.fast_api_client.get(self.create_url("/all/db")) + assert response.status_code == 200 + assert len(response.json()) == 1 + + def test_get_archived_session_user_chat_list(self): + self.test_get_user_archived_chats() + + def test_archive_all_chats(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.post(self.create_url("/archive/all")) + assert response.status_code == 200 + assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1 + + def test_get_shared_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id) + self.db_session.commit() + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}")) + assert response.status_code == 200 + data = response.json() + assert data["id"] == chat_id + assert data["chat"] == { + "name": "chat1", + "description": "chat1 description", + "tags": ["tag1", "tag2"], + "history": {"currentId": "1", "messages": []}, + } + assert data["id"] == chat_id + assert data["share_id"] == chat_id + assert data["title"] == "New Chat" + + def test_get_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/{chat_id}")) + assert response.status_code == 200 + data = response.json() + assert data["id"] == chat_id + assert data["chat"] == { + "name": "chat1", + "description": "chat1 description", + "tags": ["tag1", "tag2"], + "history": {"currentId": "1", "messages": []}, + } + assert data["share_id"] is None + assert data["title"] == "New Chat" + assert data["user_id"] == "2" + + def test_update_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url(f"/{chat_id}"), + json={ + "chat": { + "name": "chat2", + "description": "chat2 description", + "tags": ["tag2", "tag4"], + "title": "Just another title", + } + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == chat_id + assert data["chat"] == { + "name": "chat2", + "title": "Just another title", + "description": "chat2 description", + "tags": ["tag2", "tag4"], + "history": {"currentId": "1", "messages": []}, + } + assert data["share_id"] is None + assert data["title"] == "Just another title" + assert data["user_id"] == "2" + + def test_delete_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.delete(self.create_url(f"/{chat_id}")) + assert response.status_code == 200 + assert response.json() is True + + def test_clone_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone")) + + assert response.status_code == 200 + data = response.json() + assert data["id"] != chat_id + assert data["chat"] == { + "branchPointMessageId": "1", + "description": "chat1 description", + "history": {"currentId": "1", "messages": []}, + "name": "chat1", + "originalChatId": chat_id, + "tags": ["tag1", "tag2"], + "title": "Clone of New Chat", + } + assert data["share_id"] is None + assert data["title"] == "Clone of New Chat" + assert data["user_id"] == "2" + + def test_archive_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive")) + assert response.status_code == 200 + + chat = self.chats.get_chat_by_id(self.db_session, chat_id) + assert chat.archived is True + + def test_share_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share")) + assert response.status_code == 200 + + chat = self.chats.get_chat_by_id(self.db_session, chat_id) + assert chat.share_id is not None + + def test_delete_shared_chat_by_id(self): + chat_id = self.chats.get_chats(self.db_session)[0].id + share_id = str(uuid.uuid4()) + self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id) + self.db_session.commit() + with mock_webui_user(id="2"): + response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share")) + assert response.status_code + + chat = self.chats.get_chat_by_id(self.db_session, chat_id) + assert chat.share_id is None diff --git a/backend/test/apps/webui/routers/test_documents.py b/backend/test/apps/webui/routers/test_documents.py new file mode 100644 index 000000000..53ef3d2aa --- /dev/null +++ b/backend/test/apps/webui/routers/test_documents.py @@ -0,0 +1,106 @@ +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + + +class TestDocuments(AbstractPostgresTest): + + BASE_PATH = "/api/v1/documents" + + def setup_class(cls): + super().setup_class() + from apps.webui.models.documents import Documents + + cls.documents = Documents + + def test_documents(self): + # Empty database + assert len(self.documents.get_docs(self.db_session)) == 0 + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 0 + + # Create a new document + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/create"), + json={ + "name": "doc_name", + "title": "doc title", + "collection_name": "custom collection", + "filename": "doc_name.pdf", + "content": "", + }, + ) + assert response.status_code == 200 + assert response.json()["name"] == "doc_name" + assert len(self.documents.get_docs(self.db_session)) == 1 + + # Get the document + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/doc?name=doc_name")) + assert response.status_code == 200 + data = response.json() + assert data["collection_name"] == "custom collection" + assert data["name"] == "doc_name" + assert data["title"] == "doc title" + assert data["filename"] == "doc_name.pdf" + assert data["content"] == {} + + # Create another document + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/create"), + json={ + "name": "doc_name 2", + "title": "doc title 2", + "collection_name": "custom collection 2", + "filename": "doc_name2.pdf", + "content": "", + }, + ) + assert response.status_code == 200 + assert response.json()["name"] == "doc_name 2" + assert len(self.documents.get_docs(self.db_session)) == 2 + + # Get all documents + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 2 + + # Update the first document + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/doc/update?name=doc_name"), + json={"name": "doc_name rework", "title": "updated title"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "doc_name rework" + assert data["title"] == "updated title" + + # Tag the first document + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/doc/tags"), + json={ + "name": "doc_name rework", + "tags": [{"name": "testing-tag"}, {"name": "another-tag"}], + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["name"] == "doc_name rework" + assert data["content"] == { + "tags": [{"name": "testing-tag"}, {"name": "another-tag"}] + } + assert len(self.documents.get_docs(self.db_session)) == 2 + + # Delete the first document + with mock_webui_user(id="2"): + response = self.fast_api_client.delete( + self.create_url("/doc/delete?name=doc_name rework") + ) + assert response.status_code == 200 + assert len(self.documents.get_docs(self.db_session)) == 1 diff --git a/backend/test/apps/webui/routers/test_models.py b/backend/test/apps/webui/routers/test_models.py new file mode 100644 index 000000000..991c83bee --- /dev/null +++ b/backend/test/apps/webui/routers/test_models.py @@ -0,0 +1,60 @@ +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + + +class TestModels(AbstractPostgresTest): + + BASE_PATH = "/api/v1/models" + + def setup_class(cls): + super().setup_class() + from apps.webui.models.models import Model + + cls.models = Model + + def test_models(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 0 + + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/add"), + json={ + "id": "my-model", + "base_model_id": "base-model-id", + "name": "Hello World", + "meta": { + "profile_image_url": "/favicon.png", + "description": "description", + "capabilities": None, + "model_config": {}, + }, + "params": {}, + }, + ) + assert response.status_code == 200 + + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 1 + + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/my-model")) + assert response.status_code == 200 + data = response.json() + assert data["id"] == "my-model" + assert data["name"] == "Hello World" + + with mock_webui_user(id="2"): + response = self.fast_api_client.delete( + self.create_url("/delete?id=my-model") + ) + assert response.status_code == 200 + + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 0 diff --git a/backend/test/apps/webui/routers/test_prompts.py b/backend/test/apps/webui/routers/test_prompts.py new file mode 100644 index 000000000..cd2fcec87 --- /dev/null +++ b/backend/test/apps/webui/routers/test_prompts.py @@ -0,0 +1,82 @@ +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + + +class TestPrompts(AbstractPostgresTest): + + BASE_PATH = "/api/v1/prompts" + + def test_prompts(self): + # Get all prompts + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 0 + + # Create a two new prompts + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/create"), + json={ + "command": "/my-command", + "title": "Hello World", + "content": "description", + }, + ) + assert response.status_code == 200 + with mock_webui_user(id="3"): + response = self.fast_api_client.post( + self.create_url("/create"), + json={ + "command": "/my-command2", + "title": "Hello World 2", + "content": "description 2", + }, + ) + assert response.status_code == 200 + + # Get all prompts + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 2 + + # Get prompt by command + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/command/my-command")) + assert response.status_code == 200 + data = response.json() + assert data["command"] == "/my-command" + assert data["title"] == "Hello World" + assert data["content"] == "description" + assert data["user_id"] == "2" + + # Update prompt + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/command/my-command2/update"), + json={ + "command": "irrelevant for request", + "title": "Hello World Updated", + "content": "description Updated", + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["command"] == "/my-command2" + assert data["title"] == "Hello World Updated" + assert data["content"] == "description Updated" + assert data["user_id"] == "3" + + # Delete prompt + with mock_webui_user(id="2"): + response = self.fast_api_client.delete( + self.create_url("/command/my-command/delete") + ) + assert response.status_code == 200 + + # Get all prompts + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + assert len(response.json()) == 1 diff --git a/backend/test/apps/webui/routers/test_users.py b/backend/test/apps/webui/routers/test_users.py new file mode 100644 index 000000000..35b662304 --- /dev/null +++ b/backend/test/apps/webui/routers/test_users.py @@ -0,0 +1,170 @@ +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + + +def _get_user_by_id(data, param): + return next((item for item in data if item["id"] == param), None) + + +def _assert_user(data, id, **kwargs): + user = _get_user_by_id(data, id) + assert user is not None + comparison_data = { + "name": f"user {id}", + "email": f"user{id}@openwebui.com", + "profile_image_url": f"/user{id}.png", + "role": "user", + **kwargs, + } + for key, value in comparison_data.items(): + assert user[key] == value + + +class TestUsers(AbstractPostgresTest): + + BASE_PATH = "/api/v1/users" + + def setup_class(cls): + super().setup_class() + from apps.webui.models.users import Users + + cls.users = Users + + def setup_method(self): + super().setup_method() + self.users.insert_new_user( + self.db_session, + id="1", + name="user 1", + email="user1@openwebui.com", + profile_image_url="/user1.png", + role="user", + ) + self.users.insert_new_user( + self.db_session, + id="2", + name="user 2", + email="user2@openwebui.com", + profile_image_url="/user2.png", + role="user", + ) + + def test_users(self): + # Get all users + with mock_webui_user(id="3"): + response = self.fast_api_client.get(self.create_url("")) + assert response.status_code == 200 + assert len(response.json()) == 2 + data = response.json() + _assert_user(data, "1") + _assert_user(data, "2") + + # update role + with mock_webui_user(id="3"): + response = self.fast_api_client.post( + self.create_url("/update/role"), json={"id": "2", "role": "admin"} + ) + assert response.status_code == 200 + _assert_user([response.json()], "2", role="admin") + + # Get all users + with mock_webui_user(id="3"): + response = self.fast_api_client.get(self.create_url("")) + assert response.status_code == 200 + assert len(response.json()) == 2 + data = response.json() + _assert_user(data, "1") + _assert_user(data, "2", role="admin") + + # Get (empty) user settings + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/user/settings")) + assert response.status_code == 200 + assert response.json() is None + + # Update user settings + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/user/settings/update"), + json={ + "ui": {"attr1": "value1", "attr2": "value2"}, + "model_config": {"attr3": "value3", "attr4": "value4"}, + }, + ) + assert response.status_code == 200 + + # Get user settings + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/user/settings")) + assert response.status_code == 200 + assert response.json() == { + "ui": {"attr1": "value1", "attr2": "value2"}, + "model_config": {"attr3": "value3", "attr4": "value4"}, + } + + # Get (empty) user info + with mock_webui_user(id="1"): + response = self.fast_api_client.get(self.create_url("/user/info")) + assert response.status_code == 200 + assert response.json() is None + + # Update user info + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/user/info/update"), + json={"attr1": "value1", "attr2": "value2"}, + ) + assert response.status_code == 200 + + # Get user info + with mock_webui_user(id="1"): + response = self.fast_api_client.get(self.create_url("/user/info")) + assert response.status_code == 200 + assert response.json() == {"attr1": "value1", "attr2": "value2"} + + # Get user by id + with mock_webui_user(id="1"): + response = self.fast_api_client.get(self.create_url("/2")) + assert response.status_code == 200 + assert response.json() == {"name": "user 2", "profile_image_url": "/user2.png"} + + # Update user by id + with mock_webui_user(id="1"): + response = self.fast_api_client.post( + self.create_url("/2/update"), + json={ + "name": "user 2 updated", + "email": "user2-updated@openwebui.com", + "profile_image_url": "/user2-updated.png", + }, + ) + assert response.status_code == 200 + + # Get all users + with mock_webui_user(id="3"): + response = self.fast_api_client.get(self.create_url("")) + assert response.status_code == 200 + assert len(response.json()) == 2 + data = response.json() + _assert_user(data, "1") + _assert_user( + data, + "2", + role="admin", + name="user 2 updated", + email="user2-updated@openwebui.com", + profile_image_url="/user2-updated.png", + ) + + # Delete user by id + with mock_webui_user(id="1"): + response = self.fast_api_client.delete(self.create_url("/2")) + assert response.status_code == 200 + + # Get all users + with mock_webui_user(id="3"): + response = self.fast_api_client.get(self.create_url("")) + assert response.status_code == 200 + assert len(response.json()) == 1 + data = response.json() + _assert_user(data, "1") diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py new file mode 100644 index 000000000..9cbf42d47 --- /dev/null +++ b/backend/test/util/abstract_integration_test.py @@ -0,0 +1,155 @@ +import logging +import os +import time + +import docker +import pytest +from docker import DockerClient +from pytest_docker.plugin import get_docker_ip +from fastapi.testclient import TestClient +from sqlalchemy import text, create_engine + +log = logging.getLogger(__name__) + + +def get_fast_api_client(): + from main import app + + with TestClient(app) as c: + return c + + +class AbstractIntegrationTest: + BASE_PATH = None + + def create_url(self, path): + if self.BASE_PATH is None: + raise Exception("BASE_PATH is not set") + parts = self.BASE_PATH.split("/") + parts = [part.strip() for part in parts if part.strip() != ""] + path_parts = path.split("/") + path_parts = [part.strip() for part in path_parts if part.strip() != ""] + return "/".join(parts + path_parts) + + @classmethod + def setup_class(cls): + pass + + def setup_method(self): + pass + + @classmethod + def teardown_class(cls): + pass + + def teardown_method(self): + pass + + +class AbstractPostgresTest(AbstractIntegrationTest): + DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" + docker_client: DockerClient + + def get_db(self): + from apps.webui.internal.db import SessionLocal + + return SessionLocal() + + @classmethod + def _create_db_url(cls, env_vars_postgres: dict) -> str: + host = get_docker_ip() + user = env_vars_postgres["POSTGRES_USER"] + pw = env_vars_postgres["POSTGRES_PASSWORD"] + port = 8081 + db = env_vars_postgres["POSTGRES_DB"] + return f"postgresql://{user}:{pw}@{host}:{port}/{db}" + + @classmethod + def setup_class(cls): + super().setup_class() + try: + env_vars_postgres = { + "POSTGRES_USER": "user", + "POSTGRES_PASSWORD": "example", + "POSTGRES_DB": "openwebui", + } + cls.docker_client = docker.from_env() + cls.docker_client.containers.run( + "postgres:16.2", + detach=True, + environment=env_vars_postgres, + name=cls.DOCKER_CONTAINER_NAME, + ports={5432: ("0.0.0.0", 8081)}, + command="postgres -c log_statement=all", + ) + time.sleep(0.5) + + database_url = cls._create_db_url(env_vars_postgres) + os.environ["DATABASE_URL"] = database_url + retries = 10 + db = None + while retries > 0: + try: + from config import BACKEND_DIR + db = create_engine(database_url, pool_pre_ping=True) + db = db.connect() + log.info("postgres is ready!") + break + except Exception as e: + log.warning(e) + time.sleep(3) + retries -= 1 + + if db: + # import must be after setting env! + cls.fast_api_client = get_fast_api_client() + db.close() + else: + raise Exception("Could not connect to Postgres") + except Exception as ex: + log.error(ex) + cls.teardown_class() + pytest.fail(f"Could not setup test environment: {ex}") + + def _check_db_connection(self): + retries = 10 + while retries > 0: + try: + self.db_session.execute(text("SELECT 1")) + self.db_session.commit() + break + except Exception as e: + self.db_session.rollback() + log.warning(e) + time.sleep(3) + retries -= 1 + + def setup_method(self): + super().setup_method() + self.db_session = self.get_db() + self._check_db_connection() + + @classmethod + def teardown_class(cls) -> None: + super().teardown_class() + cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) + + def teardown_method(self): + # rollback everything not yet committed + self.db_session.commit() + + # truncate all tables + tables = [ + "auth", + "chat", + "chatidtag", + "document", + "memory", + "model", + "prompt", + "tag", + '"user"', + ] + for table in tables: + self.db_session.execute(text(f"TRUNCATE TABLE {table}")) + self.db_session.commit() diff --git a/backend/test/util/mock_user.py b/backend/test/util/mock_user.py new file mode 100644 index 000000000..8d0300d3f --- /dev/null +++ b/backend/test/util/mock_user.py @@ -0,0 +1,45 @@ +from contextlib import contextmanager + +from fastapi import FastAPI + + +@contextmanager +def mock_webui_user(**kwargs): + from apps.webui.main import app + + with mock_user(app, **kwargs): + yield + + +@contextmanager +def mock_user(app: FastAPI, **kwargs): + from utils.utils import ( + get_current_user, + get_verified_user, + get_admin_user, + get_current_user_by_api_key, + ) + from apps.webui.models.users import User + + def create_user(): + user_parameters = { + "id": "1", + "name": "John Doe", + "email": "john.doe@openwebui.com", + "role": "user", + "profile_image_url": "/user.png", + "last_active_at": 1627351200, + "updated_at": 1627351200, + "created_at": 162735120, + **kwargs, + } + return User(**user_parameters) + + app.dependency_overrides = { + get_current_user: create_user, + get_verified_user: create_user, + get_admin_user: create_user, + get_current_user_by_api_key: create_user, + } + yield + app.dependency_overrides = {} diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 8c3c899bd..f1225ec0e 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -1,6 +1,8 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends, Request +from sqlalchemy.orm import Session +from apps.webui.internal.db import get_db from apps.webui.models.users import Users from pydantic import BaseModel @@ -77,6 +79,7 @@ def get_http_authorization_cred(auth_header: str): def get_current_user( request: Request, auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), + db=Depends(get_db), ): token = None @@ -91,19 +94,19 @@ def get_current_user( # auth by api key if token.startswith("sk-"): - return get_current_user_by_api_key(token) + return get_current_user_by_api_key(db, token) # auth by jwt token data = decode_token(token) if data != None and "id" in data: - user = Users.get_user_by_id(data["id"]) + user = Users.get_user_by_id(db, data["id"]) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.INVALID_TOKEN, ) else: - Users.update_user_last_active_by_id(user.id) + Users.update_user_last_active_by_id(db, user.id) return user else: raise HTTPException( @@ -112,8 +115,8 @@ def get_current_user( ) -def get_current_user_by_api_key(api_key: str): - user = Users.get_user_by_api_key(api_key) +def get_current_user_by_api_key(db: Session, api_key: str): + user = Users.get_user_by_api_key(db, api_key) if user is None: raise HTTPException( @@ -121,7 +124,7 @@ def get_current_user_by_api_key(api_key: str): detail=ERROR_MESSAGES.INVALID_TOKEN, ) else: - Users.update_user_last_active_by_id(user.id) + Users.update_user_last_active_by_id(db, user.id) return user diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 9faa358d3..17d11d816 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -63,10 +63,7 @@ export const getModelInfos = async (token: string = '') => { export const getModelById = async (token: string, id: string) => { let error = null; - const searchParams = new URLSearchParams(); - searchParams.append('id', id); - - const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, { method: 'GET', headers: { Accept: 'application/json', From bee835cb65a8b3feba6824d2e6c9378b95f6e990 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Fri, 21 Jun 2024 14:58:57 +0200 Subject: [PATCH 002/115] feat(sqlalchemy): remove session reference from router --- backend/apps/ollama/main.py | 7 +- backend/apps/openai/main.py | 4 +- backend/apps/webui/internal/db.py | 9 +- backend/apps/webui/main.py | 4 +- backend/apps/webui/models/auths.py | 158 ++++---- backend/apps/webui/models/chats.py | 336 ++++++++++-------- backend/apps/webui/models/documents.py | 83 +++-- backend/apps/webui/models/files.py | 45 ++- backend/apps/webui/models/functions.py | 117 +++--- backend/apps/webui/models/memories.py | 63 ++-- backend/apps/webui/models/models.py | 55 +-- backend/apps/webui/models/prompts.py | 102 +++--- backend/apps/webui/models/tags.py | 217 +++++------ backend/apps/webui/models/tools.py | 70 ++-- backend/apps/webui/models/users.py | 309 ++++++++-------- backend/apps/webui/routers/auths.py | 60 ++-- backend/apps/webui/routers/chats.py | 116 +++--- backend/apps/webui/routers/documents.py | 26 +- backend/apps/webui/routers/files.py | 27 +- backend/apps/webui/routers/functions.py | 27 +- backend/apps/webui/routers/memories.py | 23 +- backend/apps/webui/routers/models.py | 23 +- backend/apps/webui/routers/prompts.py | 22 +- backend/apps/webui/routers/tools.py | 23 +- backend/apps/webui/routers/users.py | 49 ++- backend/main.py | 31 +- .../migrations/versions/22b5ab2667b8_init.py | 188 ---------- .../migrations/versions/ba76b0bae648_init.py | 161 +++++++++ backend/test/apps/webui/routers/test_auths.py | 19 +- backend/test/apps/webui/routers/test_chats.py | 38 +- .../test/apps/webui/routers/test_documents.py | 10 +- .../test/apps/webui/routers/test_prompts.py | 10 + backend/test/apps/webui/routers/test_users.py | 2 - backend/utils/utils.py | 8 +- 34 files changed, 1231 insertions(+), 1211 deletions(-) delete mode 100644 backend/migrations/versions/22b5ab2667b8_init.py create mode 100644 backend/migrations/versions/ba76b0bae648_init.py diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 85bb4c0df..455dc89a5 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -31,7 +31,6 @@ from typing import Optional, List, Union from starlette.background import BackgroundTask -from apps.webui.internal.db import get_db from apps.webui.models.models import Models from apps.webui.models.users import Users from constants import ERROR_MESSAGES @@ -712,7 +711,6 @@ async def generate_chat_completion( form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), - db=Depends(get_db), ): log.debug( @@ -726,7 +724,7 @@ async def generate_chat_completion( } model_id = form_data.model - model_info = Models.get_model_by_id(db, model_id) + model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: @@ -885,7 +883,6 @@ async def generate_openai_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), - db=Depends(get_db), ): form_data = OpenAIChatCompletionForm(**form_data) @@ -894,7 +891,7 @@ async def generate_openai_chat_completion( } model_id = form_data.model - model_info = Models.get_model_by_id(db, model_id) + model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index bc40bc661..302dd8d98 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -11,7 +11,6 @@ import logging from pydantic import BaseModel from starlette.background import BackgroundTask -from apps.webui.internal.db import get_db from apps.webui.models.models import Models from apps.webui.models.users import Users from constants import ERROR_MESSAGES @@ -354,13 +353,12 @@ async def generate_chat_completion( form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), - db=Depends(get_db), ): idx = 0 payload = {**form_data} model_id = form_data.get("model") - model_info = Models.get_model_by_id(db, model_id) + model_info = Models.get_model_by_id(model_id) if model_info: if model_info.base_model_id: diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 5acf83d5c..3c37bb09b 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -1,6 +1,7 @@ import os import logging import json +from contextlib import contextmanager from typing import Optional, Any from typing_extensions import Self @@ -52,11 +53,12 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL: ) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) Base = declarative_base() -def get_db(): +@contextmanager +def get_session(): db = SessionLocal() try: yield db @@ -64,5 +66,4 @@ def get_db(): except Exception as e: db.rollback() raise e - finally: - db.close() + diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 8bef22c05..1ba8a080e 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -114,8 +114,8 @@ async def get_status(): } -async def get_pipe_models(db: Session): - pipes = Functions.get_functions_by_type(db, "pipe", active_only=True) +async def get_pipe_models(): + pipes = Functions.get_functions_by_type("pipe", active_only=True) pipe_models = [] for pipe in pipes: diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 5ff348dac..fd2934bb1 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import Base +from apps.webui.internal.db import Base, get_session from config import SRC_LOG_LEVELS @@ -96,7 +96,6 @@ class AuthsTable: def insert_new_auth( self, - db: Session, email: str, password: str, name: str, @@ -104,100 +103,107 @@ class AuthsTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - log.info("insert_new_auth") + with get_session() as db: + log.info("insert_new_auth") - id = str(uuid.uuid4()) + id = str(uuid.uuid4()) - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) - result = Auth(**auth.model_dump()) - db.add(result) + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = Auth(**auth.model_dump()) + db.add(result) - user = Users.insert_new_user( - db, id, name, email, profile_image_url, role, oauth_sub - ) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub + ) - db.commit() - db.refresh(result) + db.commit() + db.refresh(result) - if result and user: - return user - else: - return None - - def authenticate_user( - self, db: Session, email: str, password: str - ) -> Optional[UserModel]: - log.info(f"authenticate_user: {email}") - try: - auth = db.query(Auth).filter_by(email=email, active=True).first() - if auth: - if verify_password(password, auth.password): - user = Users.get_user_by_id(db, auth.id) - return user - else: - return None + if result and user: + return user else: return None - except: - return None + + def authenticate_user( + self, email: str, password: str + ) -> Optional[UserModel]: + log.info(f"authenticate_user: {email}") + with get_session() as db: + try: + auth = db.query(Auth).filter_by(email=email, active=True).first() + if auth: + if verify_password(password, auth.password): + user = Users.get_user_by_id(auth.id) + return user + else: + return None + else: + return None + except: + return None def authenticate_user_by_api_key( - self, db: Session, api_key: str + self, api_key: str ) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") - # if no api_key, return None - if not api_key: - return None + with get_session() as db: + # if no api_key, return None + if not api_key: + return None - try: - user = Users.get_user_by_api_key(db, api_key) - return user if user else None - except: - return False + try: + user = Users.get_user_by_api_key(api_key) + return user if user else None + except: + return False def authenticate_user_by_trusted_header( - self, db: Session, email: str + self, email: str ) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") - try: - auth = db.query(Auth).filter(email=email, active=True).first() - if auth: - user = Users.get_user_by_id(auth.id) - return user - except: - return None + with get_session() as db: + try: + auth = db.query(Auth).filter(email=email, active=True).first() + if auth: + user = Users.get_user_by_id(auth.id) + return user + except: + return None def update_user_password_by_id( - self, db: Session, id: str, new_password: str + self, id: str, new_password: str ) -> bool: - try: - result = db.query(Auth).filter_by(id=id).update({"password": new_password}) - return True if result == 1 else False - except: - return False - - def update_email_by_id(self, db: Session, id: str, email: str) -> bool: - try: - result = db.query(Auth).filter_by(id=id).update({"email": email}) - return True if result == 1 else False - except: - return False - - def delete_auth_by_id(self, db: Session, id: str) -> bool: - try: - # Delete User - result = Users.delete_user_by_id(db, id) - - if result: - db.query(Auth).filter_by(id=id).delete() - - return True - else: + with get_session() as db: + try: + result = db.query(Auth).filter_by(id=id).update({"password": new_password}) + return True if result == 1 else False + except: + return False + + def update_email_by_id(self, id: str, email: str) -> bool: + with get_session() as db: + try: + result = db.query(Auth).filter_by(id=id).update({"email": email}) + return True if result == 1 else False + except: + return False + + def delete_auth_by_id(self, id: str) -> bool: + with get_session() as db: + try: + # Delete User + result = Users.delete_user_by_id(id) + + if result: + db.query(Auth).filter_by(id=id).delete() + + return True + else: + return False + except: return False - except: - return False Auths = AuthsTable() diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index dd92fd0a1..d71ffd992 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -8,7 +8,7 @@ import time from sqlalchemy import Column, String, BigInteger, Boolean from sqlalchemy.orm import Session -from apps.webui.internal.db import Base +from apps.webui.internal.db import Base, get_session #################### @@ -80,249 +80,269 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: def insert_new_chat( - self, db: Session, user_id: str, form_data: ChatForm + self, user_id: str, form_data: ChatForm ) -> Optional[ChatModel]: - id = str(uuid.uuid4()) - chat = ChatModel( - **{ - "id": id, - "user_id": user_id, - "title": ( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" - ), - "chat": json.dumps(form_data.chat), - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) - - result = Chat(**chat.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - return ChatModel.model_validate(result) if result else None - - def update_chat_by_id( - self, db: Session, id: str, chat: dict - ) -> Optional[ChatModel]: - try: - db.query(Chat).filter_by(id=id).update( - { - "chat": json.dumps(chat), - "title": chat["title"] if "title" in chat else "New Chat", + with get_session() as db: + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] if "title" in form_data.chat else "New Chat" + ), + "chat": json.dumps(form_data.chat), + "created_at": int(time.time()), "updated_at": int(time.time()), } ) - return self.get_chat_by_id(db, id) - except: - return None + result = Chat(**chat.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return ChatModel.model_validate(result) if result else None + + def update_chat_by_id( + self, id: str, chat: dict + ) -> Optional[ChatModel]: + with get_session() as db: + try: + chat_obj = db.get(Chat, id) + chat_obj.chat = json.dumps(chat) + chat_obj.title = chat["title"] if "title" in chat else "New Chat" + chat_obj.updated_at = int(time.time()) + db.commit() + db.refresh(chat_obj) + + return ChatModel.model_validate(chat_obj) + except Exception as e: + return None def insert_shared_chat_by_chat_id( - self, db: Session, chat_id: str + self, chat_id: str ) -> Optional[ChatModel]: - # Get the existing chat to share - chat = db.get(Chat, chat_id) - # Check if the chat is already shared - if chat.share_id: - return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared") - # Create a new chat with the same data, but with a new ID - shared_chat = ChatModel( - **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "created_at": chat.created_at, - "updated_at": int(time.time()), - } - ) - shared_result = Chat(**shared_chat.model_dump()) - db.add(shared_result) - db.commit() - db.refresh(shared_result) - # Update the original chat with the share_id - result = ( - db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id}) - ) - - return shared_chat if (shared_result and result) else None - - def update_shared_chat_by_chat_id( - self, db: Session, chat_id: str - ) -> Optional[ChatModel]: - try: - print("update_shared_chat_by_id") + with get_session() as db: + # Get the existing chat to share chat = db.get(Chat, chat_id) - print(chat) - - db.query(Chat).filter_by(id=chat.share_id).update( - {"title": chat.title, "chat": chat.chat} + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "created_at": chat.created_at, + "updated_at": int(time.time()), + } + ) + shared_result = Chat(**shared_chat.model_dump()) + db.add(shared_result) + db.commit() + db.refresh(shared_result) + # Update the original chat with the share_id + result = ( + db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id}) ) - return self.get_chat_by_id(db, chat.share_id) - except: - return None + return shared_chat if (shared_result and result) else None - def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool: + def update_shared_chat_by_chat_id( + self, chat_id: str + ) -> Optional[ChatModel]: + with get_session() as db: + try: + print("update_shared_chat_by_id") + chat = db.get(Chat, chat_id) + print(chat) + chat.title = chat.title + chat.chat = chat.chat + db.commit() + db.refresh(chat) + + return self.get_chat_by_id(chat.share_id) + except: + return None + + def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + with get_session() as db: + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() return True except: return False def update_chat_share_id_by_id( - self, db: Session, id: str, share_id: Optional[str] + self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - db.query(Chat).filter_by(id=id).update({"share_id": share_id}) - - return self.get_chat_by_id(db, id) + with get_session() as db: + chat = db.get(Chat, id) + chat.share_id = share_id + db.commit() + db.refresh(chat) + return chat except: return None - def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]: + def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = self.get_chat_by_id(db, id) - db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) + with get_session() as db: + chat = self.get_chat_by_id(id) + db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) - return self.get_chat_by_id(db, id) + return self.get_chat_by_id(id) except: return None - def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool: + def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) + with get_session() as db: + db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) return True except: return False def get_archived_chat_list_by_user_id( - self, db: Session, user_id: str, skip: int = 0, limit: int = 50 + self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_session() as db: + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, - db: Session, user_id: str, include_archived: bool = False, skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - query = db.query(Chat).filter_by(user_id=user_id) - if not include_archived: - query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_session() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( - self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50 + self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - db.query(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_session() as db: + all_chats = ( + db.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) + .order_by(Chat.updated_at.desc()) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]: + def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = db.get(Chat, id) - return ChatModel.model_validate(chat) + with get_session() as db: + chat = db.get(Chat, id) + return ChatModel.model_validate(chat) except: return None - def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]: + def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - chat = db.query(Chat).filter_by(share_id=id).first() + with get_session() as db: + chat = db.query(Chat).filter_by(share_id=id).first() - if chat: - return self.get_chat_by_id(db, id) - else: - return None + if chat: + return self.get_chat_by_id(id) + else: + return None except Exception as e: return None def get_chat_by_id_and_user_id( - self, db: Session, id: str, user_id: str + self, id: str, user_id: str ) -> Optional[ChatModel]: try: - chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() - return ChatModel.model_validate(chat) + with get_session() as db: + chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None - def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]: - all_chats = ( - db.query(Chat) - # .limit(limit).offset(skip) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: + with get_session() as db: + all_chats = ( + db.query(Chat) + # .limit(limit).offset(skip) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] - def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]: - all_chats = ( - db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: + with get_session() as db: + all_chats = ( + db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id( - self, db: Session, user_id: str + self, user_id: str ) -> List[ChatModel]: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_session() as db: + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] - def delete_chat_by_id(self, db: Session, id: str) -> bool: + def delete_chat_by_id(self, id: str) -> bool: try: - db.query(Chat).filter_by(id=id).delete() + with get_session() as db: + db.query(Chat).filter_by(id=id).delete() - return True and self.delete_shared_chat_by_chat_id(db, id) + return True and self.delete_shared_chat_by_chat_id(id) except: return False - def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool: + def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - db.query(Chat).filter_by(id=id, user_id=user_id).delete() + with get_session() as db: + db.query(Chat).filter_by(id=id, user_id=user_id).delete() - return True and self.delete_shared_chat_by_chat_id(db, id) + return True and self.delete_shared_chat_by_chat_id(id) except: return False - def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool: + def delete_chats_by_user_id(self, user_id: str) -> bool: try: + with get_session() as db: + self.delete_shared_chats_by_user_id(user_id) - self.delete_shared_chats_by_user_id(db, user_id) - - db.query(Chat).filter_by(user_id=user_id).delete() + db.query(Chat).filter_by(user_id=user_id).delete() return True except: return False - def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool: + def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + with get_session() as db: + chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() return True except: diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index b272a5912..6348967db 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -6,7 +6,7 @@ import logging from sqlalchemy import String, Column, BigInteger from sqlalchemy.orm import Session -from apps.webui.internal.db import Base +from apps.webui.internal.db import Base, get_session import json @@ -73,7 +73,7 @@ class DocumentForm(DocumentUpdateForm): class DocumentsTable: def insert_new_doc( - self, db: Session, user_id: str, form_data: DocumentForm + self, user_id: str, form_data: DocumentForm ) -> Optional[DocumentModel]: document = DocumentModel( **{ @@ -84,66 +84,73 @@ class DocumentsTable: ) try: - result = Document(**document.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: - return None + with get_session() as db: + result = Document(**document.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return DocumentModel.model_validate(result) + else: + return None except: return None - def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]: + def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: - document = db.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None + with get_session() as db: + document = db.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None - def get_docs(self, db: Session) -> List[DocumentModel]: - return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()] + def get_docs(self) -> List[DocumentModel]: + with get_session() as db: + return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()] def update_doc_by_name( - self, db: Session, name: str, form_data: DocumentUpdateForm + self, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - db.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - return self.get_doc_by_name(db, form_data.name) + with get_session() as db: + db.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(form_data.name) except Exception as e: log.exception(e) return None def update_doc_content_by_name( - self, db: Session, name: str, updated: dict + self, name: str, updated: dict ) -> Optional[DocumentModel]: try: - doc = self.get_doc_by_name(db, name) - doc_content = json.loads(doc.content if doc.content else "{}") - doc_content = {**doc_content, **updated} + with get_session() as db: + doc = self.get_doc_by_name(name) + doc_content = json.loads(doc.content if doc.content else "{}") + doc_content = {**doc_content, **updated} - db.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - - return self.get_doc_by_name(db, name) + db.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(name) except Exception as e: log.exception(e) return None - def delete_doc_by_name(self, db: Session, name: str) -> bool: + def delete_doc_by_name(self, name: str) -> bool: try: - db.query(Document).filter_by(name=name).delete() + with get_session() as db: + db.query(Document).filter_by(name=name).delete() return True except: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index dc9f6be39..d2565db3d 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -6,7 +6,7 @@ import logging from sqlalchemy import Column, String, BigInteger from sqlalchemy.orm import Session -from apps.webui.internal.db import JSONField, Base +from apps.webui.internal.db import JSONField, Base, get_session import json @@ -60,7 +60,7 @@ class FileForm(BaseModel): class FilesTable: - def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]: + def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: file = FileModel( **{ **form_data.model_dump(), @@ -70,38 +70,45 @@ class FilesTable: ) try: - result = File(**file.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return FileModel.model_validate(result) - else: - return None + with get_session() as db: + result = File(**file.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FileModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None - def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]: + def get_file_by_id(self, id: str) -> Optional[FileModel]: try: - file = db.get(File, id) - return FileModel.model_validate(file) + with get_session() as db: + file = db.get(File, id) + return FileModel.model_validate(file) except: return None - def get_files(self, db: Session) -> List[FileModel]: - return [FileModel.model_validate(file) for file in db.query(File).all()] + def get_files(self) -> List[FileModel]: + with get_session() as db: + return [FileModel.model_validate(file) for file in db.query(File).all()] - def delete_file_by_id(self, db: Session, id: str) -> bool: + def delete_file_by_id(self, id: str) -> bool: try: - db.query(File).filter_by(id=id).delete() + with get_session() as db: + db.query(File).filter_by(id=id).delete() + db.commit() return True except: return False - def delete_all_files(self, db: Session) -> bool: + def delete_all_files(self) -> bool: try: - db.query(File).delete() + with get_session() as db: + db.query(File).delete() + db.commit() return True except: return False diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 88fa24a21..417e52329 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -6,7 +6,7 @@ import logging from sqlalchemy import Column, String, Text, BigInteger, Boolean from sqlalchemy.orm import Session -from apps.webui.internal.db import JSONField, Base +from apps.webui.internal.db import JSONField, Base, get_session from apps.webui.models.users import Users import json @@ -87,7 +87,7 @@ class FunctionValves(BaseModel): class FunctionsTable: def insert_new_function( - self, db: Session, user_id: str, type: str, form_data: FunctionForm + self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: function = FunctionModel( **{ @@ -100,57 +100,64 @@ class FunctionsTable: ) try: - result = Function(**function.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return FunctionModel.model_validate(result) - else: - return None + with get_session() as db: + result = Function(**function.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FunctionModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None - def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]: + def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - function = db.get(Function, id) - return FunctionModel.model_validate(function) + with get_session() as db: + function = db.get(Function, id) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: if active_only: - return [ - FunctionModel(**model_to_dict(function)) - for function in Function.select().where(Function.is_active == True) - ] + with get_session() as db: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(is_active=True).all() + ] else: - return [ - FunctionModel(**model_to_dict(function)) - for function in Function.select() - ] + with get_session() as db: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).all() + ] def get_functions_by_type( self, type: str, active_only=False ) -> List[FunctionModel]: if active_only: - return [ - FunctionModel(**model_to_dict(function)) - for function in Function.select().where( - Function.type == type, Function.is_active == True - ) - ] + with get_session() as db: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by( + type=type, is_active=True + ).all() + ] else: - return [ - FunctionModel(**model_to_dict(function)) - for function in Function.select().where(Function.type == type) - ] + with get_session() as db: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(type=type).all() + ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: try: - function = Function.get(Function.id == id) - return function.valves if function.valves else {} + with get_session() as db: + function = db.get(Function, id) + return function.valves if function.valves else {} except Exception as e: print(f"An error occurred: {e}") return None @@ -159,14 +166,12 @@ class FunctionsTable: self, id: str, valves: dict ) -> Optional[FunctionValves]: try: - query = Function.update( - **{"valves": valves}, - updated_at=int(time.time()), - ).where(Function.id == id) - query.execute() - - function = Function.get(Function.id == id) - return FunctionValves(**model_to_dict(function)) + with get_session() as db: + db.query(Function).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + db.commit() + return self.get_function_by_id(id) except: return None @@ -214,30 +219,32 @@ class FunctionsTable: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: try: - db.query(Function).filter_by(id=id).update({ - **updated, - "updated_at": int(time.time()), - }) - return self.get_function_by_id(db, id) + with get_session() as db: + db.query(Function).filter_by(id=id).update({ + **updated, + "updated_at": int(time.time()), + }) + db.commit() + return self.get_function_by_id(id) except: return None def deactivate_all_functions(self) -> Optional[bool]: try: - query = Function.update( - **{"is_active": False}, - updated_at=int(time.time()), - ) - - query.execute() - + with get_session() as db: + db.query(Function).update({ + "is_active": False, + "updated_at": int(time.time()), + }) + db.commit() return True except: return None - def delete_function_by_id(self, db: Session, id: str) -> bool: + def delete_function_by_id(self, id: str) -> bool: try: - db.query(Function).filter_by(id=id).delete() + with get_session() as db: + db.query(Function).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index f5f6d13fb..941da5b26 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -4,7 +4,7 @@ from typing import List, Union, Optional from sqlalchemy import Column, String, BigInteger from sqlalchemy.orm import Session -from apps.webui.internal.db import Base +from apps.webui.internal.db import Base, get_session from apps.webui.models.chats import Chats import time @@ -44,7 +44,6 @@ class MemoriesTable: def insert_new_memory( self, - db: Session, user_id: str, content: str, ) -> Optional[MemoryModel]: @@ -59,53 +58,59 @@ class MemoriesTable: "updated_at": int(time.time()), } ) - result = Memory(**memory.dict()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + with get_session() as db: + result = Memory(**memory.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return MemoryModel.model_validate(result) + else: + return None def update_memory_by_id( self, - db: Session, id: str, content: str, ) -> Optional[MemoryModel]: try: - db.query(Memory).filter_by(id=id).update( - {"content": content, "updated_at": int(time.time())} - ) - return self.get_memory_by_id(db, id) + with get_session() as db: + db.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + db.commit() + return self.get_memory_by_id(id) except: return None - def get_memories(self, db: Session) -> List[MemoryModel]: + def get_memories(self) -> List[MemoryModel]: try: - memories = db.query(Memory).all() - return [MemoryModel.model_validate(memory) for memory in memories] + with get_session() as db: + memories = db.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None - def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]: + def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: try: - memories = db.query(Memory).filter_by(user_id=user_id).all() - return [MemoryModel.model_validate(memory) for memory in memories] + with get_session() as db: + memories = db.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None - def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]: + def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: try: - memory = db.get(Memory, id) - return MemoryModel.model_validate(memory) + with get_session() as db: + memory = db.get(Memory, id) + return MemoryModel.model_validate(memory) except: return None - def delete_memory_by_id(self, db: Session, id: str) -> bool: + def delete_memory_by_id(self, id: str) -> bool: try: - db.query(Memory).filter_by(id=id).delete() + with get_session() as db: + db.query(Memory).filter_by(id=id).delete() return True except: @@ -113,7 +118,8 @@ class MemoriesTable: def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool: try: - db.query(Memory).filter_by(user_id=user_id).delete() + with get_session() as db: + db.query(Memory).filter_by(user_id=user_id).delete() return True except: return False @@ -122,7 +128,8 @@ class MemoriesTable: self, db: Session, id: str, user_id: str ) -> bool: try: - db.query(Memory).filter_by(id=id, user_id=user_id).delete() + with get_session() as db: + db.query(Memory).filter_by(id=id, user_id=user_id).delete() return True except: return False diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 137333409..7641ee5a0 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from sqlalchemy import String, Column, BigInteger from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, JSONField +from apps.webui.internal.db import Base, JSONField, get_session from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -78,8 +78,6 @@ class Model(Base): class ModelModel(BaseModel): - model_config = ConfigDict(from_attributes=True) - id: str user_id: str base_model_id: Optional[str] = None @@ -91,6 +89,8 @@ class ModelModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) + #################### # Forms @@ -116,7 +116,7 @@ class ModelForm(BaseModel): class ModelsTable: def insert_new_model( - self, db: Session, form_data: ModelForm, user_id: str + self, form_data: ModelForm, user_id: str ) -> Optional[ModelModel]: model = ModelModel( **{ @@ -127,47 +127,52 @@ class ModelsTable: } ) try: - result = Model(**model.dict()) - db.add(result) - db.commit() - db.refresh(result) + with get_session() as db: + result = Model(**model.model_dump()) + db.add(result) + db.commit() + db.refresh(result) - if result: - return ModelModel.model_validate(result) - else: - return None + if result: + return ModelModel.model_validate(result) + else: + return None except Exception as e: print(e) return None - def get_all_models(self, db: Session) -> List[ModelModel]: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] + def get_all_models(self) -> List[ModelModel]: + with get_session() as db: + return [ModelModel.model_validate(model) for model in db.query(Model).all()] - def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]: + def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - model = db.get(Model, id) - return ModelModel.model_validate(model) + with get_session() as db: + model = db.get(Model, id) + return ModelModel.model_validate(model) except: return None def update_model_by_id( - self, db: Session, id: str, model: ModelForm + self, id: str, model: ModelForm ) -> Optional[ModelModel]: try: # update only the fields that are present in the model - model = db.query(Model).get(id) - model.update(**model.model_dump()) - db.commit() - db.refresh(model) - return ModelModel.model_validate(model) + with get_session() as db: + model = db.query(Model).get(id) + model.update(**model.model_dump()) + db.commit() + db.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) return None - def delete_model_by_id(self, db: Session, id: str) -> bool: + def delete_model_by_id(self, id: str) -> bool: try: - db.query(Model).filter_by(id=id).delete() + with get_session() as db: + db.query(Model).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index 21c4de3e1..2157153d8 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -5,7 +5,7 @@ import time from sqlalchemy import String, Column, BigInteger from sqlalchemy.orm import Session -from apps.webui.internal.db import Base +from apps.webui.internal.db import Base, get_session import json @@ -48,61 +48,65 @@ class PromptForm(BaseModel): class PromptsTable: def insert_new_prompt( - self, db: Session, user_id: str, form_data: PromptForm + self, user_id: str, form_data: PromptForm ) -> Optional[PromptModel]: - prompt = PromptModel( - **{ - "user_id": user_id, - "command": form_data.command, - "title": form_data.title, - "content": form_data.content, - "timestamp": int(time.time()), - } - ) - - try: - result = Prompt(**prompt.dict()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return PromptModel.model_validate(result) - else: - return None - except Exception as e: - return None - - def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]: - try: - prompt = db.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) - except: - return None - - def get_prompts(self, db: Session) -> List[PromptModel]: - return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()] - - def update_prompt_by_command( - self, db: Session, command: str, form_data: PromptForm - ) -> Optional[PromptModel]: - try: - db.query(Prompt).filter_by(command=command).update( - { + with get_session() as db: + prompt = PromptModel( + **{ + "user_id": user_id, + "command": form_data.command, "title": form_data.title, "content": form_data.content, "timestamp": int(time.time()), } ) - return self.get_prompt_by_command(db, command) - except: - return None - def delete_prompt_by_command(self, db: Session, command: str) -> bool: - try: - db.query(Prompt).filter_by(command=command).delete() - return True - except: - return False + try: + result = Prompt(**prompt.dict()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return PromptModel.model_validate(result) + else: + return None + except Exception as e: + return None + + def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: + with get_session() as db: + try: + prompt = db.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) + except: + return None + + def get_prompts(self) -> List[PromptModel]: + with get_session() as db: + return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()] + + def update_prompt_by_command( + self, command: str, form_data: PromptForm + ) -> Optional[PromptModel]: + with get_session() as db: + try: + prompt = db.query(Prompt).filter_by(command=command).first() + prompt.title = form_data.title + prompt.content = form_data.content + prompt.timestamp = int(time.time()) + db.commit() + return prompt + # return self.get_prompt_by_command(command) + except: + return None + + def delete_prompt_by_command(self, command: str) -> bool: + with get_session() as db: + try: + db.query(Prompt).filter_by(command=command).delete() + return True + except: + return False Prompts = PromptsTable() diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 419425662..5ad176c37 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -9,7 +9,7 @@ import logging from sqlalchemy import String, Column, BigInteger from sqlalchemy.orm import Session -from apps.webui.internal.db import Base +from apps.webui.internal.db import Base, get_session from config import SRC_LOG_LEVELS @@ -80,37 +80,39 @@ class ChatTagsResponse(BaseModel): class TagTable: def insert_new_tag( - self, db: Session, name: str, user_id: str + self, name: str, user_id: str ) -> Optional[TagModel]: id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: - result = Tag(**tag.dict()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return TagModel.model_validate(result) - else: - return None + with get_session() as db: + result = Tag(**tag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return TagModel.model_validate(result) + else: + return None except Exception as e: return None def get_tag_by_name_and_user_id( - self, db: Session, name: str, user_id: str + self, name: str, user_id: str ) -> Optional[TagModel]: try: - tag = db.query(Tag).filter(name=name, user_id=user_id).first() - return TagModel.model_validate(tag) + with get_session() as db: + tag = db.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None def add_tag_to_chat( - self, db: Session, user_id: str, form_data: ChatIdTagForm + self, user_id: str, form_data: ChatIdTagForm ) -> Optional[ChatIdTagModel]: - tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id) + tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) if tag == None: - tag = self.insert_new_tag(db, form_data.tag_name, user_id) + tag = self.insert_new_tag(form_data.tag_name, user_id) id = str(uuid.uuid4()) chatIdTag = ChatIdTagModel( @@ -123,118 +125,127 @@ class TagTable: } ) try: - result = ChatIdTag(**chatIdTag.dict()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None + with get_session() as db: + result = ChatIdTag(**chatIdTag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ChatIdTagModel.model_validate(result) + else: + return None except: return None - def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: + with get_session() as db: + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_tags_by_chat_id_and_user_id( - self, db: Session, chat_id: str, user_id: str + self, chat_id: str, user_id: str ) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_session() as db: + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_chat_ids_by_tag_name_and_user_id( - self, db: Session, tag_name: str, user_id: str + self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_session() as db: + return [ + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] def count_chat_ids_by_tag_name_and_user_id( - self, db: Session, tag_name: str, user_id: str + self, tag_name: str, user_id: str ) -> int: - return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count() + with get_session() as db: + return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count() def delete_tag_by_tag_name_and_user_id( - self, db: Session, tag_name: str, user_id: str + self, tag_name: str, user_id: str ) -> bool: try: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") + with get_session() as db: + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - db, tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: log.error(f"delete_tag: {e}") return False def delete_tag_by_tag_name_and_chat_id_and_user_id( - self, db: Session, tag_name: str, chat_id: str, user_id: str + self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") + with get_session() as db: + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - db, tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: @@ -242,13 +253,13 @@ class TagTable: return False def delete_tags_by_chat_id_and_user_id( - self, db: Session, chat_id: str, user_id: str + self, chat_id: str, user_id: str ) -> bool: - tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id) + tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) for tag in tags: self.delete_tag_by_tag_name_and_chat_id_and_user_id( - db, tag.tag_name, chat_id, user_id + tag.tag_name, chat_id, user_id ) return True diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index b8df2e163..534a4e3e8 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import String, Column, BigInteger from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, JSONField +from apps.webui.internal.db import Base, JSONField, get_session from apps.webui.models.users import Users import json @@ -82,7 +82,7 @@ class ToolValves(BaseModel): class ToolsTable: def insert_new_tool( - self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict] + self, user_id: str, form_data: ToolForm, specs: List[dict] ) -> Optional[ToolModel]: tool = ToolModel( **{ @@ -95,46 +95,48 @@ class ToolsTable: ) try: - result = Tool(**tool.dict()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ToolModel.model_validate(result) - else: - return None + with get_session() as db: + result = Tool(**tool.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ToolModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None - def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]: + def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - tool = db.get(Tool, id) - return ToolModel.model_validate(tool) + with get_session() as db: + tool = db.get(Tool, id) + return ToolModel.model_validate(tool) except: return None - def get_tools(self, db: Session) -> List[ToolModel]: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + def get_tools(self) -> List[ToolModel]: + with get_session() as db: + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - tool = Tool.get(Tool.id == id) - return tool.valves if tool.valves else {} + with get_session() as db: + tool = db.get(Tool, id) + return tool.valves if tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - query = Tool.update( - **{"valves": valves}, - updated_at=int(time.time()), - ).where(Tool.id == id) - query.execute() - - tool = Tool.get(Tool.id == id) - return ToolValves(**model_to_dict(tool)) + with get_session() as db: + db.query(Tool).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + db.commit() + return self.get_tool_by_id(id) except: return None @@ -172,8 +174,7 @@ class ToolsTable: user_settings["tools"]["valves"][id] = valves # Update the user settings in the database - query = Users.update_user_by_id(user_id, {"settings": user_settings}) - query.execute() + Users.update_user_by_id(user_id, {"settings": user_settings}) return user_settings["tools"]["valves"][id] except Exception as e: @@ -182,16 +183,19 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - db.query(Tool).filter_by(id=id).update( - {**updated, "updated_at": int(time.time())} - ) - return self.get_tool_by_id(db, id) + with get_session() as db: + db.query(Tool).filter_by(id=id).update( + {**updated, "updated_at": int(time.time())} + ) + db.commit() + return self.get_tool_by_id(id) except: return None - def delete_tool_by_id(self, db: Session, id: str) -> bool: + def delete_tool_by_id(self, id: str) -> bool: try: - db.query(Tool).filter_by(id=id).delete() + with get_session() as db: + db.query(Tool).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 7202d2d71..bef15185b 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from utils.misc import get_gravatar_url -from apps.webui.internal.db import Base, JSONField +from apps.webui.internal.db import Base, JSONField, get_session from apps.webui.models.chats import Chats #################### @@ -42,8 +42,6 @@ class UserSettings(BaseModel): class UserModel(BaseModel): - model_config = ConfigDict(from_attributes=True) - id: str name: str email: str @@ -60,6 +58,8 @@ class UserModel(BaseModel): oauth_sub: Optional[str] = None + model_config = ConfigDict(from_attributes=True) + #################### # Forms @@ -82,7 +82,6 @@ class UsersTable: def insert_new_user( self, - db: Session, id: str, name: str, email: str, @@ -90,165 +89,181 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - user = UserModel( - **{ - "id": id, - "name": name, - "email": email, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "oauth_sub": oauth_sub, - } - ) - result = User(**user.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return user - else: - return None + with get_session() as db: + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": profile_image_url, + "last_active_at": int(time.time()), + "created_at": int(time.time()), + "updated_at": int(time.time()), + "oauth_sub": oauth_sub, + } + ) + result = User(**user.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return user + else: + return None - def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]: - try: - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except Exception as e: - return None + def get_user_by_id(self, id: str) -> Optional[UserModel]: + with get_session() as db: + try: + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except Exception as e: + return None - def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]: - try: - user = db.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) - except: - return None + def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: + with get_session() as db: + try: + user = db.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) + except: + return None - def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]: - try: - user = db.query(User).filter_by(email=email).first() - return UserModel.model_validate(user) - except: - return None + def get_user_by_email(self, email: str) -> Optional[UserModel]: + with get_session() as db: + try: + user = db.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) + except: + return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: - try: - user = User.get(User.oauth_sub == sub) - return UserModel(**model_to_dict(user)) - except: - return None + with get_session() as db: + try: + user = db.query(User).filter_by(oauth_sub=sub).first() + return UserModel.model_validate(user) + except: + return None - def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]: - users = ( - db.query(User) - # .offset(skip).limit(limit) - .all() - ) - return [UserModel.model_validate(user) for user in users] + def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: + with get_session() as db: + users = ( + db.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] - def get_num_users(self, db: Session) -> Optional[int]: - return db.query(User).count() + def get_num_users(self) -> Optional[int]: + with get_session() as db: + return db.query(User).count() - def get_first_user(self, db: Session) -> UserModel: - try: - user = db.query(User).order_by(User.created_at).first() - return UserModel.model_validate(user) - except: - return None + def get_first_user(self) -> UserModel: + with get_session() as db: + try: + user = db.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) + except: + return None def update_user_role_by_id( - self, db: Session, id: str, role: str + self, id: str, role: str ) -> Optional[UserModel]: - try: - db.query(User).filter_by(id=id).update({"role": role}) - db.commit() - - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None - - def update_user_profile_image_url_by_id( - self, db: Session, id: str, profile_image_url: str - ) -> Optional[UserModel]: - try: - db.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - db.commit() - - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None - - def update_user_last_active_by_id( - self, db: Session, id: str - ) -> Optional[UserModel]: - try: - db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())}) - - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None - - def update_user_oauth_sub_by_id( - self, db: Session, id: str, oauth_sub: str - ) -> Optional[UserModel]: - try: - db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None - - def update_user_by_id( - self, db: Session, id: str, updated: dict - ) -> Optional[UserModel]: - try: - db.query(User).filter_by(id=id).update(updated) - db.commit() - - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - # return UserModel(**user.dict()) - except Exception as e: - return None - - def delete_user_by_id(self, db: Session, id: str) -> bool: - try: - # Delete User Chats - result = Chats.delete_chats_by_user_id(db, id) - - if result: - # Delete User - db.query(User).filter_by(id=id).delete() + with get_session() as db: + try: + db.query(User).filter_by(id=id).update({"role": role}) db.commit() - return True - else: + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None + + def update_user_profile_image_url_by_id( + self, id: str, profile_image_url: str + ) -> Optional[UserModel]: + with get_session() as db: + try: + db.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} + ) + db.commit() + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None + + def update_user_last_active_by_id( + self, id: str + ) -> Optional[UserModel]: + with get_session() as db: + try: + db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())}) + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None + + def update_user_oauth_sub_by_id( + self, id: str, oauth_sub: str + ) -> Optional[UserModel]: + with get_session() as db: + try: + db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None + + def update_user_by_id( + self, id: str, updated: dict + ) -> Optional[UserModel]: + with get_session() as db: + try: + db.query(User).filter_by(id=id).update(updated) + db.commit() + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) + except Exception as e: + return None + + def delete_user_by_id(self, id: str) -> bool: + with get_session() as db: + try: + # Delete User Chats + result = Chats.delete_chats_by_user_id(id) + + if result: + # Delete User + db.query(User).filter_by(id=id).delete() + db.commit() + + return True + else: + return False + except: return False - except: - return False - def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str: - try: - result = db.query(User).filter_by(id=id).update({"api_key": api_key}) - db.commit() - return True if result == 1 else False - except: - return False + def update_user_api_key_by_id(self, id: str, api_key: str) -> str: + with get_session() as db: + try: + result = db.query(User).filter_by(id=id).update({"api_key": api_key}) + db.commit() + return True if result == 1 else False + except: + return False - def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]: - try: - user = db.query(User).filter_by(id=id).first() - return user.api_key - except Exception as e: - return None + def get_user_api_key_by_id(self, id: str) -> Optional[str]: + with get_session() as db: + try: + user = db.query(User).filter_by(id=id).first() + return user.api_key + except Exception as e: + return None Users = UsersTable() diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index e83ee8cb9..f32b074b1 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -10,7 +10,6 @@ import re import uuid import csv -from apps.webui.internal.db import get_db from apps.webui.models.auths import ( SigninForm, SignupForm, @@ -80,12 +79,10 @@ async def get_session_user( @router.post("/update/profile", response_model=UserResponse) async def update_profile( form_data: UpdateProfileForm, - session_user=Depends(get_current_user), - db=Depends(get_db), + session_user=Depends(get_current_user) ): if session_user: user = Users.update_user_by_id( - db, session_user.id, {"profile_image_url": form_data.profile_image_url, "name": form_data.name}, ) @@ -105,17 +102,16 @@ async def update_profile( @router.post("/update/password", response_model=bool) async def update_password( form_data: UpdatePasswordForm, - session_user=Depends(get_current_user), - db=Depends(get_db), + session_user=Depends(get_current_user) ): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) if session_user: - user = Auths.authenticate_user(db, session_user.email, form_data.password) + user = Auths.authenticate_user(session_user.email, form_data.password) if user: hashed = get_password_hash(form_data.new_password) - return Auths.update_user_password_by_id(db, user.id, hashed) + return Auths.update_user_password_by_id(user.id, hashed) else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) else: @@ -128,7 +124,7 @@ async def update_password( @router.post("/signin", response_model=SigninResponse) -async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)): +async def signin(request: Request, response: Response, form_data: SigninForm): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) @@ -139,34 +135,32 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db trusted_name = request.headers.get( WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email ) - if not Users.get_user_by_email(db, trusted_email.lower()): + if not Users.get_user_by_email(trusted_email.lower()): await signup( request, SignupForm( email=trusted_email, password=str(uuid.uuid4()), name=trusted_name ), - db, ) - user = Auths.authenticate_user_by_trusted_header(db, trusted_email) + user = Auths.authenticate_user_by_trusted_header(trusted_email) elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin" - if Users.get_user_by_email(db, admin_email.lower()): - user = Auths.authenticate_user(db, admin_email.lower(), admin_password) + if Users.get_user_by_email(admin_email.lower()): + user = Auths.authenticate_user(admin_email.lower(), admin_password) else: - if Users.get_num_users(db) != 0: + if Users.get_num_users() != 0: raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) await signup( request, SignupForm(email=admin_email, password=admin_password, name="User"), - db, ) - user = Auths.authenticate_user(db, admin_email.lower(), admin_password) + user = Auths.authenticate_user(admin_email.lower(), admin_password) else: - user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password) + user = Auths.authenticate_user(form_data.email.lower(), form_data.password) if user: token = create_token( @@ -200,7 +194,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db @router.post("/signup", response_model=SigninResponse) -async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)): +async def signup(request: Request, response: Response, form_data: SignupForm): if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED @@ -211,18 +205,17 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(db, form_data.email.lower()): + if Users.get_user_by_email(form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: role = ( "admin" - if Users.get_num_users(db) == 0 + if Users.get_num_users() == 0 else request.app.state.config.DEFAULT_USER_ROLE ) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( - db, form_data.email.lower(), hashed, form_data.name, @@ -277,7 +270,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db @router.post("/add", response_model=SigninResponse) async def add_user( - form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db) + form_data: AddUserForm, user=Depends(get_admin_user) ): if not validate_email_format(form_data.email.lower()): @@ -285,7 +278,7 @@ async def add_user( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(db, form_data.email.lower()): + if Users.get_user_by_email(form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -293,7 +286,6 @@ async def add_user( print(form_data) hashed = get_password_hash(form_data.password) user = Auths.insert_new_auth( - db, form_data.email.lower(), hashed, form_data.name, @@ -325,7 +317,7 @@ async def add_user( @router.get("/admin/details") async def get_admin_details( - request: Request, user=Depends(get_current_user), db=Depends(get_db) + request: Request, user=Depends(get_current_user) ): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL @@ -334,11 +326,11 @@ async def get_admin_details( print(admin_email, admin_name) if admin_email: - admin = Users.get_user_by_email(db, admin_email) + admin = Users.get_user_by_email(admin_email) if admin: admin_name = admin.name else: - admin = Users.get_first_user(db) + admin = Users.get_first_user() if admin: admin_email = admin.email admin_name = admin.name @@ -411,9 +403,9 @@ async def update_admin_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)): +async def create_api_key_(user=Depends(get_current_user)): api_key = create_api_key() - success = Users.update_user_api_key_by_id(db, user.id, api_key) + success = Users.update_user_api_key_by_id(user.id, api_key) if success: return { "api_key": api_key, @@ -424,15 +416,15 @@ async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)): # delete api key @router.delete("/api_key", response_model=bool) -async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)): - success = Users.update_user_api_key_by_id(db, user.id, None) +async def delete_api_key(user=Depends(get_current_user)): + success = Users.update_user_api_key_by_id(user.id, None) return success # get api key @router.get("/api_key", response_model=ApiKey) -async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)): - api_key = Users.get_user_api_key_by_id(db, user.id) +async def get_api_key(user=Depends(get_current_user)): + api_key = Users.get_user_api_key_by_id(user.id) if api_key: return { "api_key": api_key, diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 1454d47bd..8b2b9987a 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -2,7 +2,6 @@ from fastapi import Depends, Request, HTTPException, status from datetime import datetime, timedelta from typing import List, Union, Optional -from apps.webui.internal.db import get_db from utils.utils import get_current_user, get_admin_user from fastapi import APIRouter from pydantic import BaseModel @@ -45,9 +44,9 @@ router = APIRouter() @router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/list", response_model=List[ChatTitleIdResponse]) async def get_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) + user=Depends(get_current_user), skip: int = 0, limit: int = 50 ): - return Chats.get_chat_list_by_user_id(db, user.id, skip, limit) + return Chats.get_chat_list_by_user_id(user.id, skip, limit) ############################ @@ -57,7 +56,7 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) async def delete_all_user_chats( - request: Request, user=Depends(get_current_user), db=Depends(get_db) + request: Request, user=Depends(get_current_user) ): if ( @@ -69,7 +68,7 @@ async def delete_all_user_chats( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Chats.delete_chats_by_user_id(db, user.id) + result = Chats.delete_chats_by_user_id(user.id) return result @@ -84,10 +83,9 @@ async def get_user_chat_list_by_user_id( user=Depends(get_admin_user), skip: int = 0, limit: int = 50, - db=Depends(get_db), ): return Chats.get_chat_list_by_user_id( - db, user_id, include_archived=True, skip=skip, limit=limit + user_id, include_archived=True, skip=skip, limit=limit ) @@ -98,10 +96,10 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) async def create_new_chat( - form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) + form_data: ChatForm, user=Depends(get_current_user) ): try: - chat = Chats.insert_new_chat(db, user.id, form_data) + chat = Chats.insert_new_chat(user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) except Exception as e: log.exception(e) @@ -116,10 +114,10 @@ async def create_new_chat( @router.get("/all", response_model=List[ChatResponse]) -async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)): +async def get_user_chats(user=Depends(get_current_user)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_chats_by_user_id(db, user.id) + for chat in Chats.get_chats_by_user_id(user.id) ] @@ -129,10 +127,10 @@ async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)): @router.get("/all/archived", response_model=List[ChatResponse]) -async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)): +async def get_user_archived_chats(user=Depends(get_current_user)): return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_archived_chats_by_user_id(db, user.id) + for chat in Chats.get_archived_chats_by_user_id(user.id) ] @@ -142,7 +140,7 @@ async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get @router.get("/all/db", response_model=List[ChatResponse]) -async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)): +async def get_all_user_chats_in_db(user=Depends(get_admin_user)): if not ENABLE_ADMIN_EXPORT: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -150,7 +148,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_ ) return [ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) - for chat in Chats.get_chats(db) + for chat in Chats.get_chats() ] @@ -161,9 +159,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_ @router.get("/archived", response_model=List[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( - user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) + user=Depends(get_current_user), skip: int = 0, limit: int = 50 ): - return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit) + return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) ############################ @@ -172,8 +170,8 @@ async def get_archived_session_user_chat_list( @router.post("/archive/all", response_model=bool) -async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)): - return Chats.archive_all_chats_by_user_id(db, user.id) +async def archive_all_chats(user=Depends(get_current_user)): + return Chats.archive_all_chats_by_user_id(user.id) ############################ @@ -183,7 +181,7 @@ async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)): @router.get("/share/{share_id}", response_model=Optional[ChatResponse]) async def get_shared_chat_by_id( - share_id: str, user=Depends(get_current_user), db=Depends(get_db) + share_id: str, user=Depends(get_current_user) ): if user.role == "pending": raise HTTPException( @@ -191,9 +189,9 @@ async def get_shared_chat_by_id( ) if user.role == "user": - chat = Chats.get_chat_by_share_id(db, share_id) + chat = Chats.get_chat_by_share_id(share_id) elif user.role == "admin": - chat = Chats.get_chat_by_id(db, share_id) + chat = Chats.get_chat_by_id(share_id) if chat: return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -216,23 +214,23 @@ class TagNameForm(BaseModel): @router.post("/tags", response_model=List[ChatTitleIdResponse]) async def get_user_chat_list_by_tag_name( - form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db) + form_data: TagNameForm, user=Depends(get_current_user) ): print(form_data) chat_ids = [ chat_id_tag.chat_id for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( - db, form_data.name, user.id + form_data.name, user.id ) ] chats = Chats.get_chat_list_by_chat_ids( - db, chat_ids, form_data.skip, form_data.limit + chat_ids, form_data.skip, form_data.limit ) if len(chats) == 0: - Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id) + Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) return chats @@ -243,9 +241,9 @@ async def get_user_chat_list_by_tag_name( @router.get("/tags/all", response_model=List[TagModel]) -async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)): +async def get_all_tags(user=Depends(get_current_user)): try: - tags = Tags.get_tags_by_user_id(db, user.id) + tags = Tags.get_tags_by_user_id(user.id) return tags except Exception as e: log.exception(e) @@ -260,8 +258,8 @@ async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)): @router.get("/{id}", response_model=Optional[ChatResponse]) -async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): - chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) +async def get_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -278,13 +276,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get @router.post("/{id}", response_model=Optional[ChatResponse]) async def update_chat_by_id( - id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) + id: str, form_data: ChatForm, user=Depends(get_current_user) ): - chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: updated_chat = {**json.loads(chat.chat), **form_data.chat} - chat = Chats.update_chat_by_id(db, id, updated_chat) + chat = Chats.update_chat_by_id(id, updated_chat) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( @@ -300,11 +298,11 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) async def delete_chat_by_id( - request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db) + request: Request, id: str, user=Depends(get_current_user) ): if user.role == "admin": - result = Chats.delete_chat_by_id(db, id) + result = Chats.delete_chat_by_id(id) return result else: if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]: @@ -313,7 +311,7 @@ async def delete_chat_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - result = Chats.delete_chat_by_id_and_user_id(db, id, user.id) + result = Chats.delete_chat_by_id_and_user_id(id, user.id) return result @@ -323,8 +321,8 @@ async def delete_chat_by_id( @router.get("/{id}/clone", response_model=Optional[ChatResponse]) -async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): - chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) +async def clone_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat_body = json.loads(chat.chat) @@ -335,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g "title": f"Clone of {chat.title}", } - chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat})) + chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( @@ -350,11 +348,11 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g @router.get("/{id}/archive", response_model=Optional[ChatResponse]) async def archive_chat_by_id( - id: str, user=Depends(get_current_user), db=Depends(get_db) + id: str, user=Depends(get_current_user) ): - chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: - chat = Chats.toggle_chat_archive_by_id(db, id) + chat = Chats.toggle_chat_archive_by_id(id) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) else: raise HTTPException( @@ -368,16 +366,16 @@ async def archive_chat_by_id( @router.post("/{id}/share", response_model=Optional[ChatResponse]) -async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): - chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) +async def share_chat_by_id(id: str, user=Depends(get_current_user)): + chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if chat.share_id: - shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id) + shared_chat = Chats.update_shared_chat_by_chat_id(chat.id) return ChatResponse( **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} ) - shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id) + shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id) if not shared_chat: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -401,15 +399,15 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g @router.delete("/{id}/share", response_model=Optional[bool]) async def delete_shared_chat_by_id( - id: str, user=Depends(get_current_user), db=Depends(get_db) + id: str, user=Depends(get_current_user) ): - chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) + chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if not chat.share_id: return False - result = Chats.delete_shared_chat_by_chat_id(db, id) - update_result = Chats.update_chat_share_id_by_id(db, id, None) + result = Chats.delete_shared_chat_by_chat_id(id) + update_result = Chats.update_chat_share_id_by_id(id, None) return result and update_result != None else: @@ -426,9 +424,9 @@ async def delete_shared_chat_by_id( @router.get("/{id}/tags", response_model=List[TagModel]) async def get_chat_tags_by_id( - id: str, user=Depends(get_current_user), db=Depends(get_db) + id: str, user=Depends(get_current_user) ): - tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) + tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) if tags != None: return tags @@ -447,13 +445,12 @@ async def get_chat_tags_by_id( async def add_chat_tag_by_id( id: str, form_data: ChatIdTagForm, - user=Depends(get_current_user), - db=Depends(get_db), + user=Depends(get_current_user) ): - tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) + tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) if form_data.tag_name not in tags: - tag = Tags.add_tag_to_chat(db, user.id, form_data) + tag = Tags.add_tag_to_chat(user.id, form_data) if tag: return tag @@ -478,10 +475,9 @@ async def delete_chat_tag_by_id( id: str, form_data: ChatIdTagForm, user=Depends(get_current_user), - db=Depends(get_db), ): result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( - db, form_data.tag_name, id, user.id + form_data.tag_name, id, user.id ) if result: @@ -499,9 +495,9 @@ async def delete_chat_tag_by_id( @router.delete("/{id}/tags/all", response_model=Optional[bool]) async def delete_all_chat_tags_by_id( - id: str, user=Depends(get_current_user), db=Depends(get_db) + id: str, user=Depends(get_current_user) ): - result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id) + result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) if result: return result diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index b9a42352a..f358e033c 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -6,7 +6,6 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.webui.internal.db import get_db from apps.webui.models.documents import ( Documents, DocumentForm, @@ -26,7 +25,7 @@ router = APIRouter() @router.get("/", response_model=List[DocumentResponse]) -async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): +async def get_documents(user=Depends(get_current_user)): docs = [ DocumentResponse( **{ @@ -34,7 +33,7 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): "content": json.loads(doc.content if doc.content else "{}"), } ) - for doc in Documents.get_docs(db) + for doc in Documents.get_docs() ] return docs @@ -46,11 +45,11 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): @router.post("/create", response_model=Optional[DocumentResponse]) async def create_new_doc( - form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db) + form_data: DocumentForm, user=Depends(get_admin_user) ): - doc = Documents.get_doc_by_name(db, form_data.name) + doc = Documents.get_doc_by_name(form_data.name) if doc == None: - doc = Documents.insert_new_doc(db, user.id, form_data) + doc = Documents.insert_new_doc(user.id, form_data) if doc: return DocumentResponse( @@ -78,9 +77,9 @@ async def create_new_doc( @router.get("/doc", response_model=Optional[DocumentResponse]) async def get_doc_by_name( - name: str, user=Depends(get_current_user), db=Depends(get_db) + name: str, user=Depends(get_current_user) ): - doc = Documents.get_doc_by_name(db, name) + doc = Documents.get_doc_by_name(name) if doc: return DocumentResponse( @@ -112,10 +111,10 @@ class TagDocumentForm(BaseModel): @router.post("/doc/tags", response_model=Optional[DocumentResponse]) async def tag_doc_by_name( - form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db) + form_data: TagDocumentForm, user=Depends(get_current_user) ): doc = Documents.update_doc_content_by_name( - db, form_data.name, {"tags": form_data.tags} + form_data.name, {"tags": form_data.tags} ) if doc: @@ -142,9 +141,8 @@ async def update_doc_by_name( name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user), - db=Depends(get_db), ): - doc = Documents.update_doc_by_name(db, name, form_data) + doc = Documents.update_doc_by_name(name, form_data) if doc: return DocumentResponse( **{ @@ -166,7 +164,7 @@ async def update_doc_by_name( @router.delete("/doc/delete", response_model=bool) async def delete_doc_by_name( - name: str, user=Depends(get_admin_user), db=Depends(get_db) + name: str, user=Depends(get_admin_user) ): - result = Documents.delete_doc_by_name(db, name) + result = Documents.delete_doc_by_name(name) return result diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index 2ed119ad0..e98d1da58 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -20,7 +20,6 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from pydantic import BaseModel import json -from apps.webui.internal.db import get_db from apps.webui.models.files import ( Files, FileForm, @@ -53,8 +52,7 @@ router = APIRouter() @router.post("/") def upload_file( file: UploadFile = File(...), - user=Depends(get_verified_user), - db=Depends(get_db) + user=Depends(get_verified_user) ): log.info(f"file.content_type: {file.content_type}") try: @@ -72,7 +70,6 @@ def upload_file( f.close() file = Files.insert_new_file( - db, user.id, FileForm( **{ @@ -109,8 +106,8 @@ def upload_file( @router.get("/", response_model=List[FileModel]) -async def list_files(user=Depends(get_verified_user), db=Depends(get_db)): - files = Files.get_files(db) +async def list_files(user=Depends(get_verified_user)): + files = Files.get_files() return files @@ -120,8 +117,8 @@ async def list_files(user=Depends(get_verified_user), db=Depends(get_db)): @router.delete("/all") -async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)): - result = Files.delete_all_files(db) +async def delete_all_files(user=Depends(get_admin_user)): + result = Files.delete_all_files() if result: folder = f"{UPLOAD_DIR}" @@ -157,8 +154,8 @@ async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)): @router.get("/{id}", response_model=Optional[FileModel]) -async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): - file = Files.get_file_by_id(db, id) +async def get_file_by_id(id: str, user=Depends(get_verified_user)): + file = Files.get_file_by_id(id) if file: return file @@ -175,8 +172,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(ge @router.get("/{id}/content", response_model=Optional[FileModel]) -async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): - file = Files.get_file_by_id(db, id) +async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): + file = Files.get_file_by_id(id) if file: file_path = Path(file.meta["path"]) @@ -226,11 +223,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.delete("/{id}") -async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): - file = Files.get_file_by_id(db, id) +async def delete_file_by_id(id: str, user=Depends(get_verified_user)): + file = Files.get_file_by_id(id) if file: - result = Files.delete_file_by_id(db, id) + result = Files.delete_file_by_id(id) if result: return {"message": "File deleted successfully"} else: diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index f15566702..4c89ca487 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -6,7 +6,6 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.webui.internal.db import get_db from apps.webui.models.functions import ( Functions, FunctionForm, @@ -32,8 +31,8 @@ router = APIRouter() @router.get("/", response_model=List[FunctionResponse]) -async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)): - return Functions.get_functions(db) +async def get_functions(user=Depends(get_verified_user)): + return Functions.get_functions() ############################ @@ -42,8 +41,8 @@ async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)): @router.get("/export", response_model=List[FunctionModel]) -async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)): - return Functions.get_functions(db) +async def get_functions(user=Depends(get_admin_user)): + return Functions.get_functions() ############################ @@ -53,7 +52,7 @@ async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)): @router.post("/create", response_model=Optional[FunctionResponse]) async def create_new_function( - request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) + request: Request, form_data: FunctionForm, user=Depends(get_admin_user) ): if not form_data.id.isidentifier(): raise HTTPException( @@ -63,7 +62,7 @@ async def create_new_function( form_data.id = form_data.id.lower() - function = Functions.get_function_by_id(db, form_data.id) + function = Functions.get_function_by_id(form_data.id) if function == None: function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") try: @@ -78,7 +77,7 @@ async def create_new_function( FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function(db, user.id, function_type, form_data) + function = Functions.insert_new_function(user.id, function_type, form_data) function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) @@ -109,8 +108,8 @@ async def create_new_function( @router.get("/id/{id}", response_model=Optional[FunctionModel]) -async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): - function = Functions.get_function_by_id(db, id) +async def get_function_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) if function: return function @@ -155,7 +154,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): @router.post("/id/{id}/update", response_model=Optional[FunctionModel]) async def update_function_by_id( - request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) + request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) ): function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") @@ -172,7 +171,7 @@ async def update_function_by_id( updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} print(updated) - function = Functions.update_function_by_id(db, id, updated) + function = Functions.update_function_by_id(id, updated) if function: return function @@ -196,9 +195,9 @@ async def update_function_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_function_by_id( - request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) + request: Request, id: str, user=Depends(get_admin_user) ): - result = Functions.delete_function_by_id(db, id) + result = Functions.delete_function_by_id(id) if result: FUNCTIONS = request.app.state.FUNCTIONS diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index e7fafa37b..d6b2d0fcb 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -7,7 +7,6 @@ from fastapi import APIRouter from pydantic import BaseModel import logging -from apps.webui.internal.db import get_db from apps.webui.models.memories import Memories, MemoryModel from utils.utils import get_verified_user @@ -32,8 +31,8 @@ async def get_embeddings(request: Request): @router.get("/", response_model=List[MemoryModel]) -async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)): - return Memories.get_memories_by_user_id(db, user.id) +async def get_memories(user=Depends(get_verified_user)): + return Memories.get_memories_by_user_id(user.id) ############################ @@ -54,9 +53,8 @@ async def add_memory( request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user), - db=Depends(get_db), ): - memory = Memories.insert_new_memory(db, user.id, form_data.content) + memory = Memories.insert_new_memory(user.id, form_data.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") @@ -76,9 +74,8 @@ async def update_memory_by_id( request: Request, form_data: MemoryUpdateModel, user=Depends(get_verified_user), - db=Depends(get_db), ): - memory = Memories.update_memory_by_id(db, memory_id, form_data.content) + memory = Memories.update_memory_by_id(memory_id, form_data.content) if memory is None: raise HTTPException(status_code=404, detail="Memory not found") @@ -129,12 +126,12 @@ async def query_memory( ############################ @router.get("/reset", response_model=bool) async def reset_memory_from_vector_db( - request: Request, user=Depends(get_verified_user), db=Depends(get_db) + request: Request, user=Depends(get_verified_user) ): CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") - memories = Memories.get_memories_by_user_id(db, user.id) + memories = Memories.get_memories_by_user_id(user.id) for memory in memories: memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) collection.upsert( @@ -151,8 +148,8 @@ async def reset_memory_from_vector_db( @router.delete("/user", response_model=bool) -async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)): - result = Memories.delete_memories_by_user_id(db, user.id) +async def delete_memory_by_user_id(user=Depends(get_verified_user)): + result = Memories.delete_memories_by_user_id(user.id) if result: try: @@ -171,9 +168,9 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(g @router.delete("/{memory_id}", response_model=bool) async def delete_memory_by_id( - memory_id: str, user=Depends(get_verified_user), db=Depends(get_db) + memory_id: str, user=Depends(get_verified_user) ): - result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id) + result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: collection = CHROMA_CLIENT.get_or_create_collection( diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py index f151e8864..eaf459d73 100644 --- a/backend/apps/webui/routers/models.py +++ b/backend/apps/webui/routers/models.py @@ -6,7 +6,6 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.webui.internal.db import get_db from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse from utils.utils import get_verified_user, get_admin_user @@ -20,8 +19,8 @@ router = APIRouter() @router.get("/", response_model=List[ModelResponse]) -async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): - return Models.get_all_models(db) +async def get_models(user=Depends(get_verified_user)): + return Models.get_all_models() ############################ @@ -34,7 +33,6 @@ async def add_new_model( request: Request, form_data: ModelForm, user=Depends(get_admin_user), - db=Depends(get_db), ): if form_data.id in request.app.state.MODELS: raise HTTPException( @@ -42,7 +40,7 @@ async def add_new_model( detail=ERROR_MESSAGES.MODEL_ID_TAKEN, ) else: - model = Models.insert_new_model(db, form_data, user.id) + model = Models.insert_new_model(form_data, user.id) if model: return model @@ -59,8 +57,8 @@ async def add_new_model( @router.get("/{id}", response_model=Optional[ModelModel]) -async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): - model = Models.get_model_by_id(db, id) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) if model: return model @@ -82,15 +80,14 @@ async def update_model_by_id( id: str, form_data: ModelForm, user=Depends(get_admin_user), - db=Depends(get_db), ): - model = Models.get_model_by_id(db, id) + model = Models.get_model_by_id(id) if model: - model = Models.update_model_by_id(db, id, form_data) + model = Models.update_model_by_id(id, form_data) return model else: if form_data.id in request.app.state.MODELS: - model = Models.insert_new_model(db, form_data, user.id) + model = Models.insert_new_model(form_data, user.id) if model: return model else: @@ -111,6 +108,6 @@ async def update_model_by_id( @router.delete("/delete", response_model=bool) -async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): - result = Models.delete_model_by_id(db, id) +async def delete_model_by_id(id: str, user=Depends(get_admin_user)): + result = Models.delete_model_by_id(id) return result diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py index c8f173a1e..3912b1028 100644 --- a/backend/apps/webui/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -6,7 +6,6 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.webui.internal.db import get_db from apps.webui.models.prompts import Prompts, PromptForm, PromptModel from utils.utils import get_current_user, get_admin_user @@ -20,8 +19,8 @@ router = APIRouter() @router.get("/", response_model=List[PromptModel]) -async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)): - return Prompts.get_prompts(db) +async def get_prompts(user=Depends(get_current_user)): + return Prompts.get_prompts() ############################ @@ -31,11 +30,11 @@ async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)): @router.post("/create", response_model=Optional[PromptModel]) async def create_new_prompt( - form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db) + form_data: PromptForm, user=Depends(get_admin_user) ): - prompt = Prompts.get_prompt_by_command(db, form_data.command) + prompt = Prompts.get_prompt_by_command(form_data.command) if prompt == None: - prompt = Prompts.insert_new_prompt(db, user.id, form_data) + prompt = Prompts.insert_new_prompt(user.id, form_data) if prompt: return prompt @@ -56,9 +55,9 @@ async def create_new_prompt( @router.get("/command/{command}", response_model=Optional[PromptModel]) async def get_prompt_by_command( - command: str, user=Depends(get_current_user), db=Depends(get_db) + command: str, user=Depends(get_current_user) ): - prompt = Prompts.get_prompt_by_command(db, f"/{command}") + prompt = Prompts.get_prompt_by_command(f"/{command}") if prompt: return prompt @@ -79,9 +78,8 @@ async def update_prompt_by_command( command: str, form_data: PromptForm, user=Depends(get_admin_user), - db=Depends(get_db), ): - prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data) + prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) if prompt: return prompt else: @@ -98,7 +96,7 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) async def delete_prompt_by_command( - command: str, user=Depends(get_admin_user), db=Depends(get_db) + command: str, user=Depends(get_admin_user) ): - result = Prompts.delete_prompt_by_command(db, f"/{command}") + result = Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 4eb6d1caf..82a09477d 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -6,7 +6,6 @@ from fastapi import APIRouter from pydantic import BaseModel import json -from apps.webui.internal.db import get_db from apps.webui.models.users import Users from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.utils import load_toolkit_module_by_id @@ -34,7 +33,7 @@ router = APIRouter() @router.get("/", response_model=List[ToolResponse]) -async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)): +async def get_toolkits(user=Depends(get_verified_user)): toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits @@ -45,8 +44,8 @@ async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)): @router.get("/export", response_model=List[ToolModel]) -async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)): - toolkits = [toolkit for toolkit in Tools.get_tools(db)] +async def get_toolkits(user=Depends(get_admin_user)): + toolkits = [toolkit for toolkit in Tools.get_tools()] return toolkits @@ -60,7 +59,6 @@ async def create_new_toolkit( request: Request, form_data: ToolForm, user=Depends(get_admin_user), - db=Depends(get_db), ): if not form_data.id.isidentifier(): raise HTTPException( @@ -70,7 +68,7 @@ async def create_new_toolkit( form_data.id = form_data.id.lower() - toolkit = Tools.get_tool_by_id(db, form_data.id) + toolkit = Tools.get_tool_by_id(form_data.id) if toolkit == None: toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") try: @@ -84,7 +82,7 @@ async def create_new_toolkit( TOOLS[form_data.id] = toolkit_module specs = get_tools_specs(TOOLS[form_data.id]) - toolkit = Tools.insert_new_tool(db, user.id, form_data, specs) + toolkit = Tools.insert_new_tool(user.id, form_data, specs) tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir.mkdir(parents=True, exist_ok=True) @@ -115,8 +113,8 @@ async def create_new_toolkit( @router.get("/id/{id}", response_model=Optional[ToolModel]) -async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): - toolkit = Tools.get_tool_by_id(db, id) +async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): + toolkit = Tools.get_tool_by_id(id) if toolkit: return toolkit @@ -138,7 +136,6 @@ async def update_toolkit_by_id( id: str, form_data: ToolForm, user=Depends(get_admin_user), - db=Depends(get_db), ): toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") @@ -160,7 +157,7 @@ async def update_toolkit_by_id( } print(updated) - toolkit = Tools.update_tool_by_id(db, id, updated) + toolkit = Tools.update_tool_by_id(id, updated) if toolkit: return toolkit @@ -184,9 +181,9 @@ async def update_toolkit_by_id( @router.delete("/id/{id}/delete", response_model=bool) async def delete_toolkit_by_id( - request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) + request: Request, id: str, user=Depends(get_admin_user) ): - result = Tools.delete_tool_by_id(db, id) + result = Tools.delete_tool_by_id(id) if result: TOOLS = request.app.state.TOOLS diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index 46a418fc1..8a38d5b9f 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -9,7 +9,6 @@ import time import uuid import logging -from apps.webui.internal.db import get_db from apps.webui.models.users import ( UserModel, UserUpdateForm, @@ -42,9 +41,9 @@ router = APIRouter() @router.get("/", response_model=List[UserModel]) async def get_users( - skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db) + skip: int = 0, limit: int = 50, user=Depends(get_admin_user) ): - return Users.get_users(db, skip, limit) + return Users.get_users(skip, limit) ############################ @@ -72,11 +71,11 @@ async def update_user_permissions( @router.post("/update/role", response_model=Optional[UserModel]) async def update_user_role( - form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db) + form_data: UserRoleUpdateForm, user=Depends(get_admin_user) ): - if user.id != form_data.id and form_data.id != Users.get_first_user(db).id: - return Users.update_user_role_by_id(db, form_data.id, form_data.role) + if user.id != form_data.id and form_data.id != Users.get_first_user().id: + return Users.update_user_role_by_id(form_data.id, form_data.role) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -91,9 +90,9 @@ async def update_user_role( @router.get("/user/settings", response_model=Optional[UserSettings]) async def get_user_settings_by_session_user( - user=Depends(get_verified_user), db=Depends(get_db) + user=Depends(get_verified_user) ): - user = Users.get_user_by_id(db, user.id) + user = Users.get_user_by_id(user.id) if user: return user.settings else: @@ -110,9 +109,9 @@ async def get_user_settings_by_session_user( @router.post("/user/settings/update", response_model=UserSettings) async def update_user_settings_by_session_user( - form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db) + form_data: UserSettings, user=Depends(get_verified_user) ): - user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()}) + user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()}) if user: return user.settings else: @@ -129,9 +128,9 @@ async def update_user_settings_by_session_user( @router.get("/user/info", response_model=Optional[dict]) async def get_user_info_by_session_user( - user=Depends(get_verified_user), db=Depends(get_db) + user=Depends(get_verified_user) ): - user = Users.get_user_by_id(db, user.id) + user = Users.get_user_by_id(user.id) if user: return user.info else: @@ -148,15 +147,15 @@ async def get_user_info_by_session_user( @router.post("/user/info/update", response_model=Optional[dict]) async def update_user_info_by_session_user( - form_data: dict, user=Depends(get_verified_user), db=Depends(get_db) + form_data: dict, user=Depends(get_verified_user) ): - user = Users.get_user_by_id(db, user.id) + user = Users.get_user_by_id(user.id) if user: if user.info is None: user.info = {} user = Users.update_user_by_id( - db, user.id, {"info": {**user.info, **form_data}} + user.id, {"info": {**user.info, **form_data}} ) if user: return user.info @@ -184,14 +183,14 @@ class UserResponse(BaseModel): @router.get("/{user_id}", response_model=UserResponse) async def get_user_by_id( - user_id: str, user=Depends(get_verified_user), db=Depends(get_db) + user_id: str, user=Depends(get_verified_user) ): # Check if user_id is a shared chat # If it is, get the user_id from the chat if user_id.startswith("shared-"): chat_id = user_id.replace("shared-", "") - chat = Chats.get_chat_by_id(db, chat_id) + chat = Chats.get_chat_by_id(chat_id) if chat: user_id = chat.user_id else: @@ -200,7 +199,7 @@ async def get_user_by_id( detail=ERROR_MESSAGES.USER_NOT_FOUND, ) - user = Users.get_user_by_id(db, user_id) + user = Users.get_user_by_id(user_id) if user: return UserResponse(name=user.name, profile_image_url=user.profile_image_url) @@ -221,13 +220,12 @@ async def update_user_by_id( user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user), - db=Depends(get_db), ): - user = Users.get_user_by_id(db, user_id) + user = Users.get_user_by_id(user_id) if user: if form_data.email.lower() != user.email: - email_user = Users.get_user_by_email(db, form_data.email.lower()) + email_user = Users.get_user_by_email(form_data.email.lower()) if email_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -237,11 +235,10 @@ async def update_user_by_id( if form_data.password: hashed = get_password_hash(form_data.password) log.debug(f"hashed: {hashed}") - Auths.update_user_password_by_id(db, user_id, hashed) + Auths.update_user_password_by_id(user_id, hashed) - Auths.update_email_by_id(db, user_id, form_data.email.lower()) + Auths.update_email_by_id(user_id, form_data.email.lower()) updated_user = Users.update_user_by_id( - db, user_id, { "name": form_data.name, @@ -271,10 +268,10 @@ async def update_user_by_id( @router.delete("/{user_id}", response_model=bool) async def delete_user_by_id( - user_id: str, user=Depends(get_admin_user), db=Depends(get_db) + user_id: str, user=Depends(get_admin_user) ): if user.id != user_id: - result = Auths.delete_auth_by_id(db, user_id) + result = Auths.delete_auth_by_id(user_id) if result: return True diff --git a/backend/main.py b/backend/main.py index d80c6a729..6e44045f2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -57,7 +57,7 @@ from apps.webui.main import ( get_pipe_models, generate_function_chat_completion, ) -from apps.webui.internal.db import get_db, SessionLocal +from apps.webui.internal.db import get_session, SessionLocal from pydantic import BaseModel @@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), - SessionLocal(), ) # Flag to skip RAG completions if file_handler is present in tools/functions skip_files = False @@ -800,9 +799,7 @@ app.add_middleware( @app.middleware("http") async def check_url(request: Request, call_next): if len(app.state.MODELS) == 0: - db = SessionLocal() - await get_all_models(db) - db.commit() + await get_all_models() else: pass @@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app) webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION -async def get_all_models(db: Session): +async def get_all_models(): pipe_models = [] openai_models = [] ollama_models = [] - pipe_models = await get_pipe_models(db) + pipe_models = await get_pipe_models() if app.state.config.ENABLE_OPENAI_API: openai_models = await get_openai_models() @@ -863,7 +860,7 @@ async def get_all_models(db: Session): models = pipe_models + openai_models + ollama_models - custom_models = Models.get_all_models(db) + custom_models = Models.get_all_models() for custom_model in custom_models: if custom_model.base_model_id == None: for model in models: @@ -903,8 +900,8 @@ async def get_all_models(db: Session): @app.get("/api/models") -async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): - models = await get_all_models(db) +async def get_models(user=Depends(get_verified_user)): + models = await get_all_models() # Filter out filter pipelines models = [ @@ -1608,9 +1605,8 @@ async def get_pipeline_valves( urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user), - db=Depends(get_db), ): - models = await get_all_models(db) + models = await get_all_models() r = None try: @@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec( urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user), - db=Depends(get_db), ): - models = await get_all_models(db) + models = await get_all_models() r = None try: @@ -1690,9 +1685,8 @@ async def update_pipeline_valves( pipeline_id: str, form_data: dict, user=Depends(get_admin_user), - db=Depends(get_db), ): - models = await get_all_models(db) + models = await get_all_models() r = None try: @@ -2040,8 +2034,9 @@ async def healthcheck(): @app.get("/health/db") -async def healthcheck_with_db(db: Session = Depends(get_db)): - result = db.execute(text("SELECT 1;")).all() +async def healthcheck_with_db(): + with get_session() as db: + result = db.execute(text("SELECT 1;")).all() return {"status": True} diff --git a/backend/migrations/versions/22b5ab2667b8_init.py b/backend/migrations/versions/22b5ab2667b8_init.py deleted file mode 100644 index af10dc2cf..000000000 --- a/backend/migrations/versions/22b5ab2667b8_init.py +++ /dev/null @@ -1,188 +0,0 @@ -"""init - -Revision ID: 22b5ab2667b8 -Revises: -Create Date: 2024-06-20 13:22:40.397002 - -""" - -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -from sqlalchemy.engine.reflection import Inspector - -import apps.webui.internal.db - - -# revision identifiers, used by Alembic. -revision: str = "22b5ab2667b8" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - con = op.get_bind() - inspector = Inspector.from_engine(con) - tables = set(inspector.get_table_names()) - - # ### commands auto generated by Alembic - please adjust! ### - if not "auth" in tables: - op.create_table( - "auth", - sa.Column("id", sa.String(), nullable=False), - sa.Column("email", sa.String(), nullable=True), - sa.Column("password", sa.String(), nullable=True), - sa.Column("active", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - - if not "chat" in tables: - op.create_table( - "chat", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.String(), nullable=True), - sa.Column("chat", sa.String(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("share_id", sa.String(), nullable=True), - sa.Column("archived", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("share_id"), - ) - - if not "chatidtag" in tables: - op.create_table( - "chatidtag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("tag_name", sa.String(), nullable=True), - sa.Column("chat_id", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - - if not "document" in tables: - op.create_table( - "document", - sa.Column("collection_name", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("title", sa.String(), nullable=True), - sa.Column("filename", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("collection_name"), - sa.UniqueConstraint("name"), - ) - - if not "memory" in tables: - op.create_table( - "memory", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - - if not "model" in tables: - op.create_table( - "model", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("base_model_id", sa.String(), nullable=True), - sa.Column("name", sa.String(), nullable=True), - sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - - if not "prompt" in tables: - op.create_table( - "prompt", - sa.Column("command", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("command"), - ) - - if not "tag" in tables: - op.create_table( - "tag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("data", sa.String(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - - if not "tool" in tables: - op.create_table( - "tool", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - - if not "user" in tables: - op.create_table( - "user", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("email", sa.String(), nullable=True), - sa.Column("role", sa.String(), nullable=True), - sa.Column("profile_image_url", sa.String(), nullable=True), - sa.Column("last_active_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("api_key"), - ) - - if not "file" in tables: - op.create_table('file', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('filename', sa.String(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - - if not "function" in tables: - op.create_table('function', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('type', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - # do nothing as we assume we had previous migrations from peewee-migrate - pass - # ### end Alembic commands ### diff --git a/backend/migrations/versions/ba76b0bae648_init.py b/backend/migrations/versions/ba76b0bae648_init.py new file mode 100644 index 000000000..b1250662f --- /dev/null +++ b/backend/migrations/versions/ba76b0bae648_init.py @@ -0,0 +1,161 @@ +"""init + +Revision ID: ba76b0bae648 +Revises: +Create Date: 2024-06-24 09:09:11.636336 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import apps.webui.internal.db + + +# revision identifiers, used by Alembic. +revision: str = 'ba76b0bae648' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('auth', + sa.Column('id', sa.String(), nullable=False), + sa.Column('email', sa.String(), nullable=True), + sa.Column('password', sa.String(), nullable=True), + sa.Column('active', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('chat', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.String(), nullable=True), + sa.Column('chat', sa.String(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('share_id', sa.String(), nullable=True), + sa.Column('archived', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('share_id') + ) + op.create_table('chatidtag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('tag_name', sa.String(), nullable=True), + sa.Column('chat_id', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('document', + sa.Column('collection_name', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('title', sa.String(), nullable=True), + sa.Column('filename', sa.String(), nullable=True), + sa.Column('content', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('collection_name'), + sa.UniqueConstraint('name') + ) + op.create_table('file', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('filename', sa.String(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('function', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('memory', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('content', sa.String(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('model', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('base_model_id', sa.String(), nullable=True), + sa.Column('name', sa.String(), nullable=True), + sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('prompt', + sa.Column('command', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.String(), nullable=True), + sa.Column('content', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('command') + ) + op.create_table('tag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('data', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('tool', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.String(), nullable=True), + sa.Column('content', sa.String(), nullable=True), + sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('user', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('email', sa.String(), nullable=True), + sa.Column('role', sa.String(), nullable=True), + sa.Column('profile_image_url', sa.String(), nullable=True), + sa.Column('last_active_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('api_key', sa.String(), nullable=True), + sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('api_key') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('user') + op.drop_table('tool') + op.drop_table('tag') + op.drop_table('prompt') + op.drop_table('model') + op.drop_table('memory') + op.drop_table('function') + op.drop_table('file') + op.drop_table('document') + op.drop_table('chatidtag') + op.drop_table('chat') + op.drop_table('auth') + # ### end Alembic commands ### diff --git a/backend/test/apps/webui/routers/test_auths.py b/backend/test/apps/webui/routers/test_auths.py index 3450f57c6..3a8695a69 100644 --- a/backend/test/apps/webui/routers/test_auths.py +++ b/backend/test/apps/webui/routers/test_auths.py @@ -31,7 +31,6 @@ class TestAuths(AbstractPostgresTest): from utils.utils import get_password_hash user = self.auths.insert_new_auth( - self.db_session, email="john.doe@openwebui.com", password=get_password_hash("old_password"), name="John Doe", @@ -45,7 +44,7 @@ class TestAuths(AbstractPostgresTest): json={"name": "John Doe 2", "profile_image_url": "/user2.png"}, ) assert response.status_code == 200 - db_user = self.users.get_user_by_id(self.db_session, user.id) + db_user = self.users.get_user_by_id(user.id) assert db_user.name == "John Doe 2" assert db_user.profile_image_url == "/user2.png" @@ -53,7 +52,6 @@ class TestAuths(AbstractPostgresTest): from utils.utils import get_password_hash user = self.auths.insert_new_auth( - self.db_session, email="john.doe@openwebui.com", password=get_password_hash("old_password"), name="John Doe", @@ -69,11 +67,11 @@ class TestAuths(AbstractPostgresTest): assert response.status_code == 200 old_auth = self.auths.authenticate_user( - self.db_session, "john.doe@openwebui.com", "old_password" + "john.doe@openwebui.com", "old_password" ) assert old_auth is None new_auth = self.auths.authenticate_user( - self.db_session, "john.doe@openwebui.com", "new_password" + "john.doe@openwebui.com", "new_password" ) assert new_auth is not None @@ -81,7 +79,6 @@ class TestAuths(AbstractPostgresTest): from utils.utils import get_password_hash user = self.auths.insert_new_auth( - self.db_session, email="john.doe@openwebui.com", password=get_password_hash("password"), name="John Doe", @@ -144,7 +141,6 @@ class TestAuths(AbstractPostgresTest): def test_get_admin_details(self): self.auths.insert_new_auth( - self.db_session, email="john.doe@openwebui.com", password="password", name="John Doe", @@ -162,7 +158,6 @@ class TestAuths(AbstractPostgresTest): def test_create_api_key_(self): user = self.auths.insert_new_auth( - self.db_session, email="john.doe@openwebui.com", password="password", name="John Doe", @@ -178,31 +173,29 @@ class TestAuths(AbstractPostgresTest): def test_delete_api_key(self): user = self.auths.insert_new_auth( - self.db_session, email="john.doe@openwebui.com", password="password", name="John Doe", profile_image_url="/user.png", role="admin", ) - self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") + self.users.update_user_api_key_by_id(user.id, "abc") with mock_webui_user(id=user.id): response = self.fast_api_client.delete(self.create_url("/api_key")) assert response.status_code == 200 assert response.json() == True - db_user = self.users.get_user_by_id(self.db_session, user.id) + db_user = self.users.get_user_by_id(user.id) assert db_user.api_key is None def test_get_api_key(self): user = self.auths.insert_new_auth( - self.db_session, email="john.doe@openwebui.com", password="password", name="John Doe", profile_image_url="/user.png", role="admin", ) - self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") + self.users.update_user_api_key_by_id(user.id, "abc") with mock_webui_user(id=user.id): response = self.fast_api_client.get(self.create_url("/api_key")) assert response.status_code == 200 diff --git a/backend/test/apps/webui/routers/test_chats.py b/backend/test/apps/webui/routers/test_chats.py index 2d1145c06..ea4518eaf 100644 --- a/backend/test/apps/webui/routers/test_chats.py +++ b/backend/test/apps/webui/routers/test_chats.py @@ -18,7 +18,6 @@ class TestChats(AbstractPostgresTest): self.chats = Chats self.chats.insert_new_chat( - self.db_session, "2", ChatForm( **{ @@ -46,7 +45,7 @@ class TestChats(AbstractPostgresTest): with mock_webui_user(id="2"): response = self.fast_api_client.delete(self.create_url("/")) assert response.status_code == 200 - assert len(self.chats.get_chats(self.db_session)) == 0 + assert len(self.chats.get_chats()) == 0 def test_get_user_chat_list_by_user_id(self): with mock_webui_user(id="3"): @@ -84,14 +83,13 @@ class TestChats(AbstractPostgresTest): assert data["title"] == "New Chat" assert data["updated_at"] is not None assert data["created_at"] is not None - assert len(self.chats.get_chats(self.db_session)) == 2 + assert len(self.chats.get_chats()) == 2 def test_get_user_chats(self): self.test_get_session_user_chat_list() def test_get_user_archived_chats(self): - self.chats.archive_all_chats_by_user_id(self.db_session, "2") - self.db_session.commit() + self.chats.archive_all_chats_by_user_id("2") with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url("/all/archived")) assert response.status_code == 200 @@ -114,12 +112,11 @@ class TestChats(AbstractPostgresTest): with mock_webui_user(id="2"): response = self.fast_api_client.post(self.create_url("/archive/all")) assert response.status_code == 200 - assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1 + assert len(self.chats.get_archived_chats_by_user_id("2")) == 1 def test_get_shared_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id - self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id) - self.db_session.commit() + chat_id = self.chats.get_chats()[0].id + self.chats.update_chat_share_id_by_id(chat_id, chat_id) with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}")) assert response.status_code == 200 @@ -136,7 +133,7 @@ class TestChats(AbstractPostgresTest): assert data["title"] == "New Chat" def test_get_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id + chat_id = self.chats.get_chats()[0].id with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url(f"/{chat_id}")) assert response.status_code == 200 @@ -153,7 +150,7 @@ class TestChats(AbstractPostgresTest): assert data["user_id"] == "2" def test_update_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id + chat_id = self.chats.get_chats()[0].id with mock_webui_user(id="2"): response = self.fast_api_client.post( self.create_url(f"/{chat_id}"), @@ -181,14 +178,14 @@ class TestChats(AbstractPostgresTest): assert data["user_id"] == "2" def test_delete_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id + chat_id = self.chats.get_chats()[0].id with mock_webui_user(id="2"): response = self.fast_api_client.delete(self.create_url(f"/{chat_id}")) assert response.status_code == 200 assert response.json() is True def test_clone_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id + chat_id = self.chats.get_chats()[0].id with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone")) @@ -209,31 +206,30 @@ class TestChats(AbstractPostgresTest): assert data["user_id"] == "2" def test_archive_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id + chat_id = self.chats.get_chats()[0].id with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive")) assert response.status_code == 200 - chat = self.chats.get_chat_by_id(self.db_session, chat_id) + chat = self.chats.get_chat_by_id(chat_id) assert chat.archived is True def test_share_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id + chat_id = self.chats.get_chats()[0].id with mock_webui_user(id="2"): response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share")) assert response.status_code == 200 - chat = self.chats.get_chat_by_id(self.db_session, chat_id) + chat = self.chats.get_chat_by_id(chat_id) assert chat.share_id is not None def test_delete_shared_chat_by_id(self): - chat_id = self.chats.get_chats(self.db_session)[0].id + chat_id = self.chats.get_chats()[0].id share_id = str(uuid.uuid4()) - self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id) - self.db_session.commit() + self.chats.update_chat_share_id_by_id(chat_id, share_id) with mock_webui_user(id="2"): response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share")) assert response.status_code - chat = self.chats.get_chat_by_id(self.db_session, chat_id) + chat = self.chats.get_chat_by_id(chat_id) assert chat.share_id is None diff --git a/backend/test/apps/webui/routers/test_documents.py b/backend/test/apps/webui/routers/test_documents.py index 53ef3d2aa..14ca339fd 100644 --- a/backend/test/apps/webui/routers/test_documents.py +++ b/backend/test/apps/webui/routers/test_documents.py @@ -14,7 +14,7 @@ class TestDocuments(AbstractPostgresTest): def test_documents(self): # Empty database - assert len(self.documents.get_docs(self.db_session)) == 0 + assert len(self.documents.get_docs()) == 0 with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url("/")) assert response.status_code == 200 @@ -34,7 +34,7 @@ class TestDocuments(AbstractPostgresTest): ) assert response.status_code == 200 assert response.json()["name"] == "doc_name" - assert len(self.documents.get_docs(self.db_session)) == 1 + assert len(self.documents.get_docs()) == 1 # Get the document with mock_webui_user(id="2"): @@ -61,7 +61,7 @@ class TestDocuments(AbstractPostgresTest): ) assert response.status_code == 200 assert response.json()["name"] == "doc_name 2" - assert len(self.documents.get_docs(self.db_session)) == 2 + assert len(self.documents.get_docs()) == 2 # Get all documents with mock_webui_user(id="2"): @@ -95,7 +95,7 @@ class TestDocuments(AbstractPostgresTest): assert data["content"] == { "tags": [{"name": "testing-tag"}, {"name": "another-tag"}] } - assert len(self.documents.get_docs(self.db_session)) == 2 + assert len(self.documents.get_docs()) == 2 # Delete the first document with mock_webui_user(id="2"): @@ -103,4 +103,4 @@ class TestDocuments(AbstractPostgresTest): self.create_url("/doc/delete?name=doc_name rework") ) assert response.status_code == 200 - assert len(self.documents.get_docs(self.db_session)) == 1 + assert len(self.documents.get_docs()) == 1 diff --git a/backend/test/apps/webui/routers/test_prompts.py b/backend/test/apps/webui/routers/test_prompts.py index cd2fcec87..9f47be992 100644 --- a/backend/test/apps/webui/routers/test_prompts.py +++ b/backend/test/apps/webui/routers/test_prompts.py @@ -68,6 +68,16 @@ class TestPrompts(AbstractPostgresTest): assert data["content"] == "description Updated" assert data["user_id"] == "3" + # Get prompt by command + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/command/my-command2")) + assert response.status_code == 200 + data = response.json() + assert data["command"] == "/my-command2" + assert data["title"] == "Hello World Updated" + assert data["content"] == "description Updated" + assert data["user_id"] == "3" + # Delete prompt with mock_webui_user(id="2"): response = self.fast_api_client.delete( diff --git a/backend/test/apps/webui/routers/test_users.py b/backend/test/apps/webui/routers/test_users.py index 35b662304..9736b4d32 100644 --- a/backend/test/apps/webui/routers/test_users.py +++ b/backend/test/apps/webui/routers/test_users.py @@ -33,7 +33,6 @@ class TestUsers(AbstractPostgresTest): def setup_method(self): super().setup_method() self.users.insert_new_user( - self.db_session, id="1", name="user 1", email="user1@openwebui.com", @@ -41,7 +40,6 @@ class TestUsers(AbstractPostgresTest): role="user", ) self.users.insert_new_user( - self.db_session, id="2", name="user 2", email="user2@openwebui.com", diff --git a/backend/utils/utils.py b/backend/utils/utils.py index f1225ec0e..6409fc7aa 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -2,7 +2,6 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends, Request from sqlalchemy.orm import Session -from apps.webui.internal.db import get_db from apps.webui.models.users import Users from pydantic import BaseModel @@ -79,7 +78,6 @@ def get_http_authorization_cred(auth_header: str): def get_current_user( request: Request, auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), - db=Depends(get_db), ): token = None @@ -94,19 +92,19 @@ def get_current_user( # auth by api key if token.startswith("sk-"): - return get_current_user_by_api_key(db, token) + return get_current_user_by_api_key(token) # auth by jwt token data = decode_token(token) if data != None and "id" in data: - user = Users.get_user_by_id(db, data["id"]) + user = Users.get_user_by_id(data["id"]) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.INVALID_TOKEN, ) else: - Users.update_user_last_active_by_id(db, user.id) + Users.update_user_last_active_by_id(user.id) return user else: raise HTTPException( From 070d9083d5dd515c32bd7bf60aeecc56f5bc059c Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 09:50:14 +0200 Subject: [PATCH 003/115] feat(sqlalchemy): use subprocess to do migrations --- backend/alembic.ini | 2 +- backend/main.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/backend/alembic.ini b/backend/alembic.ini index 72f2b762b..4eff85f0c 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -58,7 +58,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # are written from script.py.mako # output_encoding = utf-8 -sqlalchemy.url = REPLACE_WITH_DATABASE_URL +# sqlalchemy.url = REPLACE_WITH_DATABASE_URL [post_write_hooks] diff --git a/backend/main.py b/backend/main.py index 6e44045f2..8892d9bc7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -173,13 +173,11 @@ https://github.com/open-webui/open-webui def run_migrations(): - from alembic.config import Config - from alembic import command - - alembic_cfg = Config(f"{BACKEND_DIR}/alembic.ini") - alembic_cfg.set_main_option("sqlalchemy.url", DATABASE_URL) - alembic_cfg.set_main_option("script_location", f"{BACKEND_DIR}/migrations") - command.upgrade(alembic_cfg, "head") + env = os.environ.copy() + env["DATABASE_URL"] = DATABASE_URL + migration_task = subprocess.run(["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env) + if migration_task.returncode > 0: + raise ValueError("Error running migrations") @asynccontextmanager From 320e658595918241c9bdab4f302017039d1ae694 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 09:56:42 +0200 Subject: [PATCH 004/115] feat(sqlalchemy): cleanup fixes --- backend/apps/socket/main.py | 4 +--- backend/main.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/apps/socket/main.py b/backend/apps/socket/main.py index bbbbccd79..123ff31cd 100644 --- a/backend/apps/socket/main.py +++ b/backend/apps/socket/main.py @@ -24,9 +24,7 @@ async def connect(sid, environ, auth): data = decode_token(auth["token"]) if data is not None and "id" in data: - from apps.webui.internal.db import SessionLocal - - user = Users.get_user_by_id(SessionLocal(), data["id"]) + user = Users.get_user_by_id(data["id"]) if user: SESSION_POOL[sid] = user.id diff --git a/backend/main.py b/backend/main.py index 8892d9bc7..2c4d5ecfd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -751,7 +751,6 @@ class PipelineMiddleware(BaseHTTPMiddleware): user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), - SessionLocal(), ) try: From c134eab27a929cbf678a60356a4c8f6c2e718201 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 09:57:08 +0200 Subject: [PATCH 005/115] feat(sqlalchemy): format backend --- backend/apps/webui/internal/db.py | 5 +- backend/apps/webui/models/auths.py | 20 +- backend/apps/webui/models/chats.py | 36 ++- backend/apps/webui/models/documents.py | 4 +- backend/apps/webui/models/files.py | 1 + backend/apps/webui/models/functions.py | 26 +- backend/apps/webui/models/models.py | 4 +- backend/apps/webui/models/prompts.py | 4 +- backend/apps/webui/models/tags.py | 18 +- backend/apps/webui/models/users.py | 16 +- backend/apps/webui/routers/auths.py | 14 +- backend/apps/webui/routers/chats.py | 40 +-- backend/apps/webui/routers/documents.py | 20 +- backend/apps/webui/routers/files.py | 5 +- backend/apps/webui/routers/memories.py | 4 +- backend/apps/webui/routers/prompts.py | 12 +- backend/apps/webui/routers/tools.py | 4 +- backend/apps/webui/routers/users.py | 28 +- backend/main.py | 4 +- .../migrations/versions/ba76b0bae648_init.py | 255 +++++++++--------- .../test/util/abstract_integration_test.py | 1 + 21 files changed, 232 insertions(+), 289 deletions(-) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 3c37bb09b..6fd541f4e 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -53,7 +53,9 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL: ) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) +SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=engine, expire_on_commit=False +) Base = declarative_base() @@ -66,4 +68,3 @@ def get_session(): except Exception as e: db.rollback() raise e - diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index fd2934bb1..9f10e0fdd 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -126,9 +126,7 @@ class AuthsTable: else: return None - def authenticate_user( - self, email: str, password: str - ) -> Optional[UserModel]: + def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") with get_session() as db: try: @@ -144,9 +142,7 @@ class AuthsTable: except: return None - def authenticate_user_by_api_key( - self, api_key: str - ) -> Optional[UserModel]: + def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") with get_session() as db: # if no api_key, return None @@ -159,9 +155,7 @@ class AuthsTable: except: return False - def authenticate_user_by_trusted_header( - self, email: str - ) -> Optional[UserModel]: + def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") with get_session() as db: try: @@ -172,12 +166,12 @@ class AuthsTable: except: return None - def update_user_password_by_id( - self, id: str, new_password: str - ) -> bool: + def update_user_password_by_id(self, id: str, new_password: str) -> bool: with get_session() as db: try: - result = db.query(Auth).filter_by(id=id).update({"password": new_password}) + result = ( + db.query(Auth).filter_by(id=id).update({"password": new_password}) + ) return True if result == 1 else False except: return False diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index d71ffd992..b0c983ade 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -79,9 +79,7 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: - def insert_new_chat( - self, user_id: str, form_data: ChatForm - ) -> Optional[ChatModel]: + def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: with get_session() as db: id = str(uuid.uuid4()) chat = ChatModel( @@ -89,7 +87,9 @@ class ChatTable: "id": id, "user_id": user_id, "title": ( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" ), "chat": json.dumps(form_data.chat), "created_at": int(time.time()), @@ -103,9 +103,7 @@ class ChatTable: db.refresh(result) return ChatModel.model_validate(result) if result else None - def update_chat_by_id( - self, id: str, chat: dict - ) -> Optional[ChatModel]: + def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: with get_session() as db: try: chat_obj = db.get(Chat, id) @@ -119,9 +117,7 @@ class ChatTable: except Exception as e: return None - def insert_shared_chat_by_chat_id( - self, chat_id: str - ) -> Optional[ChatModel]: + def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: with get_session() as db: # Get the existing chat to share chat = db.get(Chat, chat_id) @@ -145,14 +141,14 @@ class ChatTable: db.refresh(shared_result) # Update the original chat with the share_id result = ( - db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id}) + db.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) ) return shared_chat if (shared_result and result) else None - def update_shared_chat_by_chat_id( - self, chat_id: str - ) -> Optional[ChatModel]: + def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: with get_session() as db: try: print("update_shared_chat_by_id") @@ -271,9 +267,7 @@ class ChatTable: except Exception as e: return None - def get_chat_by_id_and_user_id( - self, id: str, user_id: str - ) -> Optional[ChatModel]: + def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: with get_session() as db: chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() @@ -293,13 +287,13 @@ class ChatTable: def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: with get_session() as db: all_chats = ( - db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc()) + db.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in all_chats] - def get_archived_chats_by_user_id( - self, user_id: str - ) -> List[ChatModel]: + def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: with get_session() as db: all_chats = ( db.query(Chat) diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 6348967db..897f182be 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -106,7 +106,9 @@ class DocumentsTable: def get_docs(self) -> List[DocumentModel]: with get_session() as db: - return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()] + return [ + DocumentModel.model_validate(doc) for doc in db.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index d2565db3d..b7196d604 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -39,6 +39,7 @@ class FileModel(BaseModel): model_config = ConfigDict(from_attributes=True) + #################### # Forms #################### diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 417e52329..2343c9139 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -142,9 +142,9 @@ class FunctionsTable: with get_session() as db: return [ FunctionModel.model_validate(function) - for function in db.query(Function).filter_by( - type=type, is_active=True - ).all() + for function in db.query(Function) + .filter_by(type=type, is_active=True) + .all() ] else: with get_session() as db: @@ -220,10 +220,12 @@ class FunctionsTable: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: try: with get_session() as db: - db.query(Function).filter_by(id=id).update({ - **updated, - "updated_at": int(time.time()), - }) + db.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) db.commit() return self.get_function_by_id(id) except: @@ -232,10 +234,12 @@ class FunctionsTable: def deactivate_all_functions(self) -> Optional[bool]: try: with get_session() as db: - db.query(Function).update({ - "is_active": False, - "updated_at": int(time.time()), - }) + db.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) db.commit() return True except: diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 7641ee5a0..86b4fa49b 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -153,9 +153,7 @@ class ModelsTable: except: return None - def update_model_by_id( - self, id: str, model: ModelForm - ) -> Optional[ModelModel]: + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: # update only the fields that are present in the model with get_session() as db: diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index 2157153d8..029fd5e1b 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -83,7 +83,9 @@ class PromptsTable: def get_prompts(self) -> List[PromptModel]: with get_session() as db: - return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()] + return [ + PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 5ad176c37..dfe63688e 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -79,9 +79,7 @@ class ChatTagsResponse(BaseModel): class TagTable: - def insert_new_tag( - self, name: str, user_id: str - ) -> Optional[TagModel]: + def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: @@ -201,11 +199,13 @@ class TagTable: self, tag_name: str, user_id: str ) -> int: with get_session() as db: - return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count() + return ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) - def delete_tag_by_tag_name_and_user_id( - self, tag_name: str, user_id: str - ) -> bool: + def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: with get_session() as db: res = ( @@ -252,9 +252,7 @@ class TagTable: log.error(f"delete_tag: {e}") return False - def delete_tags_by_chat_id_and_user_id( - self, chat_id: str, user_id: str - ) -> bool: + def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id) for tag in tags: diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index bef15185b..796892927 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -165,9 +165,7 @@ class UsersTable: except: return None - def update_user_role_by_id( - self, id: str, role: str - ) -> Optional[UserModel]: + def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: with get_session() as db: try: db.query(User).filter_by(id=id).update({"role": role}) @@ -193,12 +191,12 @@ class UsersTable: except: return None - def update_user_last_active_by_id( - self, id: str - ) -> Optional[UserModel]: + def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: with get_session() as db: try: - db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())}) + db.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) @@ -217,9 +215,7 @@ class UsersTable: except: return None - def update_user_by_id( - self, id: str, updated: dict - ) -> Optional[UserModel]: + def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: with get_session() as db: try: db.query(User).filter_by(id=id).update(updated) diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py index f32b074b1..1be79d259 100644 --- a/backend/apps/webui/routers/auths.py +++ b/backend/apps/webui/routers/auths.py @@ -78,8 +78,7 @@ async def get_session_user( @router.post("/update/profile", response_model=UserResponse) async def update_profile( - form_data: UpdateProfileForm, - session_user=Depends(get_current_user) + form_data: UpdateProfileForm, session_user=Depends(get_current_user) ): if session_user: user = Users.update_user_by_id( @@ -101,8 +100,7 @@ async def update_profile( @router.post("/update/password", response_model=bool) async def update_password( - form_data: UpdatePasswordForm, - session_user=Depends(get_current_user) + form_data: UpdatePasswordForm, session_user=Depends(get_current_user) ): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) @@ -269,9 +267,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.post("/add", response_model=SigninResponse) -async def add_user( - form_data: AddUserForm, user=Depends(get_admin_user) -): +async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): if not validate_email_format(form_data.email.lower()): raise HTTPException( @@ -316,9 +312,7 @@ async def add_user( @router.get("/admin/details") -async def get_admin_details( - request: Request, user=Depends(get_current_user) -): +async def get_admin_details(request: Request, user=Depends(get_current_user)): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py index 8b2b9987a..3070483f3 100644 --- a/backend/apps/webui/routers/chats.py +++ b/backend/apps/webui/routers/chats.py @@ -55,9 +55,7 @@ async def get_session_user_chat_list( @router.delete("/", response_model=bool) -async def delete_all_user_chats( - request: Request, user=Depends(get_current_user) -): +async def delete_all_user_chats(request: Request, user=Depends(get_current_user)): if ( user.role == "user" @@ -95,9 +93,7 @@ async def get_user_chat_list_by_user_id( @router.post("/new", response_model=Optional[ChatResponse]) -async def create_new_chat( - form_data: ChatForm, user=Depends(get_current_user) -): +async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): try: chat = Chats.insert_new_chat(user.id, form_data) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) @@ -180,9 +176,7 @@ async def archive_all_chats(user=Depends(get_current_user)): @router.get("/share/{share_id}", response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id( - share_id: str, user=Depends(get_current_user) -): +async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): if user.role == "pending": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND @@ -225,9 +219,7 @@ async def get_user_chat_list_by_tag_name( ) ] - chats = Chats.get_chat_list_by_chat_ids( - chat_ids, form_data.skip, form_data.limit - ) + chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit) if len(chats) == 0: Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id) @@ -297,9 +289,7 @@ async def update_chat_by_id( @router.delete("/{id}", response_model=bool) -async def delete_chat_by_id( - request: Request, id: str, user=Depends(get_current_user) -): +async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)): if user.role == "admin": result = Chats.delete_chat_by_id(id) @@ -347,9 +337,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)): @router.get("/{id}/archive", response_model=Optional[ChatResponse]) -async def archive_chat_by_id( - id: str, user=Depends(get_current_user) -): +async def archive_chat_by_id(id: str, user=Depends(get_current_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: chat = Chats.toggle_chat_archive_by_id(id) @@ -398,9 +386,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)): @router.delete("/{id}/share", response_model=Optional[bool]) -async def delete_shared_chat_by_id( - id: str, user=Depends(get_current_user) -): +async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): chat = Chats.get_chat_by_id_and_user_id(id, user.id) if chat: if not chat.share_id: @@ -423,9 +409,7 @@ async def delete_shared_chat_by_id( @router.get("/{id}/tags", response_model=List[TagModel]) -async def get_chat_tags_by_id( - id: str, user=Depends(get_current_user) -): +async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) if tags != None: @@ -443,9 +427,7 @@ async def get_chat_tags_by_id( @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) async def add_chat_tag_by_id( - id: str, - form_data: ChatIdTagForm, - user=Depends(get_current_user) + id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) ): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) @@ -494,9 +476,7 @@ async def delete_chat_tag_by_id( @router.delete("/{id}/tags/all", response_model=Optional[bool]) -async def delete_all_chat_tags_by_id( - id: str, user=Depends(get_current_user) -): +async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id) if result: diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py index f358e033c..4e1111c07 100644 --- a/backend/apps/webui/routers/documents.py +++ b/backend/apps/webui/routers/documents.py @@ -44,9 +44,7 @@ async def get_documents(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[DocumentResponse]) -async def create_new_doc( - form_data: DocumentForm, user=Depends(get_admin_user) -): +async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)): doc = Documents.get_doc_by_name(form_data.name) if doc == None: doc = Documents.insert_new_doc(user.id, form_data) @@ -76,9 +74,7 @@ async def create_new_doc( @router.get("/doc", response_model=Optional[DocumentResponse]) -async def get_doc_by_name( - name: str, user=Depends(get_current_user) -): +async def get_doc_by_name(name: str, user=Depends(get_current_user)): doc = Documents.get_doc_by_name(name) if doc: @@ -110,12 +106,8 @@ class TagDocumentForm(BaseModel): @router.post("/doc/tags", response_model=Optional[DocumentResponse]) -async def tag_doc_by_name( - form_data: TagDocumentForm, user=Depends(get_current_user) -): - doc = Documents.update_doc_content_by_name( - form_data.name, {"tags": form_data.tags} - ) +async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): + doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) if doc: return DocumentResponse( @@ -163,8 +155,6 @@ async def update_doc_by_name( @router.delete("/doc/delete", response_model=bool) -async def delete_doc_by_name( - name: str, user=Depends(get_admin_user) -): +async def delete_doc_by_name(name: str, user=Depends(get_admin_user)): result = Documents.delete_doc_by_name(name) return result diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index e98d1da58..fffe0743c 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -50,10 +50,7 @@ router = APIRouter() @router.post("/") -def upload_file( - file: UploadFile = File(...), - user=Depends(get_verified_user) -): +def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): log.info(f"file.content_type: {file.content_type}") try: unsanitized_filename = file.filename diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py index d6b2d0fcb..2c473ebe8 100644 --- a/backend/apps/webui/routers/memories.py +++ b/backend/apps/webui/routers/memories.py @@ -167,9 +167,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)): @router.delete("/{memory_id}", response_model=bool) -async def delete_memory_by_id( - memory_id: str, user=Depends(get_verified_user) -): +async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)): result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id) if result: diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py index 3912b1028..0cbf3d366 100644 --- a/backend/apps/webui/routers/prompts.py +++ b/backend/apps/webui/routers/prompts.py @@ -29,9 +29,7 @@ async def get_prompts(user=Depends(get_current_user)): @router.post("/create", response_model=Optional[PromptModel]) -async def create_new_prompt( - form_data: PromptForm, user=Depends(get_admin_user) -): +async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)): prompt = Prompts.get_prompt_by_command(form_data.command) if prompt == None: prompt = Prompts.insert_new_prompt(user.id, form_data) @@ -54,9 +52,7 @@ async def create_new_prompt( @router.get("/command/{command}", response_model=Optional[PromptModel]) -async def get_prompt_by_command( - command: str, user=Depends(get_current_user) -): +async def get_prompt_by_command(command: str, user=Depends(get_current_user)): prompt = Prompts.get_prompt_by_command(f"/{command}") if prompt: @@ -95,8 +91,6 @@ async def update_prompt_by_command( @router.delete("/command/{command}/delete", response_model=bool) -async def delete_prompt_by_command( - command: str, user=Depends(get_admin_user) -): +async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)): result = Prompts.delete_prompt_by_command(f"/{command}") return result diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py index 82a09477d..ea9db8180 100644 --- a/backend/apps/webui/routers/tools.py +++ b/backend/apps/webui/routers/tools.py @@ -180,9 +180,7 @@ async def update_toolkit_by_id( @router.delete("/id/{id}/delete", response_model=bool) -async def delete_toolkit_by_id( - request: Request, id: str, user=Depends(get_admin_user) -): +async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)): result = Tools.delete_tool_by_id(id) if result: diff --git a/backend/apps/webui/routers/users.py b/backend/apps/webui/routers/users.py index 8a38d5b9f..9627f0b06 100644 --- a/backend/apps/webui/routers/users.py +++ b/backend/apps/webui/routers/users.py @@ -40,9 +40,7 @@ router = APIRouter() @router.get("/", response_model=List[UserModel]) -async def get_users( - skip: int = 0, limit: int = 50, user=Depends(get_admin_user) -): +async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)): return Users.get_users(skip, limit) @@ -70,9 +68,7 @@ async def update_user_permissions( @router.post("/update/role", response_model=Optional[UserModel]) -async def update_user_role( - form_data: UserRoleUpdateForm, user=Depends(get_admin_user) -): +async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): if user.id != form_data.id and form_data.id != Users.get_first_user().id: return Users.update_user_role_by_id(form_data.id, form_data.role) @@ -89,9 +85,7 @@ async def update_user_role( @router.get("/user/settings", response_model=Optional[UserSettings]) -async def get_user_settings_by_session_user( - user=Depends(get_verified_user) -): +async def get_user_settings_by_session_user(user=Depends(get_verified_user)): user = Users.get_user_by_id(user.id) if user: return user.settings @@ -127,9 +121,7 @@ async def update_user_settings_by_session_user( @router.get("/user/info", response_model=Optional[dict]) -async def get_user_info_by_session_user( - user=Depends(get_verified_user) -): +async def get_user_info_by_session_user(user=Depends(get_verified_user)): user = Users.get_user_by_id(user.id) if user: return user.info @@ -154,9 +146,7 @@ async def update_user_info_by_session_user( if user.info is None: user.info = {} - user = Users.update_user_by_id( - user.id, {"info": {**user.info, **form_data}} - ) + user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}}) if user: return user.info else: @@ -182,9 +172,7 @@ class UserResponse(BaseModel): @router.get("/{user_id}", response_model=UserResponse) -async def get_user_by_id( - user_id: str, user=Depends(get_verified_user) -): +async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): # Check if user_id is a shared chat # If it is, get the user_id from the chat @@ -267,9 +255,7 @@ async def update_user_by_id( @router.delete("/{user_id}", response_model=bool) -async def delete_user_by_id( - user_id: str, user=Depends(get_admin_user) -): +async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): if user.id != user_id: result = Auths.delete_auth_by_id(user_id) diff --git a/backend/main.py b/backend/main.py index 2c4d5ecfd..f35095bf1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -175,7 +175,9 @@ https://github.com/open-webui/open-webui def run_migrations(): env = os.environ.copy() env["DATABASE_URL"] = DATABASE_URL - migration_task = subprocess.run(["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env) + migration_task = subprocess.run( + ["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env + ) if migration_task.returncode > 0: raise ValueError("Error running migrations") diff --git a/backend/migrations/versions/ba76b0bae648_init.py b/backend/migrations/versions/ba76b0bae648_init.py index b1250662f..c491ed46c 100644 --- a/backend/migrations/versions/ba76b0bae648_init.py +++ b/backend/migrations/versions/ba76b0bae648_init.py @@ -5,6 +5,7 @@ Revises: Create Date: 2024-06-24 09:09:11.636336 """ + from typing import Sequence, Union from alembic import op @@ -13,7 +14,7 @@ import apps.webui.internal.db # revision identifiers, used by Alembic. -revision: str = 'ba76b0bae648' +revision: str = "ba76b0bae648" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,141 +22,153 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('auth', - sa.Column('id', sa.String(), nullable=False), - sa.Column('email', sa.String(), nullable=True), - sa.Column('password', sa.String(), nullable=True), - sa.Column('active', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "auth", + sa.Column("id", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=True), + sa.Column("password", sa.String(), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('chat', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('chat', sa.String(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('share_id', sa.String(), nullable=True), - sa.Column('archived', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('share_id') + op.create_table( + "chat", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("chat", sa.String(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("share_id", sa.String(), nullable=True), + sa.Column("archived", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("share_id"), ) - op.create_table('chatidtag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('tag_name', sa.String(), nullable=True), - sa.Column('chat_id', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "chatidtag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("tag_name", sa.String(), nullable=True), + sa.Column("chat_id", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('document', - sa.Column('collection_name', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('filename', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('collection_name'), - sa.UniqueConstraint('name') + op.create_table( + "document", + sa.Column("collection_name", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("filename", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("collection_name"), + sa.UniqueConstraint("name"), ) - op.create_table('file', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('filename', sa.String(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "file", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("filename", sa.String(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('function', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('type', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "function", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("type", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('memory', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "memory", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('model', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('base_model_id', sa.String(), nullable=True), - sa.Column('name', sa.String(), nullable=True), - sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "model", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("base_model_id", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('prompt', - sa.Column('command', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('command') + op.create_table( + "prompt", + sa.Column("command", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("command"), ) - op.create_table('tag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('data', sa.String(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("data", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('tool', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.String(), nullable=True), - sa.Column('content', sa.String(), nullable=True), - sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tool", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("content", sa.String(), nullable=True), + sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) - op.create_table('user', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('email', sa.String(), nullable=True), - sa.Column('role', sa.String(), nullable=True), - sa.Column('profile_image_url', sa.String(), nullable=True), - sa.Column('last_active_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('api_key', sa.String(), nullable=True), - sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('api_key') + op.create_table( + "user", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("email", sa.String(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("profile_image_url", sa.String(), nullable=True), + sa.Column("last_active_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("api_key"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('user') - op.drop_table('tool') - op.drop_table('tag') - op.drop_table('prompt') - op.drop_table('model') - op.drop_table('memory') - op.drop_table('function') - op.drop_table('file') - op.drop_table('document') - op.drop_table('chatidtag') - op.drop_table('chat') - op.drop_table('auth') + op.drop_table("user") + op.drop_table("tool") + op.drop_table("tag") + op.drop_table("prompt") + op.drop_table("model") + op.drop_table("memory") + op.drop_table("function") + op.drop_table("file") + op.drop_table("document") + op.drop_table("chatidtag") + op.drop_table("chat") + op.drop_table("auth") # ### end Alembic commands ### diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py index 9cbf42d47..781fbfff8 100644 --- a/backend/test/util/abstract_integration_test.py +++ b/backend/test/util/abstract_integration_test.py @@ -91,6 +91,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): while retries > 0: try: from config import BACKEND_DIR + db = create_engine(database_url, pool_pre_ping=True) db = db.connect() log.info("postgres is ready!") From eb01e8d2755a73f7c8db121d7b69b36bee1cae22 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 10:54:18 +0200 Subject: [PATCH 006/115] feat(sqlalchemy): use scoped session --- backend/apps/webui/internal/db.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 6fd541f4e..b9bfc8aff 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -7,7 +7,7 @@ from typing_extensions import Self from sqlalchemy import create_engine, types, Dialect from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.sql.type_api import _T from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR @@ -61,10 +61,10 @@ Base = declarative_base() @contextmanager def get_session(): - db = SessionLocal() + session = scoped_session(SessionLocal) try: - yield db - db.commit() + yield session + session.commit() except Exception as e: - db.rollback() + session.rollback() raise e From da403f3e3cf9ce700da2fdb477e0bdfc4794d37d Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 13:06:15 +0200 Subject: [PATCH 007/115] feat(sqlalchemy): use session factory instead of context manager --- backend/apps/webui/internal/db.py | 12 +- backend/apps/webui/models/auths.py | 131 ++++---- backend/apps/webui/models/chats.py | 292 ++++++++---------- backend/apps/webui/models/documents.py | 77 +++-- backend/apps/webui/models/files.py | 36 +-- backend/apps/webui/models/functions.py | 117 ++++--- backend/apps/webui/models/memories.py | 60 ++-- backend/apps/webui/models/models.py | 42 ++- backend/apps/webui/models/prompts.py | 91 +++--- backend/apps/webui/models/tags.py | 200 ++++++------ backend/apps/webui/models/tools.py | 59 ++-- backend/apps/webui/models/users.py | 245 +++++++-------- backend/main.py | 14 +- backend/test/apps/webui/routers/test_chats.py | 2 + .../test/util/abstract_integration_test.py | 21 +- 15 files changed, 640 insertions(+), 759 deletions(-) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index b9bfc8aff..320ab3e07 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -57,14 +57,4 @@ SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) Base = declarative_base() - - -@contextmanager -def get_session(): - session = scoped_session(SessionLocal) - try: - yield session - session.commit() - except Exception as e: - session.rollback() - raise e +Session = scoped_session(SessionLocal) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 9f10e0fdd..1858b2c0d 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -3,12 +3,11 @@ from typing import Optional import uuid import logging from sqlalchemy import String, Column, Boolean -from sqlalchemy.orm import Session from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session from config import SRC_LOG_LEVELS @@ -103,101 +102,93 @@ class AuthsTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - with get_session() as db: - log.info("insert_new_auth") + log.info("insert_new_auth") - id = str(uuid.uuid4()) + id = str(uuid.uuid4()) - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) - result = Auth(**auth.model_dump()) - db.add(result) + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = Auth(**auth.model_dump()) + Session.add(result) - user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub - ) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub) - db.commit() - db.refresh(result) + Session.commit() + Session.refresh(result) - if result and user: - return user - else: - return None + if result and user: + return user + else: + return None def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") - with get_session() as db: - try: - auth = db.query(Auth).filter_by(email=email, active=True).first() - if auth: - if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) - return user - else: - return None + try: + auth = Session.query(Auth).filter_by(email=email, active=True).first() + if auth: + if verify_password(password, auth.password): + user = Users.get_user_by_id(auth.id) + return user else: return None - except: + else: return None + except: + return None def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_api_key: {api_key}") - with get_session() as db: - # if no api_key, return None - if not api_key: - return None + # if no api_key, return None + if not api_key: + return None - try: - user = Users.get_user_by_api_key(api_key) - return user if user else None - except: - return False + try: + user = Users.get_user_by_api_key(api_key) + return user if user else None + except: + return False def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") - with get_session() as db: - try: - auth = db.query(Auth).filter(email=email, active=True).first() - if auth: - user = Users.get_user_by_id(auth.id) - return user - except: - return None + try: + auth = Session.query(Auth).filter(email=email, active=True).first() + if auth: + user = Users.get_user_by_id(auth.id) + return user + except: + return None def update_user_password_by_id(self, id: str, new_password: str) -> bool: - with get_session() as db: - try: - result = ( - db.query(Auth).filter_by(id=id).update({"password": new_password}) - ) - return True if result == 1 else False - except: - return False + try: + result = ( + Session.query(Auth).filter_by(id=id).update({"password": new_password}) + ) + return True if result == 1 else False + except: + return False def update_email_by_id(self, id: str, email: str) -> bool: - with get_session() as db: - try: - result = db.query(Auth).filter_by(id=id).update({"email": email}) - return True if result == 1 else False - except: - return False + try: + result = Session.query(Auth).filter_by(id=id).update({"email": email}) + return True if result == 1 else False + except: + return False def delete_auth_by_id(self, id: str) -> bool: - with get_session() as db: - try: - # Delete User - result = Users.delete_user_by_id(id) + try: + # Delete User + result = Users.delete_user_by_id(id) - if result: - db.query(Auth).filter_by(id=id).delete() + if result: + Session.query(Auth).filter_by(id=id).delete() - return True - else: - return False - except: + return True + else: return False + except: + return False Auths = AuthsTable() diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index b0c983ade..abf5b544c 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -6,9 +6,8 @@ import uuid import time from sqlalchemy import Column, String, BigInteger, Boolean -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session #################### @@ -80,93 +79,88 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - with get_session() as db: - id = str(uuid.uuid4()) - chat = ChatModel( - **{ - "id": id, - "user_id": user_id, - "title": ( - form_data.chat["title"] - if "title" in form_data.chat - else "New Chat" - ), - "chat": json.dumps(form_data.chat), - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" + ), + "chat": json.dumps(form_data.chat), + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) - result = Chat(**chat.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - return ChatModel.model_validate(result) if result else None + result = Chat(**chat.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + return ChatModel.model_validate(result) if result else None def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: - with get_session() as db: - try: - chat_obj = db.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) - db.commit() - db.refresh(chat_obj) + try: + chat_obj = Session.get(Chat, id) + chat_obj.chat = json.dumps(chat) + chat_obj.title = chat["title"] if "title" in chat else "New Chat" + chat_obj.updated_at = int(time.time()) + Session.commit() + Session.refresh(chat_obj) - return ChatModel.model_validate(chat_obj) - except Exception as e: - return None + return ChatModel.model_validate(chat_obj) + except Exception as e: + return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - with get_session() as db: - # Get the existing chat to share - chat = db.get(Chat, chat_id) - # Check if the chat is already shared - if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") - # Create a new chat with the same data, but with a new ID - shared_chat = ChatModel( - **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "created_at": chat.created_at, - "updated_at": int(time.time()), - } - ) - shared_result = Chat(**shared_chat.model_dump()) - db.add(shared_result) - db.commit() - db.refresh(shared_result) - # Update the original chat with the share_id - result = ( - db.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + # Get the existing chat to share + chat = Session.get(Chat, chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "created_at": chat.created_at, + "updated_at": int(time.time()), + } + ) + shared_result = Chat(**shared_chat.model_dump()) + Session.add(shared_result) + Session.commit() + Session.refresh(shared_result) + # Update the original chat with the share_id + result = ( + Session.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) + ) - return shared_chat if (shared_result and result) else None + return shared_chat if (shared_result and result) else None def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - with get_session() as db: - try: - print("update_shared_chat_by_id") - chat = db.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - db.commit() - db.refresh(chat) + try: + print("update_shared_chat_by_id") + chat = Session.get(Chat, chat_id) + print(chat) + chat.title = chat.title + chat.chat = chat.chat + Session.commit() + Session.refresh(chat) - return self.get_chat_by_id(chat.share_id) - except: - return None + return self.get_chat_by_id(chat.share_id) + except: + return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() return True except: return False @@ -175,30 +169,27 @@ class ChatTable: self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.get(Chat, id) - chat.share_id = share_id - db.commit() - db.refresh(chat) - return chat + chat = Session.get(Chat, id) + chat.share_id = share_id + Session.commit() + Session.refresh(chat) + return ChatModel.model_validate(chat) except: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = self.get_chat_by_id(id) - db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) - - return self.get_chat_by_id(id) + chat = Session.get(Chat, id) + chat.archived = not chat.archived + Session.commit() + Session.refresh(chat) + return ChatModel.model_validate(chat) except: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) - + Session.query(Chat).filter_by(user_id=user_id).update({"archived": True}) return True except: return False @@ -206,9 +197,8 @@ class ChatTable: def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - with get_session() as db: all_chats = ( - db.query(Chat) + Session.query(Chat) .filter_by(user_id=user_id, archived=True) .order_by(Chat.updated_at.desc()) # .limit(limit).offset(skip) @@ -223,120 +213,108 @@ class ChatTable: skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - with get_session() as db: - query = db.query(Chat).filter_by(user_id=user_id) - if not include_archived: - query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + query = Session.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) + .order_by(Chat.updated_at.desc()) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.get(Chat, id) - return ChatModel.model_validate(chat) + chat = Session.get(Chat, id) + return ChatModel.model_validate(chat) except: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.query(Chat).filter_by(share_id=id).first() + chat = Session.query(Chat).filter_by(share_id=id).first() - if chat: - return self.get_chat_by_id(id) - else: - return None + if chat: + return self.get_chat_by_id(id) + else: + return None except Exception as e: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: - with get_session() as db: - chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() - return ChatModel.model_validate(chat) + chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - # .limit(limit).offset(skip) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + # .limit(limit).offset(skip) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - with get_session() as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def delete_chat_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(id=id).delete() + Session.query(Chat).filter_by(id=id).delete() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - with get_session() as db: - db.query(Chat).filter_by(id=id, user_id=user_id).delete() + Session.query(Chat).filter_by(id=id, user_id=user_id).delete() - return True and self.delete_shared_chat_by_chat_id(id) + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - self.delete_shared_chats_by_user_id(user_id) + self.delete_shared_chats_by_user_id(user_id) - db.query(Chat).filter_by(user_id=user_id).delete() + Session.query(Chat).filter_by(user_id=user_id).delete() return True except: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() return True except: diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 897f182be..f8e7153c5 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -4,9 +4,8 @@ import time import logging from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session import json @@ -84,46 +83,42 @@ class DocumentsTable: ) try: - with get_session() as db: - result = Document(**document.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: - return None + result = Document(**document.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return DocumentModel.model_validate(result) + else: + return None except: return None def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: - with get_session() as db: - document = db.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None + document = Session.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None def get_docs(self) -> List[DocumentModel]: - with get_session() as db: - return [ - DocumentModel.model_validate(doc) for doc in db.query(Document).all() - ] + return [ + DocumentModel.model_validate(doc) for doc in Session.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - with get_session() as db: - db.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(form_data.name) + Session.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + Session.commit() + return self.get_doc_by_name(form_data.name) except Exception as e: log.exception(e) return None @@ -132,27 +127,25 @@ class DocumentsTable: self, name: str, updated: dict ) -> Optional[DocumentModel]: try: - with get_session() as db: - doc = self.get_doc_by_name(name) - doc_content = json.loads(doc.content if doc.content else "{}") - doc_content = {**doc_content, **updated} + doc = self.get_doc_by_name(name) + doc_content = json.loads(doc.content if doc.content else "{}") + doc_content = {**doc_content, **updated} - db.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - db.commit() - return self.get_doc_by_name(name) + Session.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) + Session.commit() + return self.get_doc_by_name(name) except Exception as e: log.exception(e) return None def delete_doc_by_name(self, name: str) -> bool: try: - with get_session() as db: - db.query(Document).filter_by(name=name).delete() + Session.query(Document).filter_by(name=name).delete() return True except: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index b7196d604..7664bf4f1 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -4,9 +4,8 @@ import time import logging from sqlalchemy import Column, String, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import JSONField, Base, get_session +from apps.webui.internal.db import JSONField, Base, Session import json @@ -71,45 +70,38 @@ class FilesTable: ) try: - with get_session() as db: - result = File(**file.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return FileModel.model_validate(result) - else: - return None + result = File(**file.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return FileModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_file_by_id(self, id: str) -> Optional[FileModel]: try: - with get_session() as db: - file = db.get(File, id) - return FileModel.model_validate(file) + file = Session.get(File, id) + return FileModel.model_validate(file) except: return None def get_files(self) -> List[FileModel]: - with get_session() as db: - return [FileModel.model_validate(file) for file in db.query(File).all()] + return [FileModel.model_validate(file) for file in Session.query(File).all()] def delete_file_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(File).filter_by(id=id).delete() - db.commit() + Session.query(File).filter_by(id=id).delete() return True except: return False def delete_all_files(self) -> bool: try: - with get_session() as db: - db.query(File).delete() - db.commit() + Session.query(File).delete() return True except: return False diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 2343c9139..b78ac9708 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -4,9 +4,8 @@ import time import logging from sqlalchemy import Column, String, Text, BigInteger, Boolean -from sqlalchemy.orm import Session -from apps.webui.internal.db import JSONField, Base, get_session +from apps.webui.internal.db import JSONField, Base, Session from apps.webui.models.users import Users import json @@ -100,64 +99,57 @@ class FunctionsTable: ) try: - with get_session() as db: - result = Function(**function.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return FunctionModel.model_validate(result) - else: - return None + result = Function(**function.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return FunctionModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - with get_session() as db: - function = db.get(Function, id) - return FunctionModel.model_validate(function) + function = Session.get(Function, id) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: if active_only: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(is_active=True).all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function).filter_by(is_active=True).all() + ] else: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function).all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function).all() + ] def get_functions_by_type( self, type: str, active_only=False ) -> List[FunctionModel]: if active_only: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function) - .filter_by(type=type, is_active=True) - .all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function) + .filter_by(type=type, is_active=True) + .all() + ] else: - with get_session() as db: - return [ - FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(type=type).all() - ] + return [ + FunctionModel.model_validate(function) + for function in Session.query(Function).filter_by(type=type).all() + ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: try: - with get_session() as db: - function = db.get(Function, id) - return function.valves if function.valves else {} + function = Session.get(Function, id) + return function.valves if function.valves else {} except Exception as e: print(f"An error occurred: {e}") return None @@ -166,12 +158,12 @@ class FunctionsTable: self, id: str, valves: dict ) -> Optional[FunctionValves]: try: - with get_session() as db: - db.query(Function).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - db.commit() - return self.get_function_by_id(id) + function = Session.get(Function, id) + function.valves = valves + function.updated_at = int(time.time()) + Session.commit() + Session.refresh(function) + return self.get_function_by_id(id) except: return None @@ -219,36 +211,33 @@ class FunctionsTable: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: try: - with get_session() as db: - db.query(Function).filter_by(id=id).update( - { - **updated, - "updated_at": int(time.time()), - } - ) - db.commit() - return self.get_function_by_id(id) + Session.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) + Session.commit() + return self.get_function_by_id(id) except: return None def deactivate_all_functions(self) -> Optional[bool]: try: - with get_session() as db: - db.query(Function).update( - { - "is_active": False, - "updated_at": int(time.time()), - } - ) - db.commit() + Session.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) + Session.commit() return True except: return None def delete_function_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Function).filter_by(id=id).delete() + Session.query(Function).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 941da5b26..263d1b5ab 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -2,10 +2,8 @@ from pydantic import BaseModel, ConfigDict from typing import List, Union, Optional from sqlalchemy import Column, String, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session -from apps.webui.models.chats import Chats +from apps.webui.internal.db import Base, Session import time import uuid @@ -58,15 +56,14 @@ class MemoriesTable: "updated_at": int(time.time()), } ) - with get_session() as db: - result = Memory(**memory.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + result = Memory(**memory.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return MemoryModel.model_validate(result) + else: + return None def update_memory_by_id( self, @@ -74,62 +71,55 @@ class MemoriesTable: content: str, ) -> Optional[MemoryModel]: try: - with get_session() as db: - db.query(Memory).filter_by(id=id).update( - {"content": content, "updated_at": int(time.time())} - ) - db.commit() - return self.get_memory_by_id(id) + Session.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + Session.commit() + return self.get_memory_by_id(id) except: return None def get_memories(self) -> List[MemoryModel]: try: - with get_session() as db: - memories = db.query(Memory).all() - return [MemoryModel.model_validate(memory) for memory in memories] + memories = Session.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: try: - with get_session() as db: - memories = db.query(Memory).filter_by(user_id=user_id).all() - return [MemoryModel.model_validate(memory) for memory in memories] + memories = Session.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] except: return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: try: - with get_session() as db: - memory = db.get(Memory, id) - return MemoryModel.model_validate(memory) + memory = Session.get(Memory, id) + return MemoryModel.model_validate(memory) except: return None def delete_memory_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Memory).filter_by(id=id).delete() + Session.query(Memory).filter_by(id=id).delete() return True except: return False - def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool: + def delete_memories_by_user_id(self, user_id: str) -> bool: try: - with get_session() as db: - db.query(Memory).filter_by(user_id=user_id).delete() + Session.query(Memory).filter_by(user_id=user_id).delete() return True except: return False def delete_memory_by_id_and_user_id( - self, db: Session, id: str, user_id: str + self, id: str, user_id: str ) -> bool: try: - with get_session() as db: - db.query(Memory).filter_by(id=id, user_id=user_id).delete() + Session.query(Memory).filter_by(id=id, user_id=user_id).delete() return True except: return False diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 86b4fa49b..dd736a73e 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -4,9 +4,8 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, JSONField, get_session +from apps.webui.internal.db import Base, JSONField, Session from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -127,41 +126,37 @@ class ModelsTable: } ) try: - with get_session() as db: - result = Model(**model.model_dump()) - db.add(result) - db.commit() - db.refresh(result) + result = Model(**model.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) - if result: - return ModelModel.model_validate(result) - else: - return None + if result: + return ModelModel.model_validate(result) + else: + return None except Exception as e: print(e) return None def get_all_models(self) -> List[ModelModel]: - with get_session() as db: - return [ModelModel.model_validate(model) for model in db.query(Model).all()] + return [ModelModel.model_validate(model) for model in Session.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - with get_session() as db: - model = db.get(Model, id) - return ModelModel.model_validate(model) + model = Session.get(Model, id) + return ModelModel.model_validate(model) except: return None def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: # update only the fields that are present in the model - with get_session() as db: - model = db.query(Model).get(id) - model.update(**model.model_dump()) - db.commit() - db.refresh(model) - return ModelModel.model_validate(model) + model = Session.query(Model).get(id) + model.update(**model.model_dump()) + Session.commit() + Session.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) @@ -169,8 +164,7 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Model).filter_by(id=id).delete() + Session.query(Model).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index 029fd5e1b..a2fd0366b 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -3,9 +3,8 @@ from typing import List, Optional import time from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session import json @@ -50,65 +49,59 @@ class PromptsTable: def insert_new_prompt( self, user_id: str, form_data: PromptForm ) -> Optional[PromptModel]: - with get_session() as db: - prompt = PromptModel( - **{ - "user_id": user_id, - "command": form_data.command, - "title": form_data.title, - "content": form_data.content, - "timestamp": int(time.time()), - } - ) + prompt = PromptModel( + **{ + "user_id": user_id, + "command": form_data.command, + "title": form_data.title, + "content": form_data.content, + "timestamp": int(time.time()), + } + ) - try: - result = Prompt(**prompt.dict()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return PromptModel.model_validate(result) - else: - return None - except Exception as e: + try: + result = Prompt(**prompt.dict()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return PromptModel.model_validate(result) + else: return None + except Exception as e: + return None def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: - with get_session() as db: - try: - prompt = db.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) - except: - return None + try: + prompt = Session.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) + except: + return None def get_prompts(self) -> List[PromptModel]: - with get_session() as db: - return [ - PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() - ] + return [ + PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: - with get_session() as db: - try: - prompt = db.query(Prompt).filter_by(command=command).first() - prompt.title = form_data.title - prompt.content = form_data.content - prompt.timestamp = int(time.time()) - db.commit() - return prompt - # return self.get_prompt_by_command(command) - except: - return None + try: + prompt = Session.query(Prompt).filter_by(command=command).first() + prompt.title = form_data.title + prompt.content = form_data.content + prompt.timestamp = int(time.time()) + Session.commit() + return PromptModel.model_validate(prompt) + except: + return None def delete_prompt_by_command(self, command: str) -> bool: - with get_session() as db: - try: - db.query(Prompt).filter_by(command=command).delete() - return True - except: - return False + try: + Session.query(Prompt).filter_by(command=command).delete() + return True + except: + return False Prompts = PromptsTable() diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index dfe63688e..6cfe39d0c 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -7,9 +7,8 @@ import time import logging from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, get_session +from apps.webui.internal.db import Base, Session from config import SRC_LOG_LEVELS @@ -83,15 +82,14 @@ class TagTable: id = str(uuid.uuid4()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: - with get_session() as db: - result = Tag(**tag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return TagModel.model_validate(result) - else: - return None + result = Tag(**tag.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return TagModel.model_validate(result) + else: + return None except Exception as e: return None @@ -99,9 +97,8 @@ class TagTable: self, name: str, user_id: str ) -> Optional[TagModel]: try: - with get_session() as db: - tag = db.query(Tag).filter(name=name, user_id=user_id).first() - return TagModel.model_validate(tag) + tag = Session.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None @@ -123,105 +120,99 @@ class TagTable: } ) try: - with get_session() as db: - result = ChatIdTag(**chatIdTag.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None + result = ChatIdTag(**chatIdTag.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return ChatIdTagModel.model_validate(result) + else: + return None except: return None def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: - with get_session() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + Session.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + Session.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: - with get_session() as db: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + Session.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - db.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + Session.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: - with get_session() as db: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - db.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + return [ + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + Session.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] def count_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> int: - with get_session() as db: - return ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) + return ( + Session.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: - with get_session() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - db.commit() + res = ( + Session.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: log.error(f"delete_tag: {e}") @@ -231,21 +222,20 @@ class TagTable: self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - with get_session() as db: - res = ( - db.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - db.commit() + res = ( + Session.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) - if tag_count == 0: - # Remove tag item from Tag col as well - db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() return True except Exception as e: diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 534a4e3e8..20c608921 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -3,9 +3,8 @@ from typing import List, Optional import time import logging from sqlalchemy import String, Column, BigInteger -from sqlalchemy.orm import Session -from apps.webui.internal.db import Base, JSONField, get_session +from apps.webui.internal.db import Base, JSONField, Session from apps.webui.models.users import Users import json @@ -95,48 +94,43 @@ class ToolsTable: ) try: - with get_session() as db: - result = Tool(**tool.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return ToolModel.model_validate(result) - else: - return None + result = Tool(**tool.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return ToolModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - with get_session() as db: - tool = db.get(Tool, id) - return ToolModel.model_validate(tool) + tool = Session.get(Tool, id) + return ToolModel.model_validate(tool) except: return None def get_tools(self) -> List[ToolModel]: - with get_session() as db: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - with get_session() as db: - tool = db.get(Tool, id) - return tool.valves if tool.valves else {} + tool = Session.get(Tool, id) + return tool.valves if tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - with get_session() as db: - db.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - db.commit() - return self.get_tool_by_id(id) + Session.query(Tool).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + Session.commit() + return self.get_tool_by_id(id) except: return None @@ -183,19 +177,18 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - with get_session() as db: - db.query(Tool).filter_by(id=id).update( - {**updated, "updated_at": int(time.time())} - ) - db.commit() - return self.get_tool_by_id(id) + tool = Session.get(Tool, id) + tool.update(**updated) + tool.updated_at = int(time.time()) + Session.commit() + Session.refresh(tool) + return ToolModel.model_validate(tool) except: return None def delete_tool_by_id(self, id: str) -> bool: try: - with get_session() as db: - db.query(Tool).filter_by(id=id).delete() + Session.query(Tool).filter_by(id=id).delete() return True except: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 796892927..252e3f122 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -3,11 +3,10 @@ from typing import List, Union, Optional import time from sqlalchemy import String, Column, BigInteger, Text -from sqlalchemy.orm import Session from utils.misc import get_gravatar_url -from apps.webui.internal.db import Base, JSONField, get_session +from apps.webui.internal.db import Base, JSONField, Session from apps.webui.models.chats import Chats #################### @@ -89,177 +88,161 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - with get_session() as db: - user = UserModel( - **{ - "id": id, - "name": name, - "email": email, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "oauth_sub": oauth_sub, - } - ) - result = User(**user.model_dump()) - db.add(result) - db.commit() - db.refresh(result) - if result: - return user - else: - return None + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": profile_image_url, + "last_active_at": int(time.time()), + "created_at": int(time.time()), + "updated_at": int(time.time()), + "oauth_sub": oauth_sub, + } + ) + result = User(**user.model_dump()) + Session.add(result) + Session.commit() + Session.refresh(result) + if result: + return user + else: + return None def get_user_by_id(self, id: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except Exception as e: - return None + try: + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except Exception as e: + return None def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) + except: + return None def get_user_by_email(self, email: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(email=email).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) + except: + return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: - with get_session() as db: - try: - user = db.query(User).filter_by(oauth_sub=sub).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).filter_by(oauth_sub=sub).first() + return UserModel.model_validate(user) + except: + return None def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: - with get_session() as db: - users = ( - db.query(User) - # .offset(skip).limit(limit) - .all() - ) - return [UserModel.model_validate(user) for user in users] + users = ( + Session.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] def get_num_users(self) -> Optional[int]: - with get_session() as db: - return db.query(User).count() + return Session.query(User).count() def get_first_user(self) -> UserModel: - with get_session() as db: - try: - user = db.query(User).order_by(User.created_at).first() - return UserModel.model_validate(user) - except: - return None + try: + user = Session.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) + except: + return None def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update({"role": role}) - db.commit() + try: + Session.query(User).filter_by(id=id).update({"role": role}) + Session.commit() - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_profile_image_url_by_id( self, id: str, profile_image_url: str ) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - db.commit() + try: + Session.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} + ) + Session.commit() - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update( - {"last_active_at": int(time.time())} - ) + try: + Session.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_oauth_sub_by_id( self, id: str, oauth_sub: str ) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + try: + Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - except: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + except: + return None def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: - with get_session() as db: - try: - db.query(User).filter_by(id=id).update(updated) - db.commit() + try: + Session.query(User).filter_by(id=id).update(updated) + Session.commit() - user = db.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - # return UserModel(**user.dict()) - except Exception as e: - return None + user = Session.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) + except Exception as e: + return None def delete_user_by_id(self, id: str) -> bool: - with get_session() as db: - try: - # Delete User Chats - result = Chats.delete_chats_by_user_id(id) + try: + # Delete User Chats + result = Chats.delete_chats_by_user_id(id) - if result: - # Delete User - db.query(User).filter_by(id=id).delete() - db.commit() + if result: + # Delete User + Session.query(User).filter_by(id=id).delete() + Session.commit() - return True - else: - return False - except: + return True + else: return False + except: + return False def update_user_api_key_by_id(self, id: str, api_key: str) -> str: - with get_session() as db: - try: - result = db.query(User).filter_by(id=id).update({"api_key": api_key}) - db.commit() - return True if result == 1 else False - except: - return False + try: + result = Session.query(User).filter_by(id=id).update({"api_key": api_key}) + Session.commit() + return True if result == 1 else False + except: + return False def get_user_api_key_by_id(self, id: str) -> Optional[str]: - with get_session() as db: - try: - user = db.query(User).filter_by(id=id).first() - return user.api_key - except Exception as e: - return None + try: + user = Session.query(User).filter_by(id=id).first() + return user.api_key + except Exception as e: + return None Users = UsersTable() diff --git a/backend/main.py b/backend/main.py index f35095bf1..2120b499a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -29,7 +29,6 @@ from fastapi import HTTPException from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import text -from sqlalchemy.orm import Session from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -57,7 +56,7 @@ from apps.webui.main import ( get_pipe_models, generate_function_chat_completion, ) -from apps.webui.internal.db import get_session, SessionLocal +from apps.webui.internal.db import Session, SessionLocal from pydantic import BaseModel @@ -794,6 +793,14 @@ app.add_middleware( allow_headers=["*"], ) +@app.middleware("http") +async def remove_session_after_request(request: Request, call_next): + response = await call_next(request) + log.debug("Removing session after request") + Session.commit() + Session.remove() + return response + @app.middleware("http") async def check_url(request: Request, call_next): @@ -2034,8 +2041,7 @@ async def healthcheck(): @app.get("/health/db") async def healthcheck_with_db(): - with get_session() as db: - result = db.execute(text("SELECT 1;")).all() + Session.execute(text("SELECT 1;")).all() return {"status": True} diff --git a/backend/test/apps/webui/routers/test_chats.py b/backend/test/apps/webui/routers/test_chats.py index ea4518eaf..6d2dd35b1 100644 --- a/backend/test/apps/webui/routers/test_chats.py +++ b/backend/test/apps/webui/routers/test_chats.py @@ -90,6 +90,8 @@ class TestChats(AbstractPostgresTest): def test_get_user_archived_chats(self): self.chats.archive_all_chats_by_user_id("2") + from apps.webui.internal.db import Session + Session.commit() with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url("/all/archived")) assert response.status_code == 200 diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py index 781fbfff8..f8d6d4ff7 100644 --- a/backend/test/util/abstract_integration_test.py +++ b/backend/test/util/abstract_integration_test.py @@ -9,6 +9,7 @@ from pytest_docker.plugin import get_docker_ip from fastapi.testclient import TestClient from sqlalchemy import text, create_engine + log = logging.getLogger(__name__) @@ -50,11 +51,6 @@ class AbstractPostgresTest(AbstractIntegrationTest): DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" docker_client: DockerClient - def get_db(self): - from apps.webui.internal.db import SessionLocal - - return SessionLocal() - @classmethod def _create_db_url(cls, env_vars_postgres: dict) -> str: host = get_docker_ip() @@ -113,21 +109,21 @@ class AbstractPostgresTest(AbstractIntegrationTest): pytest.fail(f"Could not setup test environment: {ex}") def _check_db_connection(self): + from apps.webui.internal.db import Session retries = 10 while retries > 0: try: - self.db_session.execute(text("SELECT 1")) - self.db_session.commit() + Session.execute(text("SELECT 1")) + Session.commit() break except Exception as e: - self.db_session.rollback() + Session.rollback() log.warning(e) time.sleep(3) retries -= 1 def setup_method(self): super().setup_method() - self.db_session = self.get_db() self._check_db_connection() @classmethod @@ -136,8 +132,9 @@ class AbstractPostgresTest(AbstractIntegrationTest): cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) def teardown_method(self): + from apps.webui.internal.db import Session # rollback everything not yet committed - self.db_session.commit() + Session.commit() # truncate all tables tables = [ @@ -152,5 +149,5 @@ class AbstractPostgresTest(AbstractIntegrationTest): '"user"', ] for table in tables: - self.db_session.execute(text(f"TRUNCATE TABLE {table}")) - self.db_session.commit() + Session.execute(text(f"TRUNCATE TABLE {table}")) + Session.commit() From a9b148791d982b9635935a41ca6bdc3aa47165c3 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 13:21:51 +0200 Subject: [PATCH 008/115] feat(sqlalchemy): fix wrong column types --- backend/apps/webui/models/auths.py | 4 +- backend/apps/webui/models/chats.py | 8 +- backend/apps/webui/models/documents.py | 8 +- backend/apps/webui/models/files.py | 4 +- backend/apps/webui/models/memories.py | 4 +- backend/apps/webui/models/models.py | 10 +- backend/apps/webui/models/prompts.py | 6 +- backend/apps/webui/models/tags.py | 4 +- backend/apps/webui/models/tools.py | 6 +- backend/apps/webui/models/users.py | 2 +- .../migrations/versions/7e5b5dc7342b_init.py | 186 ++++++++++++++++++ .../migrations/versions/ba76b0bae648_init.py | 174 ---------------- 12 files changed, 214 insertions(+), 202 deletions(-) create mode 100644 backend/migrations/versions/7e5b5dc7342b_init.py delete mode 100644 backend/migrations/versions/ba76b0bae648_init.py diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 1858b2c0d..aef895619 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -2,7 +2,7 @@ from pydantic import BaseModel from typing import Optional import uuid import logging -from sqlalchemy import String, Column, Boolean +from sqlalchemy import String, Column, Boolean, Text from apps.webui.models.users import UserModel, Users from utils.utils import verify_password @@ -24,7 +24,7 @@ class Auth(Base): id = Column(String, primary_key=True) email = Column(String) - password = Column(String) + password = Column(Text) active = Column(Boolean) diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index abf5b544c..1cf56c351 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -5,7 +5,7 @@ import json import uuid import time -from sqlalchemy import Column, String, BigInteger, Boolean +from sqlalchemy import Column, String, BigInteger, Boolean, Text from apps.webui.internal.db import Base, Session @@ -20,13 +20,13 @@ class Chat(Base): id = Column(String, primary_key=True) user_id = Column(String) - title = Column(String) - chat = Column(String) # Save Chat JSON as Text + title = Column(Text) + chat = Column(Text) # Save Chat JSON as Text created_at = Column(BigInteger) updated_at = Column(BigInteger) - share_id = Column(String, unique=True, nullable=True) + share_id = Column(Text, unique=True, nullable=True) archived = Column(Boolean, default=False) diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index f8e7153c5..1b69d44a5 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -3,7 +3,7 @@ from typing import List, Optional import time import logging -from sqlalchemy import String, Column, BigInteger +from sqlalchemy import String, Column, BigInteger, Text from apps.webui.internal.db import Base, Session @@ -24,9 +24,9 @@ class Document(Base): collection_name = Column(String, primary_key=True) name = Column(String, unique=True) - title = Column(String) - filename = Column(String) - content = Column(String, nullable=True) + title = Column(Text) + filename = Column(Text) + content = Column(Text, nullable=True) user_id = Column(String) timestamp = Column(BigInteger) diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index 7664bf4f1..ce904215d 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -3,7 +3,7 @@ from typing import List, Union, Optional import time import logging -from sqlalchemy import Column, String, BigInteger +from sqlalchemy import Column, String, BigInteger, Text from apps.webui.internal.db import JSONField, Base, Session @@ -24,7 +24,7 @@ class File(Base): id = Column(String, primary_key=True) user_id = Column(String) - filename = Column(String) + filename = Column(Text) meta = Column(JSONField) created_at = Column(BigInteger) diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 263d1b5ab..f0bd6e291 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, ConfigDict from typing import List, Union, Optional -from sqlalchemy import Column, String, BigInteger +from sqlalchemy import Column, String, BigInteger, Text from apps.webui.internal.db import Base, Session @@ -18,7 +18,7 @@ class Memory(Base): id = Column(String, primary_key=True) user_id = Column(String) - content = Column(String) + content = Column(Text) updated_at = Column(BigInteger) created_at = Column(BigInteger) diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index dd736a73e..7d1da54ff 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -3,7 +3,7 @@ import logging from typing import Optional from pydantic import BaseModel, ConfigDict -from sqlalchemy import String, Column, BigInteger +from sqlalchemy import String, Column, BigInteger, Text from apps.webui.internal.db import Base, JSONField, Session @@ -46,18 +46,18 @@ class ModelMeta(BaseModel): class Model(Base): __tablename__ = "model" - id = Column(String, primary_key=True) + id = Column(Text, primary_key=True) """ The model's id as used in the API. If set to an existing model, it will override the model. """ - user_id = Column(String) + user_id = Column(Text) - base_model_id = Column(String, nullable=True) + base_model_id = Column(Text, nullable=True) """ An optional pointer to the actual model that should be used when proxying requests. """ - name = Column(String) + name = Column(Text) """ The human-readable display name of the model. """ diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index a2fd0366b..ab8cc04ce 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, ConfigDict from typing import List, Optional import time -from sqlalchemy import String, Column, BigInteger +from sqlalchemy import String, Column, BigInteger, Text from apps.webui.internal.db import Base, Session @@ -18,8 +18,8 @@ class Prompt(Base): command = Column(String, primary_key=True) user_id = Column(String) - title = Column(String) - content = Column(String) + title = Column(Text) + content = Column(Text) timestamp = Column(BigInteger) diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 6cfe39d0c..87238c2a3 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -6,7 +6,7 @@ import uuid import time import logging -from sqlalchemy import String, Column, BigInteger +from sqlalchemy import String, Column, BigInteger, Text from apps.webui.internal.db import Base, Session @@ -26,7 +26,7 @@ class Tag(Base): id = Column(String, primary_key=True) name = Column(String) user_id = Column(String) - data = Column(String, nullable=True) + data = Column(Text, nullable=True) class ChatIdTag(Base): diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 20c608921..f5df10637 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, ConfigDict from typing import List, Optional import time import logging -from sqlalchemy import String, Column, BigInteger +from sqlalchemy import String, Column, BigInteger, Text from apps.webui.internal.db import Base, JSONField, Session from apps.webui.models.users import Users @@ -26,8 +26,8 @@ class Tool(Base): id = Column(String, primary_key=True) user_id = Column(String) - name = Column(String) - content = Column(String) + name = Column(Text) + content = Column(Text) specs = Column(JSONField) meta = Column(JSONField) valves = Column(JSONField) diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 252e3f122..e1c5ca9f3 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -21,7 +21,7 @@ class User(Base): name = Column(String) email = Column(String) role = Column(String) - profile_image_url = Column(String) + profile_image_url = Column(Text) last_active_at = Column(BigInteger) updated_at = Column(BigInteger) diff --git a/backend/migrations/versions/7e5b5dc7342b_init.py b/backend/migrations/versions/7e5b5dc7342b_init.py new file mode 100644 index 000000000..bd49d1b43 --- /dev/null +++ b/backend/migrations/versions/7e5b5dc7342b_init.py @@ -0,0 +1,186 @@ +"""init + +Revision ID: 7e5b5dc7342b +Revises: +Create Date: 2024-06-24 13:15:33.808998 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import apps.webui.internal.db +from migrations.util import get_existing_tables + +# revision identifiers, used by Alembic. +revision: str = '7e5b5dc7342b' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + existing_tables = set(get_existing_tables()) + + # ### commands auto generated by Alembic - please adjust! ### + if "auth" not in existing_tables: + op.create_table('auth', + sa.Column('id', sa.String(), nullable=False), + sa.Column('email', sa.String(), nullable=True), + sa.Column('password', sa.Text(), nullable=True), + sa.Column('active', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "chat" not in existing_tables: + op.create_table('chat', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('chat', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('share_id', sa.Text(), nullable=True), + sa.Column('archived', sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('share_id') + ) + + if "chatidtag" not in existing_tables: + op.create_table('chatidtag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('tag_name', sa.String(), nullable=True), + sa.Column('chat_id', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "document" not in existing_tables: + op.create_table('document', + sa.Column('collection_name', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('collection_name'), + sa.UniqueConstraint('name') + ) + + if "file" not in existing_tables: + op.create_table('file', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "function" not in existing_tables: + op.create_table('function', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "memory" not in existing_tables: + op.create_table('memory', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "model" not in existing_tables: + op.create_table('model', + sa.Column('id', sa.Text(), nullable=False), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('base_model_id', sa.Text(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "prompt" not in existing_tables: + op.create_table('prompt', + sa.Column('command', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('command') + ) + + if "tag" not in existing_tables: + op.create_table('tag', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('data', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "tool" not in existing_tables: + op.create_table('tool', + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + + if "user" not in existing_tables: + op.create_table('user', + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('email', sa.String(), nullable=True), + sa.Column('role', sa.String(), nullable=True), + sa.Column('profile_image_url', sa.Text(), nullable=True), + sa.Column('last_active_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('api_key', sa.String(), nullable=True), + sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('api_key') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('user') + op.drop_table('tool') + op.drop_table('tag') + op.drop_table('prompt') + op.drop_table('model') + op.drop_table('memory') + op.drop_table('function') + op.drop_table('file') + op.drop_table('document') + op.drop_table('chatidtag') + op.drop_table('chat') + op.drop_table('auth') + # ### end Alembic commands ### diff --git a/backend/migrations/versions/ba76b0bae648_init.py b/backend/migrations/versions/ba76b0bae648_init.py deleted file mode 100644 index c491ed46c..000000000 --- a/backend/migrations/versions/ba76b0bae648_init.py +++ /dev/null @@ -1,174 +0,0 @@ -"""init - -Revision ID: ba76b0bae648 -Revises: -Create Date: 2024-06-24 09:09:11.636336 - -""" - -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -import apps.webui.internal.db - - -# revision identifiers, used by Alembic. -revision: str = "ba76b0bae648" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "auth", - sa.Column("id", sa.String(), nullable=False), - sa.Column("email", sa.String(), nullable=True), - sa.Column("password", sa.String(), nullable=True), - sa.Column("active", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "chat", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.String(), nullable=True), - sa.Column("chat", sa.String(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("share_id", sa.String(), nullable=True), - sa.Column("archived", sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("share_id"), - ) - op.create_table( - "chatidtag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("tag_name", sa.String(), nullable=True), - sa.Column("chat_id", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "document", - sa.Column("collection_name", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("title", sa.String(), nullable=True), - sa.Column("filename", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("collection_name"), - sa.UniqueConstraint("name"), - ) - op.create_table( - "file", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("filename", sa.String(), nullable=True), - sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "function", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("type", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "memory", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "model", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("base_model_id", sa.String(), nullable=True), - sa.Column("name", sa.String(), nullable=True), - sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "prompt", - sa.Column("command", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("command"), - ) - op.create_table( - "tag", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("data", sa.String(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "tool", - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.String(), nullable=True), - sa.Column("content", sa.String(), nullable=True), - sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "user", - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("email", sa.String(), nullable=True), - sa.Column("role", sa.String(), nullable=True), - sa.Column("profile_image_url", sa.String(), nullable=True), - sa.Column("last_active_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("api_key"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("user") - op.drop_table("tool") - op.drop_table("tag") - op.drop_table("prompt") - op.drop_table("model") - op.drop_table("memory") - op.drop_table("function") - op.drop_table("file") - op.drop_table("document") - op.drop_table("chatidtag") - op.drop_table("chat") - op.drop_table("auth") - # ### end Alembic commands ### From 8f939cf55bc4a4de63c859f033cf5da5378e2d30 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 13:45:33 +0200 Subject: [PATCH 009/115] feat(sqlalchemy): some fixes --- backend/apps/webui/models/users.py | 1 + backend/main.py | 5 ++--- backend/utils/utils.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index e1c5ca9f3..9e1e25ac6 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -185,6 +185,7 @@ class UsersTable: Session.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) + Session.commit() user = Session.query(User).filter_by(id=id).first() return UserModel.model_validate(user) diff --git a/backend/main.py b/backend/main.py index 2120b499a..ad519bdcb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -794,11 +794,10 @@ app.add_middleware( ) @app.middleware("http") -async def remove_session_after_request(request: Request, call_next): +async def commit_session_after_request(request: Request, call_next): response = await call_next(request) - log.debug("Removing session after request") + log.debug("Commit session after request") Session.commit() - Session.remove() return response diff --git a/backend/utils/utils.py b/backend/utils/utils.py index 6409fc7aa..fbc539af5 100644 --- a/backend/utils/utils.py +++ b/backend/utils/utils.py @@ -113,8 +113,8 @@ def get_current_user( ) -def get_current_user_by_api_key(db: Session, api_key: str): - user = Users.get_user_by_api_key(db, api_key) +def get_current_user_by_api_key(api_key: str): + user = Users.get_user_by_api_key(api_key) if user is None: raise HTTPException( @@ -122,7 +122,7 @@ def get_current_user_by_api_key(db: Session, api_key: str): detail=ERROR_MESSAGES.INVALID_TOKEN, ) else: - Users.update_user_last_active_by_id(db, user.id) + Users.update_user_last_active_by_id(user.id) return user From 2fb27adbf67c13d89ac652f3652f7a578a3bcb25 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 13:54:24 +0200 Subject: [PATCH 010/115] feat(sqlalchemy): add missing file --- backend/migrations/util.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 backend/migrations/util.py diff --git a/backend/migrations/util.py b/backend/migrations/util.py new file mode 100644 index 000000000..401bb94d0 --- /dev/null +++ b/backend/migrations/util.py @@ -0,0 +1,9 @@ +from alembic import op +from sqlalchemy import Inspector + + +def get_existing_tables(): + con = op.get_bind() + inspector = Inspector.from_engine(con) + tables = set(inspector.get_table_names()) + return tables From d88bd51e3c446383b37a65ad2119ea640f7df913 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Mon, 24 Jun 2024 13:55:18 +0200 Subject: [PATCH 011/115] feat(sqlalchemy): format backend --- backend/apps/webui/models/chats.py | 20 +- backend/apps/webui/models/memories.py | 4 +- backend/apps/webui/models/models.py | 4 +- backend/apps/webui/models/tags.py | 8 +- backend/main.py | 1 + .../migrations/versions/7e5b5dc7342b_init.py | 255 +++++++++--------- backend/test/apps/webui/routers/test_chats.py | 1 + .../test/util/abstract_integration_test.py | 2 + 8 files changed, 153 insertions(+), 142 deletions(-) diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index 1cf56c351..d6829ee7b 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -85,9 +85,7 @@ class ChatTable: "id": id, "user_id": user_id, "title": ( - form_data.chat["title"] - if "title" in form_data.chat - else "New Chat" + form_data.chat["title"] if "title" in form_data.chat else "New Chat" ), "chat": json.dumps(form_data.chat), "created_at": int(time.time()), @@ -197,14 +195,14 @@ class ChatTable: def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + all_chats = ( + Session.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index f0bd6e291..1f03318fd 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -115,9 +115,7 @@ class MemoriesTable: except: return False - def delete_memory_by_id_and_user_id( - self, id: str, user_id: str - ) -> bool: + def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: Session.query(Memory).filter_by(id=id, user_id=user_id).delete() return True diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 7d1da54ff..6543edefc 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -140,7 +140,9 @@ class ModelsTable: return None def get_all_models(self) -> List[ModelModel]: - return [ModelModel.model_validate(model) for model in Session.query(Model).all()] + return [ + ModelModel.model_validate(model) for model in Session.query(Model).all() + ] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 87238c2a3..7b0df6b6b 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -207,9 +207,7 @@ class TagTable: log.debug(f"res: {res}") Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) + tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) if tag_count == 0: # Remove tag item from Tag col as well Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() @@ -230,9 +228,7 @@ class TagTable: log.debug(f"res: {res}") Session.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id( - tag_name, user_id - ) + tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) if tag_count == 0: # Remove tag item from Tag col as well Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() diff --git a/backend/main.py b/backend/main.py index ad519bdcb..a29fde198 100644 --- a/backend/main.py +++ b/backend/main.py @@ -793,6 +793,7 @@ app.add_middleware( allow_headers=["*"], ) + @app.middleware("http") async def commit_session_after_request(request: Request, call_next): response = await call_next(request) diff --git a/backend/migrations/versions/7e5b5dc7342b_init.py b/backend/migrations/versions/7e5b5dc7342b_init.py index bd49d1b43..50deac526 100644 --- a/backend/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/migrations/versions/7e5b5dc7342b_init.py @@ -5,6 +5,7 @@ Revises: Create Date: 2024-06-24 13:15:33.808998 """ + from typing import Sequence, Union from alembic import op @@ -13,7 +14,7 @@ import apps.webui.internal.db from migrations.util import get_existing_tables # revision identifiers, used by Alembic. -revision: str = '7e5b5dc7342b' +revision: str = "7e5b5dc7342b" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -24,163 +25,175 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### if "auth" not in existing_tables: - op.create_table('auth', - sa.Column('id', sa.String(), nullable=False), - sa.Column('email', sa.String(), nullable=True), - sa.Column('password', sa.Text(), nullable=True), - sa.Column('active', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "auth", + sa.Column("id", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=True), + sa.Column("password", sa.Text(), nullable=True), + sa.Column("active", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "chat" not in existing_tables: - op.create_table('chat', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.Text(), nullable=True), - sa.Column('chat', sa.Text(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('share_id', sa.Text(), nullable=True), - sa.Column('archived', sa.Boolean(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('share_id') + op.create_table( + "chat", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("chat", sa.Text(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("share_id", sa.Text(), nullable=True), + sa.Column("archived", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("share_id"), ) if "chatidtag" not in existing_tables: - op.create_table('chatidtag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('tag_name', sa.String(), nullable=True), - sa.Column('chat_id', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "chatidtag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("tag_name", sa.String(), nullable=True), + sa.Column("chat_id", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "document" not in existing_tables: - op.create_table('document', - sa.Column('collection_name', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('title', sa.Text(), nullable=True), - sa.Column('filename', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('collection_name'), - sa.UniqueConstraint('name') + op.create_table( + "document", + sa.Column("collection_name", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("filename", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("collection_name"), + sa.UniqueConstraint("name"), ) if "file" not in existing_tables: - op.create_table('file', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('filename', sa.Text(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "file", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("filename", sa.Text(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "function" not in existing_tables: - op.create_table('function', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('type', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "function", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("type", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "memory" not in existing_tables: - op.create_table('memory', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "memory", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "model" not in existing_tables: - op.create_table('model', - sa.Column('id', sa.Text(), nullable=False), - sa.Column('user_id', sa.Text(), nullable=True), - sa.Column('base_model_id', sa.Text(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "model", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=True), + sa.Column("base_model_id", sa.Text(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "prompt" not in existing_tables: - op.create_table('prompt', - sa.Column('command', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('title', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('timestamp', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('command') + op.create_table( + "prompt", + sa.Column("command", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("title", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("command"), ) if "tag" not in existing_tables: - op.create_table('tag', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('data', sa.Text(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tag", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("data", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "tool" not in existing_tables: - op.create_table('tool', - sa.Column('id', sa.String(), nullable=False), - sa.Column('user_id', sa.String(), nullable=True), - sa.Column('name', sa.Text(), nullable=True), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.PrimaryKeyConstraint('id') + op.create_table( + "tool", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("name", sa.Text(), nullable=True), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.PrimaryKeyConstraint("id"), ) if "user" not in existing_tables: - op.create_table('user', - sa.Column('id', sa.String(), nullable=False), - sa.Column('name', sa.String(), nullable=True), - sa.Column('email', sa.String(), nullable=True), - sa.Column('role', sa.String(), nullable=True), - sa.Column('profile_image_url', sa.Text(), nullable=True), - sa.Column('last_active_at', sa.BigInteger(), nullable=True), - sa.Column('updated_at', sa.BigInteger(), nullable=True), - sa.Column('created_at', sa.BigInteger(), nullable=True), - sa.Column('api_key', sa.String(), nullable=True), - sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('api_key') + op.create_table( + "user", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("email", sa.String(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("profile_image_url", sa.Text(), nullable=True), + sa.Column("last_active_at", sa.BigInteger(), nullable=True), + sa.Column("updated_at", sa.BigInteger(), nullable=True), + sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("api_key"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('user') - op.drop_table('tool') - op.drop_table('tag') - op.drop_table('prompt') - op.drop_table('model') - op.drop_table('memory') - op.drop_table('function') - op.drop_table('file') - op.drop_table('document') - op.drop_table('chatidtag') - op.drop_table('chat') - op.drop_table('auth') + op.drop_table("user") + op.drop_table("tool") + op.drop_table("tag") + op.drop_table("prompt") + op.drop_table("model") + op.drop_table("memory") + op.drop_table("function") + op.drop_table("file") + op.drop_table("document") + op.drop_table("chatidtag") + op.drop_table("chat") + op.drop_table("auth") # ### end Alembic commands ### diff --git a/backend/test/apps/webui/routers/test_chats.py b/backend/test/apps/webui/routers/test_chats.py index 6d2dd35b1..f4661b625 100644 --- a/backend/test/apps/webui/routers/test_chats.py +++ b/backend/test/apps/webui/routers/test_chats.py @@ -91,6 +91,7 @@ class TestChats(AbstractPostgresTest): def test_get_user_archived_chats(self): self.chats.archive_all_chats_by_user_id("2") from apps.webui.internal.db import Session + Session.commit() with mock_webui_user(id="2"): response = self.fast_api_client.get(self.create_url("/all/archived")) diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py index f8d6d4ff7..4e99dcc2f 100644 --- a/backend/test/util/abstract_integration_test.py +++ b/backend/test/util/abstract_integration_test.py @@ -110,6 +110,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): def _check_db_connection(self): from apps.webui.internal.db import Session + retries = 10 while retries > 0: try: @@ -133,6 +134,7 @@ class AbstractPostgresTest(AbstractIntegrationTest): def teardown_method(self): from apps.webui.internal.db import Session + # rollback everything not yet committed Session.commit() From 642c352c69ceadb118ed5347c091c761a5a65a8b Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Tue, 25 Jun 2024 08:15:59 +0200 Subject: [PATCH 012/115] feat(sqlalchemy): rebase --- .../internal/migrations/001_initial_schema.py | 254 ------------------ .../migrations/002_add_local_sharing.py | 48 ---- .../migrations/003_add_auth_api_key.py | 48 ---- .../internal/migrations/004_add_archived.py | 46 ---- .../internal/migrations/005_add_updated_at.py | 130 --------- .../006_migrate_timestamps_and_charfields.py | 130 --------- .../migrations/007_add_user_last_active_at.py | 79 ------ .../internal/migrations/008_add_memory.py | 53 ---- .../internal/migrations/009_add_models.py | 61 ----- .../010_migrate_modelfiles_to_models.py | 130 --------- .../migrations/011_add_user_settings.py | 48 ---- .../internal/migrations/012_add_tools.py | 61 ----- .../internal/migrations/013_add_user_info.py | 48 ---- .../internal/migrations/014_add_files.py | 55 ---- .../internal/migrations/015_add_functions.py | 61 ----- .../016_add_valves_and_is_active.py | 50 ---- .../migrations/017_add_user_oauth_sub.py | 49 ---- .../apps/webui/internal/migrations/README.md | 21 -- .../migrations/versions/7e5b5dc7342b_init.py | 2 + 19 files changed, 2 insertions(+), 1372 deletions(-) delete mode 100644 backend/apps/webui/internal/migrations/001_initial_schema.py delete mode 100644 backend/apps/webui/internal/migrations/002_add_local_sharing.py delete mode 100644 backend/apps/webui/internal/migrations/003_add_auth_api_key.py delete mode 100644 backend/apps/webui/internal/migrations/004_add_archived.py delete mode 100644 backend/apps/webui/internal/migrations/005_add_updated_at.py delete mode 100644 backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py delete mode 100644 backend/apps/webui/internal/migrations/007_add_user_last_active_at.py delete mode 100644 backend/apps/webui/internal/migrations/008_add_memory.py delete mode 100644 backend/apps/webui/internal/migrations/009_add_models.py delete mode 100644 backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py delete mode 100644 backend/apps/webui/internal/migrations/011_add_user_settings.py delete mode 100644 backend/apps/webui/internal/migrations/012_add_tools.py delete mode 100644 backend/apps/webui/internal/migrations/013_add_user_info.py delete mode 100644 backend/apps/webui/internal/migrations/014_add_files.py delete mode 100644 backend/apps/webui/internal/migrations/015_add_functions.py delete mode 100644 backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py delete mode 100644 backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py delete mode 100644 backend/apps/webui/internal/migrations/README.md diff --git a/backend/apps/webui/internal/migrations/001_initial_schema.py b/backend/apps/webui/internal/migrations/001_initial_schema.py deleted file mode 100644 index 93f278f15..000000000 --- a/backend/apps/webui/internal/migrations/001_initial_schema.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Peewee migrations -- 001_initial_schema.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - # We perform different migrations for SQLite and other databases - # This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite - # will require per-database SQL queries. - # Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base - # schema instead of trying to migrate from an older schema. - if isinstance(database, pw.SqliteDatabase): - migrate_sqlite(migrator, database, fake=fake) - else: - migrate_external(migrator, database, fake=fake) - - -def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): - @migrator.create_model - class Auth(pw.Model): - id = pw.CharField(max_length=255, unique=True) - email = pw.CharField(max_length=255) - password = pw.CharField(max_length=255) - active = pw.BooleanField() - - class Meta: - table_name = "auth" - - @migrator.create_model - class Chat(pw.Model): - id = pw.CharField(max_length=255, unique=True) - user_id = pw.CharField(max_length=255) - title = pw.CharField() - chat = pw.TextField() - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "chat" - - @migrator.create_model - class ChatIdTag(pw.Model): - id = pw.CharField(max_length=255, unique=True) - tag_name = pw.CharField(max_length=255) - chat_id = pw.CharField(max_length=255) - user_id = pw.CharField(max_length=255) - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "chatidtag" - - @migrator.create_model - class Document(pw.Model): - id = pw.AutoField() - collection_name = pw.CharField(max_length=255, unique=True) - name = pw.CharField(max_length=255, unique=True) - title = pw.CharField() - filename = pw.CharField() - content = pw.TextField(null=True) - user_id = pw.CharField(max_length=255) - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "document" - - @migrator.create_model - class Modelfile(pw.Model): - id = pw.AutoField() - tag_name = pw.CharField(max_length=255, unique=True) - user_id = pw.CharField(max_length=255) - modelfile = pw.TextField() - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "modelfile" - - @migrator.create_model - class Prompt(pw.Model): - id = pw.AutoField() - command = pw.CharField(max_length=255, unique=True) - user_id = pw.CharField(max_length=255) - title = pw.CharField() - content = pw.TextField() - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "prompt" - - @migrator.create_model - class Tag(pw.Model): - id = pw.CharField(max_length=255, unique=True) - name = pw.CharField(max_length=255) - user_id = pw.CharField(max_length=255) - data = pw.TextField(null=True) - - class Meta: - table_name = "tag" - - @migrator.create_model - class User(pw.Model): - id = pw.CharField(max_length=255, unique=True) - name = pw.CharField(max_length=255) - email = pw.CharField(max_length=255) - role = pw.CharField(max_length=255) - profile_image_url = pw.CharField(max_length=255) - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "user" - - -def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): - @migrator.create_model - class Auth(pw.Model): - id = pw.CharField(max_length=255, unique=True) - email = pw.CharField(max_length=255) - password = pw.TextField() - active = pw.BooleanField() - - class Meta: - table_name = "auth" - - @migrator.create_model - class Chat(pw.Model): - id = pw.CharField(max_length=255, unique=True) - user_id = pw.CharField(max_length=255) - title = pw.TextField() - chat = pw.TextField() - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "chat" - - @migrator.create_model - class ChatIdTag(pw.Model): - id = pw.CharField(max_length=255, unique=True) - tag_name = pw.CharField(max_length=255) - chat_id = pw.CharField(max_length=255) - user_id = pw.CharField(max_length=255) - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "chatidtag" - - @migrator.create_model - class Document(pw.Model): - id = pw.AutoField() - collection_name = pw.CharField(max_length=255, unique=True) - name = pw.CharField(max_length=255, unique=True) - title = pw.TextField() - filename = pw.TextField() - content = pw.TextField(null=True) - user_id = pw.CharField(max_length=255) - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "document" - - @migrator.create_model - class Modelfile(pw.Model): - id = pw.AutoField() - tag_name = pw.CharField(max_length=255, unique=True) - user_id = pw.CharField(max_length=255) - modelfile = pw.TextField() - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "modelfile" - - @migrator.create_model - class Prompt(pw.Model): - id = pw.AutoField() - command = pw.CharField(max_length=255, unique=True) - user_id = pw.CharField(max_length=255) - title = pw.TextField() - content = pw.TextField() - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "prompt" - - @migrator.create_model - class Tag(pw.Model): - id = pw.CharField(max_length=255, unique=True) - name = pw.CharField(max_length=255) - user_id = pw.CharField(max_length=255) - data = pw.TextField(null=True) - - class Meta: - table_name = "tag" - - @migrator.create_model - class User(pw.Model): - id = pw.CharField(max_length=255, unique=True) - name = pw.CharField(max_length=255) - email = pw.CharField(max_length=255) - role = pw.CharField(max_length=255) - profile_image_url = pw.TextField() - timestamp = pw.BigIntegerField() - - class Meta: - table_name = "user" - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_model("user") - - migrator.remove_model("tag") - - migrator.remove_model("prompt") - - migrator.remove_model("modelfile") - - migrator.remove_model("document") - - migrator.remove_model("chatidtag") - - migrator.remove_model("chat") - - migrator.remove_model("auth") diff --git a/backend/apps/webui/internal/migrations/002_add_local_sharing.py b/backend/apps/webui/internal/migrations/002_add_local_sharing.py deleted file mode 100644 index e93501aee..000000000 --- a/backend/apps/webui/internal/migrations/002_add_local_sharing.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - migrator.add_fields( - "chat", share_id=pw.CharField(max_length=255, null=True, unique=True) - ) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_fields("chat", "share_id") diff --git a/backend/apps/webui/internal/migrations/003_add_auth_api_key.py b/backend/apps/webui/internal/migrations/003_add_auth_api_key.py deleted file mode 100644 index 07144f3ac..000000000 --- a/backend/apps/webui/internal/migrations/003_add_auth_api_key.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - migrator.add_fields( - "user", api_key=pw.CharField(max_length=255, null=True, unique=True) - ) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_fields("user", "api_key") diff --git a/backend/apps/webui/internal/migrations/004_add_archived.py b/backend/apps/webui/internal/migrations/004_add_archived.py deleted file mode 100644 index d01c06b4e..000000000 --- a/backend/apps/webui/internal/migrations/004_add_archived.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - migrator.add_fields("chat", archived=pw.BooleanField(default=False)) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_fields("chat", "archived") diff --git a/backend/apps/webui/internal/migrations/005_add_updated_at.py b/backend/apps/webui/internal/migrations/005_add_updated_at.py deleted file mode 100644 index 950866ef0..000000000 --- a/backend/apps/webui/internal/migrations/005_add_updated_at.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - if isinstance(database, pw.SqliteDatabase): - migrate_sqlite(migrator, database, fake=fake) - else: - migrate_external(migrator, database, fake=fake) - - -def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): - # Adding fields created_at and updated_at to the 'chat' table - migrator.add_fields( - "chat", - created_at=pw.DateTimeField(null=True), # Allow null for transition - updated_at=pw.DateTimeField(null=True), # Allow null for transition - ) - - # Populate the new fields from an existing 'timestamp' field - migrator.sql( - "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" - ) - - # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("chat", "timestamp") - - # Update the fields to be not null now that they are populated - migrator.change_fields( - "chat", - created_at=pw.DateTimeField(null=False), - updated_at=pw.DateTimeField(null=False), - ) - - -def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): - # Adding fields created_at and updated_at to the 'chat' table - migrator.add_fields( - "chat", - created_at=pw.BigIntegerField(null=True), # Allow null for transition - updated_at=pw.BigIntegerField(null=True), # Allow null for transition - ) - - # Populate the new fields from an existing 'timestamp' field - migrator.sql( - "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" - ) - - # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("chat", "timestamp") - - # Update the fields to be not null now that they are populated - migrator.change_fields( - "chat", - created_at=pw.BigIntegerField(null=False), - updated_at=pw.BigIntegerField(null=False), - ) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - if isinstance(database, pw.SqliteDatabase): - rollback_sqlite(migrator, database, fake=fake) - else: - rollback_external(migrator, database, fake=fake) - - -def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): - # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) - - # Copy the earliest created_at date back into the new timestamp field - # This assumes created_at was originally a copy of timestamp - migrator.sql("UPDATE chat SET timestamp = created_at") - - # Remove the created_at and updated_at fields - migrator.remove_fields("chat", "created_at", "updated_at") - - # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) - - -def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False): - # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True)) - - # Copy the earliest created_at date back into the new timestamp field - # This assumes created_at was originally a copy of timestamp - migrator.sql("UPDATE chat SET timestamp = created_at") - - # Remove the created_at and updated_at fields - migrator.remove_fields("chat", "created_at", "updated_at") - - # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False)) diff --git a/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py deleted file mode 100644 index caca14d32..000000000 --- a/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - # Alter the tables with timestamps - migrator.change_fields( - "chatidtag", - timestamp=pw.BigIntegerField(), - ) - migrator.change_fields( - "document", - timestamp=pw.BigIntegerField(), - ) - migrator.change_fields( - "modelfile", - timestamp=pw.BigIntegerField(), - ) - migrator.change_fields( - "prompt", - timestamp=pw.BigIntegerField(), - ) - migrator.change_fields( - "user", - timestamp=pw.BigIntegerField(), - ) - # Alter the tables with varchar to text where necessary - migrator.change_fields( - "auth", - password=pw.TextField(), - ) - migrator.change_fields( - "chat", - title=pw.TextField(), - ) - migrator.change_fields( - "document", - title=pw.TextField(), - filename=pw.TextField(), - ) - migrator.change_fields( - "prompt", - title=pw.TextField(), - ) - migrator.change_fields( - "user", - profile_image_url=pw.TextField(), - ) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - if isinstance(database, pw.SqliteDatabase): - # Alter the tables with timestamps - migrator.change_fields( - "chatidtag", - timestamp=pw.DateField(), - ) - migrator.change_fields( - "document", - timestamp=pw.DateField(), - ) - migrator.change_fields( - "modelfile", - timestamp=pw.DateField(), - ) - migrator.change_fields( - "prompt", - timestamp=pw.DateField(), - ) - migrator.change_fields( - "user", - timestamp=pw.DateField(), - ) - migrator.change_fields( - "auth", - password=pw.CharField(max_length=255), - ) - migrator.change_fields( - "chat", - title=pw.CharField(), - ) - migrator.change_fields( - "document", - title=pw.CharField(), - filename=pw.CharField(), - ) - migrator.change_fields( - "prompt", - title=pw.CharField(), - ) - migrator.change_fields( - "user", - profile_image_url=pw.CharField(), - ) diff --git a/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py b/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py deleted file mode 100644 index dd176ba73..000000000 --- a/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - # Adding fields created_at and updated_at to the 'user' table - migrator.add_fields( - "user", - created_at=pw.BigIntegerField(null=True), # Allow null for transition - updated_at=pw.BigIntegerField(null=True), # Allow null for transition - last_active_at=pw.BigIntegerField(null=True), # Allow null for transition - ) - - # Populate the new fields from an existing 'timestamp' field - migrator.sql( - 'UPDATE "user" SET created_at = timestamp, updated_at = timestamp, last_active_at = timestamp WHERE timestamp IS NOT NULL' - ) - - # Now that the data has been copied, remove the original 'timestamp' field - migrator.remove_fields("user", "timestamp") - - # Update the fields to be not null now that they are populated - migrator.change_fields( - "user", - created_at=pw.BigIntegerField(null=False), - updated_at=pw.BigIntegerField(null=False), - last_active_at=pw.BigIntegerField(null=False), - ) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - # Recreate the timestamp field initially allowing null values for safe transition - migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True)) - - # Copy the earliest created_at date back into the new timestamp field - # This assumes created_at was originally a copy of timestamp - migrator.sql('UPDATE "user" SET timestamp = created_at') - - # Remove the created_at and updated_at fields - migrator.remove_fields("user", "created_at", "updated_at", "last_active_at") - - # Finally, alter the timestamp field to not allow nulls if that was the original setting - migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False)) diff --git a/backend/apps/webui/internal/migrations/008_add_memory.py b/backend/apps/webui/internal/migrations/008_add_memory.py deleted file mode 100644 index 9307aa4d5..000000000 --- a/backend/apps/webui/internal/migrations/008_add_memory.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - @migrator.create_model - class Memory(pw.Model): - id = pw.CharField(max_length=255, unique=True) - user_id = pw.CharField(max_length=255) - content = pw.TextField(null=False) - updated_at = pw.BigIntegerField(null=False) - created_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "memory" - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_model("memory") diff --git a/backend/apps/webui/internal/migrations/009_add_models.py b/backend/apps/webui/internal/migrations/009_add_models.py deleted file mode 100644 index 548ec7cdc..000000000 --- a/backend/apps/webui/internal/migrations/009_add_models.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Peewee migrations -- 009_add_models.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - @migrator.create_model - class Model(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - base_model_id = pw.TextField(null=True) - - name = pw.TextField() - - meta = pw.TextField() - params = pw.TextField() - - created_at = pw.BigIntegerField(null=False) - updated_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "model" - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py deleted file mode 100644 index 2ef814c06..000000000 --- a/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Peewee migrations -- 009_add_models.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator -import json - -from utils.misc import parse_ollama_modelfile - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - # Fetch data from 'modelfile' table and insert into 'model' table - migrate_modelfile_to_model(migrator, database) - # Drop the 'modelfile' table - migrator.remove_model("modelfile") - - -def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): - ModelFile = migrator.orm["modelfile"] - Model = migrator.orm["model"] - - modelfiles = ModelFile.select() - - for modelfile in modelfiles: - # Extract and transform data in Python - - modelfile.modelfile = json.loads(modelfile.modelfile) - meta = json.dumps( - { - "description": modelfile.modelfile.get("desc"), - "profile_image_url": modelfile.modelfile.get("imageUrl"), - "ollama": {"modelfile": modelfile.modelfile.get("content")}, - "suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), - "categories": modelfile.modelfile.get("categories"), - "user": {**modelfile.modelfile.get("user", {}), "community": True}, - } - ) - - info = parse_ollama_modelfile(modelfile.modelfile.get("content")) - - # Insert the processed data into the 'model' table - Model.create( - id=f"ollama-{modelfile.tag_name}", - user_id=modelfile.user_id, - base_model_id=info.get("base_model_id"), - name=modelfile.modelfile.get("title"), - meta=meta, - params=json.dumps(info.get("params", {})), - created_at=modelfile.timestamp, - updated_at=modelfile.timestamp, - ) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - recreate_modelfile_table(migrator, database) - move_data_back_to_modelfile(migrator, database) - migrator.remove_model("model") - - -def recreate_modelfile_table(migrator: Migrator, database: pw.Database): - query = """ - CREATE TABLE IF NOT EXISTS modelfile ( - user_id TEXT, - tag_name TEXT, - modelfile JSON, - timestamp BIGINT - ) - """ - migrator.sql(query) - - -def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): - Model = migrator.orm["model"] - Modelfile = migrator.orm["modelfile"] - - models = Model.select() - - for model in models: - # Extract and transform data in Python - meta = json.loads(model.meta) - - modelfile_data = { - "title": model.name, - "desc": meta.get("description"), - "imageUrl": meta.get("profile_image_url"), - "content": meta.get("ollama", {}).get("modelfile"), - "suggestionPrompts": meta.get("suggestion_prompts"), - "categories": meta.get("categories"), - "user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, - } - - # Insert the processed data back into the 'modelfile' table - Modelfile.create( - user_id=model.user_id, - tag_name=model.id, - modelfile=modelfile_data, - timestamp=model.created_at, - ) diff --git a/backend/apps/webui/internal/migrations/011_add_user_settings.py b/backend/apps/webui/internal/migrations/011_add_user_settings.py deleted file mode 100644 index a1620dcad..000000000 --- a/backend/apps/webui/internal/migrations/011_add_user_settings.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - # Adding fields settings to the 'user' table - migrator.add_fields("user", settings=pw.TextField(null=True)) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - # Remove the settings field - migrator.remove_fields("user", "settings") diff --git a/backend/apps/webui/internal/migrations/012_add_tools.py b/backend/apps/webui/internal/migrations/012_add_tools.py deleted file mode 100644 index 4a68eea55..000000000 --- a/backend/apps/webui/internal/migrations/012_add_tools.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Peewee migrations -- 009_add_models.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - @migrator.create_model - class Tool(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - - name = pw.TextField() - content = pw.TextField() - specs = pw.TextField() - - meta = pw.TextField() - - created_at = pw.BigIntegerField(null=False) - updated_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "tool" - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_model("tool") diff --git a/backend/apps/webui/internal/migrations/013_add_user_info.py b/backend/apps/webui/internal/migrations/013_add_user_info.py deleted file mode 100644 index 0f68669cc..000000000 --- a/backend/apps/webui/internal/migrations/013_add_user_info.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Peewee migrations -- 002_add_local_sharing.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - # Adding fields info to the 'user' table - migrator.add_fields("user", info=pw.TextField(null=True)) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - # Remove the settings field - migrator.remove_fields("user", "info") diff --git a/backend/apps/webui/internal/migrations/014_add_files.py b/backend/apps/webui/internal/migrations/014_add_files.py deleted file mode 100644 index 5e1acf0ad..000000000 --- a/backend/apps/webui/internal/migrations/014_add_files.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Peewee migrations -- 009_add_models.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - @migrator.create_model - class File(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - filename = pw.TextField() - meta = pw.TextField() - created_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "file" - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_model("file") diff --git a/backend/apps/webui/internal/migrations/015_add_functions.py b/backend/apps/webui/internal/migrations/015_add_functions.py deleted file mode 100644 index 8316a9333..000000000 --- a/backend/apps/webui/internal/migrations/015_add_functions.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Peewee migrations -- 009_add_models.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - @migrator.create_model - class Function(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - - name = pw.TextField() - type = pw.TextField() - - content = pw.TextField() - meta = pw.TextField() - - created_at = pw.BigIntegerField(null=False) - updated_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "function" - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_model("function") diff --git a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py deleted file mode 100644 index e3af521b7..000000000 --- a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Peewee migrations -- 009_add_models.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - migrator.add_fields("tool", valves=pw.TextField(null=True)) - migrator.add_fields("function", valves=pw.TextField(null=True)) - migrator.add_fields("function", is_active=pw.BooleanField(default=False)) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_fields("tool", "valves") - migrator.remove_fields("function", "valves") - migrator.remove_fields("function", "is_active") diff --git a/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py b/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py deleted file mode 100644 index fd1d9b560..000000000 --- a/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Peewee migrations -- 017_add_user_oauth_sub.py. - -Some examples (model - class or model name):: - - > Model = migrator.orm['table_name'] # Return model in current state by name - > Model = migrator.ModelClass # Return model in current state by name - - > migrator.sql(sql) # Run custom SQL - > migrator.run(func, *args, **kwargs) # Run python function with the given args - > migrator.create_model(Model) # Create a model (could be used as decorator) - > migrator.remove_model(model, cascade=True) # Remove a model - > migrator.add_fields(model, **fields) # Add fields to a model - > migrator.change_fields(model, **fields) # Change fields - > migrator.remove_fields(model, *field_names, cascade=True) - > migrator.rename_field(model, old_field_name, new_field_name) - > migrator.rename_table(model, new_table_name) - > migrator.add_index(model, *col_names, unique=False) - > migrator.add_not_null(model, *field_names) - > migrator.add_default(model, field_name, default) - > migrator.add_constraint(model, name, sql) - > migrator.drop_index(model, *col_names) - > migrator.drop_not_null(model, *field_names) - > migrator.drop_constraints(model, *constraints) - -""" - -from contextlib import suppress - -import peewee as pw -from peewee_migrate import Migrator - - -with suppress(ImportError): - import playhouse.postgres_ext as pw_pext - - -def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - - migrator.add_fields( - "user", - oauth_sub=pw.TextField(null=True, unique=True), - ) - - -def rollback(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your rollback migrations here.""" - - migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/internal/migrations/README.md b/backend/apps/webui/internal/migrations/README.md deleted file mode 100644 index 260214113..000000000 --- a/backend/apps/webui/internal/migrations/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# Database Migrations - -This directory contains all the database migrations for the web app. -Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library. - -Migrations are automatically ran at app startup. - -## Creating a migration - -Have you made a change to the schema of an existing model? -You will need to create a migration file to ensure that existing databases are updated for backwards compatibility. - -1. Have a database file (`webui.db`) that has the old schema prior to any of your changes. -2. Make your changes to the models. -3. From the `backend` directory, run the following command: - ```bash - pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} - ``` - - `$SQLITE_DB` should be the path to the database file. - - `$MIGRATION_NAME` should be a descriptive name for the migration. -4. The migration file will be created in the `apps/web/internal/migrations` directory. diff --git a/backend/migrations/versions/7e5b5dc7342b_init.py b/backend/migrations/versions/7e5b5dc7342b_init.py index 50deac526..8f197ce5b 100644 --- a/backend/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/migrations/versions/7e5b5dc7342b_init.py @@ -176,8 +176,10 @@ def upgrade() -> None: sa.Column("api_key", sa.String(), nullable=True), sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), + sa.Column('oauth_sub', sa.Text(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("api_key"), + sa.UniqueConstraint("oauth_sub"), ) # ### end Alembic commands ### From d4b6b7c4e8c9930003290c15d624d1f2c5bcd8a6 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Tue, 25 Jun 2024 08:29:18 +0200 Subject: [PATCH 013/115] feat(sqlalchemy): reverted not needed api change --- backend/apps/webui/routers/models.py | 2 +- backend/test/apps/webui/routers/test_models.py | 4 ++-- backend/test/util/abstract_integration_test.py | 10 ++++++++-- src/lib/apis/models/index.ts | 5 ++++- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/backend/apps/webui/routers/models.py b/backend/apps/webui/routers/models.py index eaf459d73..eeae9e1c4 100644 --- a/backend/apps/webui/routers/models.py +++ b/backend/apps/webui/routers/models.py @@ -56,7 +56,7 @@ async def add_new_model( ############################ -@router.get("/{id}", response_model=Optional[ModelModel]) +@router.get("/", response_model=Optional[ModelModel]) async def get_model_by_id(id: str, user=Depends(get_verified_user)): model = Models.get_model_by_id(id) diff --git a/backend/test/apps/webui/routers/test_models.py b/backend/test/apps/webui/routers/test_models.py index 991c83bee..a8495403b 100644 --- a/backend/test/apps/webui/routers/test_models.py +++ b/backend/test/apps/webui/routers/test_models.py @@ -42,9 +42,9 @@ class TestModels(AbstractPostgresTest): assert len(response.json()) == 1 with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url("/my-model")) + response = self.fast_api_client.get(self.create_url(query_params={"id": "my-model"})) assert response.status_code == 200 - data = response.json() + data = response.json()[0] assert data["id"] == "my-model" assert data["name"] == "Hello World" diff --git a/backend/test/util/abstract_integration_test.py b/backend/test/util/abstract_integration_test.py index 4e99dcc2f..8535221a8 100644 --- a/backend/test/util/abstract_integration_test.py +++ b/backend/test/util/abstract_integration_test.py @@ -23,14 +23,20 @@ def get_fast_api_client(): class AbstractIntegrationTest: BASE_PATH = None - def create_url(self, path): + def create_url(self, path="", query_params=None): if self.BASE_PATH is None: raise Exception("BASE_PATH is not set") parts = self.BASE_PATH.split("/") parts = [part.strip() for part in parts if part.strip() != ""] path_parts = path.split("/") path_parts = [part.strip() for part in path_parts if part.strip() != ""] - return "/".join(parts + path_parts) + query_parts = "" + if query_params: + query_parts = "&".join( + [f"{key}={value}" for key, value in query_params.items()] + ) + query_parts = f"?{query_parts}" + return "/".join(parts + path_parts) + query_parts @classmethod def setup_class(cls): diff --git a/src/lib/apis/models/index.ts b/src/lib/apis/models/index.ts index 17d11d816..9faa358d3 100644 --- a/src/lib/apis/models/index.ts +++ b/src/lib/apis/models/index.ts @@ -63,7 +63,10 @@ export const getModelInfos = async (token: string = '') => { export const getModelById = async (token: string, id: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, { + const searchParams = new URLSearchParams(); + searchParams.append('id', id); + + const res = await fetch(`${WEBUI_API_BASE_URL}/models?${searchParams.toString()}`, { method: 'GET', headers: { Accept: 'application/json', From 23e4d9daff157d18442374fdd52bc8acbc1cb81c Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Tue, 25 Jun 2024 08:35:55 +0200 Subject: [PATCH 014/115] feat(sqlalchemy): formatting --- backend/apps/webui/models/auths.py | 3 ++- backend/migrations/versions/7e5b5dc7342b_init.py | 2 +- backend/test/apps/webui/routers/test_models.py | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index aef895619..560d9a686 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -113,7 +113,8 @@ class AuthsTable: Session.add(result) user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub) + id, name, email, profile_image_url, role, oauth_sub + ) Session.commit() Session.refresh(result) diff --git a/backend/migrations/versions/7e5b5dc7342b_init.py b/backend/migrations/versions/7e5b5dc7342b_init.py index 8f197ce5b..90597ec9b 100644 --- a/backend/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/migrations/versions/7e5b5dc7342b_init.py @@ -176,7 +176,7 @@ def upgrade() -> None: sa.Column("api_key", sa.String(), nullable=True), sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True), sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True), - sa.Column('oauth_sub', sa.Text(), nullable=True), + sa.Column("oauth_sub", sa.Text(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("api_key"), sa.UniqueConstraint("oauth_sub"), diff --git a/backend/test/apps/webui/routers/test_models.py b/backend/test/apps/webui/routers/test_models.py index a8495403b..34d3e30bd 100644 --- a/backend/test/apps/webui/routers/test_models.py +++ b/backend/test/apps/webui/routers/test_models.py @@ -42,7 +42,9 @@ class TestModels(AbstractPostgresTest): assert len(response.json()) == 1 with mock_webui_user(id="2"): - response = self.fast_api_client.get(self.create_url(query_params={"id": "my-model"})) + response = self.fast_api_client.get( + self.create_url(query_params={"id": "my-model"}) + ) assert response.status_code == 200 data = response.json()[0] assert data["id"] == "my-model" From 827b1e58e96e76ef1d7d150e8a984178f2caf923 Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Tue, 25 Jun 2024 09:06:04 +0200 Subject: [PATCH 015/115] feat(sqlalchemy): execute tests in github actions --- .github/workflows/integration-test.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index c8e7c1672..3b455820d 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -67,6 +67,28 @@ jobs: path: compose-logs.txt if-no-files-found: ignore + pytest: + name: Run backend tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r backend/requirements.txt + + - name: pytest run + run: | + ls -al + cd backend + PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO + migration_test: name: Run Migration Tests runs-on: ubuntu-latest From 5391f4c1f7e09516b0878f504efe673821be4e7c Mon Sep 17 00:00:00 2001 From: Jonathan Rohde Date: Fri, 28 Jun 2024 09:21:07 +0200 Subject: [PATCH 016/115] feat(sqlalchemy): add new column --- backend/migrations/versions/7e5b5dc7342b_init.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/migrations/versions/7e5b5dc7342b_init.py b/backend/migrations/versions/7e5b5dc7342b_init.py index 90597ec9b..b82627f5b 100644 --- a/backend/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/migrations/versions/7e5b5dc7342b_init.py @@ -96,6 +96,7 @@ def upgrade() -> None: sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True), sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True), sa.Column("is_active", sa.Boolean(), nullable=True), + sa.Column("is_global", sa.Boolean(), nullable=True), sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column("created_at", sa.BigInteger(), nullable=True), sa.PrimaryKeyConstraint("id"), From 0c3f9a16e3c3ecd882b69bea2363902889a3c4c8 Mon Sep 17 00:00:00 2001 From: Sergey Mihaylin Date: Fri, 28 Jun 2024 16:31:40 +0300 Subject: [PATCH 017/115] custom env for set custom claims for openid --- backend/apps/webui/main.py | 5 +++++ backend/config.py | 12 ++++++++++++ backend/main.py | 6 ++++-- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 28b1b4aac..e7f0683c6 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -39,6 +39,8 @@ from config import ( WEBUI_BANNERS, ENABLE_COMMUNITY_SHARING, AppConfig, + OAUTH_USERNAME_CLAIM, + OAUTH_PICTURE_CLAIM ) import inspect @@ -74,6 +76,9 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} diff --git a/backend/config.py b/backend/config.py index 3a825f53a..cd184aab8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -395,6 +395,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig( os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), ) +OAUTH_USERNAME_CLAIM = PersistentConfig( + "OAUTH_USERNAME_CLAIM", + "oauth.oidc.username_claim", + os.environ.get("OAUTH_USERNAME_CLAIM", "name"), +) + +OAUTH_PICTURE_CLAIM = PersistentConfig( + "OAUTH_USERNAME_CLAIM", + "oauth.oidc.avatar_claim", + os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/main.py b/backend/main.py index aae305c5e..b4fd10c21 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1920,11 +1920,13 @@ async def oauth_callback(provider: str, request: Request, response: Response): # If the user does not exist, check if signups are enabled if ENABLE_OAUTH_SIGNUP.value: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) + email_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM + existing_user = Users.get_user_by_email(user_data.get(email_claim, "").lower()) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - picture_url = user_data.get("picture", "") + picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") if picture_url: # Download the profile image into a base64 string try: From 9f32e9ef602fdb25e69a95d41ae4a96358ed88f2 Mon Sep 17 00:00:00 2001 From: Sergey Mihaylin Date: Fri, 28 Jun 2024 17:08:32 +0300 Subject: [PATCH 018/115] fix username claim --- backend/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/main.py b/backend/main.py index b4fd10c21..72527c310 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1920,8 +1920,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): # If the user does not exist, check if signups are enabled if ENABLE_OAUTH_SIGNUP.value: # Check if an existing user with the same email already exists - email_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - existing_user = Users.get_user_by_email(user_data.get(email_claim, "").lower()) + existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) @@ -1946,12 +1945,13 @@ async def oauth_callback(provider: str, request: Request, response: Response): picture_url = "" if not picture_url: picture_url = "/user.png" + username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM user = Auths.insert_new_auth( email=email, password=get_password_hash( str(uuid.uuid4()) ), # Random password, not used - name=user_data.get("name", "User"), + name=user_data.get(username_claim, "User"), profile_image_url=picture_url, role=webui_app.state.config.DEFAULT_USER_ROLE, oauth_sub=provider_sub, From e475f025b74250213b71b4c4853bfeae66573890 Mon Sep 17 00:00:00 2001 From: Sergey Mihaylin Date: Mon, 1 Jul 2024 10:25:25 +0300 Subject: [PATCH 019/115] fix: merge request fail (remove picture_claim) --- backend/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/main.py b/backend/main.py index 3c4ffbbd2..edb7c74ae 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1920,7 +1920,8 @@ async def oauth_callback(provider: str, request: Request, response: Response): if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - picture_url = user_data.get("picture", "") + picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") if picture_url: # Download the profile image into a base64 string try: From a94c7e5c0973811b82ff8443220286525d0b1929 Mon Sep 17 00:00:00 2001 From: Sergey Mihaylin Date: Mon, 1 Jul 2024 10:36:21 +0300 Subject: [PATCH 020/115] fix lint --- backend/apps/webui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index e7f0683c6..8f1d8e334 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -40,7 +40,7 @@ from config import ( ENABLE_COMMUNITY_SHARING, AppConfig, OAUTH_USERNAME_CLAIM, - OAUTH_PICTURE_CLAIM + OAUTH_PICTURE_CLAIM, ) import inspect From 647aa1966f72ebbc659a1b255045cea29eb648c6 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 2 Jul 2024 16:51:30 -0700 Subject: [PATCH 021/115] chore: format --- backend/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/config.py b/backend/config.py index 2fcc0ba64..9a189c685 100644 --- a/backend/config.py +++ b/backend/config.py @@ -766,6 +766,7 @@ class BannerModel(BaseModel): dismissible: bool timestamp: int + try: banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) banners = [BannerModel(**banner) for banner in banners] From 44a9b86eece98acbf1d25e9a00ec9104dafa2d18 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 2 Jul 2024 21:46:56 -0700 Subject: [PATCH 022/115] fix: functions --- backend/apps/webui/models/functions.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 7e3ac92cd..d5f220ce7 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -107,7 +107,7 @@ class FunctionsTable: Session.commit() Session.refresh(result) if result: - return FunctionModel.model_validate(result) + return FunctionModel(**result.model_dump()) else: return None except Exception as e: @@ -117,20 +117,19 @@ class FunctionsTable: def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: function = Session.get(Function, id) - return FunctionModel.model_validate(function) + return FunctionModel(**function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: if active_only: return [ - FunctionModel.model_validate(function) + FunctionModel(**function) for function in Session.query(Function).filter_by(is_active=True).all() ] else: return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).all() + FunctionModel(**function) for function in Session.query(Function).all() ] def get_functions_by_type( @@ -138,25 +137,23 @@ class FunctionsTable: ) -> List[FunctionModel]: if active_only: return [ - FunctionModel.model_validate(function) + FunctionModel(**function) for function in Session.query(Function) .filter_by(type=type, is_active=True) .all() ] else: return [ - FunctionModel.model_validate(function) + FunctionModel(**function) for function in Session.query(Function).filter_by(type=type).all() ] def get_global_filter_functions(self) -> List[FunctionModel]: return [ - FunctionModel(**model_to_dict(function)) - for function in Function.select().where( - Function.type == "filter", - Function.is_active == True, - Function.is_global == True, - ) + FunctionModel(**function) + for function in Session.query(Function) + .filter_by(type="filter", is_active=True, is_global=True) + .all() ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: From aa8802262486922bf74442f6d49bd34d630e2561 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 2 Jul 2024 21:50:53 -0700 Subject: [PATCH 023/115] fix: functions --- backend/apps/webui/models/functions.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index d5f220ce7..a7d06eddc 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -107,7 +107,7 @@ class FunctionsTable: Session.commit() Session.refresh(result) if result: - return FunctionModel(**result.model_dump()) + return FunctionModel(**result.__dict__) else: return None except Exception as e: @@ -117,19 +117,20 @@ class FunctionsTable: def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: function = Session.get(Function, id) - return FunctionModel(**function) + return FunctionModel(**function.__dict__) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: if active_only: return [ - FunctionModel(**function) + FunctionModel(**function.__dict__) for function in Session.query(Function).filter_by(is_active=True).all() ] else: return [ - FunctionModel(**function) for function in Session.query(Function).all() + FunctionModel(**function.__dict__) + for function in Session.query(Function).all() ] def get_functions_by_type( @@ -137,20 +138,20 @@ class FunctionsTable: ) -> List[FunctionModel]: if active_only: return [ - FunctionModel(**function) + FunctionModel(**function.__dict__) for function in Session.query(Function) .filter_by(type=type, is_active=True) .all() ] else: return [ - FunctionModel(**function) + FunctionModel(**function.__dict__) for function in Session.query(Function).filter_by(type=type).all() ] def get_global_filter_functions(self) -> List[FunctionModel]: return [ - FunctionModel(**function) + FunctionModel(**function.__dict__) for function in Session.query(Function) .filter_by(type="filter", is_active=True, is_global=True) .all() From 4d23957035a6371287d8f0416de46263d7b4f1de Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Tue, 2 Jul 2024 21:56:32 -0700 Subject: [PATCH 024/115] revert: model_validate --- backend/apps/webui/models/functions.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index a7d06eddc..64ed4f3cc 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -107,7 +107,7 @@ class FunctionsTable: Session.commit() Session.refresh(result) if result: - return FunctionModel(**result.__dict__) + return FunctionModel.model_validate(result) else: return None except Exception as e: @@ -117,19 +117,19 @@ class FunctionsTable: def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: function = Session.get(Function, id) - return FunctionModel(**function.__dict__) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: if active_only: return [ - FunctionModel(**function.__dict__) + FunctionModel.model_validate(function) for function in Session.query(Function).filter_by(is_active=True).all() ] else: return [ - FunctionModel(**function.__dict__) + FunctionModel.model_validate(function) for function in Session.query(Function).all() ] @@ -138,20 +138,20 @@ class FunctionsTable: ) -> List[FunctionModel]: if active_only: return [ - FunctionModel(**function.__dict__) + FunctionModel.model_validate(function) for function in Session.query(Function) .filter_by(type=type, is_active=True) .all() ] else: return [ - FunctionModel(**function.__dict__) + FunctionModel.model_validate(function) for function in Session.query(Function).filter_by(type=type).all() ] def get_global_filter_functions(self) -> List[FunctionModel]: return [ - FunctionModel(**function.__dict__) + FunctionModel.model_validate(function) for function in Session.query(Function) .filter_by(type="filter", is_active=True, is_global=True) .all() From 1f026a181107de92d08b4129346758c05be2afab Mon Sep 17 00:00:00 2001 From: bannert <58707896+bannert1337@users.noreply.github.com> Date: Wed, 3 Jul 2024 15:16:54 +0200 Subject: [PATCH 025/115] i18n(de_DE): added translations for new entries, updated old entries --- src/lib/i18n/locales/de-DE/translation.json | 70 ++++++++++----------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/lib/i18n/locales/de-DE/translation.json b/src/lib/i18n/locales/de-DE/translation.json index 5485eee01..2db580424 100644 --- a/src/lib/i18n/locales/de-DE/translation.json +++ b/src/lib/i18n/locales/de-DE/translation.json @@ -33,7 +33,7 @@ "admin": "Administrator", "Admin": "Administrator", "Admin Panel": "Administrationsbereich", - "Admin Settings": "Administrator-Einstellungen", + "Admin Settings": "Administrationsbereich", "Admins have access to all tools at all times; users need tools assigned per model in the workspace.": "Administratoren haben jederzeit Zugriff auf alle Werkzeuge. Benutzer können im Arbeitsbereich zugewiesen.", "Advanced Parameters": "Erweiterte Parameter", "Advanced Params": "Erweiterte Parameter", @@ -86,7 +86,7 @@ "Capabilities": "Fähigkeiten", "Change Password": "Passwort ändern", "Chat": "Gespräch", - "Chat Background Image": "Unterhaltungs-Hintergrundbild", + "Chat Background Image": "Hintergrundbild des Unterhaltungsfensters", "Chat Bubble UI": "Chat Bubble UI", "Chat direction": "Textrichtung", "Chat History": "Unterhaltungsverlauf", @@ -100,7 +100,7 @@ "Chunk Params": "Blockparameter", "Chunk Size": "Blockgröße", "Citation": "Zitate", - "Clear memory": "Erinnerungen löschen", + "Clear memory": "Alle Erinnerungen entfernen", "Click here for help.": "Klicken Sie hier für Hilfe.", "Click here to": "Klicke Sie hier, um", "Click here to download user import template file.": "Klicken Sie hier, um die Vorlage für den Benutzerimport herunterzuladen.", @@ -119,14 +119,14 @@ "ComfyUI Base URL": "ComfyUI-Basis-URL", "ComfyUI Base URL is required.": "ComfyUI-Basis-URL wird benötigt.", "Command": "Befehl", - "Concurrent Requests": "Gleichzeitige Anforderungen", + "Concurrent Requests": "Anzahl gleichzeitiger Anfragen", "Confirm": "Bestätigen", "Confirm Password": "Passwort bestätigen", "Confirm your action": "Bestätigen Sie Ihre Aktion.", "Connections": "Verbindungen", "Contact Admin for WebUI Access": "Kontaktieren Sie den Administrator für den Zugriff auf die Weboberfläche", "Content": "Info", - "Content Extraction": "", + "Content Extraction": "Inhaltsextraktion", "Context Length": "Kontextlänge", "Continue Response": "Antwort fortsetzen", "Continue with {{provider}}": "Mit {{Anbieter}} fortfahren", @@ -191,7 +191,7 @@ "Documentation": "Dokumentation", "Documents": "Dokumente", "does not make any external connections, and your data stays securely on your locally hosted server.": "stellt keine externen Verbindungen her, und Ihre Daten bleiben sicher auf Ihrem lokal gehosteten Server.", - "Don't Allow": "Nicht erlauben", + "Don't Allow": "Verbieten", "Don't have an account?": "Haben Sie noch kein Benutzerkonto?", "Don't like the style": "schlechter Schreibstil", "Done": "Erledigt", @@ -205,7 +205,7 @@ "Edit Memory": "Erinnerungen bearbeiten", "Edit User": "Benutzer bearbeiten", "Email": "E-Mail", - "Embedding Batch Size": "Embedding Batch Größe", + "Embedding Batch Size": "Embedding-Stapelgröße", "Embedding Model": "Embedding-Modell", "Embedding Model Engine": "Embedding-Modell-Engine", "Embedding model set to \"{{embedding_model}}\"": "Embedding-Modell auf \"{{embedding_model}}\" gesetzt", @@ -213,21 +213,21 @@ "Enable Community Sharing": "Community-Freigabe aktivieren", "Enable New Sign Ups": "Registrierung erlauben", "Enable Web Search": "Websuche aktivieren", - "Engine": "", + "Engine": "Engine", "Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.": "Stellen Sie sicher, dass Ihre CSV-Datei 4 Spalten in dieser Reihenfolge enthält: Name, E-Mail, Passwort, Rolle.", "Enter {{role}} message here": "Geben Sie die {{role}}-Nachricht hier ein", "Enter a detail about yourself for your LLMs to recall": "Geben Sie ein Detail über sich selbst ein, das Ihre Sprachmodelle (LLMs) sich merken sollen", "Enter api auth string (e.g. username:password)": "Geben Sie die API-Authentifizierungszeichenfolge ein (z. B. Benutzername:Passwort)", "Enter Brave Search API Key": "Geben Sie den Brave Search API-Schlüssel ein", - "Enter Chunk Overlap": "Gib den Chunk Overlap ein", - "Enter Chunk Size": "Gib die Chunk Size ein", + "Enter Chunk Overlap": "Geben Sie die Blocküberlappung ein", + "Enter Chunk Size": "Geben Sie die Blockgröße ein", "Enter Github Raw URL": "Geben Sie die Github Raw-URL ein", "Enter Google PSE API Key": "Geben Sie den Google PSE-API-Schlüssel ein", "Enter Google PSE Engine Id": "Geben Sie die Google PSE-Engine-ID ein", - "Enter Image Size (e.g. 512x512)": "Gib die Bildgröße ein (z.B. 512x512)", + "Enter Image Size (e.g. 512x512)": "Geben Sie die Bildgröße ein (z. B. 512x512)", "Enter language codes": "Geben Sie die Sprachcodes ein", - "Enter model tag (e.g. {{modelTag}})": "Gib den Model-Tag ein", - "Enter Number of Steps (e.g. 50)": "Gib die Anzahl an Schritten ein (z.B. 50)", + "Enter model tag (e.g. {{modelTag}})": "Gebn Sie den Model-Tag ein", + "Enter Number of Steps (e.g. 50)": "Geben Sie die Anzahl an Schritten ein (z. B. 50)", "Enter Score": "Punktzahl eingeben", "Enter Searxng Query URL": "Geben Sie die Searxng-Abfrage-URL ein", "Enter Serper API Key": "Geben Sie den Serper-API-Schlüssel ein", @@ -235,7 +235,7 @@ "Enter Serpstack API Key": "Geben Sie den Serpstack-API-Schlüssel ein", "Enter stop sequence": "Stop-Sequenz eingeben", "Enter Tavily API Key": "Geben Sie den Tavily-API-Schlüssel ein", - "Enter Tika Server URL": "", + "Enter Tika Server URL": "Geben Sie die Tika-Server-URL ein", "Enter Top K": "Geben Sie Top K ein", "Enter URL (e.g. http://127.0.0.1:7860/)": "Geben Sie die URL ein (z. B. http://127.0.0.1:7860/)", "Enter URL (e.g. http://localhost:11434)": "Geben Sie die URL ein (z. B. http://localhost:11434)", @@ -300,7 +300,7 @@ "Image Generation Engine": "Bildgenerierungs-Engine", "Image Settings": "Bildeinstellungen", "Images": "Bilder", - "Import Chats": "Chats importieren", + "Import Chats": "Unterhaltungen importieren", "Import Documents Mapping": "Dokumentenzuordnung importieren", "Import Functions": "Funktionen importieren", "Import Models": "Modelle importieren", @@ -315,7 +315,7 @@ "Interface": "Benutzeroberfläche", "Invalid Tag": "Ungültiger Tag", "January": "Januar", - "join our Discord for help.": "Trete unserem Discord bei, um Hilfe zu erhalten.", + "join our Discord for help.": "Treten Sie unserem Discord bei, um Hilfe zu erhalten.", "JSON": "JSON", "JSON Preview": "JSON-Vorschau", "July": "Juli", @@ -345,7 +345,7 @@ "Maximum of 3 models can be downloaded simultaneously. Please try again later.": "Es können maximal 3 Modelle gleichzeitig heruntergeladen werden. Bitte versuchen Sie es später erneut.", "May": "Mai", "Memories accessible by LLMs will be shown here.": "Erinnerungen, die für Modelle zugänglich sind, werden hier angezeigt.", - "Memory": "Erinnerung", + "Memory": "Erinnerungen", "Memory added successfully": "Erinnerung erfolgreich hinzugefügt", "Memory cleared successfully": "Erinnerung erfolgreich gelöscht", "Memory deleted successfully": "Erinnerung erfolgreich gelöscht", @@ -377,7 +377,7 @@ "Name": "Name", "Name Tag": "Namens-Tag", "Name your model": "Benennen Sie Ihr Modell", - "New Chat": "Neuer Chat", + "New Chat": "Neue Unterhaltung", "New Password": "Neues Passwort", "No content to speak": "Kein Inhalt zum Vorlesen", "No documents found": "Keine Dokumente gefunden", @@ -412,7 +412,7 @@ "Open": "Öffne", "Open AI (Dall-E)": "Open AI (Dall-E)", "Open new chat": "Neuen Chat öffnen", - "Open WebUI version (v{{OPEN_WEBUI_VERSION}}) is lower than required version (v{{REQUIRED_VERSION}})": "", + "Open WebUI version (v{{OPEN_WEBUI_VERSION}}) is lower than required version (v{{REQUIRED_VERSION}})": "Die installierte Open-WebUI-Version (v{{OPEN_WEBUI_VERSION}}) ist niedriger als die erforderliche Version (v{{REQUIRED_VERSION}})", "OpenAI": "OpenAI", "OpenAI API": "OpenAI-API", "OpenAI API Config": "OpenAI-API-Konfiguration", @@ -428,8 +428,8 @@ "Permission denied when accessing microphone": "Zugriff auf das Mikrofon verweigert", "Permission denied when accessing microphone: {{error}}": "Zugriff auf das Mikrofon verweigert: {{error}}", "Personalization": "Personalisierung", - "Pin": "", - "Pinned": "", + "Pin": "Anheften", + "Pinned": "Angeheftet", "Pipeline deleted successfully": "Pipeline erfolgreich gelöscht", "Pipeline downloaded successfully": "Pipeline erfolgreich heruntergeladen", "Pipelines": "Pipelines", @@ -447,7 +447,7 @@ "Prompt suggestions": "Prompt-Vorschläge", "Prompts": "Prompts", "Pull \"{{searchValue}}\" from Ollama.com": "\"{{searchValue}}\" von Ollama.com beziehen", - "Pull a model from Ollama.com": "Modell von Ollama.com beziehn", + "Pull a model from Ollama.com": "Modell von Ollama.com beziehen", "Query Params": "Abfrageparameter", "RAG Template": "RAG-Vorlage", "Read Aloud": "Vorlesen", @@ -489,10 +489,10 @@ "Search Functions": "Funktionen durchsuchen...", "Search Models": "Modelle durchsuchen...", "Search Prompts": "Prompts durchsuchen...", - "Search Query Generation Prompt": "Suchanfragen-Generierungs-Prompt", - "Search Query Generation Prompt Length Threshold": "Suchanfragen-Generierungs-Prompt-Längenschwellenwert", + "Search Query Generation Prompt": "Suchanfragengenerierungsvorlage", + "Search Query Generation Prompt Length Threshold": "Längenschwelle für Suchanfragengenerierung", "Search Result Count": "Anzahl der Suchergebnisse", - "Search Tools": "Suchwerkzeuge", + "Search Tools": "Werkzeuge durchsuchen...", "Searched {{count}} sites_one": "{{count}} Seite durchsucht", "Searched {{count}} sites_other": "{{count}} Seiten durchsucht", "Searching \"{{searchQuery}}\"": "Suche nach \"{{searchQuery}}\"", @@ -579,22 +579,22 @@ "This setting does not sync across browsers or devices.": "Diese Einstellung wird nicht zwischen Browsern oder Geräten synchronisiert.", "This will delete": "Dies löscht", "Thorough explanation": "Ausführliche Erklärung", - "Tika": "", - "Tika Server URL required.": "", + "Tika": "Tika", + "Tika Server URL required.": "Tika-Server-URL erforderlich.", "Tip: Update multiple variable slots consecutively by pressing the tab key in the chat input after each replacement.": "Tipp: Aktualisieren Sie mehrere Variablenfelder nacheinander, indem Sie nach jedem Ersetzen die Tabulatortaste im Eingabefeld der Unterhaltung drücken.", "Title": "Titel", "Title (e.g. Tell me a fun fact)": "Titel (z. B. Erzähl mir einen lustigen Fakt)", - "Title Auto-Generation": "Automatische Titelerstellung", + "Title Auto-Generation": "Unterhaltungstitel automatisch generieren", "Title cannot be an empty string.": "Titel darf nicht leer sein.", - "Title Generation Prompt": "Titelerstellung-Prompt", + "Title Generation Prompt": "Prompt für Titelgenerierung", "to": "für", "To access the available model names for downloading,": "Um auf die verfügbaren Modellnamen zuzugreifen,", "To access the GGUF models available for downloading,": "Um auf die verfügbaren GGUF-Modelle zuzugreifen,", - "To access the WebUI, please reach out to the administrator. Admins can manage user statuses from the Admin Panel.": "Um auf das WebUI zugreifen zu könnrn, wenden Sie sich bitte an einen Administrator. Administratoren können den Benutzerstatus über das Admin-Panel verwalten.", + "To access the WebUI, please reach out to the administrator. Admins can manage user statuses from the Admin Panel.": "Um auf das WebUI zugreifen zu können, wenden Sie sich bitte an einen Administrator. Administratoren können den Benutzerstatus über das Admin-Panel verwalten.", "To add documents here, upload them to the \"Documents\" workspace first.": "Um Dokumente hinzuzufügen, laden Sie sie zuerst im Arbeitsbereich „Dokumente“ hoch.", "to chat input.": "zum Eingabefeld der Unterhaltung.", - "To select filters here, add them to the \"Functions\" workspace first.": "Um Filter auszuwählen, fügen Sie diese zuerst dem Arbeitsbereich „Funktionen“ hinzu.", - "To select toolkits here, add them to the \"Tools\" workspace first.": "Um Toolkits auszuwählen, fügen Sie sie zuerst zum Arbeitsbereich „Werkzeuge“ hinzu.", + "To select filters here, add them to the \"Functions\" workspace first.": "Um Filter auszuwählen, fügen Sie diese zunächst dem Arbeitsbereich „Funktionen“ hinzu.", + "To select toolkits here, add them to the \"Tools\" workspace first.": "Um Toolkits auszuwählen, fügen Sie sie zunächst dem Arbeitsbereich „Werkzeuge“ hinzu.", "Today": "Heute", "Toggle settings": "Einstellungen umschalten", "Toggle sidebar": "Seitenleiste umschalten", @@ -611,11 +611,11 @@ "TTS Settings": "TTS-Einstellungen", "TTS Voice": "TTS-Stimme", "Type": "Art", - "Type Hugging Face Resolve (Download) URL": "Gib die Hugging Face Resolve (Download) URL ein", + "Type Hugging Face Resolve (Download) URL": "Geben Sie die Hugging Face Resolve-URL ein", "Uh-oh! There was an issue connecting to {{provider}}.": "Ups! Es gab ein Problem bei der Verbindung mit {{provider}}.", "UI": "Oberfläche", "Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.": "Unbekannter Dateityp '{{file_type}}'. Der Datei-Upload wird trotzdem fortgesetzt.", - "Unpin": "", + "Unpin": "Lösen", "Update": "Aktualisieren", "Update and Copy Link": "Aktualisieren und Link kopieren", "Update password": "Passwort aktualisieren", @@ -654,7 +654,7 @@ "Web Search Engine": "Suchmaschine", "Webhook URL": "Webhook URL", "WebUI Settings": "WebUI-Einstellungen", - "WebUI will make requests to": "Wenn aktiviert sendet WebUI externe Anfragen an", + "WebUI will make requests to": "WebUI sendet Anfragen an:", "What’s New in": "Neuigkeiten von", "When history is turned off, new chats on this browser won't appear in your history on any of your devices.": "Wenn der Verlauf deaktiviert ist, werden neue Unterhaltungen in diesem Browser nicht im Verlauf Ihrer anderen Geräte erscheinen.", "Whisper (Local)": "Whisper (lokal)", From 3cd0b1077dde10455b2f010d3a5f7a7a03326ffb Mon Sep 17 00:00:00 2001 From: Morgan Blangeois Date: Wed, 3 Jul 2024 17:09:53 +0200 Subject: [PATCH 026/115] i18n: Improve French translations --- src/lib/i18n/locales/fr-FR/translation.json | 924 ++++++++++---------- 1 file changed, 462 insertions(+), 462 deletions(-) diff --git a/src/lib/i18n/locales/fr-FR/translation.json b/src/lib/i18n/locales/fr-FR/translation.json index b088eeb30..5d358dc7e 100644 --- a/src/lib/i18n/locales/fr-FR/translation.json +++ b/src/lib/i18n/locales/fr-FR/translation.json @@ -1,668 +1,668 @@ { - "'s', 'm', 'h', 'd', 'w' or '-1' for no expiration.": "'s', 'm', 'h', 'd', 'w' ou '-1' pour aucune expiration.", - "(Beta)": "(Bêta)", - "(e.g. `sh webui.sh --api --api-auth username_password`)": "", - "(e.g. `sh webui.sh --api`)": "(par ex. `sh webui.sh --api`)", - "(latest)": "(plus récent)", - "{{ models }}": "{{ models }}", - "{{ owner }}: You cannot delete a base model": "{{ owner }}: Vous ne pouvez pas supprimer un modèle de base", - "{{modelName}} is thinking...": "{{modelName}} réfléchit...", - "{{user}}'s Chats": "Chats de {{user}}", + "'s', 'm', 'h', 'd', 'w' or '-1' for no expiration.": " 's', 'm', 'h', 'd', 'w' ou '-1' pour une durée illimitée.", + "(Beta)": "(Version bêta)", + "(e.g. `sh webui.sh --api --api-auth username_password`)": "(par ex. `sh webui.sh --api --api-auth username_password`)", + "(e.g. `sh webui.sh --api`)": "(par exemple `sh webui.sh --api`)", + "(latest)": "(dernier)", + "{{ models }}": "{{ modèles }}", + "{{ owner }}: You cannot delete a base model": "{{ propriétaire }} : Vous ne pouvez pas supprimer un modèle de base.", + "{{modelName}} is thinking...": "{{modelName}} est en train de réfléchir...", + "{{user}}'s Chats": "Discussions de {{user}}", "{{webUIName}} Backend Required": "Backend {{webUIName}} requis", - "A task model is used when performing tasks such as generating titles for chats and web search queries": "Un modèle de tâche est utilisé lors de l’exécution de tâches telles que la génération de titres pour les chats et les requêtes de recherche sur le Web", + "A task model is used when performing tasks such as generating titles for chats and web search queries": "Un modèle de tâche est utilisé lors de l’exécution de tâches telles que la génération de titres pour les conversations et les requêtes de recherche sur le web.", "a user": "un utilisateur", - "About": "À Propos", + "About": "À propos", "Account": "Compte", - "Account Activation Pending": "", - "Accurate information": "Information précise", - "Active Users": "", + "Account Activation Pending": "Activation du compte en attente", + "Accurate information": "Information exacte", + "Active Users": "Utilisateurs actifs", "Add": "Ajouter", - "Add a model id": "Ajouter un identifiant modèle", - "Add a short description about what this model does": "Ajouter une courte description de ce que fait ce modèle", - "Add a short title for this prompt": "Ajouter un court titre pour ce prompt", - "Add a tag": "Ajouter un tag", - "Add custom prompt": "Ajouter un prompt personnalisé", - "Add Docs": "Ajouter des Documents", - "Add Files": "Ajouter des Fichiers", - "Add Memory": "Ajouter de la Mémoire", + "Add a model id": "Ajouter un identifiant de modèle", + "Add a short description about what this model does": "Ajoutez une brève description de ce que fait ce modèle.", + "Add a short title for this prompt": "Ajoutez un bref titre pour cette prompt.", + "Add a tag": "Ajouter une balise", + "Add custom prompt": "Ajouter une prompt personnalisée", + "Add Docs": "Ajouter de la documentation", + "Add Files": "Ajouter des fichiers", + "Add Memory": "Ajouter de la mémoire", "Add message": "Ajouter un message", - "Add Model": "Ajouter un Modèle", - "Add Tags": "Ajouter des Tags", + "Add Model": "Ajouter un modèle", + "Add Tags": "Ajouter des balises", "Add User": "Ajouter un Utilisateur", - "Adjusting these settings will apply changes universally to all users.": "L'ajustement de ces paramètres appliquera les changements à tous les utilisateurs.", - "admin": "admin", - "Admin": "", - "Admin Panel": "Panneau d'Administration", - "Admin Settings": "Paramètres d'Administration", - "Admins have access to all tools at all times; users need tools assigned per model in the workspace.": "", - "Advanced Parameters": "Paramètres Avancés", - "Advanced Params": "Params Avancés", - "all": "tous", - "All Documents": "Tous les Documents", + "Adjusting these settings will apply changes universally to all users.": "L'ajustement de ces paramètres appliquera universellement les changements à tous les utilisateurs.", + "admin": "administrateur", + "Admin": "Administrateur", + "Admin Panel": "Tableau de bord administrateur", + "Admin Settings": "Paramètres d'administration", + "Admins have access to all tools at all times; users need tools assigned per model in the workspace.": "Les administrateurs ont accès à tous les outils en tout temps ; les utilisateurs ont besoin d'outils affectés par modèle dans l'espace de travail.", + "Advanced Parameters": "Paramètres avancés", + "Advanced Params": "Paramètres avancés", + "all": "toutes", + "All Documents": "Tous les documents", "All Users": "Tous les Utilisateurs", "Allow": "Autoriser", - "Allow Chat Deletion": "Autoriser la suppression du chat", - "Allow non-local voices": "", - "Allow User Location": "", - "Allow Voice Interruption in Call": "", + "Allow Chat Deletion": "Autoriser la suppression de l'historique de chat", + "Allow non-local voices": "Autoriser les voix non locales", + "Allow User Location": "Autoriser l'emplacement de l'utilisateur", + "Allow Voice Interruption in Call": "Autoriser l'interruption vocale pendant un appel", "alphanumeric characters and hyphens": "caractères alphanumériques et tirets", - "Already have an account?": "Vous avez déjà un compte ?", + "Already have an account?": "Avez-vous déjà un compte ?", "an assistant": "un assistant", "and": "et", "and create a new shared link.": "et créer un nouveau lien partagé.", "API Base URL": "URL de base de l'API", - "API Key": "Clé API", - "API Key created.": "Clé d'API créée.", - "API keys": "Clés API", + "API Key": "Clé d'API", + "API Key created.": "Clé d'API générée.", + "API keys": "Clés d'API", "April": "Avril", - "Archive": "Archiver", - "Archive All Chats": "Archiver toutes les discussions", - "Archived Chats": "Chats Archivés", - "are allowed - Activate this command by typing": "sont autorisés - Activez cette commande en tapant", - "Are you sure?": "Êtes-vous sûr ?", - "Attach file": "Joindre un fichier", + "Archive": "Archivage", + "Archive All Chats": "Archiver toutes les conversations", + "Archived Chats": "Conversations archivées", + "are allowed - Activate this command by typing": "sont autorisés - Activer cette commande en tapant", + "Are you sure?": "Êtes-vous certain ?", + "Attach file": "Joindre un document", "Attention to detail": "Attention aux détails", "Audio": "Audio", - "Audio settings updated successfully": "", + "Audio settings updated successfully": "Les paramètres audio ont été mis à jour avec succès", "August": "Août", - "Auto-playback response": "Réponse en lecture automatique", - "AUTOMATIC1111 Api Auth String": "", + "Auto-playback response": "Réponse de lecture automatique", + "AUTOMATIC1111 Api Auth String": "AUTOMATIC1111 Chaîne d'authentification de l'API", "AUTOMATIC1111 Base URL": "URL de base AUTOMATIC1111", - "AUTOMATIC1111 Base URL is required.": "L'URL de base AUTOMATIC1111 est requise.", + "AUTOMATIC1111 Base URL is required.": "L'URL de base {AUTOMATIC1111} est requise.", "available!": "disponible !", - "Back": "Retour", - "Bad Response": "Mauvaise Réponse", - "Banners": "Bannières", - "Base Model (From)": "Modèle de Base (De)", - "Batch Size (num_batch)": "", + "Back": "Retour en arrière", + "Bad Response": "Mauvaise réponse", + "Banners": "Banniers", + "Base Model (From)": "Modèle de base (à partir de)", + "Batch Size (num_batch)": "Taille du lot (num_batch)", "before": "avant", - "Being lazy": "Est paresseux", + "Being lazy": "Être fainéant", "Brave Search API Key": "Clé API Brave Search", - "Bypass SSL verification for Websites": "Contourner la vérification SSL pour les sites Web.", - "Call": "", - "Call feature is not supported when using Web STT engine": "", - "Camera": "", + "Bypass SSL verification for Websites": "Bypasser la vérification SSL pour les sites web", + "Call": "Appeler", + "Call feature is not supported when using Web STT engine": "La fonction d'appel n'est pas prise en charge lors de l'utilisation du moteur Web STT", + "Camera": "Appareil photo", "Cancel": "Annuler", "Capabilities": "Capacités", "Change Password": "Changer le mot de passe", "Chat": "Chat", - "Chat Background Image": "", - "Chat Bubble UI": "UI Bulles de Chat", + "Chat Background Image": "Image d'arrière-plan de la fenêtre de chat", + "Chat Bubble UI": "Bulles de discussion", "Chat direction": "Direction du chat", - "Chat History": "Historique du chat", - "Chat History is off for this browser.": "L'historique du chat est désactivé pour ce navigateur.", - "Chats": "Chats", - "Check Again": "Vérifier à nouveau", - "Check for updates": "Vérifier les mises à jour", - "Checking for updates...": "Vérification des mises à jour...", - "Choose a model before saving...": "Choisissez un modèle avant d'enregistrer...", - "Chunk Overlap": "Chevauchement de bloc", - "Chunk Params": "Paramètres de bloc", + "Chat History": "Historique de discussion", + "Chat History is off for this browser.": "L'historique de chat est désactivé pour ce navigateur", + "Chats": "Conversations", + "Check Again": "Vérifiez à nouveau.", + "Check for updates": "Vérifier les mises à jour disponibles", + "Checking for updates...": "Recherche de mises à jour...", + "Choose a model before saving...": "Choisissez un modèle avant de sauvegarder...", + "Chunk Overlap": "Chevauchement de blocs", + "Chunk Params": "Paramètres d'encombrement", "Chunk Size": "Taille de bloc", "Citation": "Citation", - "Clear memory": "", - "Click here for help.": "Cliquez ici pour de l'aide.", + "Clear memory": "Libérer la mémoire", + "Click here for help.": "Cliquez ici pour obtenir de l'aide.", "Click here to": "Cliquez ici pour", - "Click here to download user import template file.": "", + "Click here to download user import template file.": "Cliquez ici pour télécharger le fichier modèle d'importation utilisateur.", "Click here to select": "Cliquez ici pour sélectionner", - "Click here to select a csv file.": "Cliquez ici pour sélectionner un fichier csv.", - "Click here to select a py file.": "", - "Click here to select documents.": "Cliquez ici pour sélectionner des documents.", + "Click here to select a csv file.": "Cliquez ici pour sélectionner un fichier CSV.", + "Click here to select a py file.": "Cliquez ici pour sélectionner un fichier .py.", + "Click here to select documents.": "Cliquez ici pour sélectionner les documents.", "click here.": "cliquez ici.", - "Click on the user role button to change a user's role.": "Cliquez sur le bouton de rôle d'utilisateur pour changer le rôle d'un utilisateur.", - "Clipboard write permission denied. Please check your browser settings to grant the necessary access.": "", - "Clone": "Clone", + "Click on the user role button to change a user's role.": "Cliquez sur le bouton de rôle d'utilisateur pour modifier le rôle d'un utilisateur.", + "Clipboard write permission denied. Please check your browser settings to grant the necessary access.": "L'autorisation d'écriture du presse-papier a été refusée. Veuillez vérifier les paramètres de votre navigateur pour accorder l'accès nécessaire.", + "Clone": "Copie conforme", "Close": "Fermer", - "Code formatted successfully": "", + "Code formatted successfully": "Le code a été formaté avec succès", "Collection": "Collection", "ComfyUI": "ComfyUI", "ComfyUI Base URL": "URL de base ComfyUI", "ComfyUI Base URL is required.": "L'URL de base ComfyUI est requise.", "Command": "Commande", - "Concurrent Requests": "Demandes simultanées", - "Confirm": "", + "Concurrent Requests": "Demandes concurrentes", + "Confirm": "Confirmer", "Confirm Password": "Confirmer le mot de passe", - "Confirm your action": "", + "Confirm your action": "Confirmez votre action", "Connections": "Connexions", - "Contact Admin for WebUI Access": "", + "Contact Admin for WebUI Access": "Contacter l'administrateur pour l'accès à l'interface Web", "Content": "Contenu", "Context Length": "Longueur du contexte", - "Continue Response": "Continuer la Réponse", - "Continue with {{provider}}": "", - "Copied shared chat URL to clipboard!": "URL du chat copié dans le presse-papiers !", - "Copy": "Copier", + "Continue Response": "Continuer la réponse", + "Continue with {{provider}}": "Continuer avec {{provider}}", + "Copied shared chat URL to clipboard!": "URL du chat copiée dans le presse-papiers !", + "Copy": "Copie", "Copy last code block": "Copier le dernier bloc de code", "Copy last response": "Copier la dernière réponse", - "Copy Link": "Copier le Lien", + "Copy Link": "Copier le lien", "Copying to clipboard was successful!": "La copie dans le presse-papiers a réussi !", "Create a model": "Créer un modèle", "Create Account": "Créer un compte", - "Create new key": "Créer une nouvelle clé", + "Create new key": "Créer une nouvelle clé principale", "Create new secret key": "Créer une nouvelle clé secrète", - "Created at": "Créé le", - "Created At": "Crée Le", - "Created by": "", - "CSV Import": "", - "Current Model": "Modèle actuel", + "Created at": "Créé à", + "Created At": "Créé le", + "Created by": "Créé par", + "CSV Import": "Import CSV", + "Current Model": "Modèle actuel amélioré", "Current Password": "Mot de passe actuel", - "Custom": "Personnalisé", - "Customize models for a specific purpose": "Personnaliser les modèles pour un objectif spécifique", - "Dark": "Sombre", - "Dashboard": "", + "Custom": "Sur mesure", + "Customize models for a specific purpose": "Personnaliser les modèles pour une fonction spécifique", + "Dark": "Obscur", + "Dashboard": "Tableau de bord", "Database": "Base de données", "December": "Décembre", "Default": "Par défaut", "Default (Automatic1111)": "Par défaut (Automatic1111)", - "Default (SentenceTransformers)": "Par défaut (SentenceTransformers)", - "Default Model": "Modèle par défaut", + "Default (SentenceTransformers)": "Par défaut (Sentence Transformers)", + "Default Model": "Modèle standard", "Default model updated": "Modèle par défaut mis à jour", - "Default Prompt Suggestions": "Suggestions de prompt par défaut", - "Default User Role": "Rôle d'utilisateur par défaut", + "Default Prompt Suggestions": "Suggestions de prompts par défaut", + "Default User Role": "Rôle utilisateur par défaut", "delete": "supprimer", "Delete": "Supprimer", "Delete a model": "Supprimer un modèle", - "Delete All Chats": "Supprimer toutes les discussions", - "Delete chat": "Supprimer le chat", - "Delete Chat": "Supprimer le Chat", - "Delete chat?": "", - "Delete function?": "", - "Delete prompt?": "", + "Delete All Chats": "Supprimer toutes les conversations", + "Delete chat": "Supprimer la conversation", + "Delete Chat": "Supprimer la Conversation", + "Delete chat?": "Supprimer la conversation ?", + "Delete function?": "Supprimer la fonction ?", + "Delete prompt?": "Supprimer la prompt ?", "delete this link": "supprimer ce lien", - "Delete tool?": "", - "Delete User": "Supprimer l'Utilisateur", - "Deleted {{deleteModelTag}}": "{{deleteModelTag}} supprimé", - "Deleted {{name}}": "{{name}} supprimé", + "Delete tool?": "Effacer l'outil ?", + "Delete User": "Supprimer le compte d'utilisateur", + "Deleted {{deleteModelTag}}": "Supprimé {{deleteModelTag}}", + "Deleted {{name}}": "Supprimé {{name}}", "Description": "Description", - "Didn't fully follow instructions": "N'a pas suivi entièrement les instructions", - "Discover a function": "", + "Didn't fully follow instructions": "N'a pas entièrement respecté les instructions", + "Discover a function": "Découvrez une fonction", "Discover a model": "Découvrir un modèle", - "Discover a prompt": "Découvrir un prompt", - "Discover a tool": "", - "Discover, download, and explore custom functions": "", - "Discover, download, and explore custom prompts": "Découvrir, télécharger et explorer des prompts personnalisés", - "Discover, download, and explore custom tools": "", - "Discover, download, and explore model presets": "Découvrir, télécharger et explorer des préconfigurations de modèles", - "Dismissible": "", - "Display Emoji in Call": "", - "Display the username instead of You in the Chat": "Afficher le nom d'utilisateur au lieu de 'Vous' dans le Chat", + "Discover a prompt": "Découvrir une suggestion", + "Discover a tool": "Découvrez un outil", + "Discover, download, and explore custom functions": "Découvrez, téléchargez et explorez des fonctions personnalisées", + "Discover, download, and explore custom prompts": "Découvrez, téléchargez et explorez des prompts personnalisés", + "Discover, download, and explore custom tools": "Découvrez, téléchargez et explorez des outils personnalisés", + "Discover, download, and explore model presets": "Découvrir, télécharger et explorer des préréglages de modèles", + "Dismissible": "Fermeture", + "Display Emoji in Call": "Afficher les emojis pendant l'appel", + "Display the username instead of You in the Chat": "Afficher le nom d'utilisateur à la place de \"Vous\" dans le Chat", "Document": "Document", "Document Settings": "Paramètres du document", - "Documentation": "", + "Documentation": "Documentation", "Documents": "Documents", - "does not make any external connections, and your data stays securely on your locally hosted server.": "ne fait aucune connexion externe, et vos données restent en sécurité sur votre serveur hébergé localement.", + "does not make any external connections, and your data stays securely on your locally hosted server.": "ne fait aucune connexion externe et garde vos données en sécurité sur votre serveur local.", "Don't Allow": "Ne pas autoriser", "Don't have an account?": "Vous n'avez pas de compte ?", - "Don't like the style": "N'aime pas le style", - "Done": "", + "Don't like the style": "N'apprécie pas le style", + "Done": "Terminé", "Download": "Télécharger", "Download canceled": "Téléchargement annulé", "Download Database": "Télécharger la base de données", "Drop any files here to add to the conversation": "Déposez des fichiers ici pour les ajouter à la conversation", - "e.g. '30s','10m'. Valid time units are 's', 'm', 'h'.": "par ex. '30s', '10m'. Les unités de temps valides sont 's', 'm', 'h'.", - "Edit": "Éditer", - "Edit Doc": "Éditer le document", - "Edit Memory": "", - "Edit User": "Éditer l'utilisateur", - "Email": "Email", - "Embedding Batch Size": "", - "Embedding Model": "Modèle pour l'Embedding", - "Embedding Model Engine": "Moteur du Modèle d'Embedding", - "Embedding model set to \"{{embedding_model}}\"": "Modèle d'embedding défini sur \"{{embedding_model}}\"", - "Enable Chat History": "Activer l'historique du chat", - "Enable Community Sharing": "Activer le partage de communauté", + "e.g. '30s','10m'. Valid time units are 's', 'm', 'h'.": "par ex. '30s', '10 min'. Les unités de temps valides sont 's', 'm', 'h'.", + "Edit": "Modifier", + "Edit Doc": "Modifier le document", + "Edit Memory": "Modifier la mémoire", + "Edit User": "Modifier l'utilisateur", + "Email": "E-mail", + "Embedding Batch Size": "Taille du lot d'encodage", + "Embedding Model": "Modèle d'embedding", + "Embedding Model Engine": "Moteur de modèle d'encodage", + "Embedding model set to \"{{embedding_model}}\"": "Modèle d'encodage défini sur « {{embedding_model}} »", + "Enable Chat History": "Activer l'historique de conversation", + "Enable Community Sharing": "Activer le partage communautaire", "Enable New Sign Ups": "Activer les nouvelles inscriptions", - "Enable Web Search": "Activer la recherche sur le Web", - "Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.": "Vérifiez que le fichier CSV contienne 4 colonnes dans cet ordre : Name (Nom), Email, Password (Mot de passe), Role (Rôle).", + "Enable Web Search": "Activer la recherche web", + "Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.": "Vérifiez que votre fichier CSV comprenne les 4 colonnes dans cet ordre : Name, Email, Password, Role.", "Enter {{role}} message here": "Entrez le message {{role}} ici", - "Enter a detail about yourself for your LLMs to recall": "Saisissez une donnée vous concernant pour que vos LLMs s'en souviennent", - "Enter api auth string (e.g. username:password)": "", + "Enter a detail about yourself for your LLMs to recall": "Saisissez un détail sur vous-même que vos LLMs pourront se rappeler", + "Enter api auth string (e.g. username:password)": "Entrez la chaîne d'authentification de l'API (par ex. nom d'utilisateur:mot de passe)", "Enter Brave Search API Key": "Entrez la clé API Brave Search", - "Enter Chunk Overlap": "Entrez le chevauchement de bloc", - "Enter Chunk Size": "Entrez la taille du bloc", - "Enter Github Raw URL": "Entrez l’URL brute Github", + "Enter Chunk Overlap": "Entrez le chevauchement de chunk", + "Enter Chunk Size": "Entrez la taille de bloc", + "Enter Github Raw URL": "Entrez l'URL brute de GitHub", "Enter Google PSE API Key": "Entrez la clé API Google PSE", - "Enter Google PSE Engine Id": "Entrez l’ID du moteur Google PSE", - "Enter Image Size (e.g. 512x512)": "Entrez la taille de l'image (p. ex. 512x512)", - "Enter language codes": "Entrez les codes du language", - "Enter model tag (e.g. {{modelTag}})": "Entrez le tag du modèle (p. ex. {{modelTag}})", - "Enter Number of Steps (e.g. 50)": "Entrez le nombre d'étapes (p. ex. 50)", - "Enter Score": "Entrez le Score", - "Enter Searxng Query URL": "Entrez l’URL de la requête Searxng", + "Enter Google PSE Engine Id": "Entrez l'identifiant du moteur Google PSE", + "Enter Image Size (e.g. 512x512)": "Entrez la taille de l'image (par ex. 512x512)", + "Enter language codes": "Entrez les codes de langue", + "Enter model tag (e.g. {{modelTag}})": "Entrez l'étiquette du modèle (par ex. {{modelTag}})", + "Enter Number of Steps (e.g. 50)": "Entrez le nombre de pas (par ex. 50)", + "Enter Score": "Entrez votre score", + "Enter Searxng Query URL": "Entrez l'URL de la requête Searxng", "Enter Serper API Key": "Entrez la clé API Serper", - "Enter Serply API Key": "", + "Enter Serply API Key": "Entrez la clé API Serply", "Enter Serpstack API Key": "Entrez la clé API Serpstack", - "Enter stop sequence": "Entrez la séquence de fin", - "Enter Tavily API Key": "", - "Enter Top K": "Entrez Top K", - "Enter URL (e.g. http://127.0.0.1:7860/)": "Entrez l'URL (p. ex. http://127.0.0.1:7860/)", - "Enter URL (e.g. http://localhost:11434)": "Entrez l'URL (p. ex. http://localhost:11434)", - "Enter Your Email": "Entrez Votre Email", - "Enter Your Full Name": "Entrez Votre Nom Complet", - "Enter Your Password": "Entrez Votre Mot De Passe", - "Enter Your Role": "Entrez Votre Rôle", + "Enter stop sequence": "Entrez la séquence d'arrêt", + "Enter Tavily API Key": "Entrez la clé API Tavily", + "Enter Top K": "Entrez les Top K", + "Enter URL (e.g. http://127.0.0.1:7860/)": "Entrez l'URL (par ex. {http://127.0.0.1:7860/})", + "Enter URL (e.g. http://localhost:11434)": "Entrez l'URL (par ex. http://localhost:11434)", + "Enter Your Email": "Entrez votre adresse e-mail", + "Enter Your Full Name": "Entrez votre nom complet", + "Enter Your Password": "Entrez votre mot de passe", + "Enter Your Role": "Entrez votre rôle", "Error": "Erreur", "Experimental": "Expérimental", "Export": "Exportation", - "Export All Chats (All Users)": "Exporter Tous les Chats (Tous les Utilisateurs)", - "Export chat (.json)": "", - "Export Chats": "Exporter les Chats", - "Export Documents Mapping": "Exporter la Correspondance des Documents", - "Export Functions": "", - "Export LiteLLM config.yaml": "", - "Export Models": "Exporter les Modèles", + "Export All Chats (All Users)": "Exporter toutes les conversations (tous les utilisateurs)", + "Export chat (.json)": "Exporter la discussion (.json)", + "Export Chats": "Exporter les conversations", + "Export Documents Mapping": "Exportez la correspondance des documents", + "Export Functions": "Exportez les Fonctions", + "Export LiteLLM config.yaml": "Exportez le fichier LiteLLM config.yaml", + "Export Models": "Exporter les modèles", "Export Prompts": "Exporter les Prompts", - "Export Tools": "", - "External Models": "", - "Failed to create API Key.": "Échec de la création de la clé d'API.", + "Export Tools": "Outils d'exportation", + "External Models": "Modèles externes", + "Failed to create API Key.": "Échec de la création de la clé API.", "Failed to read clipboard contents": "Échec de la lecture du contenu du presse-papiers", - "Failed to update settings": "", + "Failed to update settings": "Échec de la mise à jour des paramètres", "February": "Février", "Feel free to add specific details": "N'hésitez pas à ajouter des détails spécifiques", - "File": "", - "File Mode": "Mode Fichier", - "File not found.": "Fichier non trouvé.", - "Filter is now globally disabled": "", - "Filter is now globally enabled": "", - "Filters": "", - "Fingerprint spoofing detected: Unable to use initials as avatar. Defaulting to default profile image.": "Usurpation d'empreinte digitale détectée : Impossible d'utiliser les initiales comme avatar. L'image de profil par défaut sera utilisée.", - "Fluidly stream large external response chunks": "Diffusez de manière fluide de gros morceaux de réponses externes", - "Focus chat input": "Concentrer sur l'entrée du chat", - "Followed instructions perfectly": "A suivi les instructions parfaitement", - "Form": "", - "Format your variables using square brackets like this:": "Formatez vos variables en utilisant des crochets comme ceci :", + "File": "Fichier", + "File Mode": "Mode fichier", + "File not found.": "Fichier introuvable.", + "Filter is now globally disabled": "Le filtre est maintenant désactivé globalement", + "Filter is now globally enabled": "Le filtre est désormais activé globalement", + "Filters": "Filtres", + "Fingerprint spoofing detected: Unable to use initials as avatar. Defaulting to default profile image.": "Spoofing détecté : impossible d'utiliser les initiales comme avatar. Retour à l'image de profil par défaut.", + "Fluidly stream large external response chunks": "Diffuser de manière fluide de larges portions de réponses externes", + "Focus chat input": "Se concentrer sur le chat en entrée", + "Followed instructions perfectly": "A parfaitement suivi les instructions", + "Form": "Formulaire", + "Format your variables using square brackets like this:": "Formatez vos variables en utilisant des crochets comme suit :", "Frequency Penalty": "Pénalité de fréquence", - "Function created successfully": "", - "Function deleted successfully": "", - "Function updated successfully": "", - "Functions": "", - "Functions imported successfully": "", + "Function created successfully": "La fonction a été créée avec succès", + "Function deleted successfully": "Fonction supprimée avec succès", + "Function updated successfully": "La fonction a été mise à jour avec succès", + "Functions": "Fonctions", + "Functions imported successfully": "Fonctions importées avec succès", "General": "Général", "General Settings": "Paramètres Généraux", - "Generate Image": "", - "Generating search query": "Génération d’une requête de recherche", - "Generation Info": "Informations de la Génération", - "Global": "", - "Good Response": "Bonne Réponse", + "Generate Image": "Générer une image", + "Generating search query": "Génération d'une requête de recherche", + "Generation Info": "Informations sur la génération", + "Global": "Mondial", + "Good Response": "Bonne réponse", "Google PSE API Key": "Clé API Google PSE", - "Google PSE Engine Id": "ID du moteur Google PSE", + "Google PSE Engine Id": "ID du moteur de recherche personnalisé de Google", "h:mm a": "h:mm a", - "has no conversations.": "n'a pas de conversations.", - "Hello, {{name}}": "Bonjour, {{name}}", + "has no conversations.": "n'a aucune conversation.", + "Hello, {{name}}": "Bonjour, {{name}}.", "Help": "Aide", "Hide": "Cacher", - "Hide Model": "", - "How can I help you today?": "Comment puis-je vous aider aujourd'hui ?", - "Hybrid Search": "Recherche Hybride", - "Image Generation (Experimental)": "Génération d'Image (Expérimental)", - "Image Generation Engine": "Moteur de Génération d'Image", - "Image Settings": "Paramètres d'Image", + "Hide Model": "Masquer le modèle", + "How can I help you today?": "Comment puis-je vous être utile aujourd'hui ?", + "Hybrid Search": "Recherche hybride", + "Image Generation (Experimental)": "Génération d'images (expérimental)", + "Image Generation Engine": "Moteur de génération d'images", + "Image Settings": "Paramètres de l'image", "Images": "Images", - "Import Chats": "Importer les Chats", - "Import Documents Mapping": "Importer la Correspondance des Documents", - "Import Functions": "", - "Import Models": "Importer des Modèles", - "Import Prompts": "Importer des Prompts", - "Import Tools": "", - "Include `--api-auth` flag when running stable-diffusion-webui": "", - "Include `--api` flag when running stable-diffusion-webui": "Inclure le drapeau `--api` lors de l'exécution de stable-diffusion-webui", + "Import Chats": "Importer les discussions", + "Import Documents Mapping": "Import de la correspondance des documents", + "Import Functions": "Import de fonctions", + "Import Models": "Importer des modèles", + "Import Prompts": "Importer des Enseignes", + "Import Tools": "Outils d'importation", + "Include `--api-auth` flag when running stable-diffusion-webui": "Inclure le drapeau `--api-auth` lors de l'exécution de stable-diffusion-webui", + "Include `--api` flag when running stable-diffusion-webui": "Inclure le drapeau `--api` lorsque vous exécutez stable-diffusion-webui", "Info": "Info", - "Input commands": "Entrez les commandes d'entrée", - "Install from Github URL": "Installer à partir de l’URL Github", - "Instant Auto-Send After Voice Transcription": "", - "Interface": "Interface", - "Invalid Tag": "Tag Invalide", + "Input commands": "Entrez les commandes", + "Install from Github URL": "Installer depuis l'URL GitHub", + "Instant Auto-Send After Voice Transcription": "Envoi automatique instantané après transcription vocale", + "Interface": "Interface utilisateur", + "Invalid Tag": "Étiquette non valide", "January": "Janvier", - "join our Discord for help.": "rejoignez notre Discord pour obtenir de l'aide.", + "join our Discord for help.": "Rejoignez notre Discord pour obtenir de l'aide.", "JSON": "JSON", "JSON Preview": "Aperçu JSON", "July": "Juillet", "June": "Juin", - "JWT Expiration": "Expiration JWT", + "JWT Expiration": "Expiration du jeton JWT", "JWT Token": "Jeton JWT", - "Keep Alive": "Rester en vie", + "Keep Alive": "Rester connecté", "Keyboard shortcuts": "Raccourcis clavier", - "Knowledge": "", + "Knowledge": "Connaissance", "Language": "Langue", - "Last Active": "Dernier Activité", - "Last Modified": "", - "Light": "Clair", - "Listening...": "", - "LLMs can make mistakes. Verify important information.": "Les LLMs peuvent faire des erreurs. Vérifiez les informations importantes.", - "Local Models": "", + "Last Active": "Dernière activité", + "Last Modified": "Dernière modification", + "Light": "Lumineux", + "Listening...": "En train d'écouter...", + "LLMs can make mistakes. Verify important information.": "Les LLM peuvent faire des erreurs. Vérifiez les informations importantes.", + "Local Models": "Modèles locaux", "LTR": "LTR", "Made by OpenWebUI Community": "Réalisé par la communauté OpenWebUI", - "Make sure to enclose them with": "Assurez-vous de les entourer avec", - "Manage": "", - "Manage Models": "Gérer les modèles", + "Make sure to enclose them with": "Assurez-vous de les inclure dans", + "Manage": "Gérer", + "Manage Models": "Gérer les Modèles", "Manage Ollama Models": "Gérer les modèles Ollama", "Manage Pipelines": "Gérer les pipelines", - "Manage Valves": "", + "Manage Valves": "Gérer les vannes", "March": "Mars", "Max Tokens (num_predict)": "Tokens maximaux (num_predict)", - "Maximum of 3 models can be downloaded simultaneously. Please try again later.": "Un maximum de 3 modèles peut être téléchargé simultanément. Veuillez réessayer plus tard.", + "Maximum of 3 models can be downloaded simultaneously. Please try again later.": "Un maximum de 3 modèles peut être téléchargé en même temps. Veuillez réessayer ultérieurement.", "May": "Mai", - "Memories accessible by LLMs will be shown here.": "Les Mémoires des LLMs apparaîtront ici.", + "Memories accessible by LLMs will be shown here.": "Les mémoires accessibles par les LLMs seront affichées ici.", "Memory": "Mémoire", - "Memory added successfully": "", - "Memory cleared successfully": "", - "Memory deleted successfully": "", - "Memory updated successfully": "", - "Messages you send after creating your link won't be shared. Users with the URL will be able to view the shared chat.": "Les messages que vous envoyéz après la création du lien ne seront pas partagés. Les utilisateurs disposant de l'URL pourront voir le chat partagé.", - "Minimum Score": "Score Minimum", + "Memory added successfully": "Mémoire ajoutée avec succès", + "Memory cleared successfully": "La mémoire a été effacée avec succès", + "Memory deleted successfully": "La mémoire a été supprimée avec succès", + "Memory updated successfully": "La mémoire a été mise à jour avec succès", + "Messages you send after creating your link won't be shared. Users with the URL will be able to view the shared chat.": "Les messages que vous envoyez après avoir créé votre lien ne seront pas partagés. Les utilisateurs disposant de l'URL pourront voir le chat partagé.", + "Minimum Score": "Score minimal", "Mirostat": "Mirostat", "Mirostat Eta": "Mirostat Eta", "Mirostat Tau": "Mirostat Tau", - "MMMM DD, YYYY": "MMMM DD, YYYY", - "MMMM DD, YYYY HH:mm": "MMMM DD, YYYY HH:mm", - "MMMM DD, YYYY hh:mm:ss A": "", + "MMMM DD, YYYY": "MM DD, AAAA", + "MMMM DD, YYYY HH:mm": "MM MDDD, AAAA HH:mm", + "MMMM DD, YYYY hh:mm:ss A": "jj MM, aaaa HH:mm:ss", "Model '{{modelName}}' has been successfully downloaded.": "Le modèle '{{modelName}}' a été téléchargé avec succès.", "Model '{{modelTag}}' is already in queue for downloading.": "Le modèle '{{modelTag}}' est déjà dans la file d'attente pour le téléchargement.", - "Model {{modelId}} not found": "Modèle {{modelId}} non trouvé", - "Model {{modelName}} is not vision capable": "Modèle {{modelName}} n'est pas capable de voir", - "Model {{name}} is now {{status}}": "Le modèle {{name}} est maintenant {{status}}", - "Model created successfully!": "", - "Model filesystem path detected. Model shortname is required for update, cannot continue.": "Chemin du système de fichier du modèle détecté. Le nom court du modèle est requis pour la mise à jour, ne peut pas continuer.", - "Model ID": "ID du Modèle", + "Model {{modelId}} not found": "Modèle {{modelId}} introuvable", + "Model {{modelName}} is not vision capable": "Le modèle {{modelName}} n'a pas de capacités visuelles", + "Model {{name}} is now {{status}}": "Le modèle {{name}} est désormais {{status}}.", + "Model created successfully!": "Le modèle a été créé avec succès !", + "Model filesystem path detected. Model shortname is required for update, cannot continue.": "Chemin du système de fichiers de modèle détecté. Le nom court du modèle est requis pour la mise à jour, l'opération ne peut pas être poursuivie.", + "Model ID": "ID du modèle", "Model not selected": "Modèle non sélectionné", - "Model Params": "Paramètres du Modèle", - "Model updated successfully": "", - "Model Whitelisting": "Liste Blanche de Modèle", - "Model(s) Whitelisted": "Modèle(s) sur Liste Blanche", + "Model Params": "Paramètres du modèle", + "Model updated successfully": "Le modèle a été mis à jour avec succès", + "Model Whitelisting": "Liste blanche de modèles", + "Model(s) Whitelisted": "Modèle(s) Autorisé(s)", "Modelfile Content": "Contenu du Fichier de Modèle", "Models": "Modèles", - "More": "Plus", + "More": "Plus de", "Name": "Nom", - "Name Tag": "Tag de Nom", + "Name Tag": "Étiquette de nom", "Name your model": "Nommez votre modèle", - "New Chat": "Nouveau chat", + "New Chat": "Nouvelle conversation", "New Password": "Nouveau mot de passe", - "No content to speak": "", - "No documents found": "", - "No file selected": "", - "No results found": "Aucun résultat", + "No content to speak": "Rien à signaler", + "No documents found": "Aucun document trouvé", + "No file selected": "Aucun fichier sélectionné", + "No results found": "Aucun résultat trouvé", "No search query generated": "Aucune requête de recherche générée", - "No source available": "Aucune source disponible", - "No valves to update": "", + "No source available": "Aucune source n'est disponible", + "No valves to update": "Aucune vanne à mettre à jour", "None": "Aucun", - "Not factually correct": "Faits incorrects", - "Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.": "Note : Si vous définissez un score minimum, la recherche ne renverra que les documents ayant un score supérieur ou égal au score minimum.", - "Notifications": "Notifications de bureau", + "Not factually correct": "Non factuellement correct", + "Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.": "Note : Si vous définissez un score minimum, seuls les documents ayant un score supérieur ou égal à ce score minimum seront retournés par la recherche.", + "Notifications": "Notifications", "November": "Novembre", "num_thread (Ollama)": "num_thread (Ollama)", - "OAuth ID": "", + "OAuth ID": "ID OAuth", "October": "Octobre", "Off": "Désactivé", - "Okay, Let's Go!": "D'accord, allons-y !", - "OLED Dark": "Sombre OLED", + "Okay, Let's Go!": "D'accord, on y va !", + "OLED Dark": "Noir OLED", "Ollama": "Ollama", "Ollama API": "API Ollama", "Ollama API disabled": "API Ollama désactivée", - "Ollama API is disabled": "", - "Ollama Version": "Version Ollama", + "Ollama API is disabled": "L'API Ollama est désactivée", + "Ollama Version": "Version Ollama améliorée", "On": "Activé", "Only": "Seulement", "Only alphanumeric characters and hyphens are allowed in the command string.": "Seuls les caractères alphanumériques et les tirets sont autorisés dans la chaîne de commande.", - "Oops! Hold tight! Your files are still in the processing oven. We're cooking them up to perfection. Please be patient and we'll let you know once they're ready.": "Oups ! Tenez bon ! Vos fichiers sont encore dans le four. Nous les cuisinons à la perfection. Soyez patient et nous vous informerons dès qu'ils seront prêts.", - "Oops! Looks like the URL is invalid. Please double-check and try again.": "Oups ! On dirait que l'URL est invalide. Vérifiez et réessayez.", - "Oops! There was an error in the previous response. Please try again or contact admin.": "", - "Oops! You're using an unsupported method (frontend only). Please serve the WebUI from the backend.": "Oups ! Vous utilisez une méthode non-supportée (frontend uniquement). Veuillez également servir WebUI depuis le backend.", - "Open": "Ouvrir", + "Oops! Hold tight! Your files are still in the processing oven. We're cooking them up to perfection. Please be patient and we'll let you know once they're ready.": "Oups ! Un instant ! Vos fichiers sont toujours en train d'être traités. Nous les perfectionnons pour vous. Veuillez patienter, nous vous informerons dès qu'ils seront prêts.", + "Oops! Looks like the URL is invalid. Please double-check and try again.": "Oups ! Il semble que l'URL soit invalide. Veuillez vérifier à nouveau et réessayer.", + "Oops! There was an error in the previous response. Please try again or contact admin.": "Oops ! Il y a eu une erreur dans la réponse précédente. Veuillez réessayer ou contacter l'administrateur.", + "Oops! You're using an unsupported method (frontend only). Please serve the WebUI from the backend.": "Oups ! Vous utilisez une méthode non prise en charge (frontend uniquement). Veuillez servir l'interface Web à partir du backend.", + "Open": "Ouvrez", "Open AI (Dall-E)": "Open AI (Dall-E)", - "Open new chat": "Ouvrir un nouveau chat", + "Open new chat": "Ouvrir une nouvelle discussion", "OpenAI": "OpenAI", "OpenAI API": "API OpenAI", - "OpenAI API Config": "Config API OpenAI", - "OpenAI API Key is required.": "La clé d'API OpenAI est requise.", + "OpenAI API Config": "Configuration de l'API OpenAI", + "OpenAI API Key is required.": "Une clé API OpenAI est requise.", "OpenAI URL/Key required.": "URL/Clé OpenAI requise.", "or": "ou", "Other": "Autre", "Password": "Mot de passe", - "PDF document (.pdf)": "Document PDF (.pdf)", + "PDF document (.pdf)": "Document au format PDF (.pdf)", "PDF Extract Images (OCR)": "Extraction d'images PDF (OCR)", "pending": "en attente", - "Permission denied when accessing media devices": "", - "Permission denied when accessing microphone": "", + "Permission denied when accessing media devices": "Accès aux appareils multimédias refusé", + "Permission denied when accessing microphone": "Autorisation refusée lors de l'accès au micro", "Permission denied when accessing microphone: {{error}}": "Permission refusée lors de l'accès au microphone : {{error}}", "Personalization": "Personnalisation", - "Pipeline deleted successfully": "", - "Pipeline downloaded successfully": "", + "Pipeline deleted successfully": "Le pipeline a été supprimé avec succès", + "Pipeline downloaded successfully": "Le pipeline a été téléchargé avec succès", "Pipelines": "Pipelines", - "Pipelines Not Detected": "", - "Pipelines Valves": "Vannes de pipelines", - "Plain text (.txt)": "Texte Brute (.txt)", - "Playground": "Aire de jeu", - "Positive attitude": "Attitude Positive", - "Previous 30 days": "30 jours précédents", - "Previous 7 days": "7 jours précédents", - "Profile Image": "Image du Profil", + "Pipelines Not Detected": "Aucun pipelines détecté", + "Pipelines Valves": "Vannes de Pipelines", + "Plain text (.txt)": "Texte simple (.txt)", + "Playground": "Aire de jeux", + "Positive attitude": "Attitude positive", + "Previous 30 days": "30 derniers jours", + "Previous 7 days": "7 derniers jours", + "Profile Image": "Image de profil", "Prompt": "Prompt", - "Prompt (e.g. Tell me a fun fact about the Roman Empire)": "Prompt (p. ex. Raconte moi un fait amusant sur l'Empire Romain)", + "Prompt (e.g. Tell me a fun fact about the Roman Empire)": "Prompt (par ex. Dites-moi un fait amusant à propos de l'Empire romain)", "Prompt Content": "Contenu du prompt", - "Prompt suggestions": "Suggestions de prompt", + "Prompt suggestions": "Suggestions pour le prompt", "Prompts": "Prompts", - "Pull \"{{searchValue}}\" from Ollama.com": "Récupérer \"{{searchValue}}\" de Ollama.com", - "Pull a model from Ollama.com": "Récupérer un modèle de Ollama.com", - "Query Params": "Paramètres de Requête", + "Pull \"{{searchValue}}\" from Ollama.com": "Récupérer « {{searchValue}} » depuis Ollama.com", + "Pull a model from Ollama.com": "Télécharger un modèle depuis Ollama.com", + "Query Params": "Paramètres de requête", "RAG Template": "Modèle RAG", - "Read Aloud": "Lire à Voix Haute", + "Read Aloud": "Lire à haute voix", "Record voice": "Enregistrer la voix", "Redirecting you to OpenWebUI Community": "Redirection vers la communauté OpenWebUI", - "Refer to yourself as \"User\" (e.g., \"User is learning Spanish\")": "", - "Refused when it shouldn't have": "Refuse quand il ne devrait pas", + "Refer to yourself as \"User\" (e.g., \"User is learning Spanish\")": "Désignez-vous comme « Utilisateur » (par ex. « L'utilisateur apprend l'espagnol »)", + "Refused when it shouldn't have": "Refusé alors qu'il n'aurait pas dû l'être", "Regenerate": "Regénérer", - "Release Notes": "Notes de Version", + "Release Notes": "Notes de publication", "Remove": "Retirer", - "Remove Model": "Retirer le Modèle", + "Remove Model": "Retirer le modèle", "Rename": "Renommer", - "Repeat Last N": "Répéter les Derniers N", - "Request Mode": "Mode de Demande", - "Reranking Model": "Modèle de Reclassement", - "Reranking model disabled": "Modèle de Reclassement Désactivé", - "Reranking model set to \"{{reranking_model}}\"": "Modèle de reclassement défini sur \"{{reranking_model}}\"", - "Reset": "", - "Reset Upload Directory": "", - "Reset Vector Storage": "Réinitialiser le Stockage de Vecteur", - "Response AutoCopy to Clipboard": "Copie Automatique de la Réponse dans le Presse-papiers", - "Response notifications cannot be activated as the website permissions have been denied. Please visit your browser settings to grant the necessary access.": "", + "Repeat Last N": "Répéter les N derniers", + "Request Mode": "Mode de Requête", + "Reranking Model": "Modèle de ré-ranking", + "Reranking model disabled": "Modèle de ré-ranking désactivé", + "Reranking model set to \"{{reranking_model}}\"": "Modèle de ré-ranking défini sur « {{reranking_model}} »", + "Reset": "Réinitialiser", + "Reset Upload Directory": "Répertoire de téléchargement réinitialisé", + "Reset Vector Storage": "Réinitialiser le stockage des vecteurs", + "Response AutoCopy to Clipboard": "Copie automatique de la réponse vers le presse-papiers", + "Response notifications cannot be activated as the website permissions have been denied. Please visit your browser settings to grant the necessary access.": "Les notifications de réponse ne peuvent pas être activées car les autorisations du site web ont été refusées. Veuillez visiter les paramètres de votre navigateur pour accorder l'accès nécessaire.", "Role": "Rôle", - "Rosé Pine": "Pin Rosé", - "Rosé Pine Dawn": "Aube Pin Rosé", + "Rosé Pine": "Pin rosé", + "Rosé Pine Dawn": "Aube de Pin Rosé", "RTL": "RTL", - "Running": "", + "Running": "Courir", "Save": "Enregistrer", "Save & Create": "Enregistrer & Créer", "Save & Update": "Enregistrer & Mettre à jour", - "Saving chat logs directly to your browser's storage is no longer supported. Please take a moment to download and delete your chat logs by clicking the button below. Don't worry, you can easily re-import your chat logs to the backend through": "La sauvegarde des chat directement dans le stockage de votre navigateur n'est plus prise en charge. Veuillez prendre un moment pour télécharger et supprimer vos journaux de chat en cliquant sur le bouton ci-dessous. Ne vous inquiétez pas, vous pouvez facilement importer vos sauvegardes de chat via", + "Saving chat logs directly to your browser's storage is no longer supported. Please take a moment to download and delete your chat logs by clicking the button below. Don't worry, you can easily re-import your chat logs to the backend through": "La sauvegarde des journaux de discussion directement dans le stockage de votre navigateur n'est plus prise en charge. Veuillez prendre un instant pour télécharger et supprimer vos journaux de discussion en cliquant sur le bouton ci-dessous. Pas de soucis, vous pouvez facilement les réimporter depuis le backend via l'interface ci-dessous", "Scan": "Scanner", "Scan complete!": "Scan terminé !", "Scan for documents from {{path}}": "Scanner des documents depuis {{path}}", "Search": "Recherche", "Search a model": "Rechercher un modèle", - "Search Chats": "Rechercher des chats", - "Search Documents": "Rechercher des Documents", - "Search Functions": "", + "Search Chats": "Rechercher des conversations", + "Search Documents": "Recherche de documents", + "Search Functions": "Fonctions de recherche", "Search Models": "Rechercher des modèles", - "Search Prompts": "Rechercher des Prompts", - "Search Query Generation Prompt": "", - "Search Query Generation Prompt Length Threshold": "", + "Search Prompts": "Recherche de prompts", + "Search Query Generation Prompt": "Génération d'interrogation de recherche", + "Search Query Generation Prompt Length Threshold": "Seuil de longueur de prompt de génération de requête de recherche", "Search Result Count": "Nombre de résultats de recherche", - "Search Tools": "", - "Searched {{count}} sites_one": "Recherché {{count}} sites_one", + "Search Tools": "Outils de recherche", + "Searched {{count}} sites_one": "Recherché {{count}} site(s)_one", "Searched {{count}} sites_many": "Recherché {{count}} sites_many", - "Searched {{count}} sites_other": "Recherché {{count}} sites_other", - "Searching \"{{searchQuery}}\"": "", - "Searxng Query URL": "URL de requête Searxng", - "See readme.md for instructions": "Voir readme.md pour les instructions", - "See what's new": "Voir les nouveautés", + "Searched {{count}} sites_other": "Recherché {{count}} sites_autres", + "Searching \"{{searchQuery}}\"": "Recherche de « {{searchQuery}} »", + "Searxng Query URL": "URL de recherche Searxng", + "See readme.md for instructions": "Voir le fichier readme.md pour les instructions", + "See what's new": "Découvrez les nouvelles fonctionnalités", "Seed": "Graine", - "Select a base model": "Sélectionner un modèle de base", - "Select a engine": "", - "Select a function": "", - "Select a mode": "Sélectionner un mode", - "Select a model": "Sélectionner un modèle", - "Select a pipeline": "Sélectionner un pipeline", - "Select a pipeline url": "Sélectionnez une URL de pipeline", - "Select a tool": "", - "Select an Ollama instance": "Sélectionner une instance Ollama", - "Select Documents": "", - "Select model": "Sélectionner un modèle", - "Select only one model to call": "", - "Selected model(s) do not support image inputs": "Modèle(s) séléctionés ne supportent pas les entrées images", + "Select a base model": "Sélectionnez un modèle de base", + "Select a engine": "Sélectionnez un moteur", + "Select a function": "Sélectionnez une fonction", + "Select a mode": "Choisissez un mode", + "Select a model": "Sélectionnez un modèle", + "Select a pipeline": "Sélectionnez un pipeline", + "Select a pipeline url": "Sélectionnez l'URL du pipeline", + "Select a tool": "Sélectionnez un outil", + "Select an Ollama instance": "Sélectionnez une instance Ollama", + "Select Documents": "Sélectionnez des documents", + "Select model": "Sélectionnez un modèle", + "Select only one model to call": "Sélectionnez seulement un modèle pour appeler", + "Selected model(s) do not support image inputs": "Les modèle(s) sélectionné(s) ne prennent pas en charge les entrées d'images", "Send": "Envoyer", "Send a Message": "Envoyer un message", "Send message": "Envoyer un message", "September": "Septembre", "Serper API Key": "Clé API Serper", - "Serply API Key": "", + "Serply API Key": "Clé API Serply", "Serpstack API Key": "Clé API Serpstack", "Server connection verified": "Connexion au serveur vérifiée", - "Set as default": "Définir par défaut", - "Set Default Model": "Définir le Modèle par Défaut", - "Set embedding model (e.g. {{model}})": "Définir le modèle d'embedding (p. ex. {{model}})", - "Set Image Size": "Définir la Taille de l'Image", - "Set reranking model (e.g. {{model}})": "Définir le modèle de reclassement (p. ex. {{model}})", - "Set Steps": "Définir les Étapes", + "Set as default": "Définir comme valeur par défaut", + "Set Default Model": "Définir le modèle par défaut", + "Set embedding model (e.g. {{model}})": "Définir le modèle d'encodage (par ex. {{model}})", + "Set Image Size": "Définir la taille de l'image", + "Set reranking model (e.g. {{model}})": "Définir le modèle de reclassement (par ex. {{model}})", + "Set Steps": "Définir les étapes", "Set Task Model": "Définir le modèle de tâche", - "Set Voice": "Définir la Voix", + "Set Voice": "Définir la voix", "Settings": "Paramètres", "Settings saved successfully!": "Paramètres enregistrés avec succès !", - "Settings updated successfully": "", + "Settings updated successfully": "Les paramètres ont été mis à jour avec succès", "Share": "Partager", - "Share Chat": "Partager le Chat", + "Share Chat": "Partage de conversation", "Share to OpenWebUI Community": "Partager avec la communauté OpenWebUI", - "short-summary": "résumé court", + "short-summary": "résumé concis", "Show": "Montrer", - "Show Admin Details in Account Pending Overlay": "", - "Show Model": "", + "Show Admin Details in Account Pending Overlay": "Afficher les détails de l'administrateur dans la superposition en attente du compte", + "Show Model": "Montrer le modèle", "Show shortcuts": "Afficher les raccourcis", - "Show your support!": "", - "Showcased creativity": "Créativité affichée", + "Show your support!": "Montre ton soutien !", + "Showcased creativity": "Créativité mise en avant", "sidebar": "barre latérale", - "Sign in": "Se connecter", - "Sign Out": "Se déconnecter", - "Sign up": "S'inscrire", + "Sign in": "S'identifier", + "Sign Out": "Déconnexion", + "Sign up": "Inscrivez-vous", "Signing in": "Connexion en cours", "Source": "Source", - "Speech recognition error: {{error}}": "Erreur de reconnaissance vocale : {{error}}", - "Speech-to-Text Engine": "Moteur de Reconnaissance Vocale", - "Stop Sequence": "Séquence d'Arrêt", - "STT Model": "", - "STT Settings": "Paramètres STT", - "Submit": "Envoyer", - "Subtitle (e.g. about the Roman Empire)": "Sous-Titres (p. ex. à propos de l'Empire Romain)", - "Success": "Succès", - "Successfully updated.": "Mis à jour avec succès.", - "Suggested": "Suggéré", + "Speech recognition error: {{error}}": "Erreur de reconnaissance vocale : {{error}}", + "Speech-to-Text Engine": "Moteur de reconnaissance vocale", + "Stop Sequence": "Séquence d'arrêt", + "STT Model": "Modèle de STT", + "STT Settings": "Paramètres de STT", + "Submit": "Soumettre", + "Subtitle (e.g. about the Roman Empire)": "Sous-titres (par ex. sur l'Empire romain)", + "Success": "Réussite", + "Successfully updated.": "Mise à jour réussie.", + "Suggested": "Sugéré", "System": "Système", - "System Prompt": "Prompt du Système", - "Tags": "Tags", - "Tap to interrupt": "", - "Tavily API Key": "", - "Tell us more:": "Dites-nous en plus :", + "System Prompt": "Prompt du système", + "Tags": "Balises", + "Tap to interrupt": "Appuyez pour interrompre", + "Tavily API Key": "Clé API Tavily", + "Tell us more:": "Dites-nous en plus à ce sujet : ", "Temperature": "Température", - "Template": "Modèle", - "Text Completion": "Complétion de Texte", - "Text-to-Speech Engine": "Moteur de Synthèse Vocale", + "Template": "Template", + "Text Completion": "Complétion de texte", + "Text-to-Speech Engine": "Moteur de synthèse vocale", "Tfs Z": "Tfs Z", - "Thanks for your feedback!": "Merci pour votre avis !", - "The score should be a value between 0.0 (0%) and 1.0 (100%).": "Le score devrait avoir une valeur entre 0.0 (0%) et 1.0 (100%).", + "Thanks for your feedback!": "Merci pour vos commentaires !", + "The score should be a value between 0.0 (0%) and 1.0 (100%).": "Le score doit être une valeur comprise entre 0,0 (0 %) et 1,0 (100 %).", "Theme": "Thème", - "Thinking...": "", - "This action cannot be undone. Do you wish to continue?": "", - "This ensures that your valuable conversations are securely saved to your backend database. Thank you!": "Cela garantit que vos précieuses conversations sont en sécurité dans votre base de données. Merci !", - "This is an experimental feature, it may not function as expected and is subject to change at any time.": "", + "Thinking...": "En train de réfléchir...", + "This action cannot be undone. Do you wish to continue?": "Cette action ne peut pas être annulée. Souhaitez-vous continuer ?", + "This ensures that your valuable conversations are securely saved to your backend database. Thank you!": "Cela garantit que vos conversations précieuses soient sauvegardées en toute sécurité dans votre base de données backend. Merci !", + "This is an experimental feature, it may not function as expected and is subject to change at any time.": "Il s'agit d'une fonctionnalité expérimentale, elle peut ne pas fonctionner comme prévu et est sujette à modification à tout moment.", "This setting does not sync across browsers or devices.": "Ce paramètre ne se synchronise pas entre les navigateurs ou les appareils.", - "This will delete": "", - "Thorough explanation": "Explication détaillée", - "Tip: Update multiple variable slots consecutively by pressing the tab key in the chat input after each replacement.": "Conseil : Mettez à jour plusieurs emplacements de variables consécutivement en appuyant sur la touche tab dans l'entrée de chat après chaque remplacement", + "This will delete": "Cela supprimera", + "Thorough explanation": "Explication approfondie", + "Tip: Update multiple variable slots consecutively by pressing the tab key in the chat input after each replacement.": "Conseil : mettez à jour plusieurs emplacements de variables consécutivement en appuyant sur la touche Tab dans l’entrée de chat après chaque remplacement.", "Title": "Titre", - "Title (e.g. Tell me a fun fact)": "Titre (p. ex. Donne moi un fait amusant)", - "Title Auto-Generation": "Génération Automatique du Titre", - "Title cannot be an empty string.": "Le Titre ne peut pas être vide.", - "Title Generation Prompt": "Prompt de Génération du Titre", + "Title (e.g. Tell me a fun fact)": "Titre (par ex. raconte-moi un fait amusant)", + "Title Auto-Generation": "Génération automatique de titres", + "Title cannot be an empty string.": "Le titre ne peut pas être une chaîne de caractères vide.", + "Title Generation Prompt": "Prompt de génération de titre", "to": "à", - "To access the available model names for downloading,": "Pour accéder aux noms de modèles disponibles pour le téléchargement,", - "To access the GGUF models available for downloading,": "Pour accéder aux modèles GGUF disponibles pour le téléchargement,", - "To access the WebUI, please reach out to the administrator. Admins can manage user statuses from the Admin Panel.": "", - "To add documents here, upload them to the \"Documents\" workspace first.": "", - "to chat input.": "à l'entrée du chat.", - "To select filters here, add them to the \"Functions\" workspace first.": "", - "To select toolkits here, add them to the \"Tools\" workspace first.": "", + "To access the available model names for downloading,": "Pour accéder aux noms des modèles disponibles en téléchargement,", + "To access the GGUF models available for downloading,": "Pour accéder aux modèles GGUF disponibles en téléchargement,", + "To access the WebUI, please reach out to the administrator. Admins can manage user statuses from the Admin Panel.": "Pour accéder à l'interface Web, veuillez contacter l'administrateur. Les administrateurs peuvent gérer les statuts des utilisateurs depuis le panneau d'administration.", + "To add documents here, upload them to the \"Documents\" workspace first.": "Pour ajouter des documents ici, téléchargez-les d'abord dans l'espace de travail « Documents ». ", + "to chat input.": "à l'entrée de discussion.", + "To select filters here, add them to the \"Functions\" workspace first.": "Pour sélectionner des filtres ici, ajoutez-les d'abord à l'espace de travail « Fonctions ». ", + "To select toolkits here, add them to the \"Tools\" workspace first.": "Pour sélectionner des toolkits ici, ajoutez-les d'abord à l'espace de travail « Outils ». ", "Today": "Aujourd'hui", "Toggle settings": "Basculer les paramètres", "Toggle sidebar": "Basculer la barre latérale", - "Tokens To Keep On Context Refresh (num_keep)": "", - "Tool created successfully": "", - "Tool deleted successfully": "", - "Tool imported successfully": "", - "Tool updated successfully": "", - "Tools": "", + "Tokens To Keep On Context Refresh (num_keep)": "Jeton à conserver pour l'actualisation du contexte (num_keep)", + "Tool created successfully": "L'outil a été créé avec succès", + "Tool deleted successfully": "Outil supprimé avec succès", + "Tool imported successfully": "Outil importé avec succès", + "Tool updated successfully": "L'outil a été mis à jour avec succès", + "Tools": "Outils", "Top K": "Top K", "Top P": "Top P", - "Trouble accessing Ollama?": "Problèmes d'accès à Ollama ?", - "TTS Model": "", - "TTS Settings": "Paramètres TTS", - "TTS Voice": "", + "Trouble accessing Ollama?": "Rencontrez-vous des difficultés pour accéder à Ollama ?", + "TTS Model": "Modèle de synthèse vocale", + "TTS Settings": "Paramètres de synthèse vocale", + "TTS Voice": "Voix TTS", "Type": "Type", - "Type Hugging Face Resolve (Download) URL": "Entrez l'URL de Résolution (Téléchargement) Hugging Face", - "Uh-oh! There was an issue connecting to {{provider}}.": "Uh-oh ! Il y a eu un problème de connexion à {{provider}}.", - "UI": "", - "Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.": "", - "Update": "", - "Update and Copy Link": "Mettre à Jour et Copier le Lien", - "Update password": "Mettre à Jour le Mot de Passe", - "Updated at": "", - "Upload": "", + "Type Hugging Face Resolve (Download) URL": "Entrez l'URL de Téléchargement Hugging Face Resolve", + "Uh-oh! There was an issue connecting to {{provider}}.": "Oh non ! Un problème est survenu lors de la connexion à {{provider}}.", + "UI": "Interface utilisateur", + "Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.": "Type de fichier inconnu '{{file_type}}'. Continuons tout de même le téléchargement du fichier.", + "Update": "Mise à jour", + "Update and Copy Link": "Mettre à jour et copier le lien", + "Update password": "Mettre à jour le mot de passe", + "Updated at": "Mise à jour le", + "Upload": "Télécharger", "Upload a GGUF model": "Téléverser un modèle GGUF", - "Upload Files": "Téléverser des fichiers", - "Upload Pipeline": "", - "Upload Progress": "Progression du Téléversement", - "URL Mode": "Mode URL", - "Use '#' in the prompt input to load and select your documents.": "Utilisez '#' dans l'entrée du prompt pour charger et sélectionner vos documents.", - "Use Gravatar": "Utiliser Gravatar", - "Use Initials": "Utiliser les Initiales", + "Upload Files": "Télécharger des fichiers", + "Upload Pipeline": "Pipeline de téléchargement", + "Upload Progress": "Progression de l'envoi", + "URL Mode": "Mode d'URL", + "Use '#' in the prompt input to load and select your documents.": "Utilisez '#' dans l'entrée de prompt pour charger et sélectionner vos documents.", + "Use Gravatar": "Utilisez Gravatar", + "Use Initials": "Utiliser les initiales", "use_mlock (Ollama)": "use_mlock (Ollama)", - "use_mmap (Ollama)": "use_mmap (Ollama)", + "use_mmap (Ollama)": "utiliser mmap (Ollama)", "user": "utilisateur", - "User location successfully retrieved.": "", - "User Permissions": "Permissions d'utilisateur", + "User location successfully retrieved.": "L'emplacement de l'utilisateur a été récupéré avec succès.", + "User Permissions": "Permissions utilisateur", "Users": "Utilisateurs", - "Utilize": "Utiliser", - "Valid time units:": "Unités de temps valides :", - "Valves": "", - "Valves updated": "", - "Valves updated successfully": "", + "Utilize": "Utilisez", + "Valid time units:": "Unités de temps valides :", + "Valves": "Vannes", + "Valves updated": "Vannes mises à jour", + "Valves updated successfully": "Les vannes ont été mises à jour avec succès", "variable": "variable", - "variable to have them replaced with clipboard content.": "variable pour les remplacer par le contenu du presse-papiers.", - "Version": "Version", - "Voice": "", - "Warning": "Avertissement", - "Warning: If you update or change your embedding model, you will need to re-import all documents.": "Avertissement : Si vous mettez à jour ou modifier votre modèle d'embedding, vous devrez réimporter tous les documents.", + "variable to have them replaced with clipboard content.": "variable pour qu'elles soient remplacées par le contenu du presse-papiers.", + "Version": "Version améliorée", + "Voice": "Voix", + "Warning": "Avertissement !", + "Warning: If you update or change your embedding model, you will need to re-import all documents.": "Avertissement : Si vous mettez à jour ou modifiez votre modèle d'encodage, vous devrez réimporter tous les documents.", "Web": "Web", - "Web API": "", - "Web Loader Settings": "Paramètres du Chargeur Web", + "Web API": "API Web", + "Web Loader Settings": "Paramètres du chargeur web", "Web Params": "Paramètres Web", - "Web Search": "Recherche sur le Web", + "Web Search": "Recherche Web", "Web Search Engine": "Moteur de recherche Web", - "Webhook URL": "URL du Webhook", - "WebUI Settings": "Paramètres WebUI", - "WebUI will make requests to": "WebUI effectuera des demandes à", - "What’s New in": "Quoi de neuf dans", - "When history is turned off, new chats on this browser won't appear in your history on any of your devices.": "Lorsque l'historique est désactivé, les nouveaux chats sur ce navigateur n'apparaîtront pas dans votre historique sur aucun de vos appareils.", - "Whisper (Local)": "", - "Widescreen Mode": "", - "Workspace": "Espace de Travail", - "Write a prompt suggestion (e.g. Who are you?)": "Écrivez une suggestion de prompt (e.x. Qui est-tu ?)", - "Write a summary in 50 words that summarizes [topic or keyword].": "Ecrivez un résumé en 50 mots qui résume [sujet ou mot-clé]", + "Webhook URL": "URL du webhook", + "WebUI Settings": "Paramètres de WebUI", + "WebUI will make requests to": "WebUI effectuera des requêtes vers", + "What’s New in": "Quoi de neuf", + "When history is turned off, new chats on this browser won't appear in your history on any of your devices.": "Lorsque l'historique est désactivé, les nouvelles conversations sur ce navigateur ne seront pas enregistrés dans votre historique sur aucun de vos appareils.", + "Whisper (Local)": "Whisper (local)", + "Widescreen Mode": "Mode Grand Écran", + "Workspace": "Espace de travail", + "Write a prompt suggestion (e.g. Who are you?)": "Écrivez une suggestion de prompt (par exemple : Qui êtes-vous ?)", + "Write a summary in 50 words that summarizes [topic or keyword].": "Rédigez un résumé de 50 mots qui résume [sujet ou mot-clé].", "Yesterday": "Hier", "You": "Vous", - "You can personalize your interactions with LLMs by adding memories through the 'Manage' button below, making them more helpful and tailored to you.": "", + "You can personalize your interactions with LLMs by adding memories through the 'Manage' button below, making them more helpful and tailored to you.": "Vous pouvez personnaliser vos interactions avec les LLM en ajoutant des souvenirs via le bouton 'Gérer' ci-dessous, ce qui les rendra plus utiles et adaptés à vos besoins.", "You cannot clone a base model": "Vous ne pouvez pas cloner un modèle de base", - "You have no archived conversations.": "Vous n'avez pas de conversations archivées", - "You have shared this chat": "Vous avez partagé ce chat", - "You're a helpful assistant.": "Vous êtes un assistant utile.", - "You're now logged in.": "Vous êtes maintenant connecté.", - "Your account status is currently pending activation.": "", - "Youtube": "Youtube", - "Youtube Loader Settings": "Paramètres du Chargeur YouTube" -} + "You have no archived conversations.": "Vous n'avez aucune conversation archivée", + "You have shared this chat": "Vous avez partagé cette conversation.", + "You're a helpful assistant.": "Vous êtes un assistant serviable.", + "You're now logged in.": "Vous êtes désormais connecté.", + "Your account status is currently pending activation.": "Votre statut de compte est actuellement en attente d'activation.", + "Youtube": "YouTube", + "Youtube Loader Settings": "Paramètres de l'outil de téléchargement YouTube" + } \ No newline at end of file From 2fedd91ed94d5a6bbadcac2214d8c5748559604e Mon Sep 17 00:00:00 2001 From: Morgan Blangeois Date: Wed, 3 Jul 2024 17:20:57 +0200 Subject: [PATCH 027/115] feat: Improve French Canadian (fr-ca) translations --- src/lib/i18n/locales/fr-CA/translation.json | 890 ++++++++++---------- 1 file changed, 445 insertions(+), 445 deletions(-) diff --git a/src/lib/i18n/locales/fr-CA/translation.json b/src/lib/i18n/locales/fr-CA/translation.json index a5754f724..5d358dc7e 100644 --- a/src/lib/i18n/locales/fr-CA/translation.json +++ b/src/lib/i18n/locales/fr-CA/translation.json @@ -1,668 +1,668 @@ { - "'s', 'm', 'h', 'd', 'w' or '-1' for no expiration.": "'s', 'm', 'h', 'd', 'w' ou '-1' pour aucune expiration.", - "(Beta)": "(Bêta)", - "(e.g. `sh webui.sh --api --api-auth username_password`)": "", - "(e.g. `sh webui.sh --api`)": "(par ex. `sh webui.sh --api`)", - "(latest)": "(dernière)", + "'s', 'm', 'h', 'd', 'w' or '-1' for no expiration.": " 's', 'm', 'h', 'd', 'w' ou '-1' pour une durée illimitée.", + "(Beta)": "(Version bêta)", + "(e.g. `sh webui.sh --api --api-auth username_password`)": "(par ex. `sh webui.sh --api --api-auth username_password`)", + "(e.g. `sh webui.sh --api`)": "(par exemple `sh webui.sh --api`)", + "(latest)": "(dernier)", "{{ models }}": "{{ modèles }}", - "{{ owner }}: You cannot delete a base model": "{{ propriétaire }} : vous ne pouvez pas supprimer un modèle de base", - "{{modelName}} is thinking...": "{{modelName}} réfléchit...", - "{{user}}'s Chats": "{{user}}'s Chats", + "{{ owner }}: You cannot delete a base model": "{{ propriétaire }} : Vous ne pouvez pas supprimer un modèle de base.", + "{{modelName}} is thinking...": "{{modelName}} est en train de réfléchir...", + "{{user}}'s Chats": "Discussions de {{user}}", "{{webUIName}} Backend Required": "Backend {{webUIName}} requis", - "A task model is used when performing tasks such as generating titles for chats and web search queries": "Un modèle de tâche est utilisé lors de l’exécution de tâches telles que la génération de titres pour les chats et les requêtes de recherche Web", + "A task model is used when performing tasks such as generating titles for chats and web search queries": "Un modèle de tâche est utilisé lors de l’exécution de tâches telles que la génération de titres pour les conversations et les requêtes de recherche sur le web.", "a user": "un utilisateur", "About": "À propos", "Account": "Compte", - "Account Activation Pending": "", - "Accurate information": "Information précise", - "Active Users": "", + "Account Activation Pending": "Activation du compte en attente", + "Accurate information": "Information exacte", + "Active Users": "Utilisateurs actifs", "Add": "Ajouter", - "Add a model id": "Ajouter un id de modèle", - "Add a short description about what this model does": "Ajoutez une brève description de ce que fait ce modèle", - "Add a short title for this prompt": "Ajouter un court titre pour ce prompt", - "Add a tag": "Ajouter un tag", - "Add custom prompt": "Ajouter un prompt personnalisé", - "Add Docs": "Ajouter des documents", + "Add a model id": "Ajouter un identifiant de modèle", + "Add a short description about what this model does": "Ajoutez une brève description de ce que fait ce modèle.", + "Add a short title for this prompt": "Ajoutez un bref titre pour cette prompt.", + "Add a tag": "Ajouter une balise", + "Add custom prompt": "Ajouter une prompt personnalisée", + "Add Docs": "Ajouter de la documentation", "Add Files": "Ajouter des fichiers", - "Add Memory": "Ajouter une mémoire", + "Add Memory": "Ajouter de la mémoire", "Add message": "Ajouter un message", "Add Model": "Ajouter un modèle", - "Add Tags": "ajouter des tags", - "Add User": "Ajouter un utilisateur", - "Adjusting these settings will apply changes universally to all users.": "L'ajustement de ces paramètres appliquera les changements à tous les utilisateurs.", - "admin": "Administrateur", - "Admin": "", - "Admin Panel": "Panneau d'administration", + "Add Tags": "Ajouter des balises", + "Add User": "Ajouter un Utilisateur", + "Adjusting these settings will apply changes universally to all users.": "L'ajustement de ces paramètres appliquera universellement les changements à tous les utilisateurs.", + "admin": "administrateur", + "Admin": "Administrateur", + "Admin Panel": "Tableau de bord administrateur", "Admin Settings": "Paramètres d'administration", - "Admins have access to all tools at all times; users need tools assigned per model in the workspace.": "", + "Admins have access to all tools at all times; users need tools assigned per model in the workspace.": "Les administrateurs ont accès à tous les outils en tout temps ; les utilisateurs ont besoin d'outils affectés par modèle dans l'espace de travail.", "Advanced Parameters": "Paramètres avancés", - "Advanced Params": "Params avancés", - "all": "tous", + "Advanced Params": "Paramètres avancés", + "all": "toutes", "All Documents": "Tous les documents", - "All Users": "Tous les utilisateurs", + "All Users": "Tous les Utilisateurs", "Allow": "Autoriser", - "Allow Chat Deletion": "Autoriser la suppression des discussions", - "Allow non-local voices": "", - "Allow User Location": "", - "Allow Voice Interruption in Call": "", + "Allow Chat Deletion": "Autoriser la suppression de l'historique de chat", + "Allow non-local voices": "Autoriser les voix non locales", + "Allow User Location": "Autoriser l'emplacement de l'utilisateur", + "Allow Voice Interruption in Call": "Autoriser l'interruption vocale pendant un appel", "alphanumeric characters and hyphens": "caractères alphanumériques et tirets", - "Already have an account?": "Vous avez déjà un compte ?", + "Already have an account?": "Avez-vous déjà un compte ?", "an assistant": "un assistant", "and": "et", "and create a new shared link.": "et créer un nouveau lien partagé.", "API Base URL": "URL de base de l'API", - "API Key": "Clé API", - "API Key created.": "Clé API créée.", - "API keys": "Clés API", + "API Key": "Clé d'API", + "API Key created.": "Clé d'API générée.", + "API keys": "Clés d'API", "April": "Avril", - "Archive": "Archiver", - "Archive All Chats": "Archiver tous les chats", - "Archived Chats": "enregistrement du chat", - "are allowed - Activate this command by typing": "sont autorisés - Activez cette commande en tapant", - "Are you sure?": "Êtes-vous sûr ?", - "Attach file": "Joindre un fichier", + "Archive": "Archivage", + "Archive All Chats": "Archiver toutes les conversations", + "Archived Chats": "Conversations archivées", + "are allowed - Activate this command by typing": "sont autorisés - Activer cette commande en tapant", + "Are you sure?": "Êtes-vous certain ?", + "Attach file": "Joindre un document", "Attention to detail": "Attention aux détails", "Audio": "Audio", - "Audio settings updated successfully": "", + "Audio settings updated successfully": "Les paramètres audio ont été mis à jour avec succès", "August": "Août", - "Auto-playback response": "Réponse en lecture automatique", - "AUTOMATIC1111 Api Auth String": "", + "Auto-playback response": "Réponse de lecture automatique", + "AUTOMATIC1111 Api Auth String": "AUTOMATIC1111 Chaîne d'authentification de l'API", "AUTOMATIC1111 Base URL": "URL de base AUTOMATIC1111", - "AUTOMATIC1111 Base URL is required.": "L'URL de base AUTOMATIC1111 est requise.", + "AUTOMATIC1111 Base URL is required.": "L'URL de base {AUTOMATIC1111} est requise.", "available!": "disponible !", - "Back": "Retour", + "Back": "Retour en arrière", "Bad Response": "Mauvaise réponse", - "Banners": "Bannières", + "Banners": "Banniers", "Base Model (From)": "Modèle de base (à partir de)", - "Batch Size (num_batch)": "", + "Batch Size (num_batch)": "Taille du lot (num_batch)", "before": "avant", - "Being lazy": "En manque de temps", - "Brave Search API Key": "Clé d’API de recherche brave", - "Bypass SSL verification for Websites": "Parcourir la vérification SSL pour les sites Web", - "Call": "", - "Call feature is not supported when using Web STT engine": "", - "Camera": "", + "Being lazy": "Être fainéant", + "Brave Search API Key": "Clé API Brave Search", + "Bypass SSL verification for Websites": "Bypasser la vérification SSL pour les sites web", + "Call": "Appeler", + "Call feature is not supported when using Web STT engine": "La fonction d'appel n'est pas prise en charge lors de l'utilisation du moteur Web STT", + "Camera": "Appareil photo", "Cancel": "Annuler", "Capabilities": "Capacités", "Change Password": "Changer le mot de passe", - "Chat": "Discussion", - "Chat Background Image": "", - "Chat Bubble UI": "Bubble UI de discussion", - "Chat direction": "Direction de discussion", - "Chat History": "Historique des discussions", - "Chat History is off for this browser.": "L'historique des discussions est désactivé pour ce navigateur.", - "Chats": "Discussions", - "Check Again": "Vérifier à nouveau", - "Check for updates": "Vérifier les mises à jour", - "Checking for updates...": "Vérification des mises à jour...", - "Choose a model before saving...": "Choisissez un modèle avant d'enregistrer...", - "Chunk Overlap": "Chevauchement de bloc", - "Chunk Params": "Paramètres de bloc", + "Chat": "Chat", + "Chat Background Image": "Image d'arrière-plan de la fenêtre de chat", + "Chat Bubble UI": "Bulles de discussion", + "Chat direction": "Direction du chat", + "Chat History": "Historique de discussion", + "Chat History is off for this browser.": "L'historique de chat est désactivé pour ce navigateur", + "Chats": "Conversations", + "Check Again": "Vérifiez à nouveau.", + "Check for updates": "Vérifier les mises à jour disponibles", + "Checking for updates...": "Recherche de mises à jour...", + "Choose a model before saving...": "Choisissez un modèle avant de sauvegarder...", + "Chunk Overlap": "Chevauchement de blocs", + "Chunk Params": "Paramètres d'encombrement", "Chunk Size": "Taille de bloc", - "Citation": "Citations", - "Clear memory": "", - "Click here for help.": "Cliquez ici pour de l'aide.", + "Citation": "Citation", + "Clear memory": "Libérer la mémoire", + "Click here for help.": "Cliquez ici pour obtenir de l'aide.", "Click here to": "Cliquez ici pour", - "Click here to download user import template file.": "", + "Click here to download user import template file.": "Cliquez ici pour télécharger le fichier modèle d'importation utilisateur.", "Click here to select": "Cliquez ici pour sélectionner", - "Click here to select a csv file.": "Cliquez ici pour sélectionner un fichier csv.", - "Click here to select a py file.": "", - "Click here to select documents.": "Cliquez ici pour sélectionner des documents.", + "Click here to select a csv file.": "Cliquez ici pour sélectionner un fichier CSV.", + "Click here to select a py file.": "Cliquez ici pour sélectionner un fichier .py.", + "Click here to select documents.": "Cliquez ici pour sélectionner les documents.", "click here.": "cliquez ici.", - "Click on the user role button to change a user's role.": "Cliquez sur le bouton de rôle d'utilisateur pour changer le rôle d'un utilisateur.", - "Clipboard write permission denied. Please check your browser settings to grant the necessary access.": "", - "Clone": "Cloner", + "Click on the user role button to change a user's role.": "Cliquez sur le bouton de rôle d'utilisateur pour modifier le rôle d'un utilisateur.", + "Clipboard write permission denied. Please check your browser settings to grant the necessary access.": "L'autorisation d'écriture du presse-papier a été refusée. Veuillez vérifier les paramètres de votre navigateur pour accorder l'accès nécessaire.", + "Clone": "Copie conforme", "Close": "Fermer", - "Code formatted successfully": "", + "Code formatted successfully": "Le code a été formaté avec succès", "Collection": "Collection", "ComfyUI": "ComfyUI", - "ComfyUI Base URL": "ComfyUI Base URL", - "ComfyUI Base URL is required.": "ComfyUI Base URL est requis.", + "ComfyUI Base URL": "URL de base ComfyUI", + "ComfyUI Base URL is required.": "L'URL de base ComfyUI est requise.", "Command": "Commande", - "Concurrent Requests": "Demandes simultanées", - "Confirm": "", + "Concurrent Requests": "Demandes concurrentes", + "Confirm": "Confirmer", "Confirm Password": "Confirmer le mot de passe", - "Confirm your action": "", + "Confirm your action": "Confirmez votre action", "Connections": "Connexions", - "Contact Admin for WebUI Access": "", + "Contact Admin for WebUI Access": "Contacter l'administrateur pour l'accès à l'interface Web", "Content": "Contenu", "Context Length": "Longueur du contexte", "Continue Response": "Continuer la réponse", - "Continue with {{provider}}": "", - "Copied shared chat URL to clipboard!": "URL de chat partagé copié dans le presse-papier !", - "Copy": "Copier", + "Continue with {{provider}}": "Continuer avec {{provider}}", + "Copied shared chat URL to clipboard!": "URL du chat copiée dans le presse-papiers !", + "Copy": "Copie", "Copy last code block": "Copier le dernier bloc de code", "Copy last response": "Copier la dernière réponse", "Copy Link": "Copier le lien", "Copying to clipboard was successful!": "La copie dans le presse-papiers a réussi !", "Create a model": "Créer un modèle", "Create Account": "Créer un compte", - "Create new key": "Créer une nouvelle clé", + "Create new key": "Créer une nouvelle clé principale", "Create new secret key": "Créer une nouvelle clé secrète", - "Created at": "Créé le", + "Created at": "Créé à", "Created At": "Créé le", - "Created by": "", - "CSV Import": "", - "Current Model": "Modèle actuel", + "Created by": "Créé par", + "CSV Import": "Import CSV", + "Current Model": "Modèle actuel amélioré", "Current Password": "Mot de passe actuel", - "Custom": "Personnalisé", - "Customize models for a specific purpose": "Personnaliser les modèles dans un but spécifique", - "Dark": "Sombre", - "Dashboard": "", + "Custom": "Sur mesure", + "Customize models for a specific purpose": "Personnaliser les modèles pour une fonction spécifique", + "Dark": "Obscur", + "Dashboard": "Tableau de bord", "Database": "Base de données", "December": "Décembre", "Default": "Par défaut", "Default (Automatic1111)": "Par défaut (Automatic1111)", - "Default (SentenceTransformers)": "Par défaut (SentenceTransformers)", - "Default Model": "Modèle par défaut", + "Default (SentenceTransformers)": "Par défaut (Sentence Transformers)", + "Default Model": "Modèle standard", "Default model updated": "Modèle par défaut mis à jour", - "Default Prompt Suggestions": "Suggestions de prompt par défaut", - "Default User Role": "Rôle d'utilisateur par défaut", + "Default Prompt Suggestions": "Suggestions de prompts par défaut", + "Default User Role": "Rôle utilisateur par défaut", "delete": "supprimer", "Delete": "Supprimer", "Delete a model": "Supprimer un modèle", - "Delete All Chats": "Supprimer tous les chats", - "Delete chat": "Supprimer la discussion", - "Delete Chat": "Supprimer la discussion", - "Delete chat?": "", - "Delete function?": "", - "Delete prompt?": "", + "Delete All Chats": "Supprimer toutes les conversations", + "Delete chat": "Supprimer la conversation", + "Delete Chat": "Supprimer la Conversation", + "Delete chat?": "Supprimer la conversation ?", + "Delete function?": "Supprimer la fonction ?", + "Delete prompt?": "Supprimer la prompt ?", "delete this link": "supprimer ce lien", - "Delete tool?": "", - "Delete User": "Supprimer l'utilisateur", - "Deleted {{deleteModelTag}}": "{{deleteModelTag}} supprimé", - "Deleted {{name}}": "Supprimé {{nom}}", + "Delete tool?": "Effacer l'outil ?", + "Delete User": "Supprimer le compte d'utilisateur", + "Deleted {{deleteModelTag}}": "Supprimé {{deleteModelTag}}", + "Deleted {{name}}": "Supprimé {{name}}", "Description": "Description", - "Didn't fully follow instructions": "Ne suit pas les instructions", - "Discover a function": "", - "Discover a model": "Découvrez un modèle", - "Discover a prompt": "Découvrir un prompt", - "Discover a tool": "", - "Discover, download, and explore custom functions": "", - "Discover, download, and explore custom prompts": "Découvrir, télécharger et explorer des prompts personnalisés", - "Discover, download, and explore custom tools": "", - "Discover, download, and explore model presets": "Découvrir, télécharger et explorer des préconfigurations de modèles", - "Dismissible": "", - "Display Emoji in Call": "", - "Display the username instead of You in the Chat": "Afficher le nom d'utilisateur au lieu de 'Vous' dans la Discussion", + "Didn't fully follow instructions": "N'a pas entièrement respecté les instructions", + "Discover a function": "Découvrez une fonction", + "Discover a model": "Découvrir un modèle", + "Discover a prompt": "Découvrir une suggestion", + "Discover a tool": "Découvrez un outil", + "Discover, download, and explore custom functions": "Découvrez, téléchargez et explorez des fonctions personnalisées", + "Discover, download, and explore custom prompts": "Découvrez, téléchargez et explorez des prompts personnalisés", + "Discover, download, and explore custom tools": "Découvrez, téléchargez et explorez des outils personnalisés", + "Discover, download, and explore model presets": "Découvrir, télécharger et explorer des préréglages de modèles", + "Dismissible": "Fermeture", + "Display Emoji in Call": "Afficher les emojis pendant l'appel", + "Display the username instead of You in the Chat": "Afficher le nom d'utilisateur à la place de \"Vous\" dans le Chat", "Document": "Document", "Document Settings": "Paramètres du document", - "Documentation": "", + "Documentation": "Documentation", "Documents": "Documents", - "does not make any external connections, and your data stays securely on your locally hosted server.": "ne fait aucune connexion externe, et vos données restent en sécurité sur votre serveur hébergé localement.", + "does not make any external connections, and your data stays securely on your locally hosted server.": "ne fait aucune connexion externe et garde vos données en sécurité sur votre serveur local.", "Don't Allow": "Ne pas autoriser", "Don't have an account?": "Vous n'avez pas de compte ?", - "Don't like the style": "Vous n'aimez pas le style ?", - "Done": "", + "Don't like the style": "N'apprécie pas le style", + "Done": "Terminé", "Download": "Télécharger", "Download canceled": "Téléchargement annulé", "Download Database": "Télécharger la base de données", - "Drop any files here to add to the conversation": "Déposez n'importe quel fichier ici pour les ajouter à la conversation", - "e.g. '30s','10m'. Valid time units are 's', 'm', 'h'.": "p. ex. '30s', '10m'. Les unités de temps valides sont 's', 'm', 'h'.", - "Edit": "Éditer", - "Edit Doc": "Éditer le document", - "Edit Memory": "", - "Edit User": "Éditer l'utilisateur", - "Email": "Email", - "Embedding Batch Size": "", + "Drop any files here to add to the conversation": "Déposez des fichiers ici pour les ajouter à la conversation", + "e.g. '30s','10m'. Valid time units are 's', 'm', 'h'.": "par ex. '30s', '10 min'. Les unités de temps valides sont 's', 'm', 'h'.", + "Edit": "Modifier", + "Edit Doc": "Modifier le document", + "Edit Memory": "Modifier la mémoire", + "Edit User": "Modifier l'utilisateur", + "Email": "E-mail", + "Embedding Batch Size": "Taille du lot d'encodage", "Embedding Model": "Modèle d'embedding", - "Embedding Model Engine": "Moteur du modèle d'embedding", - "Embedding model set to \"{{embedding_model}}\"": "Modèle d'embedding défini sur \"{{embedding_model}}\"", - "Enable Chat History": "Activer l'historique des discussions", - "Enable Community Sharing": "Permettre le partage communautaire", + "Embedding Model Engine": "Moteur de modèle d'encodage", + "Embedding model set to \"{{embedding_model}}\"": "Modèle d'encodage défini sur « {{embedding_model}} »", + "Enable Chat History": "Activer l'historique de conversation", + "Enable Community Sharing": "Activer le partage communautaire", "Enable New Sign Ups": "Activer les nouvelles inscriptions", - "Enable Web Search": "Activer la recherche sur le Web", - "Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.": "Assurez-vous que votre fichier CSV inclut 4 colonnes dans cet ordre : Nom, Email, Mot de passe, Rôle.", + "Enable Web Search": "Activer la recherche web", + "Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.": "Vérifiez que votre fichier CSV comprenne les 4 colonnes dans cet ordre : Name, Email, Password, Role.", "Enter {{role}} message here": "Entrez le message {{role}} ici", - "Enter a detail about yourself for your LLMs to recall": "Entrez un détail sur vous pour que vos LLMs puissent le rappeler", - "Enter api auth string (e.g. username:password)": "", - "Enter Brave Search API Key": "Entrez la clé API de recherche brave", - "Enter Chunk Overlap": "Entrez le chevauchement de bloc", - "Enter Chunk Size": "Entrez la taille du bloc", - "Enter Github Raw URL": "Entrez l’URL Github Raw", + "Enter a detail about yourself for your LLMs to recall": "Saisissez un détail sur vous-même que vos LLMs pourront se rappeler", + "Enter api auth string (e.g. username:password)": "Entrez la chaîne d'authentification de l'API (par ex. nom d'utilisateur:mot de passe)", + "Enter Brave Search API Key": "Entrez la clé API Brave Search", + "Enter Chunk Overlap": "Entrez le chevauchement de chunk", + "Enter Chunk Size": "Entrez la taille de bloc", + "Enter Github Raw URL": "Entrez l'URL brute de GitHub", "Enter Google PSE API Key": "Entrez la clé API Google PSE", - "Enter Google PSE Engine Id": "Entrez l’id du moteur Google PSE", - "Enter Image Size (e.g. 512x512)": "Entrez la taille de l'image (p. ex. 512x512)", + "Enter Google PSE Engine Id": "Entrez l'identifiant du moteur Google PSE", + "Enter Image Size (e.g. 512x512)": "Entrez la taille de l'image (par ex. 512x512)", "Enter language codes": "Entrez les codes de langue", - "Enter model tag (e.g. {{modelTag}})": "Entrez le tag du modèle (p. ex. {{modelTag}})", - "Enter Number of Steps (e.g. 50)": "Entrez le nombre d'étapes (p. ex. 50)", - "Enter Score": "Entrez le score", - "Enter Searxng Query URL": "Entrez l’URL de requête Searxng", + "Enter model tag (e.g. {{modelTag}})": "Entrez l'étiquette du modèle (par ex. {{modelTag}})", + "Enter Number of Steps (e.g. 50)": "Entrez le nombre de pas (par ex. 50)", + "Enter Score": "Entrez votre score", + "Enter Searxng Query URL": "Entrez l'URL de la requête Searxng", "Enter Serper API Key": "Entrez la clé API Serper", - "Enter Serply API Key": "", - "Enter Serpstack API Key": "Entrez dans la clé API Serpstack", - "Enter stop sequence": "Entrez la séquence de fin", - "Enter Tavily API Key": "", - "Enter Top K": "Entrez Top K", - "Enter URL (e.g. http://127.0.0.1:7860/)": "Entrez l'URL (p. ex. http://127.0.0.1:7860/)", - "Enter URL (e.g. http://localhost:11434)": "Entrez l'URL (p. ex. http://localhost:11434)", - "Enter Your Email": "Entrez votre adresse email", + "Enter Serply API Key": "Entrez la clé API Serply", + "Enter Serpstack API Key": "Entrez la clé API Serpstack", + "Enter stop sequence": "Entrez la séquence d'arrêt", + "Enter Tavily API Key": "Entrez la clé API Tavily", + "Enter Top K": "Entrez les Top K", + "Enter URL (e.g. http://127.0.0.1:7860/)": "Entrez l'URL (par ex. {http://127.0.0.1:7860/})", + "Enter URL (e.g. http://localhost:11434)": "Entrez l'URL (par ex. http://localhost:11434)", + "Enter Your Email": "Entrez votre adresse e-mail", "Enter Your Full Name": "Entrez votre nom complet", "Enter Your Password": "Entrez votre mot de passe", "Enter Your Role": "Entrez votre rôle", "Error": "Erreur", "Experimental": "Expérimental", "Export": "Exportation", - "Export All Chats (All Users)": "Exporter toutes les discussions (Tous les utilisateurs)", - "Export chat (.json)": "", - "Export Chats": "Exporter les discussions", - "Export Documents Mapping": "Exporter le mappage des documents", - "Export Functions": "", - "Export LiteLLM config.yaml": "", - "Export Models": "Modèles d’exportation", - "Export Prompts": "Exporter les prompts", - "Export Tools": "", - "External Models": "", - "Failed to create API Key.": "Impossible de créer la clé API.", + "Export All Chats (All Users)": "Exporter toutes les conversations (tous les utilisateurs)", + "Export chat (.json)": "Exporter la discussion (.json)", + "Export Chats": "Exporter les conversations", + "Export Documents Mapping": "Exportez la correspondance des documents", + "Export Functions": "Exportez les Fonctions", + "Export LiteLLM config.yaml": "Exportez le fichier LiteLLM config.yaml", + "Export Models": "Exporter les modèles", + "Export Prompts": "Exporter les Prompts", + "Export Tools": "Outils d'exportation", + "External Models": "Modèles externes", + "Failed to create API Key.": "Échec de la création de la clé API.", "Failed to read clipboard contents": "Échec de la lecture du contenu du presse-papiers", - "Failed to update settings": "", + "Failed to update settings": "Échec de la mise à jour des paramètres", "February": "Février", - "Feel free to add specific details": "Vous pouvez ajouter des détails spécifiques", - "File": "", + "Feel free to add specific details": "N'hésitez pas à ajouter des détails spécifiques", + "File": "Fichier", "File Mode": "Mode fichier", "File not found.": "Fichier introuvable.", - "Filter is now globally disabled": "", - "Filter is now globally enabled": "", - "Filters": "", - "Fingerprint spoofing detected: Unable to use initials as avatar. Defaulting to default profile image.": "Détection de falsification de empreinte digitale\u00a0: impossible d'utiliser les initiales comme avatar. Par défaut, l'image de profil par défaut est utilisée.", - "Fluidly stream large external response chunks": "Diffusez de manière fluide de gros morceaux de réponses externes", - "Focus chat input": "Se concentrer sur l'entrée de la discussion", - "Followed instructions perfectly": "Suivi des instructions parfaitement", - "Form": "", - "Format your variables using square brackets like this:": "Formatez vos variables en utilisant des crochets comme ceci :", + "Filter is now globally disabled": "Le filtre est maintenant désactivé globalement", + "Filter is now globally enabled": "Le filtre est désormais activé globalement", + "Filters": "Filtres", + "Fingerprint spoofing detected: Unable to use initials as avatar. Defaulting to default profile image.": "Spoofing détecté : impossible d'utiliser les initiales comme avatar. Retour à l'image de profil par défaut.", + "Fluidly stream large external response chunks": "Diffuser de manière fluide de larges portions de réponses externes", + "Focus chat input": "Se concentrer sur le chat en entrée", + "Followed instructions perfectly": "A parfaitement suivi les instructions", + "Form": "Formulaire", + "Format your variables using square brackets like this:": "Formatez vos variables en utilisant des crochets comme suit :", "Frequency Penalty": "Pénalité de fréquence", - "Function created successfully": "", - "Function deleted successfully": "", - "Function updated successfully": "", - "Functions": "", - "Functions imported successfully": "", + "Function created successfully": "La fonction a été créée avec succès", + "Function deleted successfully": "Fonction supprimée avec succès", + "Function updated successfully": "La fonction a été mise à jour avec succès", + "Functions": "Fonctions", + "Functions imported successfully": "Fonctions importées avec succès", "General": "Général", - "General Settings": "Paramètres généraux", - "Generate Image": "", - "Generating search query": "Génération d’une requête de recherche", - "Generation Info": "Informations de génération", - "Global": "", + "General Settings": "Paramètres Généraux", + "Generate Image": "Générer une image", + "Generating search query": "Génération d'une requête de recherche", + "Generation Info": "Informations sur la génération", + "Global": "Mondial", "Good Response": "Bonne réponse", - "Google PSE API Key": "Clé d’API Google PSE", - "Google PSE Engine Id": "Id du moteur Google PSE", + "Google PSE API Key": "Clé API Google PSE", + "Google PSE Engine Id": "ID du moteur de recherche personnalisé de Google", "h:mm a": "h:mm a", - "has no conversations.": "n'a pas de conversations.", - "Hello, {{name}}": "Bonjour, {{name}}", + "has no conversations.": "n'a aucune conversation.", + "Hello, {{name}}": "Bonjour, {{name}}.", "Help": "Aide", "Hide": "Cacher", - "Hide Model": "", - "How can I help you today?": "Comment puis-je vous aider aujourd'hui ?", + "Hide Model": "Masquer le modèle", + "How can I help you today?": "Comment puis-je vous être utile aujourd'hui ?", "Hybrid Search": "Recherche hybride", - "Image Generation (Experimental)": "Génération d'image (Expérimental)", - "Image Generation Engine": "Moteur de génération d'image", + "Image Generation (Experimental)": "Génération d'images (expérimental)", + "Image Generation Engine": "Moteur de génération d'images", "Image Settings": "Paramètres de l'image", "Images": "Images", "Import Chats": "Importer les discussions", - "Import Documents Mapping": "Importer le mappage des documents", - "Import Functions": "", + "Import Documents Mapping": "Import de la correspondance des documents", + "Import Functions": "Import de fonctions", "Import Models": "Importer des modèles", - "Import Prompts": "Importer les prompts", - "Import Tools": "", - "Include `--api-auth` flag when running stable-diffusion-webui": "", - "Include `--api` flag when running stable-diffusion-webui": "Inclure l'indicateur `--api` lors de l'exécution de stable-diffusion-webui", - "Info": "L’info", - "Input commands": "Entrez des commandes d'entrée", - "Install from Github URL": "Installer à partir de l’URL Github", - "Instant Auto-Send After Voice Transcription": "", - "Interface": "Interface", - "Invalid Tag": "Tag invalide", + "Import Prompts": "Importer des Enseignes", + "Import Tools": "Outils d'importation", + "Include `--api-auth` flag when running stable-diffusion-webui": "Inclure le drapeau `--api-auth` lors de l'exécution de stable-diffusion-webui", + "Include `--api` flag when running stable-diffusion-webui": "Inclure le drapeau `--api` lorsque vous exécutez stable-diffusion-webui", + "Info": "Info", + "Input commands": "Entrez les commandes", + "Install from Github URL": "Installer depuis l'URL GitHub", + "Instant Auto-Send After Voice Transcription": "Envoi automatique instantané après transcription vocale", + "Interface": "Interface utilisateur", + "Invalid Tag": "Étiquette non valide", "January": "Janvier", - "join our Discord for help.": "rejoignez notre Discord pour obtenir de l'aide.", + "join our Discord for help.": "Rejoignez notre Discord pour obtenir de l'aide.", "JSON": "JSON", "JSON Preview": "Aperçu JSON", "July": "Juillet", "June": "Juin", - "JWT Expiration": "Expiration du JWT", + "JWT Expiration": "Expiration du jeton JWT", "JWT Token": "Jeton JWT", - "Keep Alive": "Garder actif", + "Keep Alive": "Rester connecté", "Keyboard shortcuts": "Raccourcis clavier", - "Knowledge": "", + "Knowledge": "Connaissance", "Language": "Langue", "Last Active": "Dernière activité", - "Last Modified": "", - "Light": "Lumière", - "Listening...": "", - "LLMs can make mistakes. Verify important information.": "Les LLMs peuvent faire des erreurs. Vérifiez les informations importantes.", - "Local Models": "", + "Last Modified": "Dernière modification", + "Light": "Lumineux", + "Listening...": "En train d'écouter...", + "LLMs can make mistakes. Verify important information.": "Les LLM peuvent faire des erreurs. Vérifiez les informations importantes.", + "Local Models": "Modèles locaux", "LTR": "LTR", "Made by OpenWebUI Community": "Réalisé par la communauté OpenWebUI", - "Make sure to enclose them with": "Assurez-vous de les entourer avec", - "Manage": "", - "Manage Models": "Gérer les modèles", + "Make sure to enclose them with": "Assurez-vous de les inclure dans", + "Manage": "Gérer", + "Manage Models": "Gérer les Modèles", "Manage Ollama Models": "Gérer les modèles Ollama", "Manage Pipelines": "Gérer les pipelines", - "Manage Valves": "", + "Manage Valves": "Gérer les vannes", "March": "Mars", - "Max Tokens (num_predict)": "Max Tokens (num_predict)", - "Maximum of 3 models can be downloaded simultaneously. Please try again later.": "Un maximum de 3 modèles peut être téléchargé simultanément. Veuillez réessayer plus tard.", + "Max Tokens (num_predict)": "Tokens maximaux (num_predict)", + "Maximum of 3 models can be downloaded simultaneously. Please try again later.": "Un maximum de 3 modèles peut être téléchargé en même temps. Veuillez réessayer ultérieurement.", "May": "Mai", - "Memories accessible by LLMs will be shown here.": "Les mémoires accessibles par les LLM seront affichées ici.", + "Memories accessible by LLMs will be shown here.": "Les mémoires accessibles par les LLMs seront affichées ici.", "Memory": "Mémoire", - "Memory added successfully": "", - "Memory cleared successfully": "", - "Memory deleted successfully": "", - "Memory updated successfully": "", - "Messages you send after creating your link won't be shared. Users with the URL will be able to view the shared chat.": "Les messages que vous envoyez après la création de votre lien ne seront pas partagés. Les utilisateurs avec l'URL pourront voir le chat partagé.", - "Minimum Score": "Score minimum", + "Memory added successfully": "Mémoire ajoutée avec succès", + "Memory cleared successfully": "La mémoire a été effacée avec succès", + "Memory deleted successfully": "La mémoire a été supprimée avec succès", + "Memory updated successfully": "La mémoire a été mise à jour avec succès", + "Messages you send after creating your link won't be shared. Users with the URL will be able to view the shared chat.": "Les messages que vous envoyez après avoir créé votre lien ne seront pas partagés. Les utilisateurs disposant de l'URL pourront voir le chat partagé.", + "Minimum Score": "Score minimal", "Mirostat": "Mirostat", "Mirostat Eta": "Mirostat Eta", "Mirostat Tau": "Mirostat Tau", - "MMMM DD, YYYY": "MMMM DD, YYYY", - "MMMM DD, YYYY HH:mm": "MMMM DD, YYYY HH:mm", - "MMMM DD, YYYY hh:mm:ss A": "", + "MMMM DD, YYYY": "MM DD, AAAA", + "MMMM DD, YYYY HH:mm": "MM MDDD, AAAA HH:mm", + "MMMM DD, YYYY hh:mm:ss A": "jj MM, aaaa HH:mm:ss", "Model '{{modelName}}' has been successfully downloaded.": "Le modèle '{{modelName}}' a été téléchargé avec succès.", "Model '{{modelTag}}' is already in queue for downloading.": "Le modèle '{{modelTag}}' est déjà dans la file d'attente pour le téléchargement.", - "Model {{modelId}} not found": "Modèle {{modelId}} non trouvé", - "Model {{modelName}} is not vision capable": "Le modèle {{modelName}} n’est pas capable de vision", - "Model {{name}} is now {{status}}": "Le modèle {{nom}} est maintenant {{statut}}", - "Model created successfully!": "", - "Model filesystem path detected. Model shortname is required for update, cannot continue.": "Le chemin du système de fichiers du modèle a été détecté. Le nom court du modèle est nécessaire pour la mise à jour, impossible de continuer.", - "Model ID": "ID de modèle", + "Model {{modelId}} not found": "Modèle {{modelId}} introuvable", + "Model {{modelName}} is not vision capable": "Le modèle {{modelName}} n'a pas de capacités visuelles", + "Model {{name}} is now {{status}}": "Le modèle {{name}} est désormais {{status}}.", + "Model created successfully!": "Le modèle a été créé avec succès !", + "Model filesystem path detected. Model shortname is required for update, cannot continue.": "Chemin du système de fichiers de modèle détecté. Le nom court du modèle est requis pour la mise à jour, l'opération ne peut pas être poursuivie.", + "Model ID": "ID du modèle", "Model not selected": "Modèle non sélectionné", - "Model Params": "Paramètres modèles", - "Model updated successfully": "", - "Model Whitelisting": "Liste blanche de modèle", - "Model(s) Whitelisted": "Modèle(s) sur liste blanche", - "Modelfile Content": "Contenu du fichier de modèle", + "Model Params": "Paramètres du modèle", + "Model updated successfully": "Le modèle a été mis à jour avec succès", + "Model Whitelisting": "Liste blanche de modèles", + "Model(s) Whitelisted": "Modèle(s) Autorisé(s)", + "Modelfile Content": "Contenu du Fichier de Modèle", "Models": "Modèles", - "More": "Plus", + "More": "Plus de", "Name": "Nom", - "Name Tag": "Tag de nom", + "Name Tag": "Étiquette de nom", "Name your model": "Nommez votre modèle", - "New Chat": "Nouvelle discussion", + "New Chat": "Nouvelle conversation", "New Password": "Nouveau mot de passe", - "No content to speak": "", - "No documents found": "", - "No file selected": "", + "No content to speak": "Rien à signaler", + "No documents found": "Aucun document trouvé", + "No file selected": "Aucun fichier sélectionné", "No results found": "Aucun résultat trouvé", "No search query generated": "Aucune requête de recherche générée", - "No source available": "Aucune source disponible", - "No valves to update": "", - "None": "Aucune", - "Not factually correct": "Non, pas exactement correct", - "Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.": "Note: Si vous définissez un score minimum, la recherche ne retournera que les documents avec un score supérieur ou égal au score minimum.", - "Notifications": "Notifications de bureau", + "No source available": "Aucune source n'est disponible", + "No valves to update": "Aucune vanne à mettre à jour", + "None": "Aucun", + "Not factually correct": "Non factuellement correct", + "Note: If you set a minimum score, the search will only return documents with a score greater than or equal to the minimum score.": "Note : Si vous définissez un score minimum, seuls les documents ayant un score supérieur ou égal à ce score minimum seront retournés par la recherche.", + "Notifications": "Notifications", "November": "Novembre", "num_thread (Ollama)": "num_thread (Ollama)", - "OAuth ID": "", + "OAuth ID": "ID OAuth", "October": "Octobre", - "Off": "Éteint", - "Okay, Let's Go!": "Okay, Allons-y !", - "OLED Dark": "OLED Sombre", + "Off": "Désactivé", + "Okay, Let's Go!": "D'accord, on y va !", + "OLED Dark": "Noir OLED", "Ollama": "Ollama", - "Ollama API": "Ollama API", + "Ollama API": "API Ollama", "Ollama API disabled": "API Ollama désactivée", - "Ollama API is disabled": "", - "Ollama Version": "Version Ollama", + "Ollama API is disabled": "L'API Ollama est désactivée", + "Ollama Version": "Version Ollama améliorée", "On": "Activé", "Only": "Seulement", "Only alphanumeric characters and hyphens are allowed in the command string.": "Seuls les caractères alphanumériques et les tirets sont autorisés dans la chaîne de commande.", - "Oops! Hold tight! Your files are still in the processing oven. We're cooking them up to perfection. Please be patient and we'll let you know once they're ready.": "Oups ! Tenez bon ! Vos fichiers sont encore dans le four de traitement. Nous les préparons jusqu'à la perfection. Soyez patient et nous vous informerons dès qu'ils seront prêts.", - "Oops! Looks like the URL is invalid. Please double-check and try again.": "Oups ! Il semble que l'URL soit invalide. Merci de vérifier et réessayer.", - "Oops! There was an error in the previous response. Please try again or contact admin.": "", - "Oops! You're using an unsupported method (frontend only). Please serve the WebUI from the backend.": "Oups ! Vous utilisez une méthode non prise en charge (frontal uniquement). Veuillez servir WebUI depuis le backend.", - "Open": "Ouvrir", + "Oops! Hold tight! Your files are still in the processing oven. We're cooking them up to perfection. Please be patient and we'll let you know once they're ready.": "Oups ! Un instant ! Vos fichiers sont toujours en train d'être traités. Nous les perfectionnons pour vous. Veuillez patienter, nous vous informerons dès qu'ils seront prêts.", + "Oops! Looks like the URL is invalid. Please double-check and try again.": "Oups ! Il semble que l'URL soit invalide. Veuillez vérifier à nouveau et réessayer.", + "Oops! There was an error in the previous response. Please try again or contact admin.": "Oops ! Il y a eu une erreur dans la réponse précédente. Veuillez réessayer ou contacter l'administrateur.", + "Oops! You're using an unsupported method (frontend only). Please serve the WebUI from the backend.": "Oups ! Vous utilisez une méthode non prise en charge (frontend uniquement). Veuillez servir l'interface Web à partir du backend.", + "Open": "Ouvrez", "Open AI (Dall-E)": "Open AI (Dall-E)", "Open new chat": "Ouvrir une nouvelle discussion", "OpenAI": "OpenAI", "OpenAI API": "API OpenAI", - "OpenAI API Config": "Configuration API OpenAI", - "OpenAI API Key is required.": "La clé API OpenAI est requise.", - "OpenAI URL/Key required.": "L'URL/Clé OpenAI est requise.", + "OpenAI API Config": "Configuration de l'API OpenAI", + "OpenAI API Key is required.": "Une clé API OpenAI est requise.", + "OpenAI URL/Key required.": "URL/Clé OpenAI requise.", "or": "ou", "Other": "Autre", "Password": "Mot de passe", - "PDF document (.pdf)": "Document PDF (.pdf)", + "PDF document (.pdf)": "Document au format PDF (.pdf)", "PDF Extract Images (OCR)": "Extraction d'images PDF (OCR)", "pending": "en attente", - "Permission denied when accessing media devices": "", - "Permission denied when accessing microphone": "", + "Permission denied when accessing media devices": "Accès aux appareils multimédias refusé", + "Permission denied when accessing microphone": "Autorisation refusée lors de l'accès au micro", "Permission denied when accessing microphone: {{error}}": "Permission refusée lors de l'accès au microphone : {{error}}", "Personalization": "Personnalisation", - "Pipeline deleted successfully": "", - "Pipeline downloaded successfully": "", + "Pipeline deleted successfully": "Le pipeline a été supprimé avec succès", + "Pipeline downloaded successfully": "Le pipeline a été téléchargé avec succès", "Pipelines": "Pipelines", - "Pipelines Not Detected": "", - "Pipelines Valves": "Vannes de pipelines", - "Plain text (.txt)": "Texte brut (.txt)", - "Playground": "Aire de jeu", + "Pipelines Not Detected": "Aucun pipelines détecté", + "Pipelines Valves": "Vannes de Pipelines", + "Plain text (.txt)": "Texte simple (.txt)", + "Playground": "Aire de jeux", "Positive attitude": "Attitude positive", "Previous 30 days": "30 derniers jours", "Previous 7 days": "7 derniers jours", "Profile Image": "Image de profil", "Prompt": "Prompt", - "Prompt (e.g. Tell me a fun fact about the Roman Empire)": "Prompt (par exemple, Dites-moi un fait amusant sur l'Imperium romain)", + "Prompt (e.g. Tell me a fun fact about the Roman Empire)": "Prompt (par ex. Dites-moi un fait amusant à propos de l'Empire romain)", "Prompt Content": "Contenu du prompt", - "Prompt suggestions": "Suggestions de prompt", + "Prompt suggestions": "Suggestions pour le prompt", "Prompts": "Prompts", - "Pull \"{{searchValue}}\" from Ollama.com": "Tirer \"{{searchValue}}\" de Ollama.com", - "Pull a model from Ollama.com": "Tirer un modèle de Ollama.com", + "Pull \"{{searchValue}}\" from Ollama.com": "Récupérer « {{searchValue}} » depuis Ollama.com", + "Pull a model from Ollama.com": "Télécharger un modèle depuis Ollama.com", "Query Params": "Paramètres de requête", "RAG Template": "Modèle RAG", - "Read Aloud": "Lire à l'échelle", + "Read Aloud": "Lire à haute voix", "Record voice": "Enregistrer la voix", - "Redirecting you to OpenWebUI Community": "Vous redirige vers la communauté OpenWebUI", - "Refer to yourself as \"User\" (e.g., \"User is learning Spanish\")": "", - "Refused when it shouldn't have": "Refusé quand il ne devrait pas l'être", - "Regenerate": "Régénérer", - "Release Notes": "Notes de version", - "Remove": "Supprimer", - "Remove Model": "Supprimer le modèle", + "Redirecting you to OpenWebUI Community": "Redirection vers la communauté OpenWebUI", + "Refer to yourself as \"User\" (e.g., \"User is learning Spanish\")": "Désignez-vous comme « Utilisateur » (par ex. « L'utilisateur apprend l'espagnol »)", + "Refused when it shouldn't have": "Refusé alors qu'il n'aurait pas dû l'être", + "Regenerate": "Regénérer", + "Release Notes": "Notes de publication", + "Remove": "Retirer", + "Remove Model": "Retirer le modèle", "Rename": "Renommer", "Repeat Last N": "Répéter les N derniers", - "Request Mode": "Mode de requête", - "Reranking Model": "Modèle de reranking", - "Reranking model disabled": "Modèle de reranking désactivé", - "Reranking model set to \"{{reranking_model}}\"": "Modèle de reranking défini sur \"{{reranking_model}}\"", - "Reset": "", - "Reset Upload Directory": "", - "Reset Vector Storage": "Réinitialiser le stockage vectoriel", + "Request Mode": "Mode de Requête", + "Reranking Model": "Modèle de ré-ranking", + "Reranking model disabled": "Modèle de ré-ranking désactivé", + "Reranking model set to \"{{reranking_model}}\"": "Modèle de ré-ranking défini sur « {{reranking_model}} »", + "Reset": "Réinitialiser", + "Reset Upload Directory": "Répertoire de téléchargement réinitialisé", + "Reset Vector Storage": "Réinitialiser le stockage des vecteurs", "Response AutoCopy to Clipboard": "Copie automatique de la réponse vers le presse-papiers", - "Response notifications cannot be activated as the website permissions have been denied. Please visit your browser settings to grant the necessary access.": "", + "Response notifications cannot be activated as the website permissions have been denied. Please visit your browser settings to grant the necessary access.": "Les notifications de réponse ne peuvent pas être activées car les autorisations du site web ont été refusées. Veuillez visiter les paramètres de votre navigateur pour accorder l'accès nécessaire.", "Role": "Rôle", - "Rosé Pine": "Pin Rosé", - "Rosé Pine Dawn": "Aube Pin Rosé", + "Rosé Pine": "Pin rosé", + "Rosé Pine Dawn": "Aube de Pin Rosé", "RTL": "RTL", - "Running": "", + "Running": "Courir", "Save": "Enregistrer", "Save & Create": "Enregistrer & Créer", "Save & Update": "Enregistrer & Mettre à jour", - "Saving chat logs directly to your browser's storage is no longer supported. Please take a moment to download and delete your chat logs by clicking the button below. Don't worry, you can easily re-import your chat logs to the backend through": "La sauvegarde des journaux de discussion directement dans le stockage de votre navigateur n'est plus prise en charge. Veuillez prendre un moment pour télécharger et supprimer vos journaux de discussion en cliquant sur le bouton ci-dessous. Ne vous inquiétez pas, vous pouvez facilement réimporter vos journaux de discussion dans le backend via", + "Saving chat logs directly to your browser's storage is no longer supported. Please take a moment to download and delete your chat logs by clicking the button below. Don't worry, you can easily re-import your chat logs to the backend through": "La sauvegarde des journaux de discussion directement dans le stockage de votre navigateur n'est plus prise en charge. Veuillez prendre un instant pour télécharger et supprimer vos journaux de discussion en cliquant sur le bouton ci-dessous. Pas de soucis, vous pouvez facilement les réimporter depuis le backend via l'interface ci-dessous", "Scan": "Scanner", "Scan complete!": "Scan terminé !", "Scan for documents from {{path}}": "Scanner des documents depuis {{path}}", "Search": "Recherche", "Search a model": "Rechercher un modèle", - "Search Chats": "Rechercher des chats", - "Search Documents": "Rechercher des documents", - "Search Functions": "", - "Search Models": "Modèles de recherche", - "Search Prompts": "Rechercher des prompts", - "Search Query Generation Prompt": "", - "Search Query Generation Prompt Length Threshold": "", - "Search Result Count": "Nombre de résultats de la recherche", - "Search Tools": "", - "Searched {{count}} sites_one": "Recherche dans {{count}} sites_one", - "Searched {{count}} sites_many": "Recherche dans {{count}} sites_many", - "Searched {{count}} sites_other": "Recherche dans {{count}} sites_other", - "Searching \"{{searchQuery}}\"": "", - "Searxng Query URL": "URL de la requête Searxng", - "See readme.md for instructions": "Voir readme.md pour les instructions", - "See what's new": "Voir les nouveautés", + "Search Chats": "Rechercher des conversations", + "Search Documents": "Recherche de documents", + "Search Functions": "Fonctions de recherche", + "Search Models": "Rechercher des modèles", + "Search Prompts": "Recherche de prompts", + "Search Query Generation Prompt": "Génération d'interrogation de recherche", + "Search Query Generation Prompt Length Threshold": "Seuil de longueur de prompt de génération de requête de recherche", + "Search Result Count": "Nombre de résultats de recherche", + "Search Tools": "Outils de recherche", + "Searched {{count}} sites_one": "Recherché {{count}} site(s)_one", + "Searched {{count}} sites_many": "Recherché {{count}} sites_many", + "Searched {{count}} sites_other": "Recherché {{count}} sites_autres", + "Searching \"{{searchQuery}}\"": "Recherche de « {{searchQuery}} »", + "Searxng Query URL": "URL de recherche Searxng", + "See readme.md for instructions": "Voir le fichier readme.md pour les instructions", + "See what's new": "Découvrez les nouvelles fonctionnalités", "Seed": "Graine", - "Select a base model": "Sélectionner un modèle de base", - "Select a engine": "", - "Select a function": "", - "Select a mode": "Sélectionnez un mode", + "Select a base model": "Sélectionnez un modèle de base", + "Select a engine": "Sélectionnez un moteur", + "Select a function": "Sélectionnez une fonction", + "Select a mode": "Choisissez un mode", "Select a model": "Sélectionnez un modèle", - "Select a pipeline": "Sélectionner un pipeline", - "Select a pipeline url": "Sélectionnez une URL de pipeline", - "Select a tool": "", - "Select an Ollama instance": "Sélectionner une instance Ollama", - "Select Documents": "", + "Select a pipeline": "Sélectionnez un pipeline", + "Select a pipeline url": "Sélectionnez l'URL du pipeline", + "Select a tool": "Sélectionnez un outil", + "Select an Ollama instance": "Sélectionnez une instance Ollama", + "Select Documents": "Sélectionnez des documents", "Select model": "Sélectionnez un modèle", - "Select only one model to call": "", - "Selected model(s) do not support image inputs": "Les modèles sélectionnés ne prennent pas en charge les entrées d’image", + "Select only one model to call": "Sélectionnez seulement un modèle pour appeler", + "Selected model(s) do not support image inputs": "Les modèle(s) sélectionné(s) ne prennent pas en charge les entrées d'images", "Send": "Envoyer", "Send a Message": "Envoyer un message", "Send message": "Envoyer un message", "September": "Septembre", "Serper API Key": "Clé API Serper", - "Serply API Key": "", + "Serply API Key": "Clé API Serply", "Serpstack API Key": "Clé API Serpstack", "Server connection verified": "Connexion au serveur vérifiée", - "Set as default": "Définir par défaut", + "Set as default": "Définir comme valeur par défaut", "Set Default Model": "Définir le modèle par défaut", - "Set embedding model (e.g. {{model}})": "Définir le modèle d'embedding (par exemple {{model}})", + "Set embedding model (e.g. {{model}})": "Définir le modèle d'encodage (par ex. {{model}})", "Set Image Size": "Définir la taille de l'image", - "Set reranking model (e.g. {{model}})": "Définir le modèle de reranking (par exemple {{model}})", + "Set reranking model (e.g. {{model}})": "Définir le modèle de reclassement (par ex. {{model}})", "Set Steps": "Définir les étapes", "Set Task Model": "Définir le modèle de tâche", "Set Voice": "Définir la voix", "Settings": "Paramètres", "Settings saved successfully!": "Paramètres enregistrés avec succès !", - "Settings updated successfully": "", + "Settings updated successfully": "Les paramètres ont été mis à jour avec succès", "Share": "Partager", - "Share Chat": "Partager le chat", + "Share Chat": "Partage de conversation", "Share to OpenWebUI Community": "Partager avec la communauté OpenWebUI", - "short-summary": "résumé court", - "Show": "Afficher", - "Show Admin Details in Account Pending Overlay": "", - "Show Model": "", + "short-summary": "résumé concis", + "Show": "Montrer", + "Show Admin Details in Account Pending Overlay": "Afficher les détails de l'administrateur dans la superposition en attente du compte", + "Show Model": "Montrer le modèle", "Show shortcuts": "Afficher les raccourcis", - "Show your support!": "", - "Showcased creativity": "Créativité affichée", + "Show your support!": "Montre ton soutien !", + "Showcased creativity": "Créativité mise en avant", "sidebar": "barre latérale", - "Sign in": "Se connecter", - "Sign Out": "Se déconnecter", - "Sign up": "S'inscrire", - "Signing in": "Connexion", + "Sign in": "S'identifier", + "Sign Out": "Déconnexion", + "Sign up": "Inscrivez-vous", + "Signing in": "Connexion en cours", "Source": "Source", - "Speech recognition error: {{error}}": "Erreur de reconnaissance vocale : {{error}}", - "Speech-to-Text Engine": "Moteur reconnaissance vocale", + "Speech recognition error: {{error}}": "Erreur de reconnaissance vocale : {{error}}", + "Speech-to-Text Engine": "Moteur de reconnaissance vocale", "Stop Sequence": "Séquence d'arrêt", - "STT Model": "", + "STT Model": "Modèle de STT", "STT Settings": "Paramètres de STT", "Submit": "Soumettre", - "Subtitle (e.g. about the Roman Empire)": "Sous-titre (par exemple, sur l'empire romain)", - "Success": "Succès", - "Successfully updated.": "Mis à jour avec succès.", - "Suggested": "Suggéré", + "Subtitle (e.g. about the Roman Empire)": "Sous-titres (par ex. sur l'Empire romain)", + "Success": "Réussite", + "Successfully updated.": "Mise à jour réussie.", + "Suggested": "Sugéré", "System": "Système", - "System Prompt": "Prompt Système", - "Tags": "Tags", - "Tap to interrupt": "", - "Tavily API Key": "", - "Tell us more:": "Donnez-nous plus:", + "System Prompt": "Prompt du système", + "Tags": "Balises", + "Tap to interrupt": "Appuyez pour interrompre", + "Tavily API Key": "Clé API Tavily", + "Tell us more:": "Dites-nous en plus à ce sujet : ", "Temperature": "Température", - "Template": "Modèle", + "Template": "Template", "Text Completion": "Complétion de texte", - "Text-to-Speech Engine": "Moteur de texte à la parole", + "Text-to-Speech Engine": "Moteur de synthèse vocale", "Tfs Z": "Tfs Z", - "Thanks for your feedback!": "Merci pour votre feedback!", - "The score should be a value between 0.0 (0%) and 1.0 (100%).": "Le score doit être une valeur entre 0.0 (0%) et 1.0 (100%).", + "Thanks for your feedback!": "Merci pour vos commentaires !", + "The score should be a value between 0.0 (0%) and 1.0 (100%).": "Le score doit être une valeur comprise entre 0,0 (0 %) et 1,0 (100 %).", "Theme": "Thème", - "Thinking...": "", - "This action cannot be undone. Do you wish to continue?": "", - "This ensures that your valuable conversations are securely saved to your backend database. Thank you!": "Cela garantit que vos précieuses conversations sont enregistrées en toute sécurité dans votre base de données backend. Merci !", - "This is an experimental feature, it may not function as expected and is subject to change at any time.": "", - "This setting does not sync across browsers or devices.": "Ce réglage ne se synchronise pas entre les navigateurs ou les appareils.", - "This will delete": "", + "Thinking...": "En train de réfléchir...", + "This action cannot be undone. Do you wish to continue?": "Cette action ne peut pas être annulée. Souhaitez-vous continuer ?", + "This ensures that your valuable conversations are securely saved to your backend database. Thank you!": "Cela garantit que vos conversations précieuses soient sauvegardées en toute sécurité dans votre base de données backend. Merci !", + "This is an experimental feature, it may not function as expected and is subject to change at any time.": "Il s'agit d'une fonctionnalité expérimentale, elle peut ne pas fonctionner comme prévu et est sujette à modification à tout moment.", + "This setting does not sync across browsers or devices.": "Ce paramètre ne se synchronise pas entre les navigateurs ou les appareils.", + "This will delete": "Cela supprimera", "Thorough explanation": "Explication approfondie", - "Tip: Update multiple variable slots consecutively by pressing the tab key in the chat input after each replacement.": "Astuce : Mettez à jour plusieurs emplacements de variables consécutivement en appuyant sur la touche tabulation dans l'entrée de chat après chaque remplacement.", + "Tip: Update multiple variable slots consecutively by pressing the tab key in the chat input after each replacement.": "Conseil : mettez à jour plusieurs emplacements de variables consécutivement en appuyant sur la touche Tab dans l’entrée de chat après chaque remplacement.", "Title": "Titre", - "Title (e.g. Tell me a fun fact)": "Titre (par exemple, Dites-moi un fait amusant)", - "Title Auto-Generation": "Génération automatique de titre", - "Title cannot be an empty string.": "Le titre ne peut pas être une chaîne vide.", + "Title (e.g. Tell me a fun fact)": "Titre (par ex. raconte-moi un fait amusant)", + "Title Auto-Generation": "Génération automatique de titres", + "Title cannot be an empty string.": "Le titre ne peut pas être une chaîne de caractères vide.", "Title Generation Prompt": "Prompt de génération de titre", "to": "à", - "To access the available model names for downloading,": "Pour accéder aux noms de modèles disponibles pour le téléchargement,", - "To access the GGUF models available for downloading,": "Pour accéder aux modèles GGUF disponibles pour le téléchargement,", - "To access the WebUI, please reach out to the administrator. Admins can manage user statuses from the Admin Panel.": "", - "To add documents here, upload them to the \"Documents\" workspace first.": "", - "to chat input.": "à l'entrée du chat.", - "To select filters here, add them to the \"Functions\" workspace first.": "", - "To select toolkits here, add them to the \"Tools\" workspace first.": "", + "To access the available model names for downloading,": "Pour accéder aux noms des modèles disponibles en téléchargement,", + "To access the GGUF models available for downloading,": "Pour accéder aux modèles GGUF disponibles en téléchargement,", + "To access the WebUI, please reach out to the administrator. Admins can manage user statuses from the Admin Panel.": "Pour accéder à l'interface Web, veuillez contacter l'administrateur. Les administrateurs peuvent gérer les statuts des utilisateurs depuis le panneau d'administration.", + "To add documents here, upload them to the \"Documents\" workspace first.": "Pour ajouter des documents ici, téléchargez-les d'abord dans l'espace de travail « Documents ». ", + "to chat input.": "à l'entrée de discussion.", + "To select filters here, add them to the \"Functions\" workspace first.": "Pour sélectionner des filtres ici, ajoutez-les d'abord à l'espace de travail « Fonctions ». ", + "To select toolkits here, add them to the \"Tools\" workspace first.": "Pour sélectionner des toolkits ici, ajoutez-les d'abord à l'espace de travail « Outils ». ", "Today": "Aujourd'hui", "Toggle settings": "Basculer les paramètres", "Toggle sidebar": "Basculer la barre latérale", - "Tokens To Keep On Context Refresh (num_keep)": "", - "Tool created successfully": "", - "Tool deleted successfully": "", - "Tool imported successfully": "", - "Tool updated successfully": "", - "Tools": "", + "Tokens To Keep On Context Refresh (num_keep)": "Jeton à conserver pour l'actualisation du contexte (num_keep)", + "Tool created successfully": "L'outil a été créé avec succès", + "Tool deleted successfully": "Outil supprimé avec succès", + "Tool imported successfully": "Outil importé avec succès", + "Tool updated successfully": "L'outil a été mis à jour avec succès", + "Tools": "Outils", "Top K": "Top K", "Top P": "Top P", - "Trouble accessing Ollama?": "Des problèmes pour accéder à Ollama ?", - "TTS Model": "", - "TTS Settings": "Paramètres TTS", - "TTS Voice": "", + "Trouble accessing Ollama?": "Rencontrez-vous des difficultés pour accéder à Ollama ?", + "TTS Model": "Modèle de synthèse vocale", + "TTS Settings": "Paramètres de synthèse vocale", + "TTS Voice": "Voix TTS", "Type": "Type", - "Type Hugging Face Resolve (Download) URL": "Entrez l'URL de résolution (téléchargement) Hugging Face", - "Uh-oh! There was an issue connecting to {{provider}}.": "Uh-oh ! Il y a eu un problème de connexion à {{provider}}.", - "UI": "", - "Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.": "", - "Update": "", + "Type Hugging Face Resolve (Download) URL": "Entrez l'URL de Téléchargement Hugging Face Resolve", + "Uh-oh! There was an issue connecting to {{provider}}.": "Oh non ! Un problème est survenu lors de la connexion à {{provider}}.", + "UI": "Interface utilisateur", + "Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.": "Type de fichier inconnu '{{file_type}}'. Continuons tout de même le téléchargement du fichier.", + "Update": "Mise à jour", "Update and Copy Link": "Mettre à jour et copier le lien", "Update password": "Mettre à jour le mot de passe", - "Updated at": "", - "Upload": "", + "Updated at": "Mise à jour le", + "Upload": "Télécharger", "Upload a GGUF model": "Téléverser un modèle GGUF", - "Upload Files": "Téléverser des fichiers", - "Upload Pipeline": "", - "Upload Progress": "Progression du Téléversement", - "URL Mode": "Mode URL", + "Upload Files": "Télécharger des fichiers", + "Upload Pipeline": "Pipeline de téléchargement", + "Upload Progress": "Progression de l'envoi", + "URL Mode": "Mode d'URL", "Use '#' in the prompt input to load and select your documents.": "Utilisez '#' dans l'entrée de prompt pour charger et sélectionner vos documents.", - "Use Gravatar": "Utiliser Gravatar", - "Use Initials": "Utiliser les Initiales", + "Use Gravatar": "Utilisez Gravatar", + "Use Initials": "Utiliser les initiales", "use_mlock (Ollama)": "use_mlock (Ollama)", - "use_mmap (Ollama)": "use_mmap (Ollama)", + "use_mmap (Ollama)": "utiliser mmap (Ollama)", "user": "utilisateur", - "User location successfully retrieved.": "", - "User Permissions": "Permissions de l'utilisateur", + "User location successfully retrieved.": "L'emplacement de l'utilisateur a été récupéré avec succès.", + "User Permissions": "Permissions utilisateur", "Users": "Utilisateurs", - "Utilize": "Utiliser", - "Valid time units:": "Unités de temps valides :", - "Valves": "", - "Valves updated": "", - "Valves updated successfully": "", + "Utilize": "Utilisez", + "Valid time units:": "Unités de temps valides :", + "Valves": "Vannes", + "Valves updated": "Vannes mises à jour", + "Valves updated successfully": "Les vannes ont été mises à jour avec succès", "variable": "variable", - "variable to have them replaced with clipboard content.": "variable pour les remplacer par le contenu du presse-papiers.", - "Version": "Version", - "Voice": "", - "Warning": "Avertissement", - "Warning: If you update or change your embedding model, you will need to re-import all documents.": "Attention : Si vous mettez à jour ou changez votre modèle d'intégration, vous devrez réimporter tous les documents.", + "variable to have them replaced with clipboard content.": "variable pour qu'elles soient remplacées par le contenu du presse-papiers.", + "Version": "Version améliorée", + "Voice": "Voix", + "Warning": "Avertissement !", + "Warning: If you update or change your embedding model, you will need to re-import all documents.": "Avertissement : Si vous mettez à jour ou modifiez votre modèle d'encodage, vous devrez réimporter tous les documents.", "Web": "Web", - "Web API": "", - "Web Loader Settings": "Paramètres du chargeur Web", + "Web API": "API Web", + "Web Loader Settings": "Paramètres du chargeur web", "Web Params": "Paramètres Web", - "Web Search": "Recherche sur le Web", + "Web Search": "Recherche Web", "Web Search Engine": "Moteur de recherche Web", - "Webhook URL": "URL Webhook", - "WebUI Settings": "Paramètres WebUI", - "WebUI will make requests to": "WebUI effectuera des demandes à", - "What’s New in": "Quoi de neuf dans", - "When history is turned off, new chats on this browser won't appear in your history on any of your devices.": "Lorsque l'historique est désactivé, les nouvelles discussions sur ce navigateur n'apparaîtront pas dans votre historique sur aucun de vos appareils.", - "Whisper (Local)": "", - "Widescreen Mode": "", + "Webhook URL": "URL du webhook", + "WebUI Settings": "Paramètres de WebUI", + "WebUI will make requests to": "WebUI effectuera des requêtes vers", + "What’s New in": "Quoi de neuf", + "When history is turned off, new chats on this browser won't appear in your history on any of your devices.": "Lorsque l'historique est désactivé, les nouvelles conversations sur ce navigateur ne seront pas enregistrés dans votre historique sur aucun de vos appareils.", + "Whisper (Local)": "Whisper (local)", + "Widescreen Mode": "Mode Grand Écran", "Workspace": "Espace de travail", - "Write a prompt suggestion (e.g. Who are you?)": "Rédigez une suggestion de prompt (p. ex. Qui êtes-vous ?)", - "Write a summary in 50 words that summarizes [topic or keyword].": "Rédigez un résumé en 50 mots qui résume [sujet ou mot-clé].", - "Yesterday": "hier", + "Write a prompt suggestion (e.g. Who are you?)": "Écrivez une suggestion de prompt (par exemple : Qui êtes-vous ?)", + "Write a summary in 50 words that summarizes [topic or keyword].": "Rédigez un résumé de 50 mots qui résume [sujet ou mot-clé].", + "Yesterday": "Hier", "You": "Vous", - "You can personalize your interactions with LLMs by adding memories through the 'Manage' button below, making them more helpful and tailored to you.": "", + "You can personalize your interactions with LLMs by adding memories through the 'Manage' button below, making them more helpful and tailored to you.": "Vous pouvez personnaliser vos interactions avec les LLM en ajoutant des souvenirs via le bouton 'Gérer' ci-dessous, ce qui les rendra plus utiles et adaptés à vos besoins.", "You cannot clone a base model": "Vous ne pouvez pas cloner un modèle de base", - "You have no archived conversations.": "Vous n'avez aucune conversation archivée.", - "You have shared this chat": "Vous avez partagé cette conversation", - "You're a helpful assistant.": "Vous êtes un assistant utile", - "You're now logged in.": "Vous êtes maintenant connecté.", - "Your account status is currently pending activation.": "", - "Youtube": "Youtube", - "Youtube Loader Settings": "Paramètres du chargeur Youtube" -} + "You have no archived conversations.": "Vous n'avez aucune conversation archivée", + "You have shared this chat": "Vous avez partagé cette conversation.", + "You're a helpful assistant.": "Vous êtes un assistant serviable.", + "You're now logged in.": "Vous êtes désormais connecté.", + "Your account status is currently pending activation.": "Votre statut de compte est actuellement en attente d'activation.", + "Youtube": "YouTube", + "Youtube Loader Settings": "Paramètres de l'outil de téléchargement YouTube" + } \ No newline at end of file From 4e433d9015b2d744bc0efdc504d2a8865f0bc5e1 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Wed, 3 Jul 2024 18:18:33 +0100 Subject: [PATCH 028/115] wip: citations via __event_emitter__ --- src/lib/components/chat/Chat.svelte | 32 ++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 3d03246b7..87bd9b4de 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -132,15 +132,33 @@ console.log(data); let message = history.messages[data.message_id]; - const status = { - done: data?.data?.done ?? null, - description: data?.data?.status ?? null - }; + const type = data?.data?.type ?? null; + if (type === "status") { + const status = { + done: data?.data?.done ?? null, + description: data?.data?.status ?? null + }; - if (message.statusHistory) { - message.statusHistory.push(status); + if (message.statusHistory) { + message.statusHistory.push(status); + } else { + message.statusHistory = [status]; + } + } else if (type === "citation") { + console.log(data); + const citation = { + document: data?.data?.document ?? null, + metadata: data?.data?.metadata ?? null, + source: data?.data?.source ?? null + }; + + if (message.citations) { + message.citations.push(citation); + } else { + message.citations = [citation]; + } } else { - message.statusHistory = [status]; + console.log("Unknown message type", data); } messages = messages; From c83704d6ca6584448209e9a787754239dc73641d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 15:46:56 -0700 Subject: [PATCH 029/115] refac: task flag Co-Authored-By: Michael Poluektov <78477503+michaelpoluektov@users.noreply.github.com> --- backend/constants.py | 11 +++++++++++ backend/main.py | 13 +++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/backend/constants.py b/backend/constants.py index f1eed43d3..7c366c222 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum): OLLAMA_API_DISABLED = ( "The Ollama API is disabled. Please enable it to use this feature." ) + + +class TASKS(str, Enum): + def __str__(self) -> str: + return super().__str__() + + DEFAULT = lambda task="": f"{task if task else 'default'}" + TITLE_GENERATION = "Title Generation" + EMOJI_GENERATION = "Emoji Generation" + QUERY_GENERATION = "Query Generation" + FUNCTION_CALLING = "Function Calling" diff --git a/backend/main.py b/backend/main.py index 0e3986f21..49e068a75 100644 --- a/backend/main.py +++ b/backend/main.py @@ -126,7 +126,7 @@ from config import ( WEBUI_SESSION_COOKIE_SECURE, AppConfig, ) -from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES +from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS from utils.webhook import post_webhook if SAFE_MODE: @@ -311,6 +311,7 @@ async def get_function_call_response( {"role": "user", "content": f"Query: {prompt}"}, ], "stream": False, + "task": TASKS.FUNCTION_CALLING, } try: @@ -323,7 +324,6 @@ async def get_function_call_response( response = None try: response = await generate_chat_completions(form_data=payload, user=user) - content = None if hasattr(response, "body_iterator"): @@ -833,9 +833,6 @@ def filter_pipeline(payload, user): pass if "pipeline" not in app.state.MODELS[model_id]: - if "title" in payload: - del payload["title"] - if "task" in payload: del payload["task"] @@ -1338,7 +1335,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): "stream": False, "max_tokens": 50, "chat_id": form_data.get("chat_id", None), - "title": True, + "task": TASKS.TITLE_GENERATION, } log.debug(payload) @@ -1401,7 +1398,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) "messages": [{"role": "user", "content": content}], "stream": False, "max_tokens": 30, - "task": True, + "task": TASKS.QUERY_GENERATION, } print(payload) @@ -1468,7 +1465,7 @@ Message: """{{prompt}}""" "stream": False, "max_tokens": 4, "chat_id": form_data.get("chat_id", None), - "task": True, + "task": TASKS.EMOJI_GENERATION, } log.debug(payload) From 15f6f7bd15b642e422f47f8d97b2056f043b15b2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 21:12:16 -0700 Subject: [PATCH 030/115] revert: peewee migrations --- .../internal/migrations/001_initial_schema.py | 254 ++++++++++++++++++ .../migrations/002_add_local_sharing.py | 48 ++++ .../migrations/003_add_auth_api_key.py | 48 ++++ .../internal/migrations/004_add_archived.py | 61 +++++ .../internal/migrations/005_add_updated_at.py | 61 +++++ .../006_migrate_timestamps_and_charfields.py | 61 +++++ .../migrations/007_add_user_last_active_at.py | 61 +++++ .../internal/migrations/008_add_memory.py | 61 +++++ .../internal/migrations/009_add_models.py | 61 +++++ .../010_migrate_modelfiles_to_models.py | 45 ++++ .../migrations/011_add_user_settings.py | 45 ++++ .../internal/migrations/012_add_tools.py | 45 ++++ .../internal/migrations/013_add_user_info.py | 45 ++++ .../internal/migrations/014_add_files.py | 55 ++++ .../internal/migrations/015_add_functions.py | 61 +++++ .../016_add_valves_and_is_active.py | 45 ++++ .../migrations/017_add_user_oauth_sub.py | 45 ++++ backend/requirements.txt | 4 +- 18 files changed, 1104 insertions(+), 2 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/001_initial_schema.py create mode 100644 backend/apps/webui/internal/migrations/002_add_local_sharing.py create mode 100644 backend/apps/webui/internal/migrations/003_add_auth_api_key.py create mode 100644 backend/apps/webui/internal/migrations/004_add_archived.py create mode 100644 backend/apps/webui/internal/migrations/005_add_updated_at.py create mode 100644 backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py create mode 100644 backend/apps/webui/internal/migrations/007_add_user_last_active_at.py create mode 100644 backend/apps/webui/internal/migrations/008_add_memory.py create mode 100644 backend/apps/webui/internal/migrations/009_add_models.py create mode 100644 backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py create mode 100644 backend/apps/webui/internal/migrations/011_add_user_settings.py create mode 100644 backend/apps/webui/internal/migrations/012_add_tools.py create mode 100644 backend/apps/webui/internal/migrations/013_add_user_info.py create mode 100644 backend/apps/webui/internal/migrations/014_add_files.py create mode 100644 backend/apps/webui/internal/migrations/015_add_functions.py create mode 100644 backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py create mode 100644 backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py diff --git a/backend/apps/webui/internal/migrations/001_initial_schema.py b/backend/apps/webui/internal/migrations/001_initial_schema.py new file mode 100644 index 000000000..93f278f15 --- /dev/null +++ b/backend/apps/webui/internal/migrations/001_initial_schema.py @@ -0,0 +1,254 @@ +"""Peewee migrations -- 001_initial_schema.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # We perform different migrations for SQLite and other databases + # This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite + # will require per-database SQL queries. + # Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base + # schema instead of trying to migrate from an older schema. + if isinstance(database, pw.SqliteDatabase): + migrate_sqlite(migrator, database, fake=fake) + else: + migrate_external(migrator, database, fake=fake) + + +def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): + @migrator.create_model + class Auth(pw.Model): + id = pw.CharField(max_length=255, unique=True) + email = pw.CharField(max_length=255) + password = pw.CharField(max_length=255) + active = pw.BooleanField() + + class Meta: + table_name = "auth" + + @migrator.create_model + class Chat(pw.Model): + id = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.CharField() + chat = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "chat" + + @migrator.create_model + class ChatIdTag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + tag_name = pw.CharField(max_length=255) + chat_id = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "chatidtag" + + @migrator.create_model + class Document(pw.Model): + id = pw.AutoField() + collection_name = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255, unique=True) + title = pw.CharField() + filename = pw.CharField() + content = pw.TextField(null=True) + user_id = pw.CharField(max_length=255) + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "document" + + @migrator.create_model + class Modelfile(pw.Model): + id = pw.AutoField() + tag_name = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + modelfile = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "modelfile" + + @migrator.create_model + class Prompt(pw.Model): + id = pw.AutoField() + command = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.CharField() + content = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "prompt" + + @migrator.create_model + class Tag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + data = pw.TextField(null=True) + + class Meta: + table_name = "tag" + + @migrator.create_model + class User(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + email = pw.CharField(max_length=255) + role = pw.CharField(max_length=255) + profile_image_url = pw.CharField(max_length=255) + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "user" + + +def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): + @migrator.create_model + class Auth(pw.Model): + id = pw.CharField(max_length=255, unique=True) + email = pw.CharField(max_length=255) + password = pw.TextField() + active = pw.BooleanField() + + class Meta: + table_name = "auth" + + @migrator.create_model + class Chat(pw.Model): + id = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.TextField() + chat = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "chat" + + @migrator.create_model + class ChatIdTag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + tag_name = pw.CharField(max_length=255) + chat_id = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "chatidtag" + + @migrator.create_model + class Document(pw.Model): + id = pw.AutoField() + collection_name = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255, unique=True) + title = pw.TextField() + filename = pw.TextField() + content = pw.TextField(null=True) + user_id = pw.CharField(max_length=255) + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "document" + + @migrator.create_model + class Modelfile(pw.Model): + id = pw.AutoField() + tag_name = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + modelfile = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "modelfile" + + @migrator.create_model + class Prompt(pw.Model): + id = pw.AutoField() + command = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.TextField() + content = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "prompt" + + @migrator.create_model + class Tag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + data = pw.TextField(null=True) + + class Meta: + table_name = "tag" + + @migrator.create_model + class User(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + email = pw.CharField(max_length=255) + role = pw.CharField(max_length=255) + profile_image_url = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "user" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("user") + + migrator.remove_model("tag") + + migrator.remove_model("prompt") + + migrator.remove_model("modelfile") + + migrator.remove_model("document") + + migrator.remove_model("chatidtag") + + migrator.remove_model("chat") + + migrator.remove_model("auth") diff --git a/backend/apps/webui/internal/migrations/002_add_local_sharing.py b/backend/apps/webui/internal/migrations/002_add_local_sharing.py new file mode 100644 index 000000000..e93501aee --- /dev/null +++ b/backend/apps/webui/internal/migrations/002_add_local_sharing.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "chat", share_id=pw.CharField(max_length=255, null=True, unique=True) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("chat", "share_id") diff --git a/backend/apps/webui/internal/migrations/003_add_auth_api_key.py b/backend/apps/webui/internal/migrations/003_add_auth_api_key.py new file mode 100644 index 000000000..07144f3ac --- /dev/null +++ b/backend/apps/webui/internal/migrations/003_add_auth_api_key.py @@ -0,0 +1,48 @@ +"""Peewee migrations -- 002_add_local_sharing.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", api_key=pw.CharField(max_length=255, null=True, unique=True) + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "api_key") diff --git a/backend/apps/webui/internal/migrations/004_add_archived.py b/backend/apps/webui/internal/migrations/004_add_archived.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/004_add_archived.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/005_add_updated_at.py b/backend/apps/webui/internal/migrations/005_add_updated_at.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/005_add_updated_at.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py b/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/008_add_memory.py b/backend/apps/webui/internal/migrations/008_add_memory.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/008_add_memory.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/009_add_models.py b/backend/apps/webui/internal/migrations/009_add_models.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/webui/internal/migrations/009_add_models.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py new file mode 100644 index 000000000..eaa3fa5fe --- /dev/null +++ b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py @@ -0,0 +1,45 @@ +"""Peewee migrations -- 017_add_user_oauth_sub.py. +Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/internal/migrations/011_add_user_settings.py b/backend/apps/webui/internal/migrations/011_add_user_settings.py new file mode 100644 index 000000000..eaa3fa5fe --- /dev/null +++ b/backend/apps/webui/internal/migrations/011_add_user_settings.py @@ -0,0 +1,45 @@ +"""Peewee migrations -- 017_add_user_oauth_sub.py. +Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/internal/migrations/012_add_tools.py b/backend/apps/webui/internal/migrations/012_add_tools.py new file mode 100644 index 000000000..eaa3fa5fe --- /dev/null +++ b/backend/apps/webui/internal/migrations/012_add_tools.py @@ -0,0 +1,45 @@ +"""Peewee migrations -- 017_add_user_oauth_sub.py. +Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/internal/migrations/013_add_user_info.py b/backend/apps/webui/internal/migrations/013_add_user_info.py new file mode 100644 index 000000000..eaa3fa5fe --- /dev/null +++ b/backend/apps/webui/internal/migrations/013_add_user_info.py @@ -0,0 +1,45 @@ +"""Peewee migrations -- 017_add_user_oauth_sub.py. +Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/internal/migrations/014_add_files.py b/backend/apps/webui/internal/migrations/014_add_files.py new file mode 100644 index 000000000..5e1acf0ad --- /dev/null +++ b/backend/apps/webui/internal/migrations/014_add_files.py @@ -0,0 +1,55 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class File(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + filename = pw.TextField() + meta = pw.TextField() + created_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "file" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("file") diff --git a/backend/apps/webui/internal/migrations/015_add_functions.py b/backend/apps/webui/internal/migrations/015_add_functions.py new file mode 100644 index 000000000..8316a9333 --- /dev/null +++ b/backend/apps/webui/internal/migrations/015_add_functions.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Function(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + type = pw.TextField() + + content = pw.TextField() + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "function" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("function") diff --git a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py new file mode 100644 index 000000000..eaa3fa5fe --- /dev/null +++ b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py @@ -0,0 +1,45 @@ +"""Peewee migrations -- 017_add_user_oauth_sub.py. +Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py b/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py new file mode 100644 index 000000000..eaa3fa5fe --- /dev/null +++ b/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py @@ -0,0 +1,45 @@ +"""Peewee migrations -- 017_add_user_oauth_sub.py. +Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + migrator.add_fields( + "user", + oauth_sub=pw.TextField(null=True, unique=True), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_fields("user", "oauth_sub") diff --git a/backend/requirements.txt b/backend/requirements.txt index 7c6d62903..77585a2f2 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,8 +14,8 @@ requests==2.32.3 aiohttp==3.9.5 sqlalchemy==2.0.30 alembic==1.13.1 -# peewee==3.17.5 -# peewee-migrate==1.12.2 +peewee==3.17.5 +peewee-migrate==1.12.2 psycopg2-binary==2.9.9 PyMySQL==1.1.1 bcrypt==4.1.3 From f6dcffab135bc0be3036970ab1bec6c9ed70a91d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 21:18:40 -0700 Subject: [PATCH 031/115] fix: pinned chat delete issue --- src/lib/components/layout/Sidebar.svelte | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 193fe41ff..e6cd45c1f 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -186,6 +186,7 @@ goto('/'); } await chats.set(await getChatList(localStorage.token)); + await pinnedChats.set(await getChatListByTagName(localStorage.token, 'pinned')); } }; From bfc53b49fd05e033537415db020972f89be1bd22 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 21:28:14 -0700 Subject: [PATCH 032/115] revert --- .../internal/migrations/004_add_archived.py | 21 +--- .../internal/migrations/005_add_updated_at.py | 97 +++++++++++++--- .../006_migrate_timestamps_and_charfields.py | 105 +++++++++++++++--- .../migrations/007_add_user_last_active_at.py | 48 +++++--- .../internal/migrations/008_add_memory.py | 24 ++-- .../010_migrate_modelfiles_to_models.py | 97 +++++++++++++++- .../migrations/011_add_user_settings.py | 15 ++- .../internal/migrations/012_add_tools.py | 28 ++++- .../internal/migrations/013_add_user_info.py | 15 ++- 9 files changed, 345 insertions(+), 105 deletions(-) diff --git a/backend/apps/webui/internal/migrations/004_add_archived.py b/backend/apps/webui/internal/migrations/004_add_archived.py index 548ec7cdc..d01c06b4e 100644 --- a/backend/apps/webui/internal/migrations/004_add_archived.py +++ b/backend/apps/webui/internal/migrations/004_add_archived.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 009_add_models.py. +"""Peewee migrations -- 002_add_local_sharing.py. Some examples (model - class or model name):: @@ -37,25 +37,10 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - @migrator.create_model - class Model(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - base_model_id = pw.TextField(null=True) - - name = pw.TextField() - - meta = pw.TextField() - params = pw.TextField() - - created_at = pw.BigIntegerField(null=False) - updated_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "model" + migrator.add_fields("chat", archived=pw.BooleanField(default=False)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("model") + migrator.remove_fields("chat", "archived") diff --git a/backend/apps/webui/internal/migrations/005_add_updated_at.py b/backend/apps/webui/internal/migrations/005_add_updated_at.py index 548ec7cdc..950866ef0 100644 --- a/backend/apps/webui/internal/migrations/005_add_updated_at.py +++ b/backend/apps/webui/internal/migrations/005_add_updated_at.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 009_add_models.py. +"""Peewee migrations -- 002_add_local_sharing.py. Some examples (model - class or model name):: @@ -37,25 +37,94 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - @migrator.create_model - class Model(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - base_model_id = pw.TextField(null=True) + if isinstance(database, pw.SqliteDatabase): + migrate_sqlite(migrator, database, fake=fake) + else: + migrate_external(migrator, database, fake=fake) - name = pw.TextField() - meta = pw.TextField() - params = pw.TextField() +def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): + # Adding fields created_at and updated_at to the 'chat' table + migrator.add_fields( + "chat", + created_at=pw.DateTimeField(null=True), # Allow null for transition + updated_at=pw.DateTimeField(null=True), # Allow null for transition + ) - created_at = pw.BigIntegerField(null=False) - updated_at = pw.BigIntegerField(null=False) + # Populate the new fields from an existing 'timestamp' field + migrator.sql( + "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" + ) - class Meta: - table_name = "model" + # Now that the data has been copied, remove the original 'timestamp' field + migrator.remove_fields("chat", "timestamp") + + # Update the fields to be not null now that they are populated + migrator.change_fields( + "chat", + created_at=pw.DateTimeField(null=False), + updated_at=pw.DateTimeField(null=False), + ) + + +def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): + # Adding fields created_at and updated_at to the 'chat' table + migrator.add_fields( + "chat", + created_at=pw.BigIntegerField(null=True), # Allow null for transition + updated_at=pw.BigIntegerField(null=True), # Allow null for transition + ) + + # Populate the new fields from an existing 'timestamp' field + migrator.sql( + "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" + ) + + # Now that the data has been copied, remove the original 'timestamp' field + migrator.remove_fields("chat", "timestamp") + + # Update the fields to be not null now that they are populated + migrator.change_fields( + "chat", + created_at=pw.BigIntegerField(null=False), + updated_at=pw.BigIntegerField(null=False), + ) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("model") + if isinstance(database, pw.SqliteDatabase): + rollback_sqlite(migrator, database, fake=fake) + else: + rollback_external(migrator, database, fake=fake) + + +def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): + # Recreate the timestamp field initially allowing null values for safe transition + migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) + + # Copy the earliest created_at date back into the new timestamp field + # This assumes created_at was originally a copy of timestamp + migrator.sql("UPDATE chat SET timestamp = created_at") + + # Remove the created_at and updated_at fields + migrator.remove_fields("chat", "created_at", "updated_at") + + # Finally, alter the timestamp field to not allow nulls if that was the original setting + migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) + + +def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False): + # Recreate the timestamp field initially allowing null values for safe transition + migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True)) + + # Copy the earliest created_at date back into the new timestamp field + # This assumes created_at was originally a copy of timestamp + migrator.sql("UPDATE chat SET timestamp = created_at") + + # Remove the created_at and updated_at fields + migrator.remove_fields("chat", "created_at", "updated_at") + + # Finally, alter the timestamp field to not allow nulls if that was the original setting + migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False)) diff --git a/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py index 548ec7cdc..caca14d32 100644 --- a/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py +++ b/backend/apps/webui/internal/migrations/006_migrate_timestamps_and_charfields.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 009_add_models.py. +"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py. Some examples (model - class or model name):: @@ -37,25 +37,94 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - @migrator.create_model - class Model(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - base_model_id = pw.TextField(null=True) - - name = pw.TextField() - - meta = pw.TextField() - params = pw.TextField() - - created_at = pw.BigIntegerField(null=False) - updated_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "model" + # Alter the tables with timestamps + migrator.change_fields( + "chatidtag", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "document", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "modelfile", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "prompt", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "user", + timestamp=pw.BigIntegerField(), + ) + # Alter the tables with varchar to text where necessary + migrator.change_fields( + "auth", + password=pw.TextField(), + ) + migrator.change_fields( + "chat", + title=pw.TextField(), + ) + migrator.change_fields( + "document", + title=pw.TextField(), + filename=pw.TextField(), + ) + migrator.change_fields( + "prompt", + title=pw.TextField(), + ) + migrator.change_fields( + "user", + profile_image_url=pw.TextField(), + ) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("model") + if isinstance(database, pw.SqliteDatabase): + # Alter the tables with timestamps + migrator.change_fields( + "chatidtag", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "document", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "modelfile", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "prompt", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "user", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "auth", + password=pw.CharField(max_length=255), + ) + migrator.change_fields( + "chat", + title=pw.CharField(), + ) + migrator.change_fields( + "document", + title=pw.CharField(), + filename=pw.CharField(), + ) + migrator.change_fields( + "prompt", + title=pw.CharField(), + ) + migrator.change_fields( + "user", + profile_image_url=pw.CharField(), + ) diff --git a/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py b/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py index 548ec7cdc..dd176ba73 100644 --- a/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py +++ b/backend/apps/webui/internal/migrations/007_add_user_last_active_at.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 009_add_models.py. +"""Peewee migrations -- 002_add_local_sharing.py. Some examples (model - class or model name):: @@ -37,25 +37,43 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - @migrator.create_model - class Model(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - base_model_id = pw.TextField(null=True) + # Adding fields created_at and updated_at to the 'user' table + migrator.add_fields( + "user", + created_at=pw.BigIntegerField(null=True), # Allow null for transition + updated_at=pw.BigIntegerField(null=True), # Allow null for transition + last_active_at=pw.BigIntegerField(null=True), # Allow null for transition + ) - name = pw.TextField() + # Populate the new fields from an existing 'timestamp' field + migrator.sql( + 'UPDATE "user" SET created_at = timestamp, updated_at = timestamp, last_active_at = timestamp WHERE timestamp IS NOT NULL' + ) - meta = pw.TextField() - params = pw.TextField() + # Now that the data has been copied, remove the original 'timestamp' field + migrator.remove_fields("user", "timestamp") - created_at = pw.BigIntegerField(null=False) - updated_at = pw.BigIntegerField(null=False) - - class Meta: - table_name = "model" + # Update the fields to be not null now that they are populated + migrator.change_fields( + "user", + created_at=pw.BigIntegerField(null=False), + updated_at=pw.BigIntegerField(null=False), + last_active_at=pw.BigIntegerField(null=False), + ) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("model") + # Recreate the timestamp field initially allowing null values for safe transition + migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True)) + + # Copy the earliest created_at date back into the new timestamp field + # This assumes created_at was originally a copy of timestamp + migrator.sql('UPDATE "user" SET timestamp = created_at') + + # Remove the created_at and updated_at fields + migrator.remove_fields("user", "created_at", "updated_at", "last_active_at") + + # Finally, alter the timestamp field to not allow nulls if that was the original setting + migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False)) diff --git a/backend/apps/webui/internal/migrations/008_add_memory.py b/backend/apps/webui/internal/migrations/008_add_memory.py index 548ec7cdc..9307aa4d5 100644 --- a/backend/apps/webui/internal/migrations/008_add_memory.py +++ b/backend/apps/webui/internal/migrations/008_add_memory.py @@ -1,4 +1,4 @@ -"""Peewee migrations -- 009_add_models.py. +"""Peewee migrations -- 002_add_local_sharing.py. Some examples (model - class or model name):: @@ -35,27 +35,19 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): - """Write your migrations here.""" - @migrator.create_model - class Model(pw.Model): - id = pw.TextField(unique=True) - user_id = pw.TextField() - base_model_id = pw.TextField(null=True) - - name = pw.TextField() - - meta = pw.TextField() - params = pw.TextField() - - created_at = pw.BigIntegerField(null=False) + class Memory(pw.Model): + id = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + content = pw.TextField(null=False) updated_at = pw.BigIntegerField(null=False) + created_at = pw.BigIntegerField(null=False) class Meta: - table_name = "model" + table_name = "memory" def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_model("model") + migrator.remove_model("memory") diff --git a/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py index eaa3fa5fe..2ef814c06 100644 --- a/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py +++ b/backend/apps/webui/internal/migrations/010_migrate_modelfiles_to_models.py @@ -1,7 +1,10 @@ -"""Peewee migrations -- 017_add_user_oauth_sub.py. +"""Peewee migrations -- 009_add_models.py. + Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL > migrator.run(func, *args, **kwargs) # Run python function with the given args > migrator.create_model(Model) # Create a model (could be used as decorator) @@ -18,13 +21,16 @@ Some examples (model - class or model name):: > migrator.drop_index(model, *col_names) > migrator.drop_not_null(model, *field_names) > migrator.drop_constraints(model, *constraints) + """ from contextlib import suppress import peewee as pw from peewee_migrate import Migrator +import json +from utils.misc import parse_ollama_modelfile with suppress(ImportError): import playhouse.postgres_ext as pw_pext @@ -33,13 +39,92 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "user", - oauth_sub=pw.TextField(null=True, unique=True), - ) + # Fetch data from 'modelfile' table and insert into 'model' table + migrate_modelfile_to_model(migrator, database) + # Drop the 'modelfile' table + migrator.remove_model("modelfile") + + +def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): + ModelFile = migrator.orm["modelfile"] + Model = migrator.orm["model"] + + modelfiles = ModelFile.select() + + for modelfile in modelfiles: + # Extract and transform data in Python + + modelfile.modelfile = json.loads(modelfile.modelfile) + meta = json.dumps( + { + "description": modelfile.modelfile.get("desc"), + "profile_image_url": modelfile.modelfile.get("imageUrl"), + "ollama": {"modelfile": modelfile.modelfile.get("content")}, + "suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), + "categories": modelfile.modelfile.get("categories"), + "user": {**modelfile.modelfile.get("user", {}), "community": True}, + } + ) + + info = parse_ollama_modelfile(modelfile.modelfile.get("content")) + + # Insert the processed data into the 'model' table + Model.create( + id=f"ollama-{modelfile.tag_name}", + user_id=modelfile.user_id, + base_model_id=info.get("base_model_id"), + name=modelfile.modelfile.get("title"), + meta=meta, + params=json.dumps(info.get("params", {})), + created_at=modelfile.timestamp, + updated_at=modelfile.timestamp, + ) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "oauth_sub") + recreate_modelfile_table(migrator, database) + move_data_back_to_modelfile(migrator, database) + migrator.remove_model("model") + + +def recreate_modelfile_table(migrator: Migrator, database: pw.Database): + query = """ + CREATE TABLE IF NOT EXISTS modelfile ( + user_id TEXT, + tag_name TEXT, + modelfile JSON, + timestamp BIGINT + ) + """ + migrator.sql(query) + + +def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): + Model = migrator.orm["model"] + Modelfile = migrator.orm["modelfile"] + + models = Model.select() + + for model in models: + # Extract and transform data in Python + meta = json.loads(model.meta) + + modelfile_data = { + "title": model.name, + "desc": meta.get("description"), + "imageUrl": meta.get("profile_image_url"), + "content": meta.get("ollama", {}).get("modelfile"), + "suggestionPrompts": meta.get("suggestion_prompts"), + "categories": meta.get("categories"), + "user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, + } + + # Insert the processed data back into the 'modelfile' table + Modelfile.create( + user_id=model.user_id, + tag_name=model.id, + modelfile=modelfile_data, + timestamp=model.created_at, + ) diff --git a/backend/apps/webui/internal/migrations/011_add_user_settings.py b/backend/apps/webui/internal/migrations/011_add_user_settings.py index eaa3fa5fe..a1620dcad 100644 --- a/backend/apps/webui/internal/migrations/011_add_user_settings.py +++ b/backend/apps/webui/internal/migrations/011_add_user_settings.py @@ -1,7 +1,10 @@ -"""Peewee migrations -- 017_add_user_oauth_sub.py. +"""Peewee migrations -- 002_add_local_sharing.py. + Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL > migrator.run(func, *args, **kwargs) # Run python function with the given args > migrator.create_model(Model) # Create a model (could be used as decorator) @@ -18,6 +21,7 @@ Some examples (model - class or model name):: > migrator.drop_index(model, *col_names) > migrator.drop_not_null(model, *field_names) > migrator.drop_constraints(model, *constraints) + """ from contextlib import suppress @@ -33,13 +37,12 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "user", - oauth_sub=pw.TextField(null=True, unique=True), - ) + # Adding fields settings to the 'user' table + migrator.add_fields("user", settings=pw.TextField(null=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "oauth_sub") + # Remove the settings field + migrator.remove_fields("user", "settings") diff --git a/backend/apps/webui/internal/migrations/012_add_tools.py b/backend/apps/webui/internal/migrations/012_add_tools.py index eaa3fa5fe..4a68eea55 100644 --- a/backend/apps/webui/internal/migrations/012_add_tools.py +++ b/backend/apps/webui/internal/migrations/012_add_tools.py @@ -1,7 +1,10 @@ -"""Peewee migrations -- 017_add_user_oauth_sub.py. +"""Peewee migrations -- 009_add_models.py. + Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL > migrator.run(func, *args, **kwargs) # Run python function with the given args > migrator.create_model(Model) # Create a model (could be used as decorator) @@ -18,6 +21,7 @@ Some examples (model - class or model name):: > migrator.drop_index(model, *col_names) > migrator.drop_not_null(model, *field_names) > migrator.drop_constraints(model, *constraints) + """ from contextlib import suppress @@ -33,13 +37,25 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "user", - oauth_sub=pw.TextField(null=True, unique=True), - ) + @migrator.create_model + class Tool(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + content = pw.TextField() + specs = pw.TextField() + + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "tool" def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "oauth_sub") + migrator.remove_model("tool") diff --git a/backend/apps/webui/internal/migrations/013_add_user_info.py b/backend/apps/webui/internal/migrations/013_add_user_info.py index eaa3fa5fe..0f68669cc 100644 --- a/backend/apps/webui/internal/migrations/013_add_user_info.py +++ b/backend/apps/webui/internal/migrations/013_add_user_info.py @@ -1,7 +1,10 @@ -"""Peewee migrations -- 017_add_user_oauth_sub.py. +"""Peewee migrations -- 002_add_local_sharing.py. + Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL > migrator.run(func, *args, **kwargs) # Run python function with the given args > migrator.create_model(Model) # Create a model (could be used as decorator) @@ -18,6 +21,7 @@ Some examples (model - class or model name):: > migrator.drop_index(model, *col_names) > migrator.drop_not_null(model, *field_names) > migrator.drop_constraints(model, *constraints) + """ from contextlib import suppress @@ -33,13 +37,12 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "user", - oauth_sub=pw.TextField(null=True, unique=True), - ) + # Adding fields info to the 'user' table + migrator.add_fields("user", info=pw.TextField(null=True)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "oauth_sub") + # Remove the settings field + migrator.remove_fields("user", "info") From 1b65df3acc4b1a12b9063518b120fcca50efef56 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 21:28:51 -0700 Subject: [PATCH 033/115] revert --- .../migrations/016_add_valves_and_is_active.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py index eaa3fa5fe..e3af521b7 100644 --- a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py +++ b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py @@ -1,7 +1,10 @@ -"""Peewee migrations -- 017_add_user_oauth_sub.py. +"""Peewee migrations -- 009_add_models.py. + Some examples (model - class or model name):: + > Model = migrator.orm['table_name'] # Return model in current state by name > Model = migrator.ModelClass # Return model in current state by name + > migrator.sql(sql) # Run custom SQL > migrator.run(func, *args, **kwargs) # Run python function with the given args > migrator.create_model(Model) # Create a model (could be used as decorator) @@ -18,6 +21,7 @@ Some examples (model - class or model name):: > migrator.drop_index(model, *col_names) > migrator.drop_not_null(model, *field_names) > migrator.drop_constraints(model, *constraints) + """ from contextlib import suppress @@ -33,13 +37,14 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" - migrator.add_fields( - "user", - oauth_sub=pw.TextField(null=True, unique=True), - ) + migrator.add_fields("tool", valves=pw.TextField(null=True)) + migrator.add_fields("function", valves=pw.TextField(null=True)) + migrator.add_fields("function", is_active=pw.BooleanField(default=False)) def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" - migrator.remove_fields("user", "oauth_sub") + migrator.remove_fields("tool", "valves") + migrator.remove_fields("function", "valves") + migrator.remove_fields("function", "is_active") From 864646094e248d1ee3ed9f09e12312ec241b3217 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 23:32:39 -0700 Subject: [PATCH 034/115] refac --- backend/apps/webui/models/auths.py | 20 +- backend/apps/webui/models/chats.py | 322 ++++++++++++++----------- backend/apps/webui/models/documents.py | 96 ++++---- backend/apps/webui/models/files.py | 82 ++++--- backend/apps/webui/models/functions.py | 184 +++++++------- backend/apps/webui/models/memories.py | 130 +++++----- backend/apps/webui/models/models.py | 53 ++-- backend/apps/webui/models/prompts.py | 54 +++-- backend/apps/webui/models/tags.py | 212 ++++++++-------- backend/apps/webui/models/tools.py | 89 ++++--- backend/apps/webui/models/users.py | 163 +++++++------ 11 files changed, 789 insertions(+), 616 deletions(-) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 560d9a686..48c8b543a 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -7,7 +7,7 @@ from sqlalchemy import String, Column, Boolean, Text from apps.webui.models.users import UserModel, Users from utils.utils import verify_password -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -110,14 +110,14 @@ class AuthsTable: **{"id": id, "email": email, "password": password, "active": True} ) result = Auth(**auth.model_dump()) - Session.add(result) + db.add(result) user = Users.insert_new_user( id, name, email, profile_image_url, role, oauth_sub ) - Session.commit() - Session.refresh(result) + db.commit() + db.refresh(result) if result and user: return user @@ -127,7 +127,7 @@ class AuthsTable: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") try: - auth = Session.query(Auth).filter_by(email=email, active=True).first() + auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: if verify_password(password, auth.password): user = Users.get_user_by_id(auth.id) @@ -154,7 +154,7 @@ class AuthsTable: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: - auth = Session.query(Auth).filter(email=email, active=True).first() + auth = db.query(Auth).filter(email=email, active=True).first() if auth: user = Users.get_user_by_id(auth.id) return user @@ -163,16 +163,14 @@ class AuthsTable: def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: - result = ( - Session.query(Auth).filter_by(id=id).update({"password": new_password}) - ) + result = db.query(Auth).filter_by(id=id).update({"password": new_password}) return True if result == 1 else False except: return False def update_email_by_id(self, id: str, email: str) -> bool: try: - result = Session.query(Auth).filter_by(id=id).update({"email": email}) + result = db.query(Auth).filter_by(id=id).update({"email": email}) return True if result == 1 else False except: return False @@ -183,7 +181,7 @@ class AuthsTable: result = Users.delete_user_by_id(id) if result: - Session.query(Auth).filter_by(id=id).delete() + db.query(Auth).filter_by(id=id).delete() return True else: diff --git a/backend/apps/webui/models/chats.py b/backend/apps/webui/models/chats.py index d6829ee7b..8d2e6b104 100644 --- a/backend/apps/webui/models/chats.py +++ b/backend/apps/webui/models/chats.py @@ -7,7 +7,7 @@ import time from sqlalchemy import Column, String, BigInteger, Boolean, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db #################### @@ -79,87 +79,99 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: - id = str(uuid.uuid4()) - chat = ChatModel( - **{ - "id": id, - "user_id": user_id, - "title": ( - form_data.chat["title"] if "title" in form_data.chat else "New Chat" - ), - "chat": json.dumps(form_data.chat), - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) + with get_db() as db: - result = Chat(**chat.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - return ChatModel.model_validate(result) if result else None + id = str(uuid.uuid4()) + chat = ChatModel( + **{ + "id": id, + "user_id": user_id, + "title": ( + form_data.chat["title"] + if "title" in form_data.chat + else "New Chat" + ), + "chat": json.dumps(form_data.chat), + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + + result = Chat(**chat.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + return ChatModel.model_validate(result) if result else None def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]: try: - chat_obj = Session.get(Chat, id) - chat_obj.chat = json.dumps(chat) - chat_obj.title = chat["title"] if "title" in chat else "New Chat" - chat_obj.updated_at = int(time.time()) - Session.commit() - Session.refresh(chat_obj) + with get_db() as db: - return ChatModel.model_validate(chat_obj) + chat_obj = db.get(Chat, id) + chat_obj.chat = json.dumps(chat) + chat_obj.title = chat["title"] if "title" in chat else "New Chat" + chat_obj.updated_at = int(time.time()) + db.commit() + db.refresh(chat_obj) + + return ChatModel.model_validate(chat_obj) except Exception as e: return None def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: - # Get the existing chat to share - chat = Session.get(Chat, chat_id) - # Check if the chat is already shared - if chat.share_id: - return self.get_chat_by_id_and_user_id(chat.share_id, "shared") - # Create a new chat with the same data, but with a new ID - shared_chat = ChatModel( - **{ - "id": str(uuid.uuid4()), - "user_id": f"shared-{chat_id}", - "title": chat.title, - "chat": chat.chat, - "created_at": chat.created_at, - "updated_at": int(time.time()), - } - ) - shared_result = Chat(**shared_chat.model_dump()) - Session.add(shared_result) - Session.commit() - Session.refresh(shared_result) - # Update the original chat with the share_id - result = ( - Session.query(Chat) - .filter_by(id=chat_id) - .update({"share_id": shared_chat.id}) - ) + with get_db() as db: - return shared_chat if (shared_result and result) else None + # Get the existing chat to share + chat = db.get(Chat, chat_id) + # Check if the chat is already shared + if chat.share_id: + return self.get_chat_by_id_and_user_id(chat.share_id, "shared") + # Create a new chat with the same data, but with a new ID + shared_chat = ChatModel( + **{ + "id": str(uuid.uuid4()), + "user_id": f"shared-{chat_id}", + "title": chat.title, + "chat": chat.chat, + "created_at": chat.created_at, + "updated_at": int(time.time()), + } + ) + shared_result = Chat(**shared_chat.model_dump()) + db.add(shared_result) + db.commit() + db.refresh(shared_result) + # Update the original chat with the share_id + result = ( + db.query(Chat) + .filter_by(id=chat_id) + .update({"share_id": shared_chat.id}) + ) + + return shared_chat if (shared_result and result) else None def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: try: - print("update_shared_chat_by_id") - chat = Session.get(Chat, chat_id) - print(chat) - chat.title = chat.title - chat.chat = chat.chat - Session.commit() - Session.refresh(chat) + with get_db() as db: - return self.get_chat_by_id(chat.share_id) + print("update_shared_chat_by_id") + chat = db.get(Chat, chat_id) + print(chat) + chat.title = chat.title + chat.chat = chat.chat + db.commit() + db.refresh(chat) + + return self.get_chat_by_id(chat.share_id) except: return None def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() + return True except: return False @@ -167,42 +179,50 @@ class ChatTable: self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.share_id = share_id - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.share_id = share_id + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - chat.archived = not chat.archived - Session.commit() - Session.refresh(chat) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + chat.archived = not chat.archived + db.commit() + db.refresh(chat) + return ChatModel.model_validate(chat) except: return None def archive_all_chats_by_user_id(self, user_id: str) -> bool: try: - Session.query(Chat).filter_by(user_id=user_id).update({"archived": True}) - return True + with get_db() as db: + + db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) + return True except: return False def get_archived_chat_list_by_user_id( self, user_id: str, skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, @@ -211,110 +231,136 @@ class ChatTable: skip: int = 0, limit: int = 50, ) -> List[ChatModel]: - query = Session.query(Chat).filter_by(user_id=user_id) - if not include_archived: - query = query.filter_by(archived=False) - all_chats = ( - query.order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + if not include_archived: + query = query.filter_by(archived=False) + all_chats = ( + query.order_by(Chat.updated_at.desc()) + # .limit(limit).offset(skip) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_chat_ids( self, chat_ids: List[str], skip: int = 0, limit: int = 50 ) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) - .all() - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter(Chat.id.in_(chat_ids)) + .filter_by(archived=False) + .order_by(Chat.updated_at.desc()) + .all() + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_by_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.get(Chat, id) - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.get(Chat, id) + return ChatModel.model_validate(chat) except: return None def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(share_id=id).first() + with get_db() as db: - if chat: - return self.get_chat_by_id(id) - else: - return None + chat = db.query(Chat).filter_by(share_id=id).first() + + if chat: + return self.get_chat_by_id(id) + else: + return None except Exception as e: return None def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: - chat = Session.query(Chat).filter_by(id=id, user_id=user_id).first() - return ChatModel.model_validate(chat) + with get_db() as db: + + chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() + return ChatModel.model_validate(chat) except: return None def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - # .limit(limit).offset(skip) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + # .limit(limit).offset(skip) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]: - all_chats = ( - Session.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - ) - return [ChatModel.model_validate(chat) for chat in all_chats] + with get_db() as db: + + all_chats = ( + db.query(Chat) + .filter_by(user_id=user_id, archived=True) + .order_by(Chat.updated_at.desc()) + ) + return [ChatModel.model_validate(chat) for chat in all_chats] def delete_chat_by_id(self, id: str) -> bool: try: - Session.query(Chat).filter_by(id=id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - Session.query(Chat).filter_by(id=id, user_id=user_id).delete() + with get_db() as db: - return True and self.delete_shared_chat_by_chat_id(id) + db.query(Chat).filter_by(id=id, user_id=user_id).delete() + + return True and self.delete_shared_chat_by_chat_id(id) except: return False def delete_chats_by_user_id(self, user_id: str) -> bool: try: - self.delete_shared_chats_by_user_id(user_id) - Session.query(Chat).filter_by(user_id=user_id).delete() - return True + with get_db() as db: + + self.delete_shared_chats_by_user_id(user_id) + + db.query(Chat).filter_by(user_id=user_id).delete() + return True except: return False def delete_shared_chats_by_user_id(self, user_id: str) -> bool: try: - chats_by_user = Session.query(Chat).filter_by(user_id=user_id).all() - shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] - Session.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + with get_db() as db: - return True + chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() + shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] + + db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() + + return True except: return False diff --git a/backend/apps/webui/models/documents.py b/backend/apps/webui/models/documents.py index 1b69d44a5..16145c4ac 100644 --- a/backend/apps/webui/models/documents.py +++ b/backend/apps/webui/models/documents.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -74,51 +74,59 @@ class DocumentsTable: def insert_new_doc( self, user_id: str, form_data: DocumentForm ) -> Optional[DocumentModel]: - document = DocumentModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "timestamp": int(time.time()), - } - ) + with get_db() as db: - try: - result = Document(**document.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return DocumentModel.model_validate(result) - else: + document = DocumentModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "timestamp": int(time.time()), + } + ) + + try: + result = Document(**document.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return DocumentModel.model_validate(result) + else: + return None + except: return None - except: - return None def get_doc_by_name(self, name: str) -> Optional[DocumentModel]: try: - document = Session.query(Document).filter_by(name=name).first() - return DocumentModel.model_validate(document) if document else None + with get_db() as db: + + document = db.query(Document).filter_by(name=name).first() + return DocumentModel.model_validate(document) if document else None except: return None def get_docs(self) -> List[DocumentModel]: - return [ - DocumentModel.model_validate(doc) for doc in Session.query(Document).all() - ] + with get_db() as db: + + return [ + DocumentModel.model_validate(doc) for doc in db.query(Document).all() + ] def update_doc_by_name( self, name: str, form_data: DocumentUpdateForm ) -> Optional[DocumentModel]: try: - Session.query(Document).filter_by(name=name).update( - { - "title": form_data.title, - "name": form_data.name, - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(form_data.name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "title": form_data.title, + "name": form_data.name, + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(form_data.name) except Exception as e: log.exception(e) return None @@ -131,22 +139,26 @@ class DocumentsTable: doc_content = json.loads(doc.content if doc.content else "{}") doc_content = {**doc_content, **updated} - Session.query(Document).filter_by(name=name).update( - { - "content": json.dumps(doc_content), - "timestamp": int(time.time()), - } - ) - Session.commit() - return self.get_doc_by_name(name) + with get_db() as db: + + db.query(Document).filter_by(name=name).update( + { + "content": json.dumps(doc_content), + "timestamp": int(time.time()), + } + ) + db.commit() + return self.get_doc_by_name(name) except Exception as e: log.exception(e) return None def delete_doc_by_name(self, name: str) -> bool: try: - Session.query(Document).filter_by(name=name).delete() - return True + with get_db() as db: + + db.query(Document).filter_by(name=name).delete() + return True except: return False diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py index ce904215d..58058e907 100644 --- a/backend/apps/webui/models/files.py +++ b/backend/apps/webui/models/files.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db import json @@ -61,50 +61,62 @@ class FileForm(BaseModel): class FilesTable: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]: - file = FileModel( - **{ - **form_data.model_dump(), - "user_id": user_id, - "created_at": int(time.time()), - } - ) + with get_db() as db: - try: - result = File(**file.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FileModel.model_validate(result) - else: + file = FileModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + } + ) + + try: + result = File(**file.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FileModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_file_by_id(self, id: str) -> Optional[FileModel]: - try: - file = Session.get(File, id) - return FileModel.model_validate(file) - except: - return None + with get_db() as db: + + try: + file = db.get(File, id) + return FileModel.model_validate(file) + except: + return None def get_files(self) -> List[FileModel]: - return [FileModel.model_validate(file) for file in Session.query(File).all()] + with get_db() as db: + + return [FileModel.model_validate(file) for file in db.query(File).all()] def delete_file_by_id(self, id: str) -> bool: - try: - Session.query(File).filter_by(id=id).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).filter_by(id=id).delete() + return True + except: + return False def delete_all_files(self) -> bool: - try: - Session.query(File).delete() - return True - except: - return False + + with get_db() as db: + + try: + db.query(File).delete() + return True + except: + return False Files = FilesTable() diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 64ed4f3cc..5718833d3 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -5,7 +5,7 @@ import logging from sqlalchemy import Column, String, Text, BigInteger, Boolean -from apps.webui.internal.db import JSONField, Base, Session +from apps.webui.internal.db import JSONField, Base, get_db from apps.webui.models.users import Users import json @@ -91,6 +91,7 @@ class FunctionsTable: def insert_new_function( self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: + function = FunctionModel( **{ **form_data.model_dump(), @@ -102,85 +103,99 @@ class FunctionsTable: ) try: - result = Function(**function.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return FunctionModel.model_validate(result) - else: - return None + with get_db() as db: + result = Function(**function.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return FunctionModel.model_validate(result) + else: + return None except Exception as e: print(f"Error creating tool: {e}") return None def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: - function = Session.get(Function, id) - return FunctionModel.model_validate(function) + with get_db() as db: + + function = db.get(Function, id) + return FunctionModel.model_validate(function) except: return None def get_functions(self, active_only=False) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(is_active=True).all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(is_active=True).all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).all() + ] def get_functions_by_type( self, type: str, active_only=False ) -> List[FunctionModel]: - if active_only: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type=type, is_active=True) - .all() - ] - else: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function).filter_by(type=type).all() - ] + with get_db() as db: + + if active_only: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type=type, is_active=True) + .all() + ] + else: + return [ + FunctionModel.model_validate(function) + for function in db.query(Function).filter_by(type=type).all() + ] def get_global_filter_functions(self) -> List[FunctionModel]: - return [ - FunctionModel.model_validate(function) - for function in Session.query(Function) - .filter_by(type="filter", is_active=True, is_global=True) - .all() - ] + with get_db() as db: + + return [ + FunctionModel.model_validate(function) + for function in db.query(Function) + .filter_by(type="filter", is_active=True, is_global=True) + .all() + ] def get_function_valves_by_id(self, id: str) -> Optional[dict]: - try: - function = Session.get(Function, id) - return function.valves if function.valves else {} - except Exception as e: - print(f"An error occurred: {e}") - return None + with get_db() as db: + + try: + function = db.get(Function, id) + return function.valves if function.valves else {} + except Exception as e: + print(f"An error occurred: {e}") + return None def update_function_valves_by_id( self, id: str, valves: dict ) -> Optional[FunctionValves]: - try: - function = Session.get(Function, id) - function.valves = valves - function.updated_at = int(time.time()) - Session.commit() - Session.refresh(function) - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + function = db.get(Function, id) + function.valves = valves + function.updated_at = int(time.time()) + db.commit() + db.refresh(function) + return self.get_function_by_id(id) + except: + return None def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() @@ -199,6 +214,7 @@ class FunctionsTable: def update_user_valves_by_id_and_user_id( self, id: str, user_id: str, valves: dict ) -> Optional[dict]: + try: user = Users.get_user_by_id(user_id) user_settings = user.settings.model_dump() @@ -220,37 +236,43 @@ class FunctionsTable: return None def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: - try: - Session.query(Function).filter_by(id=id).update( - { - **updated, - "updated_at": int(time.time()), - } - ) - Session.commit() - return self.get_function_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).update( + { + **updated, + "updated_at": int(time.time()), + } + ) + db.commit() + return self.get_function_by_id(id) + except: + return None def deactivate_all_functions(self) -> Optional[bool]: - try: - Session.query(Function).update( - { - "is_active": False, - "updated_at": int(time.time()), - } - ) - Session.commit() - return True - except: - return None + with get_db() as db: + + try: + db.query(Function).update( + { + "is_active": False, + "updated_at": int(time.time()), + } + ) + db.commit() + return True + except: + return None def delete_function_by_id(self, id: str) -> bool: - try: - Session.query(Function).filter_by(id=id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Function).filter_by(id=id).delete() + return True + except: + return False Functions = FunctionsTable() diff --git a/backend/apps/webui/models/memories.py b/backend/apps/webui/models/memories.py index 1f03318fd..662bbedfe 100644 --- a/backend/apps/webui/models/memories.py +++ b/backend/apps/webui/models/memories.py @@ -3,7 +3,7 @@ from typing import List, Union, Optional from sqlalchemy import Column, String, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import time import uuid @@ -45,82 +45,98 @@ class MemoriesTable: user_id: str, content: str, ) -> Optional[MemoryModel]: - id = str(uuid.uuid4()) - memory = MemoryModel( - **{ - "id": id, - "user_id": user_id, - "content": content, - "created_at": int(time.time()), - "updated_at": int(time.time()), - } - ) - result = Memory(**memory.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + with get_db() as db: + id = str(uuid.uuid4()) + + memory = MemoryModel( + **{ + "id": id, + "user_id": user_id, + "content": content, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + result = Memory(**memory.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return MemoryModel.model_validate(result) + else: + return None def update_memory_by_id( self, id: str, content: str, ) -> Optional[MemoryModel]: - try: - Session.query(Memory).filter_by(id=id).update( - {"content": content, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_memory_by_id(id) - except: - return None + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id).update( + {"content": content, "updated_at": int(time.time())} + ) + db.commit() + return self.get_memory_by_id(id) + except: + return None def get_memories(self) -> List[MemoryModel]: - try: - memories = Session.query(Memory).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]: - try: - memories = Session.query(Memory).filter_by(user_id=user_id).all() - return [MemoryModel.model_validate(memory) for memory in memories] - except: - return None + with get_db() as db: + + try: + memories = db.query(Memory).filter_by(user_id=user_id).all() + return [MemoryModel.model_validate(memory) for memory in memories] + except: + return None def get_memory_by_id(self, id: str) -> Optional[MemoryModel]: - try: - memory = Session.get(Memory, id) - return MemoryModel.model_validate(memory) - except: - return None + with get_db() as db: + + try: + memory = db.get(Memory, id) + return MemoryModel.model_validate(memory) + except: + return None def delete_memory_by_id(self, id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id).delete() - return True + with get_db() as db: - except: - return False + try: + db.query(Memory).filter_by(id=id).delete() + return True + + except: + return False def delete_memories_by_user_id(self, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(user_id=user_id).delete() + return True + except: + return False def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: - try: - Session.query(Memory).filter_by(id=id, user_id=user_id).delete() - return True - except: - return False + with get_db() as db: + + try: + db.query(Memory).filter_by(id=id, user_id=user_id).delete() + return True + except: + return False Memories = MemoriesTable() diff --git a/backend/apps/webui/models/models.py b/backend/apps/webui/models/models.py index 6543edefc..c95c36c7d 100644 --- a/backend/apps/webui/models/models.py +++ b/backend/apps/webui/models/models.py @@ -5,7 +5,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from typing import List, Union, Optional from config import SRC_LOG_LEVELS @@ -126,39 +126,46 @@ class ModelsTable: } ) try: - result = Model(**model.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ModelModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Model(**model.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + + if result: + return ModelModel.model_validate(result) + else: + return None except Exception as e: print(e) return None def get_all_models(self) -> List[ModelModel]: - return [ - ModelModel.model_validate(model) for model in Session.query(Model).all() - ] + with get_db() as db: + + return [ModelModel.model_validate(model) for model in db.query(Model).all()] def get_model_by_id(self, id: str) -> Optional[ModelModel]: try: - model = Session.get(Model, id) - return ModelModel.model_validate(model) + with get_db() as db: + + model = db.get(Model, id) + return ModelModel.model_validate(model) except: return None def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: try: - # update only the fields that are present in the model - model = Session.query(Model).get(id) - model.update(**model.model_dump()) - Session.commit() - Session.refresh(model) - return ModelModel.model_validate(model) + with get_db() as db: + + # update only the fields that are present in the model + model = db.query(Model).get(id) + model.update(**model.model_dump()) + db.commit() + db.refresh(model) + return ModelModel.model_validate(model) except Exception as e: print(e) @@ -166,8 +173,10 @@ class ModelsTable: def delete_model_by_id(self, id: str) -> bool: try: - Session.query(Model).filter_by(id=id).delete() - return True + with get_db() as db: + + db.query(Model).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/prompts.py b/backend/apps/webui/models/prompts.py index ab8cc04ce..2af2ce22c 100644 --- a/backend/apps/webui/models/prompts.py +++ b/backend/apps/webui/models/prompts.py @@ -4,7 +4,7 @@ import time from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db import json @@ -60,46 +60,56 @@ class PromptsTable: ) try: - result = Prompt(**prompt.dict()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return PromptModel.model_validate(result) - else: - return None + with get_db() as db: + + result = Prompt(**prompt.dict()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return PromptModel.model_validate(result) + else: + return None except Exception as e: return None def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + return PromptModel.model_validate(prompt) except: return None def get_prompts(self) -> List[PromptModel]: - return [ - PromptModel.model_validate(prompt) for prompt in Session.query(Prompt).all() - ] + with get_db() as db: + + return [ + PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all() + ] def update_prompt_by_command( self, command: str, form_data: PromptForm ) -> Optional[PromptModel]: try: - prompt = Session.query(Prompt).filter_by(command=command).first() - prompt.title = form_data.title - prompt.content = form_data.content - prompt.timestamp = int(time.time()) - Session.commit() - return PromptModel.model_validate(prompt) + with get_db() as db: + + prompt = db.query(Prompt).filter_by(command=command).first() + prompt.title = form_data.title + prompt.content = form_data.content + prompt.timestamp = int(time.time()) + db.commit() + return PromptModel.model_validate(prompt) except: return None def delete_prompt_by_command(self, command: str) -> bool: try: - Session.query(Prompt).filter_by(command=command).delete() - return True + with get_db() as db: + + db.query(Prompt).filter_by(command=command).delete() + return True except: return False diff --git a/backend/apps/webui/models/tags.py b/backend/apps/webui/models/tags.py index 7b0df6b6b..bbbc95ed2 100644 --- a/backend/apps/webui/models/tags.py +++ b/backend/apps/webui/models/tags.py @@ -8,7 +8,7 @@ import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, Session +from apps.webui.internal.db import Base, get_db from config import SRC_LOG_LEVELS @@ -79,26 +79,29 @@ class ChatTagsResponse(BaseModel): class TagTable: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - id = str(uuid.uuid4()) - tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) - try: - result = Tag(**tag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return TagModel.model_validate(result) - else: + with get_db() as db: + + id = str(uuid.uuid4()) + tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) + try: + result = Tag(**tag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return TagModel.model_validate(result) + else: + return None + except Exception as e: return None - except Exception as e: - return None def get_tag_by_name_and_user_id( self, name: str, user_id: str ) -> Optional[TagModel]: try: - tag = Session.query(Tag).filter(name=name, user_id=user_id).first() - return TagModel.model_validate(tag) + with get_db() as db: + tag = db.query(Tag).filter(name=name, user_id=user_id).first() + return TagModel.model_validate(tag) except Exception as e: return None @@ -120,98 +123,109 @@ class TagTable: } ) try: - result = ChatIdTag(**chatIdTag.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ChatIdTagModel.model_validate(result) - else: - return None + with get_db() as db: + result = ChatIdTag(**chatIdTag.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ChatIdTagModel.model_validate(result) + else: + return None except: return None def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: - tag_names = [ - chat_id_tag.tag_name - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, chat_id=chat_id) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: - return [ - TagModel.model_validate(tag) - for tag in ( - Session.query(Tag) - .filter_by(user_id=user_id) - .filter(Tag.name.in_(tag_names)) - .all() - ) - ] + tag_names = [ + chat_id_tag.tag_name + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, chat_id=chat_id) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] + + return [ + TagModel.model_validate(tag) + for tag in ( + db.query(Tag) + .filter_by(user_id=user_id) + .filter(Tag.name.in_(tag_names)) + .all() + ) + ] def get_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> List[ChatIdTagModel]: - return [ - ChatIdTagModel.model_validate(chat_id_tag) - for chat_id_tag in ( - Session.query(ChatIdTag) - .filter_by(user_id=user_id, tag_name=tag_name) - .order_by(ChatIdTag.timestamp.desc()) - .all() - ) - ] + with get_db() as db: + + return [ + ChatIdTagModel.model_validate(chat_id_tag) + for chat_id_tag in ( + db.query(ChatIdTag) + .filter_by(user_id=user_id, tag_name=tag_name) + .order_by(ChatIdTag.timestamp.desc()) + .all() + ) + ] def count_chat_ids_by_tag_name_and_user_id( self, tag_name: str, user_id: str ) -> int: - return ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .count() - ) + with get_db() as db: + + return ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .count() + ) def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + return True except Exception as e: log.error(f"delete_tag: {e}") return False @@ -220,20 +234,24 @@ class TagTable: self, tag_name: str, chat_id: str, user_id: str ) -> bool: try: - res = ( - Session.query(ChatIdTag) - .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) - .delete() - ) - log.debug(f"res: {res}") - Session.commit() + with get_db() as db: - tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) - if tag_count == 0: - # Remove tag item from Tag col as well - Session.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + res = ( + db.query(ChatIdTag) + .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) + .delete() + ) + log.debug(f"res: {res}") + db.commit() - return True + tag_count = self.count_chat_ids_by_tag_name_and_user_id( + tag_name, user_id + ) + if tag_count == 0: + # Remove tag item from Tag col as well + db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() + + return True except Exception as e: log.error(f"delete_tag: {e}") return False diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index f5df10637..dc0fe01c5 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -4,7 +4,7 @@ import time import logging from sqlalchemy import String, Column, BigInteger, Text -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.models.users import Users import json @@ -83,54 +83,63 @@ class ToolsTable: def insert_new_tool( self, user_id: str, form_data: ToolForm, specs: List[dict] ) -> Optional[ToolModel]: - tool = ToolModel( - **{ - **form_data.model_dump(), - "specs": specs, - "user_id": user_id, - "updated_at": int(time.time()), - "created_at": int(time.time()), - } - ) - try: - result = Tool(**tool.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return ToolModel.model_validate(result) - else: + with get_db() as db: + + tool = ToolModel( + **{ + **form_data.model_dump(), + "specs": specs, + "user_id": user_id, + "updated_at": int(time.time()), + "created_at": int(time.time()), + } + ) + + try: + result = Tool(**tool.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return ToolModel.model_validate(result) + else: + return None + except Exception as e: + print(f"Error creating tool: {e}") return None - except Exception as e: - print(f"Error creating tool: {e}") - return None def get_tool_by_id(self, id: str) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - return ToolModel.model_validate(tool) + with get_db() as db: + + tool = db.get(Tool, id) + return ToolModel.model_validate(tool) except: return None def get_tools(self) -> List[ToolModel]: - return [ToolModel.model_validate(tool) for tool in Session.query(Tool).all()] + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: - tool = Session.get(Tool, id) - return tool.valves if tool.valves else {} + with get_db() as db: + + tool = db.get(Tool, id) + return tool.valves if tool.valves else {} except Exception as e: print(f"An error occurred: {e}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: try: - Session.query(Tool).filter_by(id=id).update( - {"valves": valves, "updated_at": int(time.time())} - ) - Session.commit() - return self.get_tool_by_id(id) + with get_db() as db: + + db.query(Tool).filter_by(id=id).update( + {"valves": valves, "updated_at": int(time.time())} + ) + db.commit() + return self.get_tool_by_id(id) except: return None @@ -177,19 +186,21 @@ class ToolsTable: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: try: - tool = Session.get(Tool, id) - tool.update(**updated) - tool.updated_at = int(time.time()) - Session.commit() - Session.refresh(tool) - return ToolModel.model_validate(tool) + with get_db() as db: + tool = db.get(Tool, id) + tool.update(**updated) + tool.updated_at = int(time.time()) + db.commit() + db.refresh(tool) + return ToolModel.model_validate(tool) except: return None def delete_tool_by_id(self, id: str) -> bool: try: - Session.query(Tool).filter_by(id=id).delete() - return True + with get_db() as db: + db.query(Tool).filter_by(id=id).delete() + return True except: return False diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py index 9e1e25ac6..8e3b57bba 100644 --- a/backend/apps/webui/models/users.py +++ b/backend/apps/webui/models/users.py @@ -6,7 +6,7 @@ from sqlalchemy import String, Column, BigInteger, Text from utils.misc import get_gravatar_url -from apps.webui.internal.db import Base, JSONField, Session +from apps.webui.internal.db import Base, JSONField, Session, get_db from apps.webui.models.chats import Chats #################### @@ -88,81 +88,92 @@ class UsersTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - user = UserModel( - **{ - "id": id, - "name": name, - "email": email, - "role": role, - "profile_image_url": profile_image_url, - "last_active_at": int(time.time()), - "created_at": int(time.time()), - "updated_at": int(time.time()), - "oauth_sub": oauth_sub, - } - ) - result = User(**user.model_dump()) - Session.add(result) - Session.commit() - Session.refresh(result) - if result: - return user - else: - return None + with get_db() as db: + user = UserModel( + **{ + "id": id, + "name": name, + "email": email, + "role": role, + "profile_image_url": profile_image_url, + "last_active_at": int(time.time()), + "created_at": int(time.time()), + "updated_at": int(time.time()), + "oauth_sub": oauth_sub, + } + ) + result = User(**user.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + if result: + return user + else: + return None def get_user_by_id(self, id: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except Exception as e: return None def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(api_key=api_key).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(api_key=api_key).first() + return UserModel.model_validate(user) except: return None def get_user_by_email(self, email: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(email=email).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(email=email).first() + return UserModel.model_validate(user) except: return None def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: try: - user = Session.query(User).filter_by(oauth_sub=sub).first() - return UserModel.model_validate(user) + with get_db() as db: + + user = db.query(User).filter_by(oauth_sub=sub).first() + return UserModel.model_validate(user) except: return None def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: - users = ( - Session.query(User) - # .offset(skip).limit(limit) - .all() - ) - return [UserModel.model_validate(user) for user in users] + with get_db() as db: + users = ( + db.query(User) + # .offset(skip).limit(limit) + .all() + ) + return [UserModel.model_validate(user) for user in users] def get_num_users(self) -> Optional[int]: - return Session.query(User).count() + with get_db() as db: + return db.query(User).count() def get_first_user(self) -> UserModel: try: - user = Session.query(User).order_by(User.created_at).first() - return UserModel.model_validate(user) + with get_db() as db: + user = db.query(User).order_by(User.created_at).first() + return UserModel.model_validate(user) except: return None def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"role": role}) - Session.commit() - - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + with get_db() as db: + db.query(User).filter_by(id=id).update({"role": role}) + db.commit() + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -170,25 +181,28 @@ class UsersTable: self, id: str, profile_image_url: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"profile_image_url": profile_image_url} - ) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update( + {"profile_image_url": profile_image_url} + ) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update( - {"last_active_at": int(time.time())} - ) - Session.commit() + with get_db() as db: - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + db.query(User).filter_by(id=id).update( + {"last_active_at": int(time.time())} + ) + db.commit() + + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None @@ -196,21 +210,23 @@ class UsersTable: self, id: str, oauth_sub: str ) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) + with get_db() as db: + db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) except: return None def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: try: - Session.query(User).filter_by(id=id).update(updated) - Session.commit() + with get_db() as db: + db.query(User).filter_by(id=id).update(updated) + db.commit() - user = Session.query(User).filter_by(id=id).first() - return UserModel.model_validate(user) - # return UserModel(**user.dict()) + user = db.query(User).filter_by(id=id).first() + return UserModel.model_validate(user) + # return UserModel(**user.dict()) except Exception as e: return None @@ -220,9 +236,10 @@ class UsersTable: result = Chats.delete_chats_by_user_id(id) if result: - # Delete User - Session.query(User).filter_by(id=id).delete() - Session.commit() + with get_db() as db: + # Delete User + db.query(User).filter_by(id=id).delete() + db.commit() return True else: @@ -232,16 +249,18 @@ class UsersTable: def update_user_api_key_by_id(self, id: str, api_key: str) -> str: try: - result = Session.query(User).filter_by(id=id).update({"api_key": api_key}) - Session.commit() - return True if result == 1 else False + with get_db() as db: + result = db.query(User).filter_by(id=id).update({"api_key": api_key}) + db.commit() + return True if result == 1 else False except: return False def get_user_api_key_by_id(self, id: str) -> Optional[str]: try: - user = Session.query(User).filter_by(id=id).first() - return user.api_key + with get_db() as db: + user = db.query(User).filter_by(id=id).first() + return user.api_key except Exception as e: return None From 37a5d2c06b78098ed70f52e9fefdc824ad96d531 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 23:32:46 -0700 Subject: [PATCH 035/115] Update db.py --- backend/apps/webui/internal/db.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 320ab3e07..bfdc52c11 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -53,8 +53,19 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL: ) else: engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) + + SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) Base = declarative_base() Session = scoped_session(SessionLocal) + + +# Dependency +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() From 8fe2a7bb75e222f49f177437a0e1b5279b23a37e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Wed, 3 Jul 2024 23:39:16 -0700 Subject: [PATCH 036/115] fix --- backend/apps/webui/internal/db.py | 8 +++++++- backend/apps/webui/models/tools.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index bfdc52c11..333e215ea 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -62,10 +62,16 @@ Base = declarative_base() Session = scoped_session(SessionLocal) +from contextlib import contextmanager + + # Dependency -def get_db(): +def get_session(): db = SessionLocal() try: yield db finally: db.close() + + +get_db = contextmanager(get_session) diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index dc0fe01c5..4cc06826a 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -119,7 +119,8 @@ class ToolsTable: return None def get_tools(self) -> List[ToolModel]: - return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] + with get_db() as db: + return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: try: From 8b13755d5634d76840077c8bdcac6def93d86a70 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 00:25:45 -0700 Subject: [PATCH 037/115] Update auths.py --- backend/apps/webui/models/auths.py | 89 +++++++++++++++++------------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py index 48c8b543a..7698359f9 100644 --- a/backend/apps/webui/models/auths.py +++ b/backend/apps/webui/models/auths.py @@ -102,40 +102,44 @@ class AuthsTable: role: str = "pending", oauth_sub: Optional[str] = None, ) -> Optional[UserModel]: - log.info("insert_new_auth") + with get_db() as db: - id = str(uuid.uuid4()) + log.info("insert_new_auth") - auth = AuthModel( - **{"id": id, "email": email, "password": password, "active": True} - ) - result = Auth(**auth.model_dump()) - db.add(result) + id = str(uuid.uuid4()) - user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth_sub - ) + auth = AuthModel( + **{"id": id, "email": email, "password": password, "active": True} + ) + result = Auth(**auth.model_dump()) + db.add(result) - db.commit() - db.refresh(result) + user = Users.insert_new_user( + id, name, email, profile_image_url, role, oauth_sub + ) - if result and user: - return user - else: - return None + db.commit() + db.refresh(result) + + if result and user: + return user + else: + return None def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") try: - auth = db.query(Auth).filter_by(email=email, active=True).first() - if auth: - if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) - return user + with get_db() as db: + + auth = db.query(Auth).filter_by(email=email, active=True).first() + if auth: + if verify_password(password, auth.password): + user = Users.get_user_by_id(auth.id) + return user + else: + return None else: return None - else: - return None except: return None @@ -154,38 +158,47 @@ class AuthsTable: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: log.info(f"authenticate_user_by_trusted_header: {email}") try: - auth = db.query(Auth).filter(email=email, active=True).first() - if auth: - user = Users.get_user_by_id(auth.id) - return user + with get_db() as db: + auth = db.query(Auth).filter(email=email, active=True).first() + if auth: + user = Users.get_user_by_id(auth.id) + return user except: return None def update_user_password_by_id(self, id: str, new_password: str) -> bool: try: - result = db.query(Auth).filter_by(id=id).update({"password": new_password}) - return True if result == 1 else False + with get_db() as db: + + result = ( + db.query(Auth).filter_by(id=id).update({"password": new_password}) + ) + return True if result == 1 else False except: return False def update_email_by_id(self, id: str, email: str) -> bool: try: - result = db.query(Auth).filter_by(id=id).update({"email": email}) - return True if result == 1 else False + with get_db() as db: + + result = db.query(Auth).filter_by(id=id).update({"email": email}) + return True if result == 1 else False except: return False def delete_auth_by_id(self, id: str) -> bool: try: - # Delete User - result = Users.delete_user_by_id(id) + with get_db() as db: - if result: - db.query(Auth).filter_by(id=id).delete() + # Delete User + result = Users.delete_user_by_id(id) - return True - else: - return False + if result: + db.query(Auth).filter_by(id=id).delete() + + return True + else: + return False except: return False From 9a6cbafdef7a1a44c7e3ad914996204d07c4a77e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 00:37:05 -0700 Subject: [PATCH 038/115] fix: user valves --- backend/apps/webui/models/functions.py | 4 ++-- backend/apps/webui/models/tools.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 677f022f6..33a9d1297 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -185,7 +185,7 @@ class FunctionsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings if "functions" not in user_settings: @@ -203,7 +203,7 @@ class FunctionsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "functions" and "valves" settings if "functions" not in user_settings: diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py index 950972c2d..e7830e214 100644 --- a/backend/apps/webui/models/tools.py +++ b/backend/apps/webui/models/tools.py @@ -141,7 +141,7 @@ class ToolsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings if "tools" not in user_settings: @@ -159,7 +159,7 @@ class ToolsTable: ) -> Optional[dict]: try: user = Users.get_user_by_id(user_id) - user_settings = user.settings.model_dump() + user_settings = user.settings.model_dump() if user.settings else {} # Check if user has "tools" and "valves" settings if "tools" not in user_settings: From 740b6f5c17533350ae002f62e0097d8730350c04 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 00:42:18 -0700 Subject: [PATCH 039/115] fix: pull model --- src/lib/components/admin/Settings/Models.svelte | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index 57d0be135..b95829826 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -158,12 +158,14 @@ return; } - const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch( - (error) => { - toast.error(error); - return null; - } - ); + const [res, controller] = await pullModel( + localStorage.token, + sanitizedModelTag, + selectedOllamaUrlIdx + ).catch((error) => { + toast.error(error); + return null; + }); if (res) { const reader = res.body From 05277556005230847f552b55c2d896ecd57fe281 Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 4 Jul 2024 12:21:09 +0100 Subject: [PATCH 040/115] use data field --- src/lib/components/chat/Chat.svelte | 31 ++++++++++++----------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 87bd9b4de..de64d2681 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -133,29 +133,24 @@ let message = history.messages[data.message_id]; const type = data?.data?.type ?? null; - if (type === "status") { - const status = { - done: data?.data?.done ?? null, - description: data?.data?.status ?? null - }; - + const payload = data?.data?.data ?? null; + if (!type || !payload) { + console.log("Data and type fields must be provided.", data); + return; + } + const status_keys = ["done", "description"]; + const citation_keys = ["document", "metadata", "source"]; + if (type === "status" && status_keys.every(key => key in payload)) { if (message.statusHistory) { - message.statusHistory.push(status); + message.statusHistory.push(payload); } else { - message.statusHistory = [status]; + message.statusHistory = [payload]; } - } else if (type === "citation") { - console.log(data); - const citation = { - document: data?.data?.document ?? null, - metadata: data?.data?.metadata ?? null, - source: data?.data?.source ?? null - }; - + } else if (type === "citation" && citation_keys.every(key => key in payload)) { if (message.citations) { - message.citations.push(citation); + message.citations.push(payload); } else { - message.citations = [citation]; + message.citations = [payload]; } } else { console.log("Unknown message type", data); From d20601dc475034d51a7617ba9ceedb84fdbacabf Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 13:53:28 +0000 Subject: [PATCH 041/115] feat: Add custom Collapsible component for collapsible content --- src/lib/components/common/Collapsible.svelte | 37 ++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/lib/components/common/Collapsible.svelte diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte new file mode 100644 index 000000000..c87ffe8ba --- /dev/null +++ b/src/lib/components/common/Collapsible.svelte @@ -0,0 +1,37 @@ + + + + +
+ +
+ +
+
\ No newline at end of file From 2389c36a70d55ee6da4164b9e085a322e488a194 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 13:55:37 +0000 Subject: [PATCH 042/115] refactor: Update WebSearchResults.svelte to use new CollapsibleComponent --- .../ResponseMessage/WebSearchResults.svelte | 146 +++++++++--------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte index 528108036..25001730e 100644 --- a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte @@ -2,17 +2,18 @@ import ChevronDown from '$lib/components/icons/ChevronDown.svelte'; import ChevronUp from '$lib/components/icons/ChevronUp.svelte'; import MagnifyingGlass from '$lib/components/icons/MagnifyingGlass.svelte'; - import { Collapsible } from 'bits-ui'; - import { slide } from 'svelte/transition'; + import Collapsible from '$lib/components/common/Collapsible.svelte'; + export let status = { urls: [], query: '' }; let state = false; - - +
+
@@ -22,76 +23,75 @@ {/if}
- - - - {#if status?.query} - -
- - -
- {status.query} + - -
- - - - -
-
- {/if} - - {#each status.urls as url, urlIdx} - -
- {url} -
- -
+ + + +
+
+ {/if} + + {#each status.urls as url, urlIdx} + - - + {url} +
+ +
- - -
-
- {/each} - - + + + + +
+ + {/each} +
+ + \ No newline at end of file From d5c0876a0b180cfc413a7dfb55ae4fe34f2f5d52 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 14:02:26 +0000 Subject: [PATCH 043/115] refactor: fixed new Collapsible Component to allow passed in classes chore: format --- .../ResponseMessage/WebSearchResults.svelte | 160 +++++++++--------- src/lib/components/common/Collapsible.svelte | 31 ++-- 2 files changed, 92 insertions(+), 99 deletions(-) diff --git a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte index 25001730e..4523c8482 100644 --- a/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte +++ b/src/lib/components/chat/Messages/ResponseMessage/WebSearchResults.svelte @@ -4,94 +4,88 @@ import MagnifyingGlass from '$lib/components/icons/MagnifyingGlass.svelte'; import Collapsible from '$lib/components/common/Collapsible.svelte'; - export let status = { urls: [], query: '' }; let state = false; -
- -
- + +
+ + + {#if state} + + {:else} + + {/if} +
+
+ {#if status?.query} + +
+ - {#if state} - - {:else} - - {/if} -
-
\ No newline at end of file + + +
+ + {/if} + + {#each status.urls as url, urlIdx} + +
+ {url} +
+ +
+ + + + +
+
+ {/each} +
+
diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index c87ffe8ba..b681143a6 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -2,11 +2,11 @@ import { afterUpdate } from 'svelte'; export let open = false; - + export let className = ''; // Manage the max-height of the collapsible content for snappy transitions let contentElement: HTMLElement; - let maxHeight = '0px'; // Initial max-height + let maxHeight = '0px'; // Initial max-height // After any state update, adjust the max-height for the transition afterUpdate(() => { if (open) { @@ -15,23 +15,22 @@ } else { maxHeight = '0px'; } - }); - + }); +
+ +
+ +
+
+ - -
- -
- -
-
\ No newline at end of file From db58bb5f0f51521fa5c52e1b4e8107e6275904ad Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 14:15:16 +0000 Subject: [PATCH 044/115] refactor: Removed dependency --- src/lib/components/common/Collapsible.svelte | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index b681143a6..0a140d9dd 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -1,21 +1,19 @@
From 78ba18a680f9cce4c895279282ecf60fc581f382 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 14:55:48 +0000 Subject: [PATCH 045/115] refactor: Update Collapsible component to include dynamic margin for open state --- src/lib/components/common/Collapsible.svelte | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index 0a140d9dd..14e5785a4 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -20,7 +20,7 @@ -
+
@@ -28,7 +28,7 @@ From f611533764ece12128d5a3daaa4a0ee53e0e3b64 Mon Sep 17 00:00:00 2001 From: Karl Lee <61072264+KarlLee830@users.noreply.github.com> Date: Thu, 4 Jul 2024 22:57:32 +0800 Subject: [PATCH 046/115] i18n: Update Chinese translation --- src/lib/i18n/locales/zh-CN/translation.json | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/lib/i18n/locales/zh-CN/translation.json b/src/lib/i18n/locales/zh-CN/translation.json index d5887e2ff..366b717f0 100644 --- a/src/lib/i18n/locales/zh-CN/translation.json +++ b/src/lib/i18n/locales/zh-CN/translation.json @@ -126,7 +126,7 @@ "Connections": "外部连接", "Contact Admin for WebUI Access": "请联系管理员以获取访问权限", "Content": "内容", - "Content Extraction": "", + "Content Extraction": "内容提取", "Context Length": "上下文长度", "Continue Response": "继续生成", "Continue with {{provider}}": "使用 {{provider}} 继续", @@ -213,7 +213,7 @@ "Enable Community Sharing": "启用分享至社区", "Enable New Sign Ups": "允许新用户注册", "Enable Web Search": "启用网络搜索", - "Engine": "", + "Engine": "引擎", "Ensure your CSV file includes 4 columns in this order: Name, Email, Password, Role.": "确保您的 CSV 文件按以下顺序包含 4 列: 姓名、电子邮箱、密码、角色。", "Enter {{role}} message here": "在此处输入 {{role}} 信息", "Enter a detail about yourself for your LLMs to recall": "输入一个关于你自己的详细信息,方便你的大语言模型记住这些内容", @@ -235,7 +235,7 @@ "Enter Serpstack API Key": "输入 Serpstack API 密钥", "Enter stop sequence": "输入停止序列 (Stop Sequence)", "Enter Tavily API Key": "输入 Tavily API 密钥", - "Enter Tika Server URL": "", + "Enter Tika Server URL": "输入 Tika 服务器地址", "Enter Top K": "输入 Top K", "Enter URL (e.g. http://127.0.0.1:7860/)": "输入地址 (例如:http://127.0.0.1:7860/)", "Enter URL (e.g. http://localhost:11434)": "输入地址 (例如:http://localhost:11434)", @@ -412,7 +412,7 @@ "Open": "打开", "Open AI (Dall-E)": "Open AI (Dall-E)", "Open new chat": "打开新对话", - "Open WebUI version (v{{OPEN_WEBUI_VERSION}}) is lower than required version (v{{REQUIRED_VERSION}})": "", + "Open WebUI version (v{{OPEN_WEBUI_VERSION}}) is lower than required version (v{{REQUIRED_VERSION}})": "当前 Open WebUI 版本 (v{{OPEN_WEBUI_VERSION}}) 低于所需的版本 (v{{REQUIRED_VERSION}})", "OpenAI": "OpenAI", "OpenAI API": "OpenAI API", "OpenAI API Config": "OpenAI API 配置", @@ -428,8 +428,8 @@ "Permission denied when accessing microphone": "申请麦克风权限被拒绝", "Permission denied when accessing microphone: {{error}}": "申请麦克风权限被拒绝:{{error}}", "Personalization": "个性化", - "Pin": "", - "Pinned": "", + "Pin": "置顶", + "Pinned": "已置顶", "Pipeline deleted successfully": "Pipeline 删除成功", "Pipeline downloaded successfully": "Pipeline 下载成功", "Pipelines": "Pipeline", @@ -578,8 +578,8 @@ "This setting does not sync across browsers or devices.": "此设置不会在浏览器或设备之间同步。", "This will delete": "这将删除", "Thorough explanation": "解释较为详细", - "Tika": "", - "Tika Server URL required.": "", + "Tika": "Tika", + "Tika Server URL required.": "请输入 Tika 服务器地址。", "Tip: Update multiple variable slots consecutively by pressing the tab key in the chat input after each replacement.": "提示:在每次替换后,在对话输入中按 Tab 键可以连续更新多个变量。", "Title": "标题", "Title (e.g. Tell me a fun fact)": "标题(例如 给我讲一个有趣的事实)", @@ -614,7 +614,7 @@ "Uh-oh! There was an issue connecting to {{provider}}.": "糟糕!连接到 {{provider}} 时出现问题。", "UI": "界面", "Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.": "未知文件类型“{{file_type}}”,将无视继续上传文件。", - "Unpin": "", + "Unpin": "取消置顶", "Update": "更新", "Update and Copy Link": "更新和复制链接", "Update password": "更新密码", From ca3f8e6cb52231a21f7c157fd1e38504665b1793 Mon Sep 17 00:00:00 2001 From: rdavis Date: Thu, 4 Jul 2024 15:18:21 +0000 Subject: [PATCH 047/115] chore: format --- src/lib/components/common/Collapsible.svelte | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lib/components/common/Collapsible.svelte b/src/lib/components/common/Collapsible.svelte index 14e5785a4..8a3ef9690 100644 --- a/src/lib/components/common/Collapsible.svelte +++ b/src/lib/components/common/Collapsible.svelte @@ -20,7 +20,11 @@ -
+
From 55b7c30028c96dc58e14b563dcd26780dbea34cb Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Thu, 4 Jul 2024 18:50:09 +0100 Subject: [PATCH 048/115] simplify citation API --- src/lib/components/chat/Chat.svelte | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index de64d2681..a087e76ed 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -139,7 +139,7 @@ return; } const status_keys = ["done", "description"]; - const citation_keys = ["document", "metadata", "source"]; + const citation_keys = ["document", "url", "title"]; if (type === "status" && status_keys.every(key => key in payload)) { if (message.statusHistory) { message.statusHistory.push(payload); @@ -147,10 +147,15 @@ message.statusHistory = [payload]; } } else if (type === "citation" && citation_keys.every(key => key in payload)) { + const citation = { + document: [payload.document], + metadata: [{source: payload.url}], + source: {name: payload.title} + }; if (message.citations) { - message.citations.push(payload); + message.citations.push(citation); } else { - message.citations = [payload]; + message.citations = [citation]; } } else { console.log("Unknown message type", data); From 67c2ab006d06e442c4ca7cc4e0293e119f67f715 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 4 Jul 2024 13:41:18 -0700 Subject: [PATCH 049/115] fix: pipe custom model --- backend/apps/webui/main.py | 76 ++++++++++++++++++++++++++++++++++++++ backend/main.py | 6 ++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 552edf7fa..745157ac6 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -19,8 +19,13 @@ from apps.webui.routers import ( functions, ) from apps.webui.models.functions import Functions +from apps.webui.models.models import Models + from apps.webui.utils import load_function_module_by_id + from utils.misc import stream_message_template +from utils.task import prompt_template + from config import ( WEBUI_BUILD_HASH, @@ -186,6 +191,77 @@ async def get_pipe_models(): async def generate_function_chat_completion(form_data, user): + model_id = form_data.get("model") + model_info = Models.get_model_by_id(model_id) + + if model_info: + if model_info.base_model_id: + form_data["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + if model_info.params.get("temperature", None) is not None: + form_data["temperature"] = float(model_info.params.get("temperature")) + + if model_info.params.get("top_p", None): + form_data["top_p"] = int(model_info.params.get("top_p", None)) + + if model_info.params.get("max_tokens", None): + form_data["max_tokens"] = int(model_info.params.get("max_tokens", None)) + + if model_info.params.get("frequency_penalty", None): + form_data["frequency_penalty"] = int( + model_info.params.get("frequency_penalty", None) + ) + + if model_info.params.get("seed", None): + form_data["seed"] = model_info.params.get("seed", None) + + if model_info.params.get("stop", None): + form_data["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + system = model_info.params.get("system", None) + if system: + system = prompt_template( + system, + **( + { + "user_name": user.name, + "user_location": ( + user.info.get("location") if user.info else None + ), + } + if user + else {} + ), + ) + # Check if the payload already has a system message + # If not, add a system message to the payload + if form_data.get("messages"): + for message in form_data["messages"]: + if message.get("role") == "system": + message["content"] = system + message["content"] + break + else: + form_data["messages"].insert( + 0, + { + "role": "system", + "content": system, + }, + ) + + else: + pass + async def job(): pipe_id = form_data["model"] if "." in pipe_id: diff --git a/backend/main.py b/backend/main.py index 8f818c85b..f2019b30f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -975,12 +975,16 @@ async def get_all_models(): model["info"] = custom_model.model_dump() else: owned_by = "openai" + pipe = None + for model in models: if ( custom_model.base_model_id == model["id"] or custom_model.base_model_id == model["id"].split(":")[0] ): owned_by = model["owned_by"] + if "pipe" in model: + pipe = model["pipe"] break models.append( @@ -992,11 +996,11 @@ async def get_all_models(): "owned_by": owned_by, "info": custom_model.model_dump(), "preset": True, + **({"pipe": pipe} if pipe is not None else {}), } ) app.state.MODELS = {model["id"]: model for model in models} - webui_app.state.MODELS = app.state.MODELS return models From 838134637818ae64127bcb27a9208b0466b438d4 Mon Sep 17 00:00:00 2001 From: Peter De-Ath Date: Fri, 5 Jul 2024 02:05:59 +0100 Subject: [PATCH 050/115] enh: add sideways scrolling to settings tabs container --- src/lib/components/admin/Settings.svelte | 17 ++++++++++-- src/lib/components/chat/SettingsModal.svelte | 29 +++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 5538a11cf..24cf595a7 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -1,5 +1,5 @@
-
- -
-
- + {#if open} +
+ +
+ {/if} +
From 1436bb7c61b1df4dba2b5b383ecb8c86ec452f37 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 5 Jul 2024 23:38:53 -0700 Subject: [PATCH 055/115] enh: handle peewee migration --- backend/apps/webui/internal/db.py | 36 +++++++++++-- backend/apps/webui/internal/wrappers.py | 72 +++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 backend/apps/webui/internal/wrappers.py diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py index 333e215ea..8437ae4fa 100644 --- a/backend/apps/webui/internal/db.py +++ b/backend/apps/webui/internal/db.py @@ -2,6 +2,10 @@ import os import logging import json from contextlib import contextmanager + +from peewee_migrate import Router +from apps.webui.internal.wrappers import register_connection + from typing import Optional, Any from typing_extensions import Self @@ -46,6 +50,35 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): else: pass + +# Workaround to handle the peewee migration +# This is required to ensure the peewee migration is handled before the alembic migration +def handle_peewee_migration(): + try: + db = register_connection(DATABASE_URL) + migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations" + router = Router(db, logger=log, migrate_dir=migrate_dir) + router.run() + db.close() + + # check if db connection has been closed + + except Exception as e: + log.error(f"Failed to initialize the database connection: {e}") + raise + + finally: + # Properly closing the database connection + if db and not db.is_closed(): + db.close() + + # Assert if db connection has been closed + assert db.is_closed(), "Database connection is still open." + + +handle_peewee_migration() + + SQLALCHEMY_DATABASE_URL = DATABASE_URL if "sqlite" in SQLALCHEMY_DATABASE_URL: engine = create_engine( @@ -62,9 +95,6 @@ Base = declarative_base() Session = scoped_session(SessionLocal) -from contextlib import contextmanager - - # Dependency def get_session(): db = SessionLocal() diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py new file mode 100644 index 000000000..2b5551ce2 --- /dev/null +++ b/backend/apps/webui/internal/wrappers.py @@ -0,0 +1,72 @@ +from contextvars import ContextVar +from peewee import * +from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError + +import logging +from playhouse.db_url import connect, parse +from playhouse.shortcuts import ReconnectMixin + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["DB"]) + +db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} +db_state = ContextVar("db_state", default=db_state_default.copy()) + + +class PeeweeConnectionState(object): + def __init__(self, **kwargs): + super().__setattr__("_state", db_state) + super().__init__(**kwargs) + + def __setattr__(self, name, value): + self._state.get()[name] = value + + def __getattr__(self, name): + value = self._state.get()[name] + return value + + +class CustomReconnectMixin(ReconnectMixin): + reconnect_errors = ( + # psycopg2 + (OperationalError, "termin"), + (InterfaceError, "closed"), + # peewee + (PeeWeeInterfaceError, "closed"), + ) + + +class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): + pass + + +def register_connection(db_url): + db = connect(db_url) + if isinstance(db, PostgresqlDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to PostgreSQL database") + + # Get the connection details + connection = parse(db_url) + + # Use our custom database class that supports reconnection + db = ReconnectingPostgresqlDatabase( + connection["database"], + user=connection["user"], + password=connection["password"], + host=connection["host"], + port=connection["port"], + ) + db.connect(reuse_if_open=True) + elif isinstance(db, SqliteDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to SQLite database") + else: + raise ValueError("Unsupported database connection") + return db From d5716ae751f2ea24bda45fad810750b5c7c72b29 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 5 Jul 2024 23:48:53 -0700 Subject: [PATCH 056/115] chore: format --- src/lib/components/admin/Settings.svelte | 3 ++- src/lib/i18n/locales/fr-CA/translation.json | 16 +++++++++------- src/lib/i18n/locales/fr-FR/translation.json | 20 +++++++++++--------- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/lib/components/admin/Settings.svelte b/src/lib/components/admin/Settings.svelte index 24cf595a7..661401443 100644 --- a/src/lib/components/admin/Settings.svelte +++ b/src/lib/components/admin/Settings.svelte @@ -38,7 +38,8 @@