Merge pull request #3321 from open-webui/functions

feat: functions
This commit is contained in:
Timothy Jaeryang Baek 2024-06-20 04:52:32 -07:00 committed by GitHub
commit 09a81eb225
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1365 additions and 248 deletions

View File

@ -53,7 +53,7 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@ -834,18 +834,9 @@ async def generate_chat_completion(
) )
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: payload["messages"] = add_or_update_system_message(
if message.get("role") == "system": system, payload["messages"]
message["content"] = system + message["content"] )
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
if url_idx == None: if url_idx == None:
if ":" not in payload["model"]: if ":" not in payload["model"]:

View File

@ -432,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"] idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"): if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id} payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model # This is a workaround until OpenAI fixes the issue with this model

View File

@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Function(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
type = pw.TextField()
content = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")

View File

@ -13,7 +13,11 @@ from apps.webui.routers import (
memories, memories,
utils, utils,
files, files,
functions,
) )
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from config import ( from config import (
WEBUI_BUILD_HASH, WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
@ -60,7 +64,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {} app.state.MODELS = {}
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -70,19 +74,22 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
app.include_router(files.router, prefix="/files", tags=["files"]) app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@app.get("/") @app.get("/")
@ -93,3 +100,58 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
} }
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": pipe.type},
}
)
else:
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": "pipe"},
}
)
return pipe_models

View File

@ -55,6 +55,7 @@ class FunctionModel(BaseModel):
class FunctionResponse(BaseModel): class FunctionResponse(BaseModel):
id: str id: str
user_id: str user_id: str
type: str
name: str name: str
meta: FunctionMeta meta: FunctionMeta
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
@ -64,23 +65,23 @@ class FunctionResponse(BaseModel):
class FunctionForm(BaseModel): class FunctionForm(BaseModel):
id: str id: str
name: str name: str
type: str
content: str content: str
meta: FunctionMeta meta: FunctionMeta
class ToolsTable: class FunctionsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Function]) self.db.create_tables([Function])
def insert_new_function( def insert_new_function(
self, user_id: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
"user_id": user_id, "user_id": user_id,
"type": type,
"updated_at": int(time.time()), "updated_at": int(time.time()),
"created_at": int(time.time()), "created_at": int(time.time()),
} }
@ -137,4 +138,4 @@ class ToolsTable:
return False return False
Tools = ToolsTable(DB) Functions = FunctionsTable(DB)

View File

