diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 6a355038b..bef91443a 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -18,8 +18,9 @@ import requests from pydantic import BaseModel, ConfigDict from typing import Optional, List +from apps.web.models.models import Models from utils.utils import get_verified_user, get_current_user, get_admin_user -from config import SRC_LOG_LEVELS, ENV +from config import SRC_LOG_LEVELS from constants import MESSAGES import os @@ -77,7 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file: app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value - +app.state.MODEL_CONFIG = Models.get_all_models() app.state.ENABLE = ENABLE_LITELLM app.state.CONFIG = litellm_config @@ -241,6 +242,8 @@ async def get_models(user=Depends(get_current_user)): ) ) + for model in data["data"]: + add_custom_info_to_model(model) return data except Exception as e: @@ -261,6 +264,14 @@ async def get_models(user=Depends(get_current_user)): "object": "model", "created": int(time.time()), "owned_by": "openai", + "custom_info": next( + ( + item + for item in app.state.MODEL_CONFIG + if item.id == model["model_name"] + ), + None, + ), } for model in app.state.CONFIG["model_list"] ], @@ -273,6 +284,12 @@ async def get_models(user=Depends(get_current_user)): } +def add_custom_info_to_model(model: dict): + model["custom_info"] = next( + (item for item in app.state.MODEL_CONFIG if item.id == model["id"]), None + ) + + @app.get("/model/info") async def get_model_list(user=Depends(get_admin_user)): return {"data": app.state.CONFIG["model_list"]} diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index fb8a35a17..178cfa5fd 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -29,7 +29,7 @@ import time from urllib.parse import urlparse from typing import Optional, List, Union - +from apps.web.models.models import Models from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( @@ -67,6 +67,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.MODEL_CONFIG = Models.get_all_models() app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API @@ -191,12 +192,21 @@ async def get_all_models(): else: models = {"models": []} + + for model in models["models"]: + add_custom_info_to_model(model) app.state.MODELS = {model["model"]: model for model in models["models"]} return models +def add_custom_info_to_model(model: dict): + model["custom_info"] = next( + (item for item in app.state.MODEL_CONFIG if item.id == model["model"]), None + ) + + @app.get("/api/tags") @app.get("/api/tags/{url_idx}") async def get_ollama_tags( diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 6659ebfcf..0e2f28409 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -10,7 +10,7 @@ import logging from pydantic import BaseModel - +from apps.web.models.models import Models from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( @@ -52,6 +52,7 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.MODEL_CONFIG = Models.get_all_models() app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API @@ -249,10 +250,19 @@ async def get_all_models(): ) } + for model in models["data"]: + add_custom_info_to_model(model) + log.info(f"models: {models}") app.state.MODELS = {model["id"]: model for model in models["data"]} - return models + return models + + +def add_custom_info_to_model(model: dict): + model["custom_info"] = next( + (item for item in app.state.MODEL_CONFIG if item.id == model["id"]), None + ) @app.get("/models") diff --git a/backend/apps/web/internal/db.py b/backend/apps/web/internal/db.py index a6051de50..c8011460c 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/web/internal/db.py @@ -1,3 +1,5 @@ +import json + from peewee import * from peewee_migrate import Router from playhouse.db_url import connect @@ -8,6 +10,16 @@ import logging log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) + +class JSONField(TextField): + def db_value(self, value): + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + + # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file diff --git a/backend/apps/web/internal/migrations/009_add_models.py b/backend/apps/web/internal/migrations/009_add_models.py new file mode 100644 index 000000000..276769441 --- /dev/null +++ b/backend/apps/web/internal/migrations/009_add_models.py @@ -0,0 +1,55 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + meta = pw.TextField() + base_model_id = pw.TextField(null=True) + name = pw.TextField() + params = pw.TextField() + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/web/models/models.py b/backend/apps/web/models/models.py new file mode 100644 index 000000000..cd734e67b --- /dev/null +++ b/backend/apps/web/models/models.py @@ -0,0 +1,136 @@ +import json +import logging +from typing import Optional + +import peewee as pw +from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel + +from apps.web.internal.db import DB, JSONField + +from config import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +#################### +# Models DB Schema +#################### + + +# ModelParams is a model for the data stored in the params field of the Model table +# It isn't currently used in the backend, but it's here as a reference +class ModelParams(BaseModel): + pass + + +# ModelMeta is a model for the data stored in the meta field of the Model table +# It isn't currently used in the backend, but it's here as a reference +class ModelMeta(BaseModel): + description: str + """ + User-facing description of the model. + """ + + vision_capable: bool + """ + A flag indicating if the model is capable of vision and thus image inputs + """ + + +class Model(pw.Model): + id = pw.TextField(unique=True) + """ + The model's id as used in the API. If set to an existing model, it will override the model. + """ + + meta = JSONField() + """ + Holds a JSON encoded blob of metadata, see `ModelMeta`. + """ + + base_model_id = pw.TextField(null=True) + """ + An optional pointer to the actual model that should be used when proxying requests. + Currently unused - but will be used to support Modelfile like behaviour in the future + """ + + name = pw.TextField() + """ + The human-readable display name of the model. + """ + + params = JSONField() + """ + Holds a JSON encoded blob of parameters, see `ModelParams`. + """ + + class Meta: + database = DB + + +class ModelModel(BaseModel): + id: str + meta: ModelMeta + base_model_id: Optional[str] = None + name: str + params: ModelParams + + +#################### +# Forms +#################### + + +class ModelsTable: + + def __init__( + self, + db: pw.SqliteDatabase | pw.PostgresqlDatabase, + ): + self.db = db + self.db.create_tables([Model]) + + def get_all_models(self) -> list[ModelModel]: + return [ModelModel(**model_to_dict(model)) for model in Model.select()] + + def update_all_models(self, models: list[ModelModel]) -> bool: + try: + with self.db.atomic(): + # Fetch current models from the database + current_models = self.get_all_models() + current_model_dict = {model.id: model for model in current_models} + + # Create a set of model IDs from the current models and the new models + current_model_keys = set(current_model_dict.keys()) + new_model_keys = set(model.id for model in models) + + # Determine which models need to be created, updated, or deleted + models_to_create = [ + model for model in models if model.id not in current_model_keys + ] + models_to_update = [ + model for model in models if model.id in current_model_keys + ] + models_to_delete = current_model_keys - new_model_keys + + # Perform the necessary database operations + for model in models_to_create: + Model.create(**model.model_dump()) + + for model in models_to_update: + Model.update(**model.model_dump()).where( + Model.id == model.id + ).execute() + + for model_id, model_source in models_to_delete: + Model.delete().where(Model.id == model_id).execute() + + return True + except Exception as e: + log.exception(e) + return False + + +Models = ModelsTable(DB) diff --git a/backend/main.py b/backend/main.py index df79a3106..7e505c7cd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -36,9 +36,9 @@ from apps.web.main import app as webui_app import asyncio from pydantic import BaseModel -from typing import List - +from typing import List, Optional +from apps.web.models.models import Models, ModelModel from utils.utils import get_admin_user from apps.rag.utils import rag_messages @@ -113,6 +113,8 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST +app.state.MODEL_CONFIG = Models.get_all_models() + app.state.config.WEBHOOK_URL = WEBHOOK_URL origins = ["*"] @@ -318,6 +320,33 @@ async def update_model_filter_config( } +class SetModelConfigForm(BaseModel): + models: List[ModelModel] + + +@app.post("/api/config/models") +async def update_model_config( + form_data: SetModelConfigForm, user=Depends(get_admin_user) +): + if not Models.update_all_models(form_data.models): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=ERROR_MESSAGES.DEFAULT("Failed to update model config"), + ) + + ollama_app.state.MODEL_CONFIG = form_data.models + openai_app.state.MODEL_CONFIG = form_data.models + litellm_app.state.MODEL_CONFIG = form_data.models + app.state.MODEL_CONFIG = form_data.models + + return {"models": app.state.MODEL_CONFIG} + + +@app.get("/api/config/models") +async def get_model_config(user=Depends(get_admin_user)): + return {"models": app.state.MODEL_CONFIG} + + @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a610f7210..a7b59a7ca 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -196,3 +196,77 @@ export const updateWebhookUrl = async (token: string, url: string) => { return res.url; }; + +export const getModelConfig = async (token: string): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res.models; +}; + +export interface ModelConfig { + id: string; + name: string; + meta: ModelMeta; + base_model_id?: string; + params: ModelParams; +} + +export interface ModelMeta { + description?: string; + vision_capable?: boolean; +} + +export interface ModelParams {} + +export type GlobalModelConfig = ModelConfig[]; + +export const updateModelConfig = async (token: string, config: GlobalModelConfig) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + models: config + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/litellm/index.ts b/src/lib/apis/litellm/index.ts index 643146b73..b1c24c5bd 100644 --- a/src/lib/apis/litellm/index.ts +++ b/src/lib/apis/litellm/index.ts @@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => { id: model.id, name: model.name ?? model.id, external: true, - source: 'LiteLLM' + source: 'LiteLLM', + custom_info: model.custom_info })) .sort((a, b) => { return a.name.localeCompare(b.name); diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 02281eff0..8afcec018 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -230,7 +230,12 @@ export const getOpenAIModels = async (token: string = '') => { return models ? models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) + .map((model) => ({ + id: model.id, + name: model.name ?? model.id, + external: true, + custom_info: model.custom_info + })) .sort((a, b) => { return a.name.localeCompare(b.name); }) diff --git a/src/lib/components/admin/Settings/Users.svelte b/src/lib/components/admin/Settings/Users.svelte index f2a8bb19a..6f4144634 100644 --- a/src/lib/components/admin/Settings/Users.svelte +++ b/src/lib/components/admin/Settings/Users.svelte @@ -125,7 +125,7 @@ {#each $models.filter((model) => model.id) as model} {model.custom_info?.name ?? model.name} {/each} diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 871647101..f2cf16007 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -10,6 +10,7 @@ chatId, chats, config, + type Model, modelfiles, models, settings, @@ -60,7 +61,7 @@ let showModelSelector = true; let selectedModels = ['']; - let atSelectedModel = ''; + let atSelectedModel: Model | undefined; let selectedModelfile = null; $: selectedModelfile = @@ -328,75 +329,91 @@ const _chatId = JSON.parse(JSON.stringify($chatId)); await Promise.all( - (modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map( - async (modelId) => { - console.log('modelId', modelId); - const model = $models.filter((m) => m.id === modelId).at(0); + (modelId + ? [modelId] + : atSelectedModel !== undefined + ? [atSelectedModel.id] + : selectedModels + ).map(async (modelId) => { + console.log('modelId', modelId); + const model = $models.filter((m) => m.id === modelId).at(0); - if (model) { - // Create response message - let responseMessageId = uuidv4(); - let responseMessage = { - parentId: parentId, - id: responseMessageId, - childrenIds: [], - role: 'assistant', - content: '', - model: model.id, - userContext: null, - timestamp: Math.floor(Date.now() / 1000) // Unix epoch - }; + if (model) { + // If there are image files, check if model is vision capable + const hasImages = messages.some((message) => + message.files?.some((file) => file.type === 'image') + ); + if (hasImages && !(model.custom_info?.meta.vision_capable ?? true)) { + toast.error( + $i18n.t('Model {{modelName}} is not vision capable', { + modelName: model.custom_info?.name ?? model.name ?? model.id + }) + ); + } - // Add message to history and Set currentId to messageId - history.messages[responseMessageId] = responseMessage; - history.currentId = responseMessageId; + // Create response message + let responseMessageId = uuidv4(); + let responseMessage = { + parentId: parentId, + id: responseMessageId, + childrenIds: [], + role: 'assistant', + content: '', + model: model.id, + modelName: model.custom_info?.name ?? model.name ?? model.id, + userContext: null, + timestamp: Math.floor(Date.now() / 1000) // Unix epoch + }; - // Append messageId to childrenIds of parent message - if (parentId !== null) { - history.messages[parentId].childrenIds = [ - ...history.messages[parentId].childrenIds, - responseMessageId - ]; - } + // Add message to history and Set currentId to messageId + history.messages[responseMessageId] = responseMessage; + history.currentId = responseMessageId; - await tick(); + // Append messageId to childrenIds of parent message + if (parentId !== null) { + history.messages[parentId].childrenIds = [ + ...history.messages[parentId].childrenIds, + responseMessageId + ]; + } - let userContext = null; - if ($settings?.memory ?? false) { - if (userContext === null) { - const res = await queryMemory(localStorage.token, prompt).catch((error) => { - toast.error(error); - return null; - }); + await tick(); - if (res) { - if (res.documents[0].length > 0) { - userContext = res.documents.reduce((acc, doc, index) => { - const createdAtTimestamp = res.metadatas[index][0].created_at; - const createdAtDate = new Date(createdAtTimestamp * 1000) - .toISOString() - .split('T')[0]; - acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`); - return acc; - }, []); - } + let userContext = null; + if ($settings?.memory ?? false) { + if (userContext === null) { + const res = await queryMemory(localStorage.token, prompt).catch((error) => { + toast.error(error); + return null; + }); - console.log(userContext); + if (res) { + if (res.documents[0].length > 0) { + userContext = res.documents.reduce((acc, doc, index) => { + const createdAtTimestamp = res.metadatas[index][0].created_at; + const createdAtDate = new Date(createdAtTimestamp * 1000) + .toISOString() + .split('T')[0]; + acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`); + return acc; + }, []); } + + console.log(userContext); } } - responseMessage.userContext = userContext; - - if (model?.external) { - await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); - } else if (model) { - await sendPromptOllama(model, prompt, responseMessageId, _chatId); - } - } else { - toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); } + responseMessage.userContext = userContext; + + if (model?.external) { + await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); + } else if (model) { + await sendPromptOllama(model, prompt, responseMessageId, _chatId); + } + } else { + toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); } - ) + }) ); await chats.set(await getChatList(localStorage.token)); @@ -855,7 +872,7 @@ responseMessage.error = true; responseMessage.content = $i18n.t(`Uh-oh! There was an issue connecting to {{provider}}.`, { - provider: model.name ?? model.id + provider: model.custom_info?.name ?? model.name ?? model.id }) + '\n' + errorMessage; @@ -1049,6 +1066,7 @@ bind:prompt bind:autoScroll bind:selectedModel={atSelectedModel} + {selectedModels} {messages} {submitPrompt} {stopResponse} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index acf797cd1..3f7250c4a 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -1,7 +1,7 @@ @@ -587,24 +666,28 @@ viewBox="0 0 24 24" fill="currentColor" xmlns="http://www.w3.org/2000/svg" - > + + + /> + {:else} m.size != null && (selectedOllamaUrlIdx === null ? true : (m?.urls ?? []).includes(selectedOllamaUrlIdx))) as model} {(model.custom_info?.name ?? model.name) + + ' (' + + (model.size / 1024 ** 3).toFixed(1) + + ' GB)'} {/each} @@ -833,24 +919,28 @@ viewBox="0 0 24 24" fill="currentColor" xmlns="http://www.w3.org/2000/svg" - > + /> + {:else} {/if} +
@@ -1126,6 +1217,146 @@ {/if}
+
+
+ +
+
+
+
+
+
{$i18n.t('Manage Model Information')}
+ +
+
+ + {#if showModelInfo} +
+
+
{$i18n.t('Current Models')}
+
+ +
+
+ +
+ +
+ + {#if selectedModelId} +
+
{$i18n.t('Model Display Name')}
+
+
+ +
+ + +
+
+ +
+
{$i18n.t('Model Description')}
+ +
+
+