refac: toolkit -> tools

This commit is contained in:
Timothy Jaeryang Baek 2024-11-16 17:54:38 -08:00
parent a1ce8422fd
commit 0a8f69285c
3 changed files with 53 additions and 53 deletions

View File

@ -3,7 +3,7 @@ from pathlib import Path
from typing import Optional from typing import Optional
from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools from open_webui.apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
from open_webui.apps.webui.utils import load_toolkit_module_by_id, replace_imports from open_webui.apps.webui.utils import load_tools_module_by_id, replace_imports
from open_webui.config import CACHE_DIR, DATA_DIR from open_webui.config import CACHE_DIR, DATA_DIR
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request, status
@ -71,30 +71,30 @@ async def create_new_tools(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id) tools = Tools.get_tool_by_id(form_data.id)
if toolkit is None: if tools is None:
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
toolkit_module, frontmatter = load_toolkit_module_by_id( tools_module, frontmatter = load_tools_module_by_id(
form_data.id, content=form_data.content form_data.id, content=form_data.content
) )
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module TOOLS[form_data.id] = tools_module
specs = get_tools_specs(TOOLS[form_data.id]) specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs) tools = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True) tool_cache_dir.mkdir(parents=True, exist_ok=True)
if toolkit: if tools:
return toolkit return tools
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"), detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
) )
except Exception as e: except Exception as e:
print(e) print(e)
@ -116,10 +116,10 @@ async def create_new_tools(
@router.get("/id/{id}", response_model=Optional[ToolModel]) @router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_tools_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if toolkit: if tools:
return toolkit return tools
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -141,13 +141,13 @@ async def update_tools_by_id(
): ):
try: try:
form_data.content = replace_imports(form_data.content) form_data.content = replace_imports(form_data.content)
toolkit_module, frontmatter = load_toolkit_module_by_id( tools_module, frontmatter = load_tools_module_by_id(
id, content=form_data.content id, content=form_data.content
) )
form_data.meta.manifest = frontmatter form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module TOOLS[id] = tools_module
specs = get_tools_specs(TOOLS[id]) specs = get_tools_specs(TOOLS[id])
@ -157,14 +157,14 @@ async def update_tools_by_id(
} }
print(updated) print(updated)
toolkit = Tools.update_tool_by_id(id, updated) tools = Tools.update_tool_by_id(id, updated)
if toolkit: if tools:
return toolkit return tools
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"), detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
) )
except Exception as e: except Exception as e:
@ -200,8 +200,8 @@ async def delete_tools_by_id(
@router.get("/id/{id}/valves", response_model=Optional[dict]) @router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if toolkit: if tools:
try: try:
valves = Tools.get_tool_valves_by_id(id) valves = Tools.get_tool_valves_by_id(id)
return valves return valves
@ -226,16 +226,16 @@ async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
async def get_tools_valves_spec_by_id( async def get_tools_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user)
): ):
toolkit = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if toolkit: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
toolkit_module, _ = load_toolkit_module_by_id(id) tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "Valves"): if hasattr(tools_module, "Valves"):
Valves = toolkit_module.Valves Valves = tools_module.Valves
return Valves.schema() return Valves.schema()
return None return None
else: else:
@ -254,16 +254,16 @@ async def get_tools_valves_spec_by_id(
async def update_tools_valves_by_id( async def update_tools_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
): ):
toolkit = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if toolkit: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
toolkit_module, _ = load_toolkit_module_by_id(id) tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "Valves"): if hasattr(tools_module, "Valves"):
Valves = toolkit_module.Valves Valves = tools_module.Valves
try: try:
form_data = {k: v for k, v in form_data.items() if v is not None} form_data = {k: v for k, v in form_data.items() if v is not None}
@ -296,8 +296,8 @@ async def update_tools_valves_by_id(
@router.get("/id/{id}/valves/user", response_model=Optional[dict]) @router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)): async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if toolkit: if tools:
try: try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id) user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves return user_valves
@ -317,16 +317,16 @@ async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
async def get_tools_user_valves_spec_by_id( async def get_tools_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user) request: Request, id: str, user=Depends(get_verified_user)
): ):
toolkit = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if toolkit: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
toolkit_module, _ = load_toolkit_module_by_id(id) tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "UserValves"): if hasattr(tools_module, "UserValves"):
UserValves = toolkit_module.UserValves UserValves = tools_module.UserValves
return UserValves.schema() return UserValves.schema()
return None return None
else: else:
@ -340,17 +340,17 @@ async def get_tools_user_valves_spec_by_id(
async def update_tools_user_valves_by_id( async def update_tools_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user) request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
): ):
toolkit = Tools.get_tool_by_id(id) tools = Tools.get_tool_by_id(id)
if toolkit: if tools:
if id in request.app.state.TOOLS: if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id] tools_module = request.app.state.TOOLS[id]
else: else:
toolkit_module, _ = load_toolkit_module_by_id(id) tools_module, _ = load_tools_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module request.app.state.TOOLS[id] = tools_module
if hasattr(toolkit_module, "UserValves"): if hasattr(tools_module, "UserValves"):
UserValves = toolkit_module.UserValves UserValves = tools_module.UserValves
try: try:
form_data = {k: v for k, v in form_data.items() if v is not None} form_data = {k: v for k, v in form_data.items() if v is not None}

View File

@ -63,7 +63,7 @@ def replace_imports(content):
return content return content
def load_toolkit_module_by_id(toolkit_id, content=None): def load_tools_module_by_id(toolkit_id, content=None):
if content is None: if content is None:
tool = Tools.get_tool_by_id(toolkit_id) tool = Tools.get_tool_by_id(toolkit_id)

View File

@ -4,7 +4,7 @@ from typing import Awaitable, Callable, get_type_hints
from open_webui.apps.webui.models.tools import Tools from open_webui.apps.webui.models.tools import Tools
from open_webui.apps.webui.models.users import UserModel from open_webui.apps.webui.models.users import UserModel
from open_webui.apps.webui.utils import load_toolkit_module_by_id from open_webui.apps.webui.utils import load_tools_module_by_id
from open_webui.utils.schemas import json_schema_to_model from open_webui.utils.schemas import json_schema_to_model
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -40,7 +40,7 @@ def get_tools(
module = webui_app.state.TOOLS.get(tool_id, None) module = webui_app.state.TOOLS.get(tool_id, None)
if module is None: if module is None:
module, _ = load_toolkit_module_by_id(tool_id) module, _ = load_tools_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = module webui_app.state.TOOLS[tool_id] = module
extra_params["__id__"] = tool_id extra_params["__id__"] = tool_id