@ -0,0 +1,180 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.functions import (
Functions,
FunctionForm,
FunctionModel,
FunctionResponse,
)
from apps.webui.utils import load_function_module_by_id
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
from importlib import util
import os
from pathlib import Path
from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
router = APIRouter()
############################
# GetFunctions
############################
@router.get("/", response_model=List[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
############################
# ExportFunctions
############################
@router.get("/export", response_model=List[FunctionModel])
async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
############################
# CreateNewFunction
############################
@router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
if function == None:
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(form_data.id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetFunctionById
############################
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionById
############################
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
function = Functions.update_function_by_id(id, updated)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteFunctionById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
result = Functions.delete_function_by_id(id)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
if id in FUNCTIONS:
del FUNCTIONS[id]
# delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path)
return result

View File

@ -1,7 +1,7 @@
from importlib import util from importlib import util
import os import os
from config import TOOLS_DIR from config import TOOLS_DIR, FUNCTIONS_DIR
def load_toolkit_module_by_id(toolkit_id): def load_toolkit_module_by_id(toolkit_id):
@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
# Move the file to the error folder # Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error") os.rename(toolkit_path, f"{toolkit_path}.error")
raise e raise e
def load_function_module_by_id(function_id):
function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
spec = util.spec_from_file_location(function_id, function_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipe"):
return module.Pipe(), "pipe"
elif hasattr(module, "Filter"):
return module.Filter(), "filter"
else:
raise Exception("No Function class found")
except Exception as e:
print(f"Error loading module: {function_id}")
# Move the file to the error folder
os.rename(function_path, f"{function_path}.error")
raise e

View File

@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# LITELLM_CONFIG # LITELLM_CONFIG
#################################### ####################################

View File

@ -15,6 +15,7 @@ import uuid
import inspect import inspect
import asyncio import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -42,15 +43,17 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app from apps.webui.main import app as webui_app, get_pipe_models
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.models.functions import Functions
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from utils.utils import ( from utils.utils import (
@ -64,7 +67,11 @@ from utils.task import (
search_query_generation_template, search_query_generation_template,
tools_function_calling_generation_template, tools_function_calling_generation_template,
) )
from utils.misc import get_last_user_message, add_or_update_system_message from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
)
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
@ -170,6 +177,13 @@ app.state.MODELS = {}
origins = ["*"] origins = ["*"]
##################################
#
# ChatCompletion Middleware
#
##################################
async def get_function_call_response( async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user messages, files, tool_id, template, task_model_id, user
): ):
@ -309,41 +323,72 @@ async def get_function_call_response(
class ChatCompletionMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False data_items = []
if request.method == "POST" and ( if request.method == "POST" and any(
"/ollama/api/chat" in request.url.path endpoint in request.url.path
or "/chat/completions" in request.url.path for endpoint in ["/ollama/api/chat", "/chat/completions"]
): ):
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
# Read the original request body # Read the original request body
body = await request.body() body = await request.body()
# Decode body to string
body_str = body.decode("utf-8") body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
user = get_current_user( user = get_current_user(
request, request,
get_http_authorization_cred(request.headers.get("Authorization")), get_http_authorization_cred(request.headers.get("Authorization")),
) )
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
# Remove the citations from the body model_id = data["model"]
return_citations = data.get("citations", False) if model_id not in app.state.MODELS:
if "citations" in data:
del data["citations"]
# Set the task model
task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id]
# Check if the user has a custom task model # Check if the model has any filters
# If the user has a custom task model, use that model if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if getattr(function_module, "file_handler"):
skip_files = True
try:
if hasattr(function_module, "inlet"):
data = function_module.inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model
task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if ( if (
app.state.config.TASK_MODEL app.state.config.TASK_MODEL
@ -361,8 +406,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context = "" context = ""
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
skip_files = False
if "tool_ids" in data: if "tool_ids" in data:
print(data["tool_ids"]) print(data["tool_ids"])
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
@ -408,18 +451,22 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
context += ("\n" if context != "" else "") + rag_context context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}") log.debug(f"rag_context: {rag_context}, citations: {citations}")
else:
return_citations = False if citations and data.get("citations"):
data_items.append({"citations": citations})
del data["files"] del data["files"]
if data.get("citations"):
del data["citations"]
if context != "": if context != "":
system_prompt = rag_template( system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt rag_app.state.config.RAG_TEMPLATE, context, prompt
) )
print(system_prompt) print(system_prompt)
data["messages"] = add_or_update_system_message( data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"] system_prompt, data["messages"]
) )
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
@ -435,40 +482,51 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
], ],
] ]
response = await call_next(request) response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse): if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line # If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type") content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type: if "text/event-stream" in content_type:
return StreamingResponse( return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations), self.openai_stream_wrapper(response.body_iterator, data_items),
) )
if "application/x-ndjson" in content_type: if "application/x-ndjson" in content_type:
return StreamingResponse( return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations), self.ollama_stream_wrapper(response.body_iterator, data_items),
) )
else:
return response
# If it's not a chat completion request, just pass it through
response = await call_next(request)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations): async def openai_stream_wrapper(self, original_generator, data_items):
yield f"data: {json.dumps({'citations': citations})}\n\n" for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator: async for data in original_generator:
yield data yield data
async def ollama_stream_wrapper(self, original_generator, citations): async def ollama_stream_wrapper(self, original_generator, data_items):
yield f"{json.dumps({'citations': citations})}\n" for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator: async for data in original_generator:
yield data yield data
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
#
##################################
def filter_pipeline(payload, user): def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
@ -628,7 +686,6 @@ async def update_embedding_function(request: Request, call_next):
app.mount("/ws", socket_app) app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app) app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app) app.mount("/openai", openai_app)
@ -642,17 +699,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(): async def get_all_models():
pipe_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API: if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models() openai_models = await get_openai_models()
openai_models = openai_models["data"] openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models() ollama_models = await get_ollama_models()
ollama_models = [ ollama_models = [
{ {
"id": model["model"], "id": model["model"],
@ -665,9 +723,9 @@ async def get_all_models():
for model in ollama_models["models"] for model in ollama_models["models"]
] ]
models = openai_models + ollama_models models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models()
custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id == None: if custom_model.base_model_id == None:
for model in models: for model in models:
@ -730,6 +788,234 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models} return {"data": models}
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
pipe = model.get("pipe")
if pipe:
form_data["user"] = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
if form_data["stream"]:
def stream_content():
res = pipe(body=form_data)
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
res = pipe(body=form_data)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:
if hasattr(function_module, "outlet"):
data = function_module.outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
##################################
#
# Task Endpoints
#
##################################
# TODO: Refactor task API endpoints below into a separate file
@app.get("/api/task/config") @app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)): async def get_task_config(user=Depends(get_verified_user)):
return { return {
@ -1015,92 +1301,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
) )
@app.post("/api/chat/completions") ##################################
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): #
model_id = form_data["model"] # Pipelines Endpoints
if model_id not in app.state.MODELS: #
raise HTTPException( ##################################
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed") # TODO: Refactor pipelines API endpoints below into a separate file
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
@app.get("/api/pipelines/list") @app.get("/api/pipelines/list")
@ -1423,6 +1631,13 @@ async def update_pipeline_valves(
) )
##################################
#
# Config Endpoints
#
##################################
@app.get("/api/config") @app.get("/api/config")
async def get_app_config(): async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA # Checking and Handling the Absence of 'ui' in CONFIG_DATA
@ -1486,6 +1701,9 @@ async def update_model_filter_config(
} }
# TODO: webhook endpoint should be under config endpoints
@app.get("/api/webhook") @app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)): async def get_webhook_url(user=Depends(get_admin_user)):
return { return {

View File

@ -4,6 +4,8 @@ import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str: def get_last_user_message(messages: List[dict]) -> str:
@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters

View File

@ -0,0 +1,193 @@
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewFunction = async (token: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const exportFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFunctionById = async (token: string, id: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};

View File

@ -278,7 +278,9 @@
})), })),
chat_id: $chatId chat_id: $chatId
}).catch((error) => { }).catch((error) => {
console.error(error); toast.error(error);
messages.at(-1).error = { content: error };
return null; return null;
}); });
@ -323,6 +325,13 @@
} else if (messages.length != 0 && messages.at(-1).done != true) { } else if (messages.length != 0 && messages.at(-1).done != true) {
// Response not done // Response not done
console.log('wait'); console.log('wait');
} else if (messages.length != 0 && messages.at(-1).error) {
// Error in response
toast.error(
$i18n.t(
`Oops! There was an error in the previous response. Please try again or contact admin.`
)
);
} else if ( } else if (
files.length > 0 && files.length > 0 &&
files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0 files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0
@ -630,7 +639,7 @@
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0, citations: files.length > 0 ? true : undefined,
chat_id: $chatId chat_id: $chatId
}); });
@ -928,10 +937,11 @@
max_tokens: $settings?.params?.max_tokens ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0, citations: files.length > 0 ? true : undefined,
chat_id: $chatId chat_id: $chatId
}, },
`${OPENAI_API_BASE_URL}` `${WEBUI_BASE_URL}/api`
); );
// Wait until history/message have been updated // Wait until history/message have been updated

