diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 6a355038b..2b771d5c6 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -18,8 +18,9 @@ import requests from pydantic import BaseModel, ConfigDict from typing import Optional, List +from apps.web.models.models import Models from utils.utils import get_verified_user, get_current_user, get_admin_user -from config import SRC_LOG_LEVELS, ENV +from config import SRC_LOG_LEVELS from constants import MESSAGES import os @@ -77,7 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file: app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value - +app.state.MODEL_CONFIG = Models.get_all_models() app.state.ENABLE = ENABLE_LITELLM app.state.CONFIG = litellm_config @@ -261,6 +262,14 @@ async def get_models(user=Depends(get_current_user)): "object": "model", "created": int(time.time()), "owned_by": "openai", + "custom_info": next( + ( + item + for item in app.state.MODEL_CONFIG + if item.id == model["model_name"] + ), + None, + ), } for model in app.state.CONFIG["model_list"] ], diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index fb8a35a17..7288f3467 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -29,7 +29,7 @@ import time from urllib.parse import urlparse from typing import Optional, List, Union - +from apps.web.models.models import Models from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( @@ -39,6 +39,8 @@ from utils.utils import ( get_admin_user, ) +from utils.models import get_model_id_from_custom_model_id + from config import ( SRC_LOG_LEVELS, @@ -68,7 +70,6 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.MODELS = {} @@ -875,14 +876,93 @@ async def generate_chat_completion( user=Depends(get_verified_user), ): + log.debug( + "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( + form_data.model_dump_json(exclude_none=True).encode() + ) + ) + + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["options"] = {} + + payload["options"]["mirostat"] = model_info.params.get("mirostat", None) + payload["options"]["mirostat_eta"] = model_info.params.get( + "mirostat_eta", None + ) + payload["options"]["mirostat_tau"] = model_info.params.get( + "mirostat_tau", None + ) + payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + + payload["options"]["repeat_last_n"] = model_info.params.get( + "repeat_last_n", None + ) + payload["options"]["repeat_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + + payload["options"]["temperature"] = model_info.params.get( + "temperature", None + ) + payload["options"]["seed"] = model_info.params.get("seed", None) + + payload["options"]["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) + + payload["options"]["num_predict"] = model_info.params.get( + "max_tokens", None + ) + payload["options"]["top_k"] = model_info.params.get("top_k", None) + + payload["options"]["top_p"] = model_info.params.get("top_p", None) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + if url_idx == None: - model = form_data.model + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -892,16 +972,12 @@ async def generate_chat_completion( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + print(payload) + r = None - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) - def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -910,7 +986,7 @@ async def generate_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream", None): yield json.dumps({"id": request_id, "done": False}) + "\n" for chunk in r.iter_content(chunk_size=8192): @@ -928,7 +1004,7 @@ async def generate_chat_completion( r = requests.request( method="POST", url=f"{url}/api/chat", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) @@ -984,14 +1060,62 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + if url_idx == None: - model = form_data.model + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -1004,7 +1128,7 @@ async def generate_openai_chat_completion( r = None def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -1013,7 +1137,7 @@ async def generate_openai_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream"): yield json.dumps( {"request_id": request_id, "done": False} ) + "\n" @@ -1033,7 +1157,7 @@ async def generate_openai_chat_completion( r = requests.request( method="POST", url=f"{url}/v1/chat/completions", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 6659ebfcf..ded7e6626 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -10,7 +10,7 @@ import logging from pydantic import BaseModel - +from apps.web.models.models import Models from apps.web.models.users import Users from constants import ERROR_MESSAGES from utils.utils import ( @@ -53,7 +53,6 @@ app.state.config = AppConfig() app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST - app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS @@ -206,7 +205,13 @@ def merge_models_lists(model_lists): if models is not None and "error" not in models: merged_list.extend( [ - {**model, "urlIdx": idx} + { + **model, + "name": model.get("name", model["id"]), + "owned_by": "openai", + "openai": model, + "urlIdx": idx, + } for model in models if "api.openai.com" not in app.state.config.OPENAI_API_BASE_URLS[idx] @@ -252,7 +257,7 @@ async def get_all_models(): log.info(f"models: {models}") app.state.MODELS = {model["id"]: model for model in models["data"]} - return models + return models @app.get("/models") @@ -310,39 +315,93 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() # TODO: Remove below after gpt-4-vision fix from Open AI # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + + payload = None + try: - body = body.decode("utf-8") - body = json.loads(body) + if "chat/completions" in path: + body = body.decode("utf-8") + body = json.loads(body) - model = app.state.MODELS[body.get("model")] + payload = {**body} - idx = model["urlIdx"] + model_id = body.get("model") + model_info = Models.get_model_by_id(model_id) - if "pipeline" in model and model.get("pipeline"): - body["user"] = {"name": user.name, "id": user.id} - body["title"] = ( - True if body["stream"] == False and body["max_tokens"] == 50 else False - ) + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if body.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in body: - body["max_tokens"] = 4000 - log.debug("Modified body_dict:", body) + model_info.params = model_info.params.model_dump() - # Fix for ChatGPT calls failing because the num_ctx key is in body - if "num_ctx" in body: - # If 'num_ctx' is in the dictionary, delete it - # Leaving it there generates an error with the - # OpenAI API (Feb 2024) - del body["num_ctx"] + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + payload["stop"] = ( + [ + bytes(stop, "utf-8").decode("unicode_escape") + for stop in model_info.params["stop"] + ] + if model_info.params.get("stop", None) + else None + ) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + else: + pass + + print(app.state.MODELS) + model = app.state.MODELS[payload.get("model")] + + idx = model["urlIdx"] + + if "pipeline" in model and model.get("pipeline"): + payload["user"] = {"name": user.name, "id": user.id} + payload["title"] = ( + True + if payload["stream"] == False and payload["max_tokens"] == 50 + else False + ) + + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if payload.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in payload: + payload["max_tokens"] = 4000 + log.debug("Modified payload:", payload) + + # Convert the modified body back to JSON + payload = json.dumps(payload) - # Convert the modified body back to JSON - body = json.dumps(body) except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) + print(payload) + url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] @@ -361,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): r = requests.request( method=request.method, url=target_url, - data=body, + data=payload if payload else body, headers=headers, stream=True, ) diff --git a/backend/apps/web/internal/db.py b/backend/apps/web/internal/db.py index a6051de50..c8011460c 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/web/internal/db.py @@ -1,3 +1,5 @@ +import json + from peewee import * from peewee_migrate import Router from playhouse.db_url import connect @@ -8,6 +10,16 @@ import logging log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["DB"]) + +class JSONField(TextField): + def db_value(self, value): + return json.dumps(value) + + def python_value(self, value): + if value is not None: + return json.loads(value) + + # Check if the file exists if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file diff --git a/backend/apps/web/internal/migrations/009_add_models.py b/backend/apps/web/internal/migrations/009_add_models.py new file mode 100644 index 000000000..548ec7cdc --- /dev/null +++ b/backend/apps/web/internal/migrations/009_add_models.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Model(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + base_model_id = pw.TextField(null=True) + + name = pw.TextField() + + meta = pw.TextField() + params = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "model" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("model") diff --git a/backend/apps/web/internal/migrations/010_migrate_modelfiles_to_models.py b/backend/apps/web/internal/migrations/010_migrate_modelfiles_to_models.py new file mode 100644 index 000000000..2ef814c06 --- /dev/null +++ b/backend/apps/web/internal/migrations/010_migrate_modelfiles_to_models.py @@ -0,0 +1,130 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator +import json + +from utils.misc import parse_ollama_modelfile + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # Fetch data from 'modelfile' table and insert into 'model' table + migrate_modelfile_to_model(migrator, database) + # Drop the 'modelfile' table + migrator.remove_model("modelfile") + + +def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): + ModelFile = migrator.orm["modelfile"] + Model = migrator.orm["model"] + + modelfiles = ModelFile.select() + + for modelfile in modelfiles: + # Extract and transform data in Python + + modelfile.modelfile = json.loads(modelfile.modelfile) + meta = json.dumps( + { + "description": modelfile.modelfile.get("desc"), + "profile_image_url": modelfile.modelfile.get("imageUrl"), + "ollama": {"modelfile": modelfile.modelfile.get("content")}, + "suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), + "categories": modelfile.modelfile.get("categories"), + "user": {**modelfile.modelfile.get("user", {}), "community": True}, + } + ) + + info = parse_ollama_modelfile(modelfile.modelfile.get("content")) + + # Insert the processed data into the 'model' table + Model.create( + id=f"ollama-{modelfile.tag_name}", + user_id=modelfile.user_id, + base_model_id=info.get("base_model_id"), + name=modelfile.modelfile.get("title"), + meta=meta, + params=json.dumps(info.get("params", {})), + created_at=modelfile.timestamp, + updated_at=modelfile.timestamp, + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + recreate_modelfile_table(migrator, database) + move_data_back_to_modelfile(migrator, database) + migrator.remove_model("model") + + +def recreate_modelfile_table(migrator: Migrator, database: pw.Database): + query = """ + CREATE TABLE IF NOT EXISTS modelfile ( + user_id TEXT, + tag_name TEXT, + modelfile JSON, + timestamp BIGINT + ) + """ + migrator.sql(query) + + +def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): + Model = migrator.orm["model"] + Modelfile = migrator.orm["modelfile"] + + models = Model.select() + + for model in models: + # Extract and transform data in Python + meta = json.loads(model.meta) + + modelfile_data = { + "title": model.name, + "desc": meta.get("description"), + "imageUrl": meta.get("profile_image_url"), + "content": meta.get("ollama", {}).get("modelfile"), + "suggestionPrompts": meta.get("suggestion_prompts"), + "categories": meta.get("categories"), + "user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, + } + + # Insert the processed data back into the 'modelfile' table + Modelfile.create( + user_id=model.user_id, + tag_name=model.id, + modelfile=modelfile_data, + timestamp=model.created_at, + ) diff --git a/backend/apps/web/main.py b/backend/apps/web/main.py index 2b6966381..9704cde77 100644 --- a/backend/apps/web/main.py +++ b/backend/apps/web/main.py @@ -6,7 +6,7 @@ from apps.web.routers import ( users, chats, documents, - modelfiles, + models, prompts, configs, memories, @@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL + + +app.state.MODELS = {} app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER @@ -56,11 +59,10 @@ app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) -app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) +app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(memories.router, prefix="/memories", tags=["memories"]) - app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(utils.router, prefix="/utils", tags=["utils"]) diff --git a/backend/apps/web/models/modelfiles.py b/backend/apps/web/models/modelfiles.py index 1d60d7c55..fe278ed5f 100644 --- a/backend/apps/web/models/modelfiles.py +++ b/backend/apps/web/models/modelfiles.py @@ -1,3 +1,11 @@ +################################################################################ +# DEPRECATION NOTICE # +# # +# This file has been deprecated since version 0.2.0. # +# # +################################################################################ + + from pydantic import BaseModel from peewee import * from playhouse.shortcuts import model_to_dict diff --git a/backend/apps/web/models/models.py b/backend/apps/web/models/models.py new file mode 100644 index 000000000..bf835c8fd --- /dev/null +++ b/backend/apps/web/models/models.py @@ -0,0 +1,179 @@ +import json +import logging +from typing import Optional + +import peewee as pw +from peewee import * + +from playhouse.shortcuts import model_to_dict +from pydantic import BaseModel, ConfigDict + +from apps.web.internal.db import DB, JSONField + +from typing import List, Union, Optional +from config import SRC_LOG_LEVELS + +import time + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + + +#################### +# Models DB Schema +#################### + + +# ModelParams is a model for the data stored in the params field of the Model table +class ModelParams(BaseModel): + model_config = ConfigDict(extra="allow") + pass + + +# ModelMeta is a model for the data stored in the meta field of the Model table +class ModelMeta(BaseModel): + profile_image_url: Optional[str] = "/favicon.png" + + description: Optional[str] = None + """ + User-facing description of the model. + """ + + capabilities: Optional[dict] = None + + model_config = ConfigDict(extra="allow") + + pass + + +class Model(pw.Model): + id = pw.TextField(unique=True) + """ + The model's id as used in the API. If set to an existing model, it will override the model. + """ + user_id = pw.TextField() + + base_model_id = pw.TextField(null=True) + """ + An optional pointer to the actual model that should be used when proxying requests. + """ + + name = pw.TextField() + """ + The human-readable display name of the model. + """ + + params = JSONField() + """ + Holds a JSON encoded blob of parameters, see `ModelParams`. + """ + + meta = JSONField() + """ + Holds a JSON encoded blob of metadata, see `ModelMeta`. + """ + + updated_at = BigIntegerField() + created_at = BigIntegerField() + + class Meta: + database = DB + + +class ModelModel(BaseModel): + id: str + user_id: str + base_model_id: Optional[str] = None + + name: str + params: ModelParams + meta: ModelMeta + + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +#################### +# Forms +#################### + + +class ModelResponse(BaseModel): + id: str + name: str + meta: ModelMeta + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + +class ModelForm(BaseModel): + id: str + base_model_id: Optional[str] = None + name: str + meta: ModelMeta + params: ModelParams + + +class ModelsTable: + def __init__( + self, + db: pw.SqliteDatabase | pw.PostgresqlDatabase, + ): + self.db = db + self.db.create_tables([Model]) + + def insert_new_model( + self, form_data: ModelForm, user_id: str + ) -> Optional[ModelModel]: + model = ModelModel( + **{ + **form_data.model_dump(), + "user_id": user_id, + "created_at": int(time.time()), + "updated_at": int(time.time()), + } + ) + try: + result = Model.create(**model.model_dump()) + + if result: + return model + else: + return None + except Exception as e: + print(e) + return None + + def get_all_models(self) -> List[ModelModel]: + return [ModelModel(**model_to_dict(model)) for model in Model.select()] + + def get_model_by_id(self, id: str) -> Optional[ModelModel]: + try: + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except: + return None + + def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: + try: + # update only the fields that are present in the model + query = Model.update(**model.model_dump()).where(Model.id == id) + query.execute() + + model = Model.get(Model.id == id) + return ModelModel(**model_to_dict(model)) + except Exception as e: + print(e) + + return None + + def delete_model_by_id(self, id: str) -> bool: + try: + query = Model.delete().where(Model.id == id) + query.execute() + return True + except: + return False + + +Models = ModelsTable(DB) diff --git a/backend/apps/web/routers/modelfiles.py b/backend/apps/web/routers/modelfiles.py deleted file mode 100644 index 3cdbf8a74..000000000 --- a/backend/apps/web/routers/modelfiles.py +++ /dev/null @@ -1,124 +0,0 @@ -from fastapi import Depends, FastAPI, HTTPException, status -from datetime import datetime, timedelta -from typing import List, Union, Optional - -from fastapi import APIRouter -from pydantic import BaseModel -import json -from apps.web.models.modelfiles import ( - Modelfiles, - ModelfileForm, - ModelfileTagNameForm, - ModelfileUpdateForm, - ModelfileResponse, -) - -from utils.utils import get_current_user, get_admin_user -from constants import ERROR_MESSAGES - -router = APIRouter() - -############################ -# GetModelfiles -############################ - - -@router.get("/", response_model=List[ModelfileResponse]) -async def get_modelfiles( - skip: int = 0, limit: int = 50, user=Depends(get_current_user) -): - return Modelfiles.get_modelfiles(skip, limit) - - -############################ -# CreateNewModelfile -############################ - - -@router.post("/create", response_model=Optional[ModelfileResponse]) -async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)): - modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.DEFAULT(), - ) - - -############################ -# GetModelfileByTagName -############################ - - -@router.post("/", response_model=Optional[ModelfileResponse]) -async def get_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_current_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - - if modelfile: - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - -############################ -# UpdateModelfileByTagName -############################ - - -@router.post("/update", response_model=Optional[ModelfileResponse]) -async def update_modelfile_by_tag_name( - form_data: ModelfileUpdateForm, user=Depends(get_admin_user) -): - modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) - if modelfile: - updated_modelfile = { - **json.loads(modelfile.modelfile), - **form_data.modelfile, - } - - modelfile = Modelfiles.update_modelfile_by_tag_name( - form_data.tag_name, updated_modelfile - ) - - return ModelfileResponse( - **{ - **modelfile.model_dump(), - "modelfile": json.loads(modelfile.modelfile), - } - ) - else: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, - ) - - -############################ -# DeleteModelfileByTagName -############################ - - -@router.delete("/delete", response_model=bool) -async def delete_modelfile_by_tag_name( - form_data: ModelfileTagNameForm, user=Depends(get_admin_user) -): - result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) - return result diff --git a/backend/apps/web/routers/models.py b/backend/apps/web/routers/models.py new file mode 100644 index 000000000..654d0d2fb --- /dev/null +++ b/backend/apps/web/routers/models.py @@ -0,0 +1,108 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json +from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse + +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +router = APIRouter() + +########################### +# getModels +########################### + + +@router.get("/", response_model=List[ModelResponse]) +async def get_models(user=Depends(get_verified_user)): + return Models.get_all_models() + + +############################ +# AddNewModel +############################ + + +@router.post("/add", response_model=Optional[ModelModel]) +async def add_new_model( + request: Request, form_data: ModelForm, user=Depends(get_admin_user) +): + if form_data.id in request.app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.MODEL_ID_TAKEN, + ) + else: + model = Models.insert_new_model(form_data, user.id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +############################ +# GetModelById +############################ + + +@router.get("/{id}", response_model=Optional[ModelModel]) +async def get_model_by_id(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateModelById +############################ + + +@router.post("/{id}/update", response_model=Optional[ModelModel]) +async def update_model_by_id( + request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user) +): + model = Models.get_model_by_id(id) + if model: + model = Models.update_model_by_id(id, form_data) + return model + else: + if form_data.id in request.app.state.MODELS: + model = Models.insert_new_model(form_data, user.id) + print(model) + if model: + return model + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.DEFAULT(), + ) + + +############################ +# DeleteModelById +############################ + + +@router.delete("/{id}/delete", response_model=bool) +async def delete_model_by_id(id: str, user=Depends(get_admin_user)): + result = Models.delete_model_by_id(id) + return result diff --git a/backend/constants.py b/backend/constants.py index be4d135b2..86875d2df 100644 --- a/backend/constants.py +++ b/backend/constants.py @@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum): COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." + MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." + NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." diff --git a/backend/main.py b/backend/main.py index df79a3106..aa3004865 100644 --- a/backend/main.py +++ b/backend/main.py @@ -19,8 +19,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import StreamingResponse, Response -from apps.ollama.main import app as ollama_app -from apps.openai.main import app as openai_app +from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models +from apps.openai.main import app as openai_app, get_all_models as get_openai_models from apps.litellm.main import ( app as litellm_app, @@ -36,10 +36,10 @@ from apps.web.main import app as webui_app import asyncio from pydantic import BaseModel -from typing import List +from typing import List, Optional - -from utils.utils import get_admin_user +from apps.web.models.models import Models, ModelModel +from utils.utils import get_admin_user, get_verified_user from apps.rag.utils import rag_messages from config import ( @@ -53,6 +53,8 @@ from config import ( FRONTEND_BUILD_DIR, CACHE_DIR, STATIC_DIR, + ENABLE_OPENAI_API, + ENABLE_OLLAMA_API, ENABLE_LITELLM, ENABLE_MODEL_FILTER, MODEL_FILTER_LIST, @@ -110,11 +112,19 @@ app = FastAPI( ) app.state.config = AppConfig() + +app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API +app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API + app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST + app.state.config.WEBHOOK_URL = WEBHOOK_URL + +app.state.MODELS = {} + origins = ["*"] @@ -231,6 +241,11 @@ app.add_middleware( @app.middleware("http") async def check_url(request: Request, call_next): + if len(app.state.MODELS) == 0: + await get_all_models() + else: + pass + start_time = int(time.time()) response = await call_next(request) process_time = int(time.time()) - start_time @@ -247,9 +262,11 @@ async def update_embedding_function(request: Request, call_next): return response +# TODO: Deprecate LiteLLM app.mount("/litellm/api", litellm_app) + app.mount("/ollama", ollama_app) -app.mount("/openai/api", openai_app) +app.mount("/openai", openai_app) app.mount("/images/api/v1", images_app) app.mount("/audio/api/v1", audio_app) @@ -260,6 +277,87 @@ app.mount("/api/v1", webui_app) webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION +async def get_all_models(): + openai_models = [] + ollama_models = [] + + if app.state.config.ENABLE_OPENAI_API: + openai_models = await get_openai_models() + + openai_models = openai_models["data"] + + if app.state.config.ENABLE_OLLAMA_API: + ollama_models = await get_ollama_models() + + ollama_models = [ + { + "id": model["model"], + "name": model["name"], + "object": "model", + "created": int(time.time()), + "owned_by": "ollama", + "ollama": model, + } + for model in ollama_models["models"] + ] + + models = openai_models + ollama_models + custom_models = Models.get_all_models() + + for custom_model in custom_models: + if custom_model.base_model_id == None: + for model in models: + if ( + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] + ): + model["name"] = custom_model.name + model["info"] = custom_model.model_dump() + else: + owned_by = "openai" + for model in models: + if ( + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] + ): + owned_by = model["owned_by"] + break + + models.append( + { + "id": custom_model.id, + "name": custom_model.name, + "object": "model", + "created": custom_model.created_at, + "owned_by": owned_by, + "info": custom_model.model_dump(), + "preset": True, + } + ) + + app.state.MODELS = {model["id"]: model for model in models} + + webui_app.state.MODELS = app.state.MODELS + + return models + + +@app.get("/api/models") +async def get_models(user=Depends(get_verified_user)): + models = await get_all_models() + if app.state.config.ENABLE_MODEL_FILTER: + if user.role == "user": + models = list( + filter( + lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, + models, + ) + ) + return {"data": models} + + return {"data": models} + + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5efff4a35..fca941263 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -1,5 +1,6 @@ from pathlib import Path import hashlib +import json import re from datetime import timedelta from typing import Optional @@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]: total_duration += timedelta(weeks=number) return total_duration + + +def parse_ollama_modelfile(model_text): + parameters_meta = { + "mirostat": int, + "mirostat_eta": float, + "mirostat_tau": float, + "num_ctx": int, + "repeat_last_n": int, + "repeat_penalty": float, + "temperature": float, + "seed": int, + "stop": str, + "tfs_z": float, + "num_predict": int, + "top_k": int, + "top_p": float, + } + + data = {"base_model_id": None, "params": {}} + + # Parse base model + base_model_match = re.search( + r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE + ) + if base_model_match: + data["base_model_id"] = base_model_match.group(1) + + # Parse template + template_match = re.search( + r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE + ) + if template_match: + data["params"] = {"template": template_match.group(1).strip()} + + # Parse stops + stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE) + if stops: + data["params"]["stop"] = stops + + # Parse other parameters from the provided list + for param, param_type in parameters_meta.items(): + param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE) + if param_match: + value = param_match.group(1) + if param_type == int: + value = int(value) + elif param_type == float: + value = float(value) + data["params"][param] = value + + # Parse adapter + adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE) + if adapter_match: + data["params"]["adapter"] = adapter_match.group(1) + + # Parse system description + system_desc_match = re.search( + r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE + ) + if system_desc_match: + data["params"]["system"] = system_desc_match.group(1).strip() + + # Parse messages + messages = [] + message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE) + for role, content in message_matches: + messages.append({"role": role, "content": content}) + + if messages: + data["params"]["messages"] = messages + + return data diff --git a/backend/utils/models.py b/backend/utils/models.py new file mode 100644 index 000000000..7a57b4fdb --- /dev/null +++ b/backend/utils/models.py @@ -0,0 +1,10 @@ +from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse + + +def get_model_id_from_custom_model_id(id: str): + model = Models.get_model_by_id(id) + + if model: + return model.id + else: + return id diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index a610f7210..5d94e7678 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,5 +1,54 @@ import { WEBUI_BASE_URL } from '$lib/constants'; +export const getModels = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/models`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + let models = res?.data ?? []; + + models = models + .filter((models) => models) + .sort((a, b) => { + // Compare case-insensitively + const lowerA = a.name.toLowerCase(); + const lowerB = b.name.toLowerCase(); + + if (lowerA < lowerB) return -1; + if (lowerA > lowerB) return 1; + + // If same case-insensitively, sort by original strings, + // lowercase will come before uppercase due to ASCII values + if (a < b) return -1; + if (a > b) return 1; + + return 0; // They are equal + }); + + console.log(models); + return models; +}; + export const getBackendConfig = async () => { let error = null; @@ -196,3 +245,77 @@ export const updateWebhookUrl = async (token: string, url: string) => { return res.url; }; + +export const getModelConfig = async (token: string): Promise => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res.models; +}; + +export interface ModelConfig { + id: string; + name: string; + meta: ModelMeta; + base_model_id?: string; + params: ModelParams; +} + +export interface ModelMeta { + description?: string; + capabilities?: object; +} + +export interface ModelParams {} + +export type GlobalModelConfig = ModelConfig[]; + +export const updateModelConfig = async (token: string, config: GlobalModelConfig) => { + let error = null; + + const res = await fetch(`${WEBUI_BASE_URL}/api/config/models`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + models: config + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/apis/litellm/index.ts b/src/lib/apis/litellm/index.ts index 643146b73..b1c24c5bd 100644 --- a/src/lib/apis/litellm/index.ts +++ b/src/lib/apis/litellm/index.ts @@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => { id: model.id, name: model.name ?? model.id, external: true, - source: 'LiteLLM' + source: 'LiteLLM', + custom_info: model.custom_info })) .sort((a, b) => { return a.name.localeCompare(b.name); diff --git a/src/lib/apis/modelfiles/index.ts b/src/lib/apis/models/index.ts similarity index 66% rename from src/lib/apis/modelfiles/index.ts rename to src/lib/apis/models/index.ts index 91af5e381..092926583 100644 --- a/src/lib/apis/modelfiles/index.ts +++ b/src/lib/apis/models/index.ts @@ -1,18 +1,16 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; -export const createNewModelfile = async (token: string, modelfile: object) => { +export const addNewModel = async (token: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/create`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/add`, { method: 'POST', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` }, - body: JSON.stringify({ - modelfile: modelfile - }) + body: JSON.stringify(model) }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => { return res; }; -export const getModelfiles = async (token: string = '') => { +export const getModelInfos = async (token: string = '') => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/`, { method: 'GET', headers: { Accept: 'application/json', @@ -59,62 +57,19 @@ export const getModelfiles = async (token: string = '') => { throw error; } - return res.map((modelfile) => modelfile.modelfile); + return res; }; -export const getModelfileByTagName = async (token: string, tagName: string) => { +export const getModelById = async (token: string, id: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, { - method: 'POST', + const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}`, { + method: 'GET', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) - }) - .then(async (res) => { - if (!res.ok) throw await res.json(); - return res.json(); - }) - .then((json) => { - return json; - }) - .catch((err) => { - error = err; - - console.log(err); - return null; - }); - - if (error) { - throw error; - } - - return res.modelfile; -}; - -export const updateModelfileByTagName = async ( - token: string, - tagName: string, - modelfile: object -) => { - let error = null; - - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/update`, { - method: 'POST', - headers: { - Accept: 'application/json', - 'Content-Type': 'application/json', - authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName, - modelfile: modelfile - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); @@ -137,19 +92,49 @@ export const updateModelfileByTagName = async ( return res; }; -export const deleteModelfileByTagName = async (token: string, tagName: string) => { +export const updateModelById = async (token: string, id: string, model: object) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/delete`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify(model) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteModelById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/models/${id}/delete`, { method: 'DELETE', headers: { Accept: 'application/json', 'Content-Type': 'application/json', authorization: `Bearer ${token}` - }, - body: JSON.stringify({ - tag_name: tagName - }) + } }) .then(async (res) => { if (!res.ok) throw await res.json(); diff --git a/src/lib/apis/openai/index.ts b/src/lib/apis/openai/index.ts index 02281eff0..8afcec018 100644 --- a/src/lib/apis/openai/index.ts +++ b/src/lib/apis/openai/index.ts @@ -230,7 +230,12 @@ export const getOpenAIModels = async (token: string = '') => { return models ? models - .map((model) => ({ id: model.id, name: model.name ?? model.id, external: true })) + .map((model) => ({ + id: model.id, + name: model.name ?? model.id, + external: true, + custom_info: model.custom_info + })) .sort((a, b) => { return a.name.localeCompare(b.name); }) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 871647101..ff025868c 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -10,7 +10,7 @@ chatId, chats, config, - modelfiles, + type Model, models, settings, showSidebar, @@ -60,25 +60,7 @@ let showModelSelector = true; let selectedModels = ['']; - let atSelectedModel = ''; - - let selectedModelfile = null; - $: selectedModelfile = - selectedModels.length === 1 && - $modelfiles.filter((modelfile) => modelfile.tagName === selectedModels[0]).length > 0 - ? $modelfiles.filter((modelfile) => modelfile.tagName === selectedModels[0])[0] - : null; - - let selectedModelfiles = {}; - $: selectedModelfiles = selectedModels.reduce((a, tagName, i, arr) => { - const modelfile = - $modelfiles.filter((modelfile) => modelfile.tagName === tagName)?.at(0) ?? undefined; - - return { - ...a, - ...(modelfile && { [tagName]: modelfile }) - }; - }, {}); + let atSelectedModel: Model | undefined; let chat = null; let tags = []; @@ -164,6 +146,7 @@ if ($page.url.searchParams.get('q')) { prompt = $page.url.searchParams.get('q') ?? ''; + if (prompt) { await tick(); submitPrompt(prompt); @@ -211,7 +194,7 @@ await settings.set({ ..._settings, system: chatContent.system ?? _settings.system, - options: chatContent.options ?? _settings.options + params: chatContent.options ?? _settings.params }); autoScroll = true; await tick(); @@ -300,7 +283,7 @@ models: selectedModels, system: $settings.system ?? undefined, options: { - ...($settings.options ?? {}) + ...($settings.params ?? {}) }, messages: messages, history: history, @@ -317,6 +300,7 @@ // Reset chat input textarea prompt = ''; + document.getElementById('chat-textarea').style.height = ''; files = []; // Send prompt @@ -328,75 +312,92 @@ const _chatId = JSON.parse(JSON.stringify($chatId)); await Promise.all( - (modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map( - async (modelId) => { - console.log('modelId', modelId); - const model = $models.filter((m) => m.id === modelId).at(0); + (modelId + ? [modelId] + : atSelectedModel !== undefined + ? [atSelectedModel.id] + : selectedModels + ).map(async (modelId) => { + console.log('modelId', modelId); + const model = $models.filter((m) => m.id === modelId).at(0); - if (model) { - // Create response message - let responseMessageId = uuidv4(); - let responseMessage = { - parentId: parentId, - id: responseMessageId, - childrenIds: [], - role: 'assistant', - content: '', - model: model.id, - userContext: null, - timestamp: Math.floor(Date.now() / 1000) // Unix epoch - }; + if (model) { + // If there are image files, check if model is vision capable + const hasImages = messages.some((message) => + message.files?.some((file) => file.type === 'image') + ); - // Add message to history and Set currentId to messageId - history.messages[responseMessageId] = responseMessage; - history.currentId = responseMessageId; + if (hasImages && !(model.info?.meta?.capabilities?.vision ?? true)) { + toast.error( + $i18n.t('Model {{modelName}} is not vision capable', { + modelName: model.name ?? model.id + }) + ); + } - // Append messageId to childrenIds of parent message - if (parentId !== null) { - history.messages[parentId].childrenIds = [ - ...history.messages[parentId].childrenIds, - responseMessageId - ]; - } + // Create response message + let responseMessageId = uuidv4(); + let responseMessage = { + parentId: parentId, + id: responseMessageId, + childrenIds: [], + role: 'assistant', + content: '', + model: model.id, + modelName: model.name ?? model.id, + userContext: null, + timestamp: Math.floor(Date.now() / 1000) // Unix epoch + }; - await tick(); + // Add message to history and Set currentId to messageId + history.messages[responseMessageId] = responseMessage; + history.currentId = responseMessageId; - let userContext = null; - if ($settings?.memory ?? false) { - if (userContext === null) { - const res = await queryMemory(localStorage.token, prompt).catch((error) => { - toast.error(error); - return null; - }); + // Append messageId to childrenIds of parent message + if (parentId !== null) { + history.messages[parentId].childrenIds = [ + ...history.messages[parentId].childrenIds, + responseMessageId + ]; + } - if (res) { - if (res.documents[0].length > 0) { - userContext = res.documents.reduce((acc, doc, index) => { - const createdAtTimestamp = res.metadatas[index][0].created_at; - const createdAtDate = new Date(createdAtTimestamp * 1000) - .toISOString() - .split('T')[0]; - acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`); - return acc; - }, []); - } + await tick(); - console.log(userContext); + let userContext = null; + if ($settings?.memory ?? false) { + if (userContext === null) { + const res = await queryMemory(localStorage.token, prompt).catch((error) => { + toast.error(error); + return null; + }); + + if (res) { + if (res.documents[0].length > 0) { + userContext = res.documents.reduce((acc, doc, index) => { + const createdAtTimestamp = res.metadatas[index][0].created_at; + const createdAtDate = new Date(createdAtTimestamp * 1000) + .toISOString() + .split('T')[0]; + acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`); + return acc; + }, []); } + + console.log(userContext); } } - responseMessage.userContext = userContext; - - if (model?.external) { - await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); - } else if (model) { - await sendPromptOllama(model, prompt, responseMessageId, _chatId); - } - } else { - toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); } + responseMessage.userContext = userContext; + + if (model?.owned_by === 'openai') { + await sendPromptOpenAI(model, prompt, responseMessageId, _chatId); + } else if (model) { + await sendPromptOllama(model, prompt, responseMessageId, _chatId); + } + } else { + toast.error($i18n.t(`Model {{modelId}} not found`, { modelId })); } - ) + }) ); await chats.set(await getChatList(localStorage.token)); @@ -430,7 +431,7 @@ // Prepare the base message object const baseMessage = { role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + content: message.content }; // Extract and format image URLs if any exist @@ -442,7 +443,6 @@ if (imageUrls && imageUrls.length > 0 && message.role === 'user') { baseMessage.images = imageUrls; } - return baseMessage; }); @@ -473,13 +473,15 @@ model: model, messages: messagesBody, options: { - ...($settings.options ?? {}), + ...($settings.params ?? {}), stop: - $settings?.options?.stop ?? undefined - ? $settings.options.stop.map((str) => + $settings?.params?.stop ?? undefined + ? $settings.params.stop.map((str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) ) - : undefined + : undefined, + num_predict: $settings?.params?.max_tokens ?? undefined, + repeat_penalty: $settings?.params?.frequency_penalty ?? undefined }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, @@ -605,7 +607,8 @@ if ($settings.saveChatHistory ?? true) { chat = await updateChatById(localStorage.token, _chatId, { messages: messages, - history: history + history: history, + models: selectedModels }); await chats.set(await getChatList(localStorage.token)); } @@ -716,18 +719,17 @@ : message?.raContent ?? message.content }) })), - seed: $settings?.options?.seed ?? undefined, + seed: $settings?.params?.seed ?? undefined, stop: - $settings?.options?.stop ?? undefined - ? $settings.options.stop.map((str) => + $settings?.params?.stop ?? undefined + ? $settings.params.stop.map((str) => decodeURIComponent(JSON.parse('"' + str.replace(/\"/g, '\\"') + '"')) ) : undefined, - temperature: $settings?.options?.temperature ?? undefined, - top_p: $settings?.options?.top_p ?? undefined, - num_ctx: $settings?.options?.num_ctx ?? undefined, - frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined, + temperature: $settings?.params?.temperature ?? undefined, + top_p: $settings?.params?.top_p ?? undefined, + frequency_penalty: $settings?.params?.frequency_penalty ?? undefined, + max_tokens: $settings?.params?.max_tokens ?? undefined, docs: docs.length > 0 ? docs : undefined, citations: docs.length > 0 }, @@ -797,6 +799,7 @@ if ($chatId == _chatId) { if ($settings.saveChatHistory ?? true) { chat = await updateChatById(localStorage.token, _chatId, { + models: selectedModels, messages: messages, history: history }); @@ -935,10 +938,8 @@ ) + ' {{prompt}}', titleModelId, userPrompt, - titleModel?.external ?? false - ? titleModel?.source?.toLowerCase() === 'litellm' - ? `${LITELLM_API_BASE_URL}/v1` - : `${OPENAI_API_BASE_URL}` + titleModel?.owned_by === 'openai' ?? false + ? `${OPENAI_API_BASE_URL}` : `${OLLAMA_API_BASE_URL}/v1` ); @@ -1025,16 +1026,12 @@ 0} - suggestionPrompts={chatIdProp - ? [] - : selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions} {sendPrompt} {continueGeneration} {regenerateResponse} @@ -1048,7 +1045,8 @@ bind:files bind:prompt bind:autoScroll - bind:selectedModel={atSelectedModel} + bind:atSelectedModel + {selectedModels} {messages} {submitPrompt} {stopResponse} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index acf797cd1..afff1217d 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -1,7 +1,7 @@ -{#if code} -
-
-
{@html lang}
+
+
+
{@html lang}
-
- {#if lang === 'python' || (lang === '' && checkPythonCode(code))} - {#if executing} -
Running
- {:else} - - {/if} +
+ {#if lang === 'python' || (lang === '' && checkPythonCode(code))} + {#if executing} +
Running
+ {:else} + {/if} - -
+ {/if} +
- -
{@html highlightedCode || code}
- -
- - {#if executing} -
-
STDOUT/STDERR
-
Running...
-
- {:else if stdout || stderr || result} -
-
STDOUT/STDERR
-
{stdout || stderr || result}
-
- {/if}
-{/if} + +
{@html highlightedCode || code}
+ +
+ + {#if executing} +
+
STDOUT/STDERR
+
Running...
+
+ {:else if stdout || stderr || result} +
+
STDOUT/STDERR
+
{stdout || stderr || result}
+
+ {/if} +
diff --git a/src/lib/components/chat/Messages/CompareMessages.svelte b/src/lib/components/chat/Messages/CompareMessages.svelte index 60efdb2ab..f904a57ab 100644 --- a/src/lib/components/chat/Messages/CompareMessages.svelte +++ b/src/lib/components/chat/Messages/CompareMessages.svelte @@ -13,8 +13,6 @@ export let parentMessage; - export let selectedModelfiles; - export let updateChatMessages: Function; export let confirmEditResponseMessage: Function; export let rateMessage: Function; @@ -130,7 +128,6 @@ > m.id)} isLastMessage={true} {updateChatMessages} diff --git a/src/lib/components/chat/Messages/Placeholder.svelte b/src/lib/components/chat/Messages/Placeholder.svelte index dfb6cfb36..ed121dbe6 100644 --- a/src/lib/components/chat/Messages/Placeholder.svelte +++ b/src/lib/components/chat/Messages/Placeholder.svelte @@ -1,6 +1,6 @@ - -
-
-
{$i18n.t('Parameters')}
- - -
- -
-
-
{$i18n.t('Keep Alive')}
- - -
- - {#if keepAlive !== null} -
- -
- {/if} -
- -
-
-
{$i18n.t('Request Mode')}
- - -
-
-
- -
- -
-
diff --git a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte index 6eaf82da8..93c482711 100644 --- a/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte +++ b/src/lib/components/chat/Settings/Advanced/AdvancedParams.svelte @@ -1,14 +1,16 @@ -
-
-
-
{$i18n.t('Seed')}
-
- -
+
+
+
+
{$i18n.t('Seed')}
+ +
+ + {#if (params?.seed ?? null) !== null} +
+
+ +
+
+ {/if}
-
-
-
{$i18n.t('Stop Sequence')}
-
- -
+
+
+
{$i18n.t('Stop Sequence')}
+ +
+ + {#if (params?.stop ?? null) !== null} +
+
+ +
+
+ {/if}
@@ -61,10 +109,10 @@ class="p-1 px-3 text-xs flex rounded transition" type="button" on:click={() => { - options.temperature = options.temperature === '' ? 0.8 : ''; + params.temperature = (params?.temperature ?? '') === '' ? 0.8 : ''; }} > - {#if options.temperature === ''} + {#if (params?.temperature ?? '') === ''} {$i18n.t('Default')} {:else} {$i18n.t('Custom')} @@ -72,7 +120,7 @@
- {#if options.temperature !== ''} + {#if (params?.temperature ?? '') !== ''}
{ - options.mirostat = options.mirostat === '' ? 0 : ''; + params.mirostat = (params?.mirostat ?? '') === '' ? 0 : ''; }} > - {#if options.mirostat === ''} + {#if (params?.mirostat ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.mirostat !== ''} + {#if (params?.mirostat ?? '') !== ''}
{ - options.mirostat_eta = options.mirostat_eta === '' ? 0.1 : ''; + params.mirostat_eta = (params?.mirostat_eta ?? '') === '' ? 0.1 : ''; }} > - {#if options.mirostat_eta === ''} + {#if (params?.mirostat_eta ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.mirostat_eta !== ''} + {#if (params?.mirostat_eta ?? '') !== ''}
{ - options.mirostat_tau = options.mirostat_tau === '' ? 5.0 : ''; + params.mirostat_tau = (params?.mirostat_tau ?? '') === '' ? 5.0 : ''; }} > - {#if options.mirostat_tau === ''} + {#if (params?.mirostat_tau ?? '') === ''} {$i18n.t('Default')} {:else} {$i18n.t('Custom')} @@ -210,7 +258,7 @@
- {#if options.mirostat_tau !== ''} + {#if (params?.mirostat_tau ?? '') !== ''}
{ - options.top_k = options.top_k === '' ? 40 : ''; + params.top_k = (params?.top_k ?? '') === '' ? 40 : ''; }} > - {#if options.top_k === ''} + {#if (params?.top_k ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.top_k !== ''} + {#if (params?.top_k ?? '') !== ''}
{ - options.top_p = options.top_p === '' ? 0.9 : ''; + params.top_p = (params?.top_p ?? '') === '' ? 0.9 : ''; }} > - {#if options.top_p === ''} + {#if (params?.top_p ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.top_p !== ''} + {#if (params?.top_p ?? '') !== ''}
-
{$i18n.t('Repeat Penalty')}
+
{$i18n.t('Frequencey Penalty')}
- {#if options.repeat_penalty !== ''} + {#if (params?.frequency_penalty ?? '') !== ''}
{ - options.repeat_last_n = options.repeat_last_n === '' ? 64 : ''; + params.repeat_last_n = (params?.repeat_last_n ?? '') === '' ? 64 : ''; }} > - {#if options.repeat_last_n === ''} + {#if (params?.repeat_last_n ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.repeat_last_n !== ''} + {#if (params?.repeat_last_n ?? '') !== ''}
{ - options.tfs_z = options.tfs_z === '' ? 1 : ''; + params.tfs_z = (params?.tfs_z ?? '') === '' ? 1 : ''; }} > - {#if options.tfs_z === ''} + {#if (params?.tfs_z ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.tfs_z !== ''} + {#if (params?.tfs_z ?? '') !== ''}
{ - options.num_ctx = options.num_ctx === '' ? 2048 : ''; + params.num_ctx = (params?.num_ctx ?? '') === '' ? 2048 : ''; }} > - {#if options.num_ctx === ''} + {#if (params?.num_ctx ?? '') === ''} {$i18n.t('Default')} {:else} - {$i18n.t('Default')} + {$i18n.t('Custom')} {/if}
- {#if options.num_ctx !== ''} + {#if (params?.num_ctx ?? '') !== ''}
-
{$i18n.t('Max Tokens')}
+
{$i18n.t('Max Tokens (num_predict)')}
- {#if options.num_predict !== ''} + {#if (params?.max_tokens ?? '') !== ''}
{/if}
+
+
+
{$i18n.t('Template')}
+ + +
+ + {#if (params?.template ?? null) !== null} +
+
+