From f68aba687e14cc0b539a01ee7665746def64bd01 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 00:37:02 -0700 Subject: [PATCH 01/22] feat: functions router --- backend/apps/webui/main.py | 2 +- backend/apps/webui/models/functions.py | 4 +- backend/apps/webui/routers/functions.py | 180 ++++++++++++++++++++++++ backend/apps/webui/utils.py | 24 +++- backend/config.py | 8 ++ 5 files changed, 214 insertions(+), 4 deletions(-) create mode 100644 backend/apps/webui/routers/functions.py diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index bdc6ec4f4..ee5957224 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -60,7 +60,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.MODELS = {} app.state.TOOLS = {} - +app.state.FUNCTIONS = {} app.add_middleware( CORSMiddleware, diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index cd877434d..ac12ab9e3 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -69,7 +69,7 @@ class FunctionForm(BaseModel): meta: FunctionMeta -class ToolsTable: +class FunctionsTable: def __init__(self, db): self.db = db self.db.create_tables([Function]) @@ -137,4 +137,4 @@ class ToolsTable: return False -Tools = ToolsTable(DB) +Functions = FunctionsTable(DB) diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py new file mode 100644 index 000000000..1021cc10a --- /dev/null +++ b/backend/apps/webui/routers/functions.py @@ -0,0 +1,180 @@ +from fastapi import Depends, FastAPI, HTTPException, status, Request +from datetime import datetime, timedelta +from typing import List, Union, Optional + +from fastapi import APIRouter +from pydantic import BaseModel +import json + +from apps.webui.models.functions import ( + Functions, + FunctionForm, + FunctionModel, + FunctionResponse, +) +from apps.webui.utils import load_function_module_by_id +from utils.utils import get_verified_user, get_admin_user +from constants import ERROR_MESSAGES + +from importlib import util +import os +from pathlib import Path + +from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR + + +router = APIRouter() + +############################ +# GetFunctions +############################ + + +@router.get("/", response_model=List[FunctionResponse]) +async def get_functions(user=Depends(get_verified_user)): + return Functions.get_functions() + + +############################ +# ExportFunctions +############################ + + +@router.get("/export", response_model=List[FunctionModel]) +async def get_functions(user=Depends(get_admin_user)): + return Functions.get_functions() + + +############################ +# CreateNewFunction +############################ + + +@router.post("/create", response_model=Optional[FunctionResponse]) +async def create_new_function( + request: Request, form_data: FunctionForm, user=Depends(get_admin_user) +): + if not form_data.id.isidentifier(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only alphanumeric characters and underscores are allowed in the id", + ) + + form_data.id = form_data.id.lower() + + function = Functions.get_function_by_id(form_data.id) + if function == None: + function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") + try: + with open(function_path, "w") as function_file: + function_file.write(form_data.content) + + function_module = load_function_module_by_id(form_data.id) + + FUNCTIONS = request.app.state.FUNCTIONS + FUNCTIONS[form_data.id] = function_module + + function = Functions.insert_new_function(user.id, form_data) + + function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id + function_cache_dir.mkdir(parents=True, exist_ok=True) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error creating function"), + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.ID_TAKEN, + ) + + +############################ +# GetFunctionById +############################ + + +@router.get("/id/{id}", response_model=Optional[FunctionModel]) +async def get_function_by_id(id: str, user=Depends(get_admin_user)): + function = Functions.get_function_by_id(id) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# UpdateFunctionById +############################ + + +@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) +async def update_toolkit_by_id( + request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) +): + function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") + + try: + with open(function_path, "w") as function_file: + function_file.write(form_data.content) + + function_module = load_function_module_by_id(id) + + FUNCTIONS = request.app.state.FUNCTIONS + FUNCTIONS[id] = function_module + + updated = {**form_data.model_dump(exclude={"id"})} + print(updated) + + function = Functions.update_function_by_id(id, updated) + + if function: + return function + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error updating function"), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +############################ +# DeleteFunctionById +############################ + + +@router.delete("/id/{id}/delete", response_model=bool) +async def delete_function_by_id( + request: Request, id: str, user=Depends(get_admin_user) +): + result = Functions.delete_function_by_id(id) + + if result: + FUNCTIONS = request.app.state.FUNCTIONS + if id in FUNCTIONS: + del FUNCTIONS[id] + + # delete the function file + function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") + os.remove(function_path) + + return result diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index 19a8615bc..64d116f11 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -1,7 +1,7 @@ from importlib import util import os -from config import TOOLS_DIR +from config import TOOLS_DIR, FUNCTIONS_DIR def load_toolkit_module_by_id(toolkit_id): @@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id): # Move the file to the error folder os.rename(toolkit_path, f"{toolkit_path}.error") raise e + + +def load_function_module_by_id(function_id): + function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py") + + spec = util.spec_from_file_location(function_id, function_path) + module = util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + print(f"Loaded module: {module.__name__}") + if hasattr(module, "Pipe"): + return module.Pipe() + elif hasattr(module, "Filter"): + return module.Filter() + else: + raise Exception("No Function class found") + except Exception as e: + print(f"Error loading module: {function_id}") + # Move the file to the error folder + os.rename(function_path, f"{function_path}.error") + raise e diff --git a/backend/config.py b/backend/config.py index 01ce060a3..842cea1ba 100644 --- a/backend/config.py +++ b/backend/config.py @@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) +#################################### +# Functions DIR +#################################### + +FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") +Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) + + #################################### # LITELLM_CONFIG #################################### From 27f8afebabde4ded3e249a6f0d69607806be9c8e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 00:49:11 -0700 Subject: [PATCH 02/22] feat: function db migration --- .../internal/migrations/015_add_functions.py | 61 ++++++ backend/apps/webui/main.py | 12 +- src/lib/apis/functions/index.ts | 193 ++++++++++++++++++ src/lib/components/workspace/Functions.svelte | 122 ++++++----- 4 files changed, 334 insertions(+), 54 deletions(-) create mode 100644 backend/apps/webui/internal/migrations/015_add_functions.py create mode 100644 src/lib/apis/functions/index.ts diff --git a/backend/apps/webui/internal/migrations/015_add_functions.py b/backend/apps/webui/internal/migrations/015_add_functions.py new file mode 100644 index 000000000..8316a9333 --- /dev/null +++ b/backend/apps/webui/internal/migrations/015_add_functions.py @@ -0,0 +1,61 @@ +"""Peewee migrations -- 009_add_models.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class Function(pw.Model): + id = pw.TextField(unique=True) + user_id = pw.TextField() + + name = pw.TextField() + type = pw.TextField() + + content = pw.TextField() + meta = pw.TextField() + + created_at = pw.BigIntegerField(null=False) + updated_at = pw.BigIntegerField(null=False) + + class Meta: + table_name = "function" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model("function") diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index ee5957224..4a53b15bf 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -13,6 +13,7 @@ from apps.webui.routers import ( memories, utils, files, + functions, ) from config import ( WEBUI_BUILD_HASH, @@ -70,19 +71,22 @@ app.add_middleware( allow_headers=["*"], ) + +app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(documents.router, prefix="/documents", tags=["documents"]) -app.include_router(tools.router, prefix="/tools", tags=["tools"]) app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) -app.include_router(memories.router, prefix="/memories", tags=["memories"]) -app.include_router(configs.router, prefix="/configs", tags=["configs"]) -app.include_router(utils.router, prefix="/utils", tags=["utils"]) +app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(files.router, prefix="/files", tags=["files"]) +app.include_router(tools.router, prefix="/tools", tags=["tools"]) +app.include_router(functions.router, prefix="/functions", tags=["functions"]) + +app.include_router(utils.router, prefix="/utils", tags=["utils"]) @app.get("/") diff --git a/src/lib/apis/functions/index.ts b/src/lib/apis/functions/index.ts new file mode 100644 index 000000000..e035ef1c1 --- /dev/null +++ b/src/lib/apis/functions/index.ts @@ -0,0 +1,193 @@ +import { WEBUI_API_BASE_URL } from '$lib/constants'; + +export const createNewFunction = async (token: string, func: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...func + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFunctions = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const exportFunctions = async (token: string = '') => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getFunctionById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const updateFunctionById = async (token: string, id: string, func: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + ...func + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const deleteFunctionById = async (token: string, id: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, { + method: 'DELETE', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index f00e9ad2f..ebadce50c 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -3,29 +3,39 @@ import fileSaver from 'file-saver'; const { saveAs } = fileSaver; + import { WEBUI_NAME } from '$lib/stores'; import { onMount, getContext } from 'svelte'; - import { WEBUI_NAME, prompts, tools } from '$lib/stores'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { goto } from '$app/navigation'; import { - createNewTool, - deleteToolById, - exportTools, - getToolById, - getTools - } from '$lib/apis/tools'; + createNewFunction, + deleteFunctionById, + exportFunctions, + getFunctionById, + getFunctions + } from '$lib/apis/functions'; + import ArrowDownTray from '../icons/ArrowDownTray.svelte'; import Tooltip from '../common/Tooltip.svelte'; import ConfirmDialog from '../common/ConfirmDialog.svelte'; const i18n = getContext('i18n'); - let toolsImportInputElement: HTMLInputElement; + let functionsImportInputElement: HTMLInputElement; let importFiles; let showConfirm = false; let query = ''; + + let functions = []; + + onMount(async () => { + functions = await getFunctions(localStorage.token).catch((error) => { + toast.error(error); + return []; + }); + }); @@ -82,30 +92,30 @@
- {#each $tools.filter((t) => query === '' || t.name + {#each functions.filter((f) => query === '' || f.name .toLowerCase() - .includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool} + .includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
From 43e08c6afa00a536fc50bd81a667bc935e1d5fae Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 00:54:58 -0700 Subject: [PATCH 03/22] refac --- backend/apps/webui/models/functions.py | 4 +- backend/apps/webui/routers/functions.py | 8 +- backend/apps/webui/utils.py | 4 +- src/lib/components/workspace/Functions.svelte | 2 +- .../workspace/Functions/FunctionEditor.svelte | 281 ++++++++++++++++++ .../workspace/functions/create/+page.svelte | 38 ++- 6 files changed, 308 insertions(+), 29 deletions(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index ac12ab9e3..91fbdb769 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -64,7 +64,6 @@ class FunctionResponse(BaseModel): class FunctionForm(BaseModel): id: str name: str - type: str content: str meta: FunctionMeta @@ -75,12 +74,13 @@ class FunctionsTable: self.db.create_tables([Function]) def insert_new_function( - self, user_id: str, form_data: FunctionForm + self, user_id: str, type: str, form_data: FunctionForm ) -> Optional[FunctionModel]: function = FunctionModel( **{ **form_data.model_dump(), "user_id": user_id, + "type": type, "updated_at": int(time.time()), "created_at": int(time.time()), } diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py index 1021cc10a..ea5fde336 100644 --- a/backend/apps/webui/routers/functions.py +++ b/backend/apps/webui/routers/functions.py @@ -69,12 +69,12 @@ async def create_new_function( with open(function_path, "w") as function_file: function_file.write(form_data.content) - function_module = load_function_module_by_id(form_data.id) + function_module, function_type = load_function_module_by_id(form_data.id) FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[form_data.id] = function_module - function = Functions.insert_new_function(user.id, form_data) + function = Functions.insert_new_function(user.id, function_type, form_data) function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) @@ -132,12 +132,12 @@ async def update_toolkit_by_id( with open(function_path, "w") as function_file: function_file.write(form_data.content) - function_module = load_function_module_by_id(id) + function_module, function_type = load_function_module_by_id(id) FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS[id] = function_module - updated = {**form_data.model_dump(exclude={"id"})} + updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} print(updated) function = Functions.update_function_by_id(id, updated) diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py index 64d116f11..3e075a8a8 100644 --- a/backend/apps/webui/utils.py +++ b/backend/apps/webui/utils.py @@ -33,9 +33,9 @@ def load_function_module_by_id(function_id): spec.loader.exec_module(module) print(f"Loaded module: {module.__name__}") if hasattr(module, "Pipe"): - return module.Pipe() + return module.Pipe(), "pipe" elif hasattr(module, "Filter"): - return module.Filter() + return module.Filter(), "filter" else: raise Exception("No Function class found") except Exception as e: diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index ebadce50c..aeb53bde0 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -74,7 +74,7 @@
+ import { getContext, createEventDispatcher, onMount } from 'svelte'; + + const i18n = getContext('i18n'); + + import CodeEditor from '$lib/components/common/CodeEditor.svelte'; + import { goto } from '$app/navigation'; + import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; + + const dispatch = createEventDispatcher(); + + let formElement = null; + let loading = false; + let showConfirm = false; + + export let edit = false; + export let clone = false; + + export let id = ''; + export let name = ''; + export let meta = { + description: '' + }; + export let content = ''; + + $: if (name && !edit && !clone) { + id = name.replace(/\s+/g, '_').toLowerCase(); + } + + let codeEditor; + let boilerplate = `import os +import requests +from datetime import datetime + + +class Tools: + def __init__(self): + pass + + # Add your custom tools using pure Python code here, make sure to add type hints + # Use Sphinx-style docstrings to document your tools, they will be used for generating tools specifications + # Please refer to function_calling_filter_pipeline.py file from pipelines project for an example + + def get_user_name_and_email_and_id(self, __user__: dict = {}) -> str: + """ + Get the user name, Email and ID from the user object. + """ + + # Do not include :param for __user__ in the docstring as it should not be shown in the tool's specification + # The session user object will be passed as a parameter when the function is called + + print(__user__) + result = "" + + if "name" in __user__: + result += f"User: {__user__['name']}" + if "id" in __user__: + result += f" (ID: {__user__['id']})" + if "email" in __user__: + result += f" (Email: {__user__['email']})" + + if result == "": + result = "User: Unknown" + + return result + + def get_current_time(self) -> str: + """ + Get the current time in a more human-readable format. + :return: The current time. + """ + + now = datetime.now() + current_time = now.strftime("%I:%M:%S %p") # Using 12-hour format with AM/PM + current_date = now.strftime( + "%A, %B %d, %Y" + ) # Full weekday, month name, day, and year + + return f"Current Date and Time = {current_date}, {current_time}" + + def calculator(self, equation: str) -> str: + """ + Calculate the result of an equation. + :param equation: The equation to calculate. + """ + + # Avoid using eval in production code + # https://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html + try: + result = eval(equation) + return f"{equation} = {result}" + except Exception as e: + print(e) + return "Invalid equation" + + def get_current_weather(self, city: str) -> str: + """ + Get the current weather for a given city. + :param city: The name of the city to get the weather for. + :return: The current weather information or an error message. + """ + api_key = os.getenv("OPENWEATHER_API_KEY") + if not api_key: + return ( + "API key is not set in the environment variable 'OPENWEATHER_API_KEY'." + ) + + base_url = "http://api.openweathermap.org/data/2.5/weather" + params = { + "q": city, + "appid": api_key, + "units": "metric", # Optional: Use 'imperial' for Fahrenheit + } + + try: + response = requests.get(base_url, params=params) + response.raise_for_status() # Raise HTTPError for bad responses (4xx and 5xx) + data = response.json() + + if data.get("cod") != 200: + return f"Error fetching weather data: {data.get('message')}" + + weather_description = data["weather"][0]["description"] + temperature = data["main"]["temp"] + humidity = data["main"]["humidity"] + wind_speed = data["wind"]["speed"] + + return f"Weather in {city}: {temperature}°C" + except requests.RequestException as e: + return f"Error fetching weather data: {str(e)}" +`; + + const saveHandler = async () => { + loading = true; + dispatch('save', { + id, + name, + meta, + content + }); + }; + + const submitHandler = async () => { + if (codeEditor) { + const res = await codeEditor.formatPythonCodeHandler(); + + if (res) { + console.log('Code formatted successfully'); + saveHandler(); + } + } + }; + + +
+
+
{ + if (edit) { + submitHandler(); + } else { + showConfirm = true; + } + }} + > +
+ +
+ + +
+
+
+ + { + submitHandler(); + }} +> +
+
+
Please carefully review the following warnings:
+ +
    +
  • Tools have a function calling system that allows arbitrary code execution.
  • +
  • Do not install tools from sources you do not fully trust.
  • +
+
+ +
+ I acknowledge that I have read and I understand the implications of my action. I am aware of + the risks associated with executing arbitrary code and I have verified the trustworthiness of + the source. +
+
+
diff --git a/src/routes/(app)/workspace/functions/create/+page.svelte b/src/routes/(app)/workspace/functions/create/+page.svelte index c785c74cd..3b7dc270d 100644 --- a/src/routes/(app)/workspace/functions/create/+page.svelte +++ b/src/routes/(app)/workspace/functions/create/+page.svelte @@ -1,18 +1,18 @@ {#if mounted} - { saveHandler(e.detail); From 40cde07e5ce90cebed8e1cb5a066c355035e4386 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 01:07:55 -0700 Subject: [PATCH 04/22] feat: function filter example boilerplate --- .../workspace/Functions/FunctionEditor.svelte | 135 ++++-------------- .../workspace/functions/edit/+page.svelte | 40 +++--- 2 files changed, 49 insertions(+), 126 deletions(-) diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index 385a9ea68..d84865b93 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -1,14 +1,13 @@ -{#if tool} - { saveHandler(e.detail); }} From 9108df177c57a4e8f945c1b428585317e962180e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 01:12:09 -0700 Subject: [PATCH 05/22] refac: comments --- src/lib/components/workspace/Functions/FunctionEditor.svelte | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index d84865b93..210e8db2b 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -36,6 +36,8 @@ class Filter: pass def inlet(self, body: dict, user: Optional[dict] = None) -> dict: + # This method is invoked before the request is sent to the chat completion API. + # It can be used to modify the request body or perform validation checks. print("inlet") print(body) print(user) @@ -50,9 +52,12 @@ class Filter: return body def outlet(self, body: dict, user: Optional[dict] = None) -> dict: + # This method is invoked after the chat completion API has processed + # the request and generated a response. It can be used to overwrite the response messages. print(f"outlet") print(body) print(user) + return body`; const saveHandler = async () => { From bf5775e07a635c9bfe6491a0db5ebc905c2adc61 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 01:16:31 -0700 Subject: [PATCH 06/22] refac --- backend/apps/webui/models/functions.py | 1 + src/lib/components/workspace/Functions.svelte | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py index 91fbdb769..f5fab34db 100644 --- a/backend/apps/webui/models/functions.py +++ b/backend/apps/webui/models/functions.py @@ -55,6 +55,7 @@ class FunctionModel(BaseModel): class FunctionResponse(BaseModel): id: str user_id: str + type: str name: str meta: FunctionMeta updated_at: int # timestamp in epoch diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index aeb53bde0..15793fa01 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -35,6 +35,8 @@ toast.error(error); return []; }); + + console.log(functions); }); @@ -107,15 +109,25 @@ href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`} class="flex items-center text-left" > -
+
+
+ {func.type} +
+
{func.name}
-
{func.id}
-
- {func.meta.description} + +
+
{func.id}
+ +
+ {func.meta.description} +
From 08cc20cb935924dd34476c8e21b654be01b7c39c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 01:44:52 -0700 Subject: [PATCH 07/22] feat: filter selector model --- src/lib/components/workspace/Functions.svelte | 27 ++------- .../workspace/Functions/FunctionEditor.svelte | 27 ++++++--- .../workspace/Models/FiltersSelector.svelte | 59 +++++++++++++++++++ src/lib/stores/index.ts | 2 + src/routes/(app)/workspace/+layout.svelte | 7 ++- .../workspace/functions/create/+page.svelte | 2 + .../workspace/functions/edit/+page.svelte | 2 + .../(app)/workspace/models/edit/+page.svelte | 23 +++++++- 8 files changed, 117 insertions(+), 32 deletions(-) create mode 100644 src/lib/components/workspace/Models/FiltersSelector.svelte diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index 15793fa01..8f0139009 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -3,7 +3,7 @@ import fileSaver from 'file-saver'; const { saveAs } = fileSaver; - import { WEBUI_NAME } from '$lib/stores'; + import { WEBUI_NAME, functions } from '$lib/stores'; import { onMount, getContext } from 'svelte'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; @@ -27,17 +27,6 @@ let showConfirm = false; let query = ''; - - let functions = []; - - onMount(async () => { - functions = await getFunctions(localStorage.token).catch((error) => { - toast.error(error); - return []; - }); - - console.log(functions); - }); @@ -94,7 +83,7 @@
- {#each functions.filter((f) => query === '' || f.name + {#each $functions.filter((f) => query === '' || f.name .toLowerCase() .includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
+
+ func.type === 'filter')} + /> +
+
{$i18n.t('Capabilities')}
From 448ca9d836566163a6e8037d55fefcffeae2247a Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 01:51:39 -0700 Subject: [PATCH 08/22] refac --- backend/main.py | 212 +++++++++++------- .../workspace/Functions/FunctionEditor.svelte | 9 +- 2 files changed, 132 insertions(+), 89 deletions(-) diff --git a/backend/main.py b/backend/main.py index 0a0587159..11c78645b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -170,6 +170,13 @@ app.state.MODELS = {} origins = ["*"] +################################## +# +# ChatCompletion Middleware +# +################################## + + async def get_function_call_response( messages, files, tool_id, template, task_model_id, user ): @@ -469,6 +476,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) +################################## +# +# Pipeline Middleware +# +################################## + def filter_pipeline(payload, user): user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} @@ -628,7 +641,6 @@ async def update_embedding_function(request: Request, call_next): app.mount("/ws", socket_app) - app.mount("/ollama", ollama_app) app.mount("/openai", openai_app) @@ -730,6 +742,104 @@ async def get_models(user=Depends(get_verified_user)): return {"data": models} +@app.post("/api/chat/completions") +async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): + model_id = form_data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + model = app.state.MODELS[model_id] + print(model) + + if model["owned_by"] == "ollama": + return await generate_ollama_chat_completion(form_data, user=user) + else: + return await generate_openai_chat_completion(form_data, user=user) + + +@app.post("/api/chat/completed") +async def chat_completed(form_data: dict, user=Depends(get_verified_user)): + data = form_data + model_id = data["model"] + + filters = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) + ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) + + print(model_id) + + if model_id in app.state.MODELS: + model = app.state.MODELS[model_id] + if "pipeline" in model: + sorted_filters = [model] + sorted_filters + + for filter in sorted_filters: + r = None + try: + urlIdx = filter["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/{filter['id']}/filter/outlet", + headers=headers, + json={ + "user": {"id": user.id, "name": user.name, "role": user.role}, + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + print(f"Connection error: {e}") + + if r is not None: + try: + res = r.json() + if "detail" in res: + return JSONResponse( + status_code=r.status_code, + content=res, + ) + except: + pass + + else: + pass + + return data + + +################################## +# +# Task Endpoints +# +################################## + + +# TODO: Refactor task API endpoints below into a separate file + + @app.get("/api/task/config") async def get_task_config(user=Depends(get_verified_user)): return { @@ -1015,92 +1125,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ) -@app.post("/api/chat/completions") -async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): - model_id = form_data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = app.state.MODELS[model_id] - print(model) - - if model["owned_by"] == "ollama": - return await generate_ollama_chat_completion(form_data, user=user) - else: - return await generate_openai_chat_completion(form_data, user=user) +################################## +# +# Pipelines Endpoints +# +################################## -@app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - data = form_data - model_id = data["model"] - - filters = [ - model - for model in app.state.MODELS.values() - if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) - ] - sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - print(model_id) - - if model_id in app.state.MODELS: - model = app.state.MODELS[model_id] - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": {"id": user.id, "name": user.name, "role": user.role}, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except: - pass - - else: - pass - - return data +# TODO: Refactor pipelines API endpoints below into a separate file @app.get("/api/pipelines/list") @@ -1423,6 +1455,13 @@ async def update_pipeline_valves( ) +################################## +# +# Config Endpoints +# +################################## + + @app.get("/api/config") async def get_app_config(): # Checking and Handling the Absence of 'ui' in CONFIG_DATA @@ -1486,6 +1525,9 @@ async def update_model_filter_config( } +# TODO: webhook endpoint should be under config endpoints + + @app.get("/api/webhook") async def get_webhook_url(user=Depends(get_admin_user)): return { diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index 12cb0386d..6e35616f2 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -30,9 +30,10 @@ let boilerplate = `from pydantic import BaseModel from typing import Optional + class Filter: class Valves(BaseModel): - max_turns: int + max_turns: int = 4 pass def __init__(self): @@ -42,14 +43,14 @@ class Filter: # Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings, # which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'. - self.valves = self.Valves(**{"max_turns": 10}) + self.valves = self.Valves(**{"max_turns": 2}) pass def inlet(self, body: dict, user: Optional[dict] = None) -> dict: # Modify the request body or validate it before processing by the chat completion API. # This function is the pre-processor for the API where various checks on the input can be performed. # It can also modify the request before sending it to the API. - + print("inlet") print(body) print(user) @@ -65,7 +66,7 @@ class Filter: def outlet(self, body: dict, user: Optional[dict] = None) -> dict: # Modify or analyze the response body after processing by the API. - # This function is the post-processor for the API, which can be used to modify the response + # This function is the post-processor for the API, which can be used to modify the response # or perform additional checks and analytics. print(f"outlet") print(body) From 6b8a7b993949f343fd4d1702b74218d9fe9b5c72 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 02:06:10 -0700 Subject: [PATCH 09/22] refac: chat completion middleware --- backend/main.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/backend/main.py b/backend/main.py index 11c78645b..febda4ced 100644 --- a/backend/main.py +++ b/backend/main.py @@ -316,7 +316,7 @@ async def get_function_call_response( class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - return_citations = False + data_items = [] if request.method == "POST" and ( "/ollama/api/chat" in request.url.path @@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): # Read the original request body body = await request.body() - # Decode body to string body_str = body.decode("utf-8") - # Parse string to JSON data = json.loads(body_str) if body_str else {} + model_id = data["model"] user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), ) - # Remove the citations from the body - return_citations = data.get("citations", False) - if "citations" in data: - del data["citations"] - # Set the task model - task_model_id = data["model"] + task_model_id = model_id if task_model_id not in app.state.MODELS: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL + skip_files = False prompt = get_last_user_message(data["messages"]) context = "" # If tool_ids field is present, call the functions - - skip_files = False if "tool_ids" in data: print(data["tool_ids"]) for tool_id in data["tool_ids"]: @@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): context += ("\n" if context != "" else "") + rag_context log.debug(f"rag_context: {rag_context}, citations: {citations}") - else: - return_citations = False + + if citations: + data_items.append({"citations": citations}) del data["files"] @@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) print(system_prompt) data["messages"] = add_or_update_system_message( - f"\n{system_prompt}", data["messages"] + system_prompt, data["messages"] ) modified_body_bytes = json.dumps(data).encode("utf-8") @@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): response = await call_next(request) - if return_citations: - # Inject the citations into the response + # If there are data_items to inject into the response + if len(data_items) > 0: if isinstance(response, StreamingResponse): # If it's a streaming response, inject it as SSE event or NDJSON line content_type = response.headers.get("Content-Type") if "text/event-stream" in content_type: return StreamingResponse( - self.openai_stream_wrapper(response.body_iterator, citations), + self.openai_stream_wrapper(response.body_iterator, data_items), ) if "application/x-ndjson" in content_type: return StreamingResponse( - self.ollama_stream_wrapper(response.body_iterator, citations), + self.ollama_stream_wrapper(response.body_iterator, data_items), ) return response @@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): async def _receive(self, body: bytes): return {"type": "http.request", "body": body, "more_body": False} - async def openai_stream_wrapper(self, original_generator, citations): - yield f"data: {json.dumps({'citations': citations})}\n\n" + async def openai_stream_wrapper(self, original_generator, data_items): + for item in data_items: + yield f"data: {json.dumps(item)}\n\n" + async for data in original_generator: yield data - async def ollama_stream_wrapper(self, original_generator, citations): - yield f"{json.dumps({'citations': citations})}\n" + async def ollama_stream_wrapper(self, original_generator, data_items): + for item in data_items: + yield f"{json.dumps(item)}\n" + async for data in original_generator: yield data From c4bd60114eb6996f650d5ce0ab4b3ac0f41d6606 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 02:30:00 -0700 Subject: [PATCH 10/22] feat: filter inlet support --- backend/main.py | 67 +++++++++++++++---- src/lib/components/chat/Chat.svelte | 5 +- .../workspace/Models/FiltersSelector.svelte | 1 + 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/backend/main.py b/backend/main.py index febda4ced..951bf9654 100644 --- a/backend/main.py +++ b/backend/main.py @@ -50,7 +50,9 @@ from typing import List, Optional from apps.webui.models.models import Models, ModelModel from apps.webui.models.tools import Tools -from apps.webui.utils import load_toolkit_module_by_id +from apps.webui.models.functions import Functions + +from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from utils.utils import ( @@ -318,9 +320,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): data_items = [] - if request.method == "POST" and ( - "/ollama/api/chat" in request.url.path - or "/chat/completions" in request.url.path + if request.method == "POST" and any( + endpoint in request.url.path + for endpoint in ["/ollama/api/chat", "/chat/completions"] ): log.debug(f"request.url.path: {request.url.path}") @@ -328,23 +330,62 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): body = await request.body() body_str = body.decode("utf-8") data = json.loads(body_str) if body_str else {} - - model_id = data["model"] user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), ) - # Set the task model - task_model_id = model_id - if task_model_id not in app.state.MODELS: + # Flag to skip RAG completions if file_handler is present in tools/functions + skip_files = False + + model_id = data["model"] + if model_id not in app.state.MODELS: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Model not found", ) + model = app.state.MODELS[model_id] - # Check if the user has a custom task model - # If the user has a custom task model, use that model + print(":", data) + + # Check if the model has any filters + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if getattr(function_module, "file_handler"): + skip_files = True + + try: + if hasattr(function_module, "inlet"): + data = function_module.inlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=e, + ) + + print("Filtered:", data) + # Set the task model + task_model_id = data["model"] + # Check if the user has a custom task model and use that model if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if ( app.state.config.TASK_MODEL @@ -358,7 +399,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL - skip_files = False prompt = get_last_user_message(data["messages"]) context = "" @@ -409,8 +449,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): log.debug(f"rag_context: {rag_context}, citations: {citations}") - if citations: + if citations and data.get("citations"): data_items.append({"citations": citations}) + del data["citations"] del data["files"] diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index b33b26fa3..1fae82415 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -630,7 +630,7 @@ keep_alive: $settings.keepAlive ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0, + citations: files.length > 0 ? true : undefined, chat_id: $chatId }); @@ -928,7 +928,8 @@ max_tokens: $settings?.params?.max_tokens ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0, + citations: files.length > 0 ? true : undefined, + chat_id: $chatId }, `${OPENAI_API_BASE_URL}` diff --git a/src/lib/components/workspace/Models/FiltersSelector.svelte b/src/lib/components/workspace/Models/FiltersSelector.svelte index 291bb8939..92f64c2cf 100644 --- a/src/lib/components/workspace/Models/FiltersSelector.svelte +++ b/src/lib/components/workspace/Models/FiltersSelector.svelte @@ -31,6 +31,7 @@ {$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
+
{#if filters.length > 0}
From 96d7c3e99fd9e6c5a70d73051576b89478ead098 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 02:37:36 -0700 Subject: [PATCH 11/22] fix: raise error --- backend/main.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/backend/main.py b/backend/main.py index 951bf9654..5f845877e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -330,11 +330,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): body = await request.body() body_str = body.decode("utf-8") data = json.loads(body_str) if body_str else {} + user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), ) - # Flag to skip RAG completions if file_handler is present in tools/functions skip_files = False @@ -346,8 +346,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) model = app.state.MODELS[model_id] - print(":", data) - # Check if the model has any filters for filter_id in model["info"]["meta"].get("filterIds", []): filter = Functions.get_function_by_id(filter_id) @@ -377,12 +375,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) except Exception as e: print(f"Error: {e}") - raise HTTPException( + return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - detail=e, + content={"detail": str(e)}, ) - print("Filtered:", data) # Set the task model task_model_id = data["model"] # Check if the user has a custom task model and use that model From f14ca48334718f3f187f3fbd3f3c5221d349f685 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 02:42:56 -0700 Subject: [PATCH 12/22] refac --- .../workspace/Functions/FunctionEditor.svelte | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index 6e35616f2..b2a9ea652 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -39,6 +39,7 @@ class Filter: def __init__(self): # Indicates custom file handling logic. This flag helps disengage default routines in favor of custom # implementations, informing the WebUI to defer file-related operations to designated methods within this class. + # Alternatively, you can remove the files directly from the body in from the inlet hook self.file_handler = True # Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings, @@ -50,16 +51,15 @@ class Filter: # Modify the request body or validate it before processing by the chat completion API. # This function is the pre-processor for the API where various checks on the input can be performed. # It can also modify the request before sending it to the API. + print(f"inlet:{__name__}") + print(f"inlet:body:{body}") + print(f"inlet:user:{user}") - print("inlet") - print(body) - print(user) - - if user.get("role", "admin") in ["user"]: + if user.get("role", "admin") in ["user", "admin"]: messages = body.get("messages", []) - if len(messages) > self.max_turns: + if len(messages) > self.valves.max_turns: raise Exception( - f"Conversation turn limit exceeded. Max turns: {self.max_turns}" + f"Conversation turn limit exceeded. Max turns: {self.valves.max_turns}" ) return body @@ -68,9 +68,9 @@ class Filter: # Modify or analyze the response body after processing by the API. # This function is the post-processor for the API, which can be used to modify the response # or perform additional checks and analytics. - print(f"outlet") - print(body) - print(user) + print(f"outlet:{__name__}") + print(f"outlet:body:{body}") + print(f"outlet:user:{user}") return body`; From 3101ff143b558b132176cb2362df5a3792c99e6f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 02:47:27 -0700 Subject: [PATCH 13/22] refac: disable continuing with error message --- src/lib/components/chat/Chat.svelte | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 1fae82415..a60aef51a 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -323,6 +323,13 @@ } else if (messages.length != 0 && messages.at(-1).done != true) { // Response not done console.log('wait'); + } else if (messages.length != 0 && messages.at(-1).error) { + // Error in response + toast.error( + $i18n.t( + `Oops! There was an error in the previous response. Please try again or contact admin.` + ) + ); } else if ( files.length > 0 && files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0 From afd270523c20af86138224bb2b20151bdb5984c0 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 03:23:50 -0700 Subject: [PATCH 14/22] feat: filter func outlet --- backend/main.py | 53 +++++++++++++++++++++++------ src/lib/components/chat/Chat.svelte | 4 ++- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/backend/main.py b/backend/main.py index 5f845877e..dade596a4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -474,10 +474,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ], ] - response = await call_next(request) - - # If there are data_items to inject into the response - if len(data_items) > 0: + response = await call_next(request) if isinstance(response, StreamingResponse): # If it's a streaming response, inject it as SSE event or NDJSON line content_type = response.headers.get("Content-Type") @@ -489,7 +486,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): return StreamingResponse( self.ollama_stream_wrapper(response.body_iterator, data_items), ) + else: + return response + # If it's not a chat completion request, just pass it through + response = await call_next(request) return response async def _receive(self, body: bytes): @@ -800,6 +801,12 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u async def chat_completed(form_data: dict, user=Depends(get_verified_user)): data = form_data model_id = data["model"] + if model_id not in app.state.MODELS: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + model = app.state.MODELS[model_id] filters = [ model @@ -815,14 +822,10 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ) ) ] + sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) - - print(model_id) - - if model_id in app.state.MODELS: - model = app.state.MODELS[model_id] - if "pipeline" in model: - sorted_filters = [model] + sorted_filters + if "pipeline" in model: + sorted_filters = [model] + sorted_filters for filter in sorted_filters: r = None @@ -863,6 +866,34 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): else: pass + # Check if the model has any filters + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id(filter_id) + webui_app.state.FUNCTIONS[filter_id] = function_module + + try: + if hasattr(function_module, "outlet"): + data = function_module.outlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, + ) + return data diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index a60aef51a..9cf2201fc 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -278,7 +278,9 @@ })), chat_id: $chatId }).catch((error) => { - console.error(error); + toast.error(error); + messages.at(-1).error = { content: error }; + return null; }); From a3f09949c0aef73bfaa8444aeb2775bc2fef501e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 03:29:50 -0700 Subject: [PATCH 15/22] refac --- .../components/workspace/Functions/FunctionEditor.svelte | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index b2a9ea652..9706bd65c 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -72,7 +72,13 @@ class Filter: print(f"outlet:body:{body}") print(f"outlet:user:{user}") - return body`; + messages = [ + {**message, "content": f"{message['content']} - @@Modified from Outlet"} + for message in body.get("messages", []) + ] + + return {"messages": messages} +`; const saveHandler = async () => { loading = true; From e20baad60112e2d909b7a032726e906b1c34c213 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 03:35:21 -0700 Subject: [PATCH 16/22] refac --- .../components/workspace/Functions/FunctionEditor.svelte | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lib/components/workspace/Functions/FunctionEditor.svelte b/src/lib/components/workspace/Functions/FunctionEditor.svelte index 9706bd65c..6e30013cc 100644 --- a/src/lib/components/workspace/Functions/FunctionEditor.svelte +++ b/src/lib/components/workspace/Functions/FunctionEditor.svelte @@ -73,11 +73,15 @@ class Filter: print(f"outlet:user:{user}") messages = [ - {**message, "content": f"{message['content']} - @@Modified from Outlet"} + { + **message, + "content": f"{message['content']} - @@Modified from Filter Outlet", + } for message in body.get("messages", []) ] return {"messages": messages} + `; const saveHandler = async () => { From 015772ef9ab2447e97855bff7ce73a8e83ad0dc2 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 03:45:13 -0700 Subject: [PATCH 17/22] refac --- backend/apps/openai/main.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index c60c52fad..302dd8d98 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -432,7 +432,12 @@ async def generate_chat_completion( idx = model["urlIdx"] if "pipeline" in model and model.get("pipeline"): - payload["user"] = {"name": user.name, "id": user.id} + payload["user"] = { + "name": user.name, + "id": user.id, + "email": user.email, + "role": user.role, + } # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # This is a workaround until OpenAI fixes the issue with this model From c689356b31bd032983c680223804d90d8cbd6602 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 03:57:36 -0700 Subject: [PATCH 18/22] refac --- backend/apps/ollama/main.py | 17 ++++------------- src/lib/components/chat/Chat.svelte | 2 +- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index 22a30474e..455dc89a5 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -53,7 +53,7 @@ from config import ( UPLOAD_DIR, AppConfig, ) -from utils.misc import calculate_sha256 +from utils.misc import calculate_sha256, add_or_update_system_message log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) @@ -834,18 +834,9 @@ async def generate_chat_completion( ) if payload.get("messages"): - for message in payload["messages"]: - if message.get("role") == "system": - message["content"] = system + message["content"] - break - else: - payload["messages"].insert( - 0, - { - "role": "system", - "content": system, - }, - ) + payload["messages"] = add_or_update_system_message( + system, payload["messages"] + ) if url_idx == None: if ":" not in payload["model"]: diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 9cf2201fc..d83eb3cb2 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -941,7 +941,7 @@ chat_id: $chatId }, - `${OPENAI_API_BASE_URL}` + `${WEBUI_BASE_URL}/api` ); // Wait until history/message have been updated From de26a78a16ff51312be4c1dbcfea5d6be4b0bbc7 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 04:21:55 -0700 Subject: [PATCH 19/22] refac --- backend/main.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/backend/main.py b/backend/main.py index dade596a4..3d95d1913 100644 --- a/backend/main.py +++ b/backend/main.py @@ -42,7 +42,7 @@ from apps.openai.main import ( from apps.audio.main import app as audio_app from apps.images.main import app as images_app from apps.rag.main import app as rag_app -from apps.webui.main import app as webui_app +from apps.webui.main import app as webui_app, get_pipe_models from pydantic import BaseModel @@ -448,10 +448,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if citations and data.get("citations"): data_items.append({"citations": citations}) - del data["citations"] del data["files"] + if data.get("citations"): + del data["citations"] + if context != "": system_prompt = rag_template( rag_app.state.config.RAG_TEMPLATE, context, prompt @@ -691,17 +693,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION async def get_all_models(): + pipe_models = [] openai_models = [] ollama_models = [] + pipe_models = await get_pipe_models() + if app.state.config.ENABLE_OPENAI_API: openai_models = await get_openai_models() - openai_models = openai_models["data"] if app.state.config.ENABLE_OLLAMA_API: ollama_models = await get_ollama_models() - ollama_models = [ { "id": model["model"], @@ -714,9 +717,9 @@ async def get_all_models(): for model in ollama_models["models"] ] - models = openai_models + ollama_models - custom_models = Models.get_all_models() + models = pipe_models + openai_models + ollama_models + custom_models = Models.get_all_models() for custom_model in custom_models: if custom_model.base_model_id == None: for model in models: @@ -791,6 +794,13 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u model = app.state.MODELS[model_id] print(model) + + + if model.get('pipe') == True: + print('hi') + + + if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: From d6e4aef607350ec2d54f7c46b5417e26bb17fc55 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 04:38:59 -0700 Subject: [PATCH 20/22] feat: pipe function --- backend/apps/webui/main.py | 58 ++++++++++ backend/main.py | 221 ++++++++++++++++++++++++++----------- backend/utils/misc.py | 19 ++++ 3 files changed, 234 insertions(+), 64 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 4a53b15bf..5ccb8ae58 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -15,6 +15,9 @@ from apps.webui.routers import ( files, functions, ) +from apps.webui.models.functions import Functions +from apps.webui.utils import load_function_module_by_id + from config import ( WEBUI_BUILD_HASH, SHOW_ADMIN_DETAILS, @@ -97,3 +100,58 @@ async def get_status(): "default_models": app.state.config.DEFAULT_MODELS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, } + + +async def get_pipe_models(): + pipes = Functions.get_functions_by_type("pipe") + pipe_models = [] + + for pipe in pipes: + # Check if function is already loaded + if pipe.id not in app.state.FUNCTIONS: + function_module, function_type = load_function_module_by_id(pipe.id) + app.state.FUNCTIONS[pipe.id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe.id] + + # Check if function is a manifold + if hasattr(function_module, "type"): + if function_module.type == "manifold": + manifold_pipes = [] + + # Check if pipes is a function or a list + if callable(pipe.pipes): + manifold_pipes = pipe.pipes() + else: + manifold_pipes = pipe.pipes + + for p in manifold_pipes: + manifold_pipe_id = f'{pipe.id}.{p["id"]}' + manifold_pipe_name = p["name"] + + if hasattr(pipe, "name"): + manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}" + + pipe_models.append( + { + "id": manifold_pipe_id, + "name": manifold_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": {"type": pipe.type}, + } + ) + else: + pipe_models.append( + { + "id": pipe.id, + "name": pipe.name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": {"type": "pipe"}, + } + ) + + return pipe_models diff --git a/backend/main.py b/backend/main.py index 3d95d1913..d6a8c8831 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,6 +15,7 @@ import uuid import inspect import asyncio +from fastapi.concurrency import run_in_threadpool from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse @@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models from pydantic import BaseModel -from typing import List, Optional +from typing import List, Optional, Iterator, Generator, Union from apps.webui.models.models import Models, ModelModel from apps.webui.models.tools import Tools @@ -66,7 +67,11 @@ from utils.task import ( search_query_generation_template, tools_function_calling_generation_template, ) -from utils.misc import get_last_user_message, add_or_update_system_message +from utils.misc import ( + get_last_user_message, + add_or_update_system_message, + stream_message_template, +) from apps.rag.utils import get_rag_context, rag_template @@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): model = app.state.MODELS[model_id] # Check if the model has any filters - for filter_id in model["info"]["meta"].get("filterIds", []): - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type = load_function_module_by_id( - filter_id - ) - webui_app.state.FUNCTIONS[filter_id] = function_module - - # Check if the function has a file_handler variable - if getattr(function_module, "file_handler"): - skip_files = True - - try: - if hasattr(function_module, "inlet"): - data = function_module.inlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + if "info" in model and "meta" in model["info"]: + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if getattr(function_module, "file_handler"): + skip_files = True + + try: + if hasattr(function_module, "inlet"): + data = function_module.inlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, ) - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) # Set the task model task_model_id = data["model"] @@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u model = app.state.MODELS[model_id] print(model) - + pipe = model.get("pipe") + if pipe: - if model.get('pipe') == True: - print('hi') - - - + def job(): + pipe_id = form_data["model"] + if "." in pipe_id: + pipe_id, sub_pipe_id = pipe_id.split(".", 1) + print(pipe_id) + + pipe = webui_app.state.FUNCTIONS[pipe_id].pipe + if form_data["stream"]: + + def stream_content(): + res = pipe(body=form_data) + + if isinstance(res, str): + message = stream_message_template(form_data["model"], res) + yield f"data: {json.dumps(message)}\n\n" + + if isinstance(res, Iterator): + for line in res: + if isinstance(line, BaseModel): + line = line.model_dump_json() + line = f"data: {line}" + try: + line = line.decode("utf-8") + except: + pass + + if line.startswith("data:"): + yield f"{line}\n\n" + else: + line = stream_message_template(form_data["model"], line) + yield f"data: {json.dumps(line)}\n\n" + + if isinstance(res, str) or isinstance(res, Generator): + finish_message = { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + yield f"data: {json.dumps(finish_message)}\n\n" + yield f"data: [DONE]" + + return StreamingResponse( + stream_content(), media_type="text/event-stream" + ) + else: + res = pipe(body=form_data) + + if isinstance(res, dict): + return res + elif isinstance(res, BaseModel): + return res.model_dump() + else: + message = "" + if isinstance(res, str): + message = res + if isinstance(res, Generator): + for stream in res: + message = f"{message}{stream}" + + return { + "id": f"{form_data['model']}-{str(uuid.uuid4())}", + "object": "chat.completion", + "created": int(time.time()), + "model": form_data["model"], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": message, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + } + + return await run_in_threadpool(job) if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: @@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): pass # Check if the model has any filters - for filter_id in model["info"]["meta"].get("filterIds", []): - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - try: - if hasattr(function_module, "outlet"): - data = function_module.outlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + if "info" in model and "meta" in model["info"]: + for filter_id in model["info"]["meta"].get("filterIds", []): + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type = load_function_module_by_id( + filter_id + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + try: + if hasattr(function_module, "outlet"): + data = function_module.outlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: + print(f"Error: {e}") + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, ) - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) return data diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 41fbdcc75..b4e499df8 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -4,6 +4,8 @@ import json import re from datetime import timedelta from typing import Optional, List, Tuple +import uuid +import time def get_last_user_message(messages: List[dict]) -> str: @@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]): return messages +def stream_message_template(model: str, message: str): + return { + "id": f"{model}-{str(uuid.uuid4())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": message}, + "logprobs": None, + "finish_reason": None, + } + ], + } + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters From 59fa2f8f26ca11d3089528dd01a479e59af77241 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 04:47:40 -0700 Subject: [PATCH 21/22] refac: pipe function support --- backend/apps/webui/main.py | 8 ++++---- backend/main.py | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 5ccb8ae58..ce58047ed 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -120,16 +120,16 @@ async def get_pipe_models(): manifold_pipes = [] # Check if pipes is a function or a list - if callable(pipe.pipes): - manifold_pipes = pipe.pipes() + if callable(function_module.pipes): + manifold_pipes = function_module.pipes() else: - manifold_pipes = pipe.pipes + manifold_pipes = function_module.pipes for p in manifold_pipes: manifold_pipe_id = f'{pipe.id}.{p["id"]}' manifold_pipe_name = p["name"] - if hasattr(pipe, "name"): + if hasattr(function_module, "name"): manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}" pipe_models.append( diff --git a/backend/main.py b/backend/main.py index d6a8c8831..47078b681 100644 --- a/backend/main.py +++ b/backend/main.py @@ -802,6 +802,12 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = model.get("pipe") if pipe: + form_data["user"] = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } def job(): pipe_id = form_data["model"] From 9ebd308d286094f7b317630af7fe1991fb14294b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 04:51:51 -0700 Subject: [PATCH 22/22] refac --- src/lib/components/workspace/Functions.svelte | 6 +++++- src/routes/(app)/workspace/functions/create/+page.svelte | 5 ++++- src/routes/(app)/workspace/functions/edit/+page.svelte | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/lib/components/workspace/Functions.svelte b/src/lib/components/workspace/Functions.svelte index 8f0139009..35e308220 100644 --- a/src/lib/components/workspace/Functions.svelte +++ b/src/lib/components/workspace/Functions.svelte @@ -3,7 +3,7 @@ import fileSaver from 'file-saver'; const { saveAs } = fileSaver; - import { WEBUI_NAME, functions } from '$lib/stores'; + import { WEBUI_NAME, functions, models } from '$lib/stores'; import { onMount, getContext } from 'svelte'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; @@ -19,6 +19,7 @@ import ArrowDownTray from '../icons/ArrowDownTray.svelte'; import Tooltip from '../common/Tooltip.svelte'; import ConfirmDialog from '../common/ConfirmDialog.svelte'; + import { getModels } from '$lib/apis'; const i18n = getContext('i18n'); @@ -226,7 +227,9 @@ if (res) { toast.success('Function deleted successfully'); + functions.set(await getFunctions(localStorage.token)); + models.set(await getModels(localStorage.token)); } }} > @@ -349,6 +352,7 @@ toast.success('Functions imported successfully'); functions.set(await getFunctions(localStorage.token)); + models.set(await getModels(localStorage.token)); }; reader.readAsText(importFiles[0]); diff --git a/src/routes/(app)/workspace/functions/create/+page.svelte b/src/routes/(app)/workspace/functions/create/+page.svelte index 5faf3d50b..0f73cf94e 100644 --- a/src/routes/(app)/workspace/functions/create/+page.svelte +++ b/src/routes/(app)/workspace/functions/create/+page.svelte @@ -3,9 +3,10 @@ import { onMount } from 'svelte'; import { goto } from '$app/navigation'; - import { functions } from '$lib/stores'; + import { functions, models } from '$lib/stores'; import { createNewFunction, getFunctions } from '$lib/apis/functions'; import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte'; + import { getModels } from '$lib/apis'; let mounted = false; let clone = false; @@ -26,6 +27,8 @@ if (res) { toast.success('Function created successfully'); functions.set(await getFunctions(localStorage.token)); + models.set(await getModels(localStorage.token)); + await goto('/workspace/functions'); } }; diff --git a/src/routes/(app)/workspace/functions/edit/+page.svelte b/src/routes/(app)/workspace/functions/edit/+page.svelte index a61dda142..21fc5acb6 100644 --- a/src/routes/(app)/workspace/functions/edit/+page.svelte +++ b/src/routes/(app)/workspace/functions/edit/+page.svelte @@ -4,11 +4,12 @@ import { goto } from '$app/navigation'; import { page } from '$app/stores'; - import { functions } from '$lib/stores'; + import { functions, models } from '$lib/stores'; import { updateFunctionById, getFunctions, getFunctionById } from '$lib/apis/functions'; import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte'; import Spinner from '$lib/components/common/Spinner.svelte'; + import { getModels } from '$lib/apis'; let func = null; @@ -27,6 +28,7 @@ if (res) { toast.success('Function updated successfully'); functions.set(await getFunctions(localStorage.token)); + models.set(await getModels(localStorage.token)); } };