View File

@ -3,25 +3,27 @@
import fileSaver from 'file-saver'; import fileSaver from 'file-saver';
const { saveAs } = fileSaver; const { saveAs } = fileSaver;
import { WEBUI_NAME, functions, models } from '$lib/stores';
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, prompts, tools } from '$lib/stores';
import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts'; import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { import {
createNewTool, createNewFunction,
deleteToolById, deleteFunctionById,
exportTools, exportFunctions,
getToolById, getFunctionById,
getTools getFunctions
} from '$lib/apis/tools'; } from '$lib/apis/functions';
import ArrowDownTray from '../icons/ArrowDownTray.svelte'; import ArrowDownTray from '../icons/ArrowDownTray.svelte';
import Tooltip from '../common/Tooltip.svelte'; import Tooltip from '../common/Tooltip.svelte';
import ConfirmDialog from '../common/ConfirmDialog.svelte'; import ConfirmDialog from '../common/ConfirmDialog.svelte';
import { getModels } from '$lib/apis';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
let toolsImportInputElement: HTMLInputElement; let functionsImportInputElement: HTMLInputElement;
let importFiles; let importFiles;
let showConfirm = false; let showConfirm = false;
@ -64,7 +66,7 @@
<div> <div>
<a <a
class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1" class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1"
href="/workspace/tools/create" href="/workspace/functions/create"
> >
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
@ -82,30 +84,40 @@
<hr class=" dark:border-gray-850 my-2.5" /> <hr class=" dark:border-gray-850 my-2.5" />
<div class="my-3 mb-5"> <div class="my-3 mb-5">
{#each $tools.filter((t) => query === '' || t.name {#each $functions.filter((f) => query === '' || f.name
.toLowerCase() .toLowerCase()
.includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool} .includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
<button <button
class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl" class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
type="button" type="button"
on:click={() => { on:click={() => {
goto(`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`); goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
}} }}
> >
<div class=" flex flex-1 space-x-4 cursor-pointer w-full"> <div class=" flex flex-1 space-x-4 cursor-pointer w-full">
<a <a
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`} href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
class="flex items-center text-left" class="flex items-center text-left"
> >
<div class=" flex-1 self-center pl-5"> <div class=" flex-1 self-center pl-1">
<div class=" font-semibold flex items-center gap-1.5"> <div class=" font-semibold flex items-center gap-1.5">
<div> <div
{tool.name} class=" text-xs font-black px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
>
{func.type}
</div>
<div>
{func.name}
</div> </div>
<div class=" text-gray-500 text-xs font-medium">{tool.id}</div>
</div> </div>
<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
{tool.meta.description} <div class="flex gap-1.5 px-1">
<div class=" text-gray-500 text-xs font-medium">{func.id}</div>
<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
{func.meta.description}
</div>
</div> </div>
</div> </div>
</a> </a>
@ -115,7 +127,7 @@
<a <a
class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl" class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
type="button" type="button"
href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`} href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
> >
<svg <svg
xmlns="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg"
@ -141,18 +153,20 @@
on:click={async (e) => { on:click={async (e) => {
e.stopPropagation(); e.stopPropagation();
const _tool = await getToolById(localStorage.token, tool.id).catch((error) => { const _function = await getFunctionById(localStorage.token, func.id).catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
);
if (_tool) { if (_function) {
sessionStorage.tool = JSON.stringify({ sessionStorage.function = JSON.stringify({
..._tool, ..._function,
id: `${_tool.id}_clone`, id: `${_function.id}_clone`,
name: `${_tool.name} (Clone)` name: `${_function.name} (Clone)`
}); });
goto('/workspace/tools/create'); goto('/workspace/functions/create');
} }
}} }}
> >
@ -180,16 +194,18 @@
on:click={async (e) => { on:click={async (e) => {
e.stopPropagation(); e.stopPropagation();
const _tool = await getToolById(localStorage.token, tool.id).catch((error) => { const _function = await getFunctionById(localStorage.token, func.id).catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
);
if (_tool) { if (_function) {
let blob = new Blob([JSON.stringify([_tool])], { let blob = new Blob([JSON.stringify([_function])], {
type: 'application/json' type: 'application/json'
}); });
saveAs(blob, `tool-${_tool.id}-export-${Date.now()}.json`); saveAs(blob, `function-${_function.id}-export-${Date.now()}.json`);
} }
}} }}
> >
@ -204,14 +220,16 @@
on:click={async (e) => { on:click={async (e) => {
e.stopPropagation(); e.stopPropagation();
const res = await deleteToolById(localStorage.token, tool.id).catch((error) => { const res = await deleteFunctionById(localStorage.token, func.id).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
}); });
if (res) { if (res) {
toast.success('Tool deleted successfully'); toast.success('Function deleted successfully');
tools.set(await getTools(localStorage.token));
functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
} }
}} }}
> >
@ -246,7 +264,7 @@
<div class="flex space-x-2"> <div class="flex space-x-2">
<input <input
id="documents-import-input" id="documents-import-input"
bind:this={toolsImportInputElement} bind:this={functionsImportInputElement}
bind:files={importFiles} bind:files={importFiles}
type="file" type="file"
accept=".json" accept=".json"
@ -260,7 +278,7 @@
<button <button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition" class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={() => { on:click={() => {
toolsImportInputElement.click(); functionsImportInputElement.click();
}} }}
> >
<div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div> <div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div>
@ -284,16 +302,16 @@
<button <button
class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition" class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
on:click={async () => { on:click={async () => {
const _tools = await exportTools(localStorage.token).catch((error) => { const _functions = await exportFunctions(localStorage.token).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
}); });
if (_tools) { if (_functions) {
let blob = new Blob([JSON.stringify(_tools)], { let blob = new Blob([JSON.stringify(_functions)], {
type: 'application/json' type: 'application/json'
}); });
saveAs(blob, `tools-export-${Date.now()}.json`); saveAs(blob, `functions-export-${Date.now()}.json`);
} }
}} }}
> >
@ -322,18 +340,19 @@
on:confirm={() => { on:confirm={() => {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = async (event) => { reader.onload = async (event) => {
const _tools = JSON.parse(event.target.result); const _functions = JSON.parse(event.target.result);
console.log(_tools); console.log(_functions);
for (const tool of _tools) { for (const func of _functions) {
const res = await createNewTool(localStorage.token, tool).catch((error) => { const res = await createNewFunction(localStorage.token, func).catch((error) => {
toast.error(error); toast.error(error);
return null; return null;
}); });
} }
toast.success('Tool imported successfully'); toast.success('Functions imported successfully');
tools.set(await getTools(localStorage.token)); functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
}; };
reader.readAsText(importFiles[0]); reader.readAsText(importFiles[0]);
@ -344,8 +363,8 @@
<div>Please carefully review the following warnings:</div> <div>Please carefully review the following warnings:</div>
<ul class=" mt-1 list-disc pl-4 text-xs"> <ul class=" mt-1 list-disc pl-4 text-xs">
<li>Tools have a function calling system that allows arbitrary code execution.</li> <li>Functions allow arbitrary code execution.</li>
<li>Do not install tools from sources you do not fully trust.</li> <li>Do not install functions from sources you do not fully trust.</li>
</ul> </ul>
</div> </div>

View File

@ -0,0 +1,235 @@
<script>
import { getContext, createEventDispatcher, onMount } from 'svelte';
import { goto } from '$app/navigation';
const dispatch = createEventDispatcher();
const i18n = getContext('i18n');
import CodeEditor from '$lib/components/common/CodeEditor.svelte';
import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
let formElement = null;
let loading = false;
let showConfirm = false;
export let edit = false;
export let clone = false;
export let id = '';
export let name = '';
export let meta = {
description: ''
};
export let content = '';
$: if (name && !edit && !clone) {
id = name.replace(/\s+/g, '_').toLowerCase();
}
let codeEditor;
let boilerplate = `from pydantic import BaseModel
from typing import Optional
class Filter:
class Valves(BaseModel):
max_turns: int = 4
pass
def __init__(self):
# Indicates custom file handling logic. This flag helps disengage default routines in favor of custom
# implementations, informing the WebUI to defer file-related operations to designated methods within this class.
# Alternatively, you can remove the files directly from the body in from the inlet hook
self.file_handler = True
# Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
# which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
self.valves = self.Valves(**{"max_turns": 2})
pass
def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify the request body or validate it before processing by the chat completion API.
# This function is the pre-processor for the API where various checks on the input can be performed.
# It can also modify the request before sending it to the API.
print(f"inlet:{__name__}")
print(f"inlet:body:{body}")
print(f"inlet:user:{user}")
if user.get("role", "admin") in ["user", "admin"]:
messages = body.get("messages", [])
if len(messages) > self.valves.max_turns:
raise Exception(
f"Conversation turn limit exceeded. Max turns: {self.valves.max_turns}"
)
return body
def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
# Modify or analyze the response body after processing by the API.
# This function is the post-processor for the API, which can be used to modify the response
# or perform additional checks and analytics.
print(f"outlet:{__name__}")
print(f"outlet:body:{body}")
print(f"outlet:user:{user}")
messages = [
{
**message,
"content": f"{message['content']} - @@Modified from Filter Outlet",
}
for message in body.get("messages", [])
]
return {"messages": messages}
`;
const saveHandler = async () => {
loading = true;
dispatch('save', {
id,
name,
meta,
content
});
};
const submitHandler = async () => {
if (codeEditor) {
const res = await codeEditor.formatPythonCodeHandler();
if (res) {
console.log('Code formatted successfully');
saveHandler();
}
}
};
</script>
<div class=" flex flex-col justify-between w-full overflow-y-auto h-full">
<div class="mx-auto w-full md:px-0 h-full">
<form
bind:this={formElement}
class=" flex flex-col max-h-[100dvh] h-full"
on:submit|preventDefault={() => {
if (edit) {
submitHandler();
} else {
showConfirm = true;
}
}}
>
<div class="mb-2.5">
<button
class="flex space-x-1"
on:click={() => {
goto('/workspace/functions');
}}
type="button"
>
<div class=" self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-4 h-4"
>
<path
fill-rule="evenodd"
d="M17 10a.75.75 0 01-.75.75H5.612l4.158 3.96a.75.75 0 11-1.04 1.08l-5.5-5.25a.75.75 0 010-1.08l5.5-5.25a.75.75 0 111.04 1.08L5.612 9.25H16.25A.75.75 0 0117 10z"
clip-rule="evenodd"
/>
</svg>
</div>
<div class=" self-center font-medium text-sm">{$i18n.t('Back')}</div>
</button>
</div>
<div class="flex flex-col flex-1 overflow-auto h-0 rounded-lg">
<div class="w-full mb-2 flex flex-col gap-1.5">
<div class="flex gap-2 w-full">
<input
class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function Name (e.g. My Filter)"
bind:value={name}
required
/>
<input
class="w-full px-3 py-2 text-sm font-medium disabled:text-gray-300 dark:disabled:text-gray-700 bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function ID (e.g. my_filter)"
bind:value={id}
required
disabled={edit}
/>
</div>
<input
class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
type="text"
placeholder="Function Description (e.g. A filter to remove profanity from text)"
bind:value={meta.description}
required
/>
</div>
<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
<CodeEditor
bind:value={content}
bind:this={codeEditor}
{boilerplate}
on:save={() => {
if (formElement) {
formElement.requestSubmit();
}
}}
/>
</div>
<div class="pb-3 flex justify-between">
<div class="flex-1 pr-3">
<div class="text-xs text-gray-500 line-clamp-2">
<span class=" font-semibold dark:text-gray-200">Warning:</span> Functions allow
arbitrary code execution <br />
<span class=" font-medium dark:text-gray-400"
>don't install random functions from sources you don't trust.</span
>
</div>
</div>
<button
class="px-3 py-1.5 text-sm font-medium bg-emerald-600 hover:bg-emerald-700 text-gray-50 transition rounded-lg"
type="submit"
>
{$i18n.t('Save')}
</button>
</div>
</div>
</form>
</div>
</div>
<ConfirmDialog
bind:show={showConfirm}
on:confirm={() => {
submitHandler();
}}
>
<div class="text-sm text-gray-500">
<div class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-lg px-4 py-3">
<div>Please carefully review the following warnings:</div>
<ul class=" mt-1 list-disc pl-4 text-xs">
<li>Functions allow arbitrary code execution.</li>
<li>Do not install functions from sources you do not fully trust.</li>
</ul>
</div>
<div class="my-3">
I acknowledge that I have read and I understand the implications of my action. I am aware of
the risks associated with executing arbitrary code and I have verified the trustworthiness of
the source.
</div>
</div>
</ConfirmDialog>

View File

@ -0,0 +1,60 @@
<script lang="ts">
import { getContext, onMount } from 'svelte';
import Checkbox from '$lib/components/common/Checkbox.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
const i18n = getContext('i18n');
export let filters = [];
export let selectedFilterIds = [];
let _filters = {};
onMount(() => {
_filters = filters.reduce((acc, filter) => {
acc[filter.id] = {
...filter,
selected: selectedFilterIds.includes(filter.id)
};
return acc;
}, {});
});
</script>
<div>
<div class="flex w-full justify-between mb-1">
<div class=" self-center text-sm font-semibold">{$i18n.t('Filters')}</div>
</div>
<div class=" text-xs dark:text-gray-500">
{$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
</div>
<!-- TODO: Filer order matters -->
<div class="flex flex-col">
{#if filters.length > 0}
<div class=" flex items-center mt-2 flex-wrap">
{#each Object.keys(_filters) as filter, filterIdx}
<div class=" flex items-center gap-2 mr-3">
<div class="self-center flex items-center">
<Checkbox
state={_filters[filter].selected ? 'checked' : 'unchecked'}
on:change={(e) => {
_filters[filter].selected = e.detail === 'checked';
selectedFilterIds = Object.keys(_filters).filter((t) => _filters[t].selected);
}}
/>
</div>
<div class=" py-0.5 text-sm w-full capitalize font-medium">
<Tooltip content={_filters[filter].meta.description}>
{_filters[filter].name}
</Tooltip>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@ -27,7 +27,9 @@ export const tags = writable([]);
export const models: Writable<Model[]> = writable([]); export const models: Writable<Model[]> = writable([]);
export const prompts: Writable<Prompt[]> = writable([]); export const prompts: Writable<Prompt[]> = writable([]);
export const documents: Writable<Document[]> = writable([]); export const documents: Writable<Document[]> = writable([]);
export const tools = writable([]); export const tools = writable([]);
export const functions = writable([]);
export const banners: Writable<Banner[]> = writable([]); export const banners: Writable<Banner[]> = writable([]);

View File

@ -1,11 +1,16 @@
<script lang="ts"> <script lang="ts">
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { WEBUI_NAME, showSidebar } from '$lib/stores'; import { WEBUI_NAME, showSidebar, functions } from '$lib/stores';
import MenuLines from '$lib/components/icons/MenuLines.svelte'; import MenuLines from '$lib/components/icons/MenuLines.svelte';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { getFunctions } from '$lib/apis/functions';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
onMount(async () => {
functions.set(await getFunctions(localStorage.token));
});
</script> </script>
<svelte:head> <svelte:head>

View File

@ -1,18 +1,20 @@
<script> <script>
import { goto } from '$app/navigation';
import { createNewTool, getTools } from '$lib/apis/tools';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
import { onMount } from 'svelte';
import { goto } from '$app/navigation';
import { functions, models } from '$lib/stores';
import { createNewFunction, getFunctions } from '$lib/apis/functions';
import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
import { getModels } from '$lib/apis';
let mounted = false; let mounted = false;
let clone = false; let clone = false;
let tool = null; let func = null;
const saveHandler = async (data) => { const saveHandler = async (data) => {
console.log(data); console.log(data);
const res = await createNewTool(localStorage.token, { const res = await createNewFunction(localStorage.token, {
id: data.id, id: data.id,
name: data.name, name: data.name,
meta: data.meta, meta: data.meta,
@ -23,19 +25,20 @@
}); });
if (res) { if (res) {
toast.success('Tool created successfully'); toast.success('Function created successfully');
tools.set(await getTools(localStorage.token)); functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
await goto('/workspace/tools'); await goto('/workspace/functions');
} }
}; };
onMount(() => { onMount(() => {
if (sessionStorage.tool) { if (sessionStorage.function) {
tool = JSON.parse(sessionStorage.tool); func = JSON.parse(sessionStorage.function);
sessionStorage.removeItem('tool'); sessionStorage.removeItem('function');
console.log(tool); console.log(func);
clone = true; clone = true;
} }
@ -44,11 +47,11 @@
</script> </script>
{#if mounted} {#if mounted}
<ToolkitEditor <FunctionEditor
id={tool?.id ?? ''} id={func?.id ?? ''}
name={tool?.name ?? ''} name={func?.name ?? ''}
meta={tool?.meta ?? { description: '' }} meta={func?.meta ?? { description: '' }}
content={tool?.content ?? ''} content={func?.content ?? ''}
{clone} {clone}
on:save={(e) => { on:save={(e) => {
saveHandler(e.detail); saveHandler(e.detail);

View File

@ -1,18 +1,21 @@
<script> <script>
import { toast } from 'svelte-sonner';
import { onMount } from 'svelte';
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { getToolById, getTools, updateToolById } from '$lib/apis/tools'; import { functions, models } from '$lib/stores';
import Spinner from '$lib/components/common/Spinner.svelte'; import { updateFunctionById, getFunctions, getFunctionById } from '$lib/apis/functions';
import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
import { tools } from '$lib/stores';
import { onMount } from 'svelte';
import { toast } from 'svelte-sonner';
let tool = null; import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
import Spinner from '$lib/components/common/Spinner.svelte';
import { getModels } from '$lib/apis';
let func = null;
const saveHandler = async (data) => { const saveHandler = async (data) => {
console.log(data); console.log(data);
const res = await updateToolById(localStorage.token, tool.id, { const res = await updateFunctionById(localStorage.token, func.id, {
id: data.id, id: data.id,
name: data.name, name: data.name,
meta: data.meta, meta: data.meta,
@ -23,10 +26,9 @@
}); });
if (res) { if (res) {
toast.success('Tool updated successfully'); toast.success('Function updated successfully');
tools.set(await getTools(localStorage.token)); functions.set(await getFunctions(localStorage.token));
models.set(await getModels(localStorage.token));
// await goto('/workspace/tools');
} }
}; };
@ -35,24 +37,24 @@
const id = $page.url.searchParams.get('id'); const id = $page.url.searchParams.get('id');
if (id) { if (id) {
tool = await getToolById(localStorage.token, id).catch((error) => { func = await getFunctionById(localStorage.token, id).catch((error) => {
toast.error(error); toast.error(error);
goto('/workspace/tools'); goto('/workspace/functions');
return null; return null;
}); });
console.log(tool); console.log(func);
} }
}); });
</script> </script>
{#if tool} {#if func}
<ToolkitEditor <FunctionEditor
edit={true} edit={true}
id={tool.id} id={func.id}
name={tool.name} name={func.name}
meta={tool.meta} meta={func.meta}
content={tool.content} content={func.content}
on:save={(e) => { on:save={(e) => {
saveHandler(e.detail); saveHandler(e.detail);
}} }}

View File

@ -5,7 +5,7 @@
import { onMount, getContext } from 'svelte'; import { onMount, getContext } from 'svelte';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { settings, user, config, models, tools } from '$lib/stores'; import { settings, user, config, models, tools, functions } from '$lib/stores';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
import { getModelInfos, updateModelById } from '$lib/apis/models'; import { getModelInfos, updateModelById } from '$lib/apis/models';
@ -16,6 +16,7 @@
import Tags from '$lib/components/common/Tags.svelte'; import Tags from '$lib/components/common/Tags.svelte';
import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte'; import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte';
import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte'; import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte';
import FiltersSelector from '$lib/components/workspace/Models/FiltersSelector.svelte';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
@ -62,6 +63,7 @@
let knowledge = []; let knowledge = [];
let toolIds = []; let toolIds = [];
let filterIds = [];
const updateHandler = async () => { const updateHandler = async () => {
loading = true; loading = true;
@ -86,6 +88,14 @@
} }
} }
if (filterIds.length > 0) {
info.meta.filterIds = filterIds;
} else {
if (info.meta.filterIds) {
delete info.meta.filterIds;
}
}
info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null; info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null;
Object.keys(info.params).forEach((key) => { Object.keys(info.params).forEach((key) => {
if (info.params[key] === '' || info.params[key] === null) { if (info.params[key] === '' || info.params[key] === null) {
@ -147,6 +157,10 @@
toolIds = [...model?.info?.meta?.toolIds]; toolIds = [...model?.info?.meta?.toolIds];
} }
if (model?.info?.meta?.filterIds) {
filterIds = [...model?.info?.meta?.filterIds];
}
if (model?.owned_by === 'openai') { if (model?.owned_by === 'openai') {
capabilities.usage = false; capabilities.usage = false;
} }
@ -534,6 +548,13 @@
<ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} /> <ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} />
</div> </div>
<div class="my-2">
<FiltersSelector
bind:selectedFilterIds={filterIds}
filters={$functions.filter((func) => func.type === 'filter')}
/>
</div>
<div class="my-2"> <div class="my-2">
<div class="flex w-full justify-between mb-1"> <div class="flex w-full justify-between mb-1">
<div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div> <div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div>