diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml
index 2426aff27..85810c2ed 100644
--- a/.github/workflows/integration-test.yml
+++ b/.github/workflows/integration-test.yml
@@ -25,7 +25,7 @@ jobs:
--file docker-compose.api.yaml \
--file docker-compose.a1111-test.yaml \
up --detach --build
-
+
- name: Wait for Ollama to be up
timeout-minutes: 5
run: |
@@ -43,7 +43,7 @@ jobs:
uses: cypress-io/github-action@v6
with:
browser: chrome
- wait-on: 'http://localhost:3000'
+ wait-on: "http://localhost:3000"
config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4
@@ -82,18 +82,18 @@ jobs:
--health-retries 5
ports:
- 5432:5432
-# mysql:
-# image: mysql
-# env:
-# MYSQL_ROOT_PASSWORD: mysql
-# MYSQL_DATABASE: mysql
-# options: >-
-# --health-cmd "mysqladmin ping -h localhost"
-# --health-interval 10s
-# --health-timeout 5s
-# --health-retries 5
-# ports:
-# - 3306:3306
+ # mysql:
+ # image: mysql
+ # env:
+ # MYSQL_ROOT_PASSWORD: mysql
+ # MYSQL_DATABASE: mysql
+ # options: >-
+ # --health-cmd "mysqladmin ping -h localhost"
+ # --health-interval 10s
+ # --health-timeout 5s
+ # --health-retries 5
+ # ports:
+ # - 3306:3306
steps:
- name: Checkout Repository
uses: actions/checkout@v4
@@ -142,7 +142,6 @@ jobs:
echo "Server has stopped"
exit 1
fi
-
- name: Test backend with Postgres
if: success() || steps.sqlite.conclusion == 'failure'
@@ -171,6 +170,25 @@ jobs:
exit 1
fi
+ # Check that service will reconnect to postgres when connection will be closed
+ status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
+ if [[ "$status_code" -ne 200 ]] ; then
+ echo "Server has failed before postgres reconnect check"
+ exit 1
+ fi
+
+ echo "Terminating all connections to postgres..."
+ python -c "import os, psycopg2 as pg2; \
+ conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
+ cur = conn.cursor(); \
+ cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
+
+ status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
+ if [[ "$status_code" -ne 200 ]] ; then
+ echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
+ exit 1
+ fi
+
# - name: Test backend with MySQL
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
# env:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6756d105b..e27c37a5b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,36 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [0.3.6] - 2024-06-27
+
+### Added
+
+- **✨ "Functions" Feature**: You can now utilize "Functions" like filters (middleware) and pipe (model) functions directly within the WebUI. While largely compatible with Pipelines, these native functions can be executed easily within Open WebUI. Example use cases for filter functions include usage monitoring, real-time translation, moderation, and automemory. For pipe functions, the scope ranges from Cohere and Anthropic integration directly within Open WebUI, enabling "Valves" for per-user OpenAI API key usage, and much more. If you encounter issues, SAFE_MODE has been introduced.
+- **📁 Files API**: Compatible with OpenAI, this feature allows for custom Retrieval-Augmented Generation (RAG) in conjunction with the Filter Function. More examples will be shared on our community platform and official documentation website.
+- **🛠️ Tool Enhancements**: Tools now support citations and "Valves". Documentation will be available shortly.
+- **🔗 Iframe Support via Files API**: Enables rendering HTML directly into your chat interface using functions and tools. Use cases include playing games like DOOM and Snake, displaying a weather applet, and implementing Anthropic "artifacts"-like features. Stay tuned for updates on our community platform and documentation.
+- **🔒 Experimental OAuth Support**: New experimental OAuth support. Check our documentation for more details.
+- **🖼️ Custom Background Support**: Set a custom background from Settings > Interface to personalize your experience.
+- **🔑 AUTOMATIC1111_API_AUTH Support**: Enhanced security for the AUTOMATIC1111 API.
+- **🎨 Code Highlight Optimization**: Improved code highlighting features.
+- **🎙️ Voice Interruption Feature**: Reintroduced and now toggleable from Settings > Interface.
+- **💤 Wakelock API**: Now in use to prevent screen dimming during important tasks.
+- **🔐 API Key Privacy**: All API keys are now hidden by default for better security.
+- **🔍 New Web Search Provider**: Added jina_search as a new option.
+- **🌐 Enhanced Internationalization (i18n)**: Improved Korean translation and updated Chinese and Ukrainian translations.
+
+### Fixed
+
+- **🔧 Conversation Mode Issue**: Fixed the issue where Conversation Mode remained active after being removed from settings.
+- **📏 Scroll Button Obstruction**: Resolved the issue where the scrollToBottom button container obstructed clicks on buttons beneath it.
+
+### Changed
+
+- **⏲️ AIOHTTP_CLIENT_TIMEOUT**: Now set to `None` by default for improved configuration flexibility.
+- **📞 Voice Call Enhancements**: Improved by skipping code blocks and expressions during calls.
+- **🚫 Error Message Handling**: Disabled the continuation of operations with error messages.
+- **🗂️ Playground Relocation**: Moved the Playground from the workspace to the user menu for better user experience.
+
## [0.3.5] - 2024-06-16
### Added
diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py
index 663e20c97..8843f376f 100644
--- a/backend/apps/audio/main.py
+++ b/backend/apps/audio/main.py
@@ -325,7 +325,7 @@ def transcribe(
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
- data = {"model": "whisper-1"}
+ data = {"model": app.state.config.STT_MODEL}
print(files, data)
diff --git a/backend/apps/images/main.py b/backend/apps/images/main.py
index af7e5592d..8f1a08e04 100644
--- a/backend/apps/images/main.py
+++ b/backend/apps/images/main.py
@@ -1,5 +1,6 @@
import re
import requests
+import base64
from fastapi import (
FastAPI,
Request,
@@ -15,7 +16,7 @@ from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
- get_current_user,
+ get_verified_user,
get_admin_user,
)
@@ -36,6 +37,7 @@ from config import (
IMAGE_GENERATION_ENGINE,
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
+ AUTOMATIC1111_API_AUTH,
COMFYUI_BASE_URL,
COMFYUI_CFG_SCALE,
COMFYUI_SAMPLER,
@@ -49,7 +51,6 @@ from config import (
AppConfig,
)
-
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -75,11 +76,10 @@ app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
-
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
+app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
-
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
@@ -88,6 +88,16 @@ app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
+def get_automatic1111_api_auth():
+ if app.state.config.AUTOMATIC1111_API_AUTH == None:
+ return ""
+ else:
+ auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
+ auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
+ auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
+ return f"Basic {auth1111_base64_encoded_string}"
+
+
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
@@ -113,6 +123,7 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
class EngineUrlUpdateForm(BaseModel):
AUTOMATIC1111_BASE_URL: Optional[str] = None
+ AUTOMATIC1111_API_AUTH: Optional[str] = None
COMFYUI_BASE_URL: Optional[str] = None
@@ -120,6 +131,7 @@ class EngineUrlUpdateForm(BaseModel):
async def get_engine_url(user=Depends(get_admin_user)):
return {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
+ "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
}
@@ -128,7 +140,6 @@ async def get_engine_url(user=Depends(get_admin_user)):
async def update_engine_url(
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
):
-
if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
@@ -150,8 +161,14 @@ async def update_engine_url(
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
+ if form_data.AUTOMATIC1111_API_AUTH == None:
+ app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
+ else:
+ app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
+
return {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
+ "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True,
}
@@ -241,7 +258,7 @@ async def update_image_size(
@app.get("/models")
-def get_models(user=Depends(get_current_user)):
+def get_models(user=Depends(get_verified_user)):
try:
if app.state.config.ENGINE == "openai":
return [
@@ -262,7 +279,8 @@ def get_models(user=Depends(get_current_user)):
else:
r = requests.get(
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
+ headers={"authorization": get_automatic1111_api_auth()},
)
models = r.json()
return list(
@@ -289,7 +307,8 @@ async def get_default_model(user=Depends(get_admin_user)):
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else:
r = requests.get(
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
+ headers={"authorization": get_automatic1111_api_auth()},
)
options = r.json()
return {"model": options["sd_model_checkpoint"]}
@@ -307,8 +326,10 @@ def set_model_handler(model: str):
app.state.config.MODEL = model
return app.state.config.MODEL
else:
+ api_auth = get_automatic1111_api_auth()
r = requests.get(
- url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
+ url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
+ headers={"authorization": api_auth},
)
options = r.json()
@@ -317,6 +338,7 @@ def set_model_handler(model: str):
r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
+ headers={"authorization": api_auth},
)
return options
@@ -325,7 +347,7 @@ def set_model_handler(model: str):
@app.post("/models/default/update")
def update_default_model(
form_data: UpdateModelForm,
- user=Depends(get_current_user),
+ user=Depends(get_verified_user),
):
return set_model_handler(form_data.model)
@@ -402,9 +424,8 @@ def save_url_image(url):
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
- user=Depends(get_current_user),
+ user=Depends(get_verified_user),
):
-
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
r = None
@@ -519,6 +540,7 @@ def generate_image(
r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
+ headers={"authorization": get_automatic1111_api_auth()},
)
res = r.json()
diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py
index 81a3b2a0e..455dc89a5 100644
--- a/backend/apps/ollama/main.py
+++ b/backend/apps/ollama/main.py
@@ -40,6 +40,7 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
+from utils.task import prompt_template
from config import (
@@ -52,7 +53,7 @@ from config import (
UPLOAD_DIR,
AppConfig,
)
-from utils.misc import calculate_sha256
+from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@@ -199,9 +200,6 @@ def merge_models_lists(model_lists):
return list(merged_models.values())
-# user=Depends(get_current_user)
-
-
async def get_all_models():
log.info("get_all_models()")
@@ -817,24 +815,28 @@ async def generate_chat_completion(
"num_thread", None
)
- if model_info.params.get("system", None):
+ system = model_info.params.get("system", None)
+ if system:
# Check if the payload already has a system message
# If not, add a system message to the payload
+ system = prompt_template(
+ system,
+ **(
+ {
+ "user_name": user.name,
+ "user_location": (
+ user.info.get("location") if user.info else None
+ ),
+ }
+ if user
+ else {}
+ ),
+ )
+
if payload.get("messages"):
- for message in payload["messages"]:
- if message.get("role") == "system":
- message["content"] = (
- model_info.params.get("system", None) + message["content"]
- )
- break
- else:
- payload["messages"].insert(
- 0,
- {
- "role": "system",
- "content": model_info.params.get("system", None),
- },
- )
+ payload["messages"] = add_or_update_system_message(
+ system, payload["messages"]
+ )
if url_idx == None:
if ":" not in payload["model"]:
@@ -878,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel):
@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion(
- form_data: OpenAIChatCompletionForm,
+ form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
+ form_data = OpenAIChatCompletionForm(**form_data)
payload = {
**form_data.model_dump(exclude_none=True),
@@ -913,22 +916,35 @@ async def generate_openai_chat_completion(
else None
)
- if model_info.params.get("system", None):
+ system = model_info.params.get("system", None)
+
+ if system:
+ system = prompt_template(
+ system,
+ **(
+ {
+ "user_name": user.name,
+ "user_location": (
+ user.info.get("location") if user.info else None
+ ),
+ }
+ if user
+ else {}
+ ),
+ )
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
- message["content"] = (
- model_info.params.get("system", None) + message["content"]
- )
+ message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
- "content": model_info.params.get("system", None),
+ "content": system,
},
)
@@ -1094,17 +1110,13 @@ async def download_file_stream(
raise "Ollama: Could not create blob, Please try again."
-# def number_generator():
-# for i in range(1, 101):
-# yield f"data: {i}\n"
-
-
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download")
@app.post("/models/download/{url_idx}")
async def download_model(
form_data: UrlForm,
url_idx: Optional[int] = None,
+ user=Depends(get_admin_user),
):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
@@ -1133,7 +1145,11 @@ async def download_model(
@app.post("/models/upload")
@app.post("/models/upload/{url_idx}")
-def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
+def upload_model(
+ file: UploadFile = File(...),
+ url_idx: Optional[int] = None,
+ user=Depends(get_admin_user),
+):
if url_idx == None:
url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@@ -1196,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
-
-
-# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
-# if url_idx == None:
-# url_idx = 0
-# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-
-# file_location = os.path.join(UPLOAD_DIR, file.filename)
-# total_size = file.size
-
-# async def file_upload_generator(file):
-# print(file)
-# try:
-# async with aiofiles.open(file_location, "wb") as f:
-# completed_size = 0
-# while True:
-# chunk = await file.read(1024*1024)
-# if not chunk:
-# break
-# await f.write(chunk)
-# completed_size += len(chunk)
-# progress = (completed_size / total_size) * 100
-
-# print(progress)
-# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
-# except Exception as e:
-# print(e)
-# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
-# finally:
-# await file.close()
-# print("done")
-# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
-
-# return StreamingResponse(
-# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
-# )
-
-
-@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
-async def deprecated_proxy(
- path: str, request: Request, user=Depends(get_verified_user)
-):
- url = app.state.config.OLLAMA_BASE_URLS[0]
- target_url = f"{url}/{path}"
-
- body = await request.body()
- headers = dict(request.headers)
-
- if user.role in ["user", "admin"]:
- if path in ["pull", "delete", "push", "copy", "create"]:
- if user.role != "admin":
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
- )
- else:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
- )
-
- headers.pop("host", None)
- headers.pop("authorization", None)
- headers.pop("origin", None)
- headers.pop("referer", None)
-
- r = None
-
- def get_request():
- nonlocal r
-
- request_id = str(uuid.uuid4())
- try:
- REQUEST_POOL.append(request_id)
-
- def stream_content():
- try:
- if path == "generate":
- data = json.loads(body.decode("utf-8"))
-
- if data.get("stream", True):
- yield json.dumps({"id": request_id, "done": False}) + "\n"
-
- elif path == "chat":
- yield json.dumps({"id": request_id, "done": False}) + "\n"
-
- for chunk in r.iter_content(chunk_size=8192):
- if request_id in REQUEST_POOL:
- yield chunk
- else:
- log.warning("User: canceled request")
- break
- finally:
- if hasattr(r, "close"):
- r.close()
- if request_id in REQUEST_POOL:
- REQUEST_POOL.remove(request_id)
-
- r = requests.request(
- method=request.method,
- url=target_url,
- data=body,
- headers=headers,
- stream=True,
- )
-
- r.raise_for_status()
-
- # r.close()
-
- return StreamingResponse(
- stream_content(),
- status_code=r.status_code,
- headers=dict(r.headers),
- )
- except Exception as e:
- raise e
-
- try:
- return await run_in_threadpool(get_request)
- except Exception as e:
- error_detail = "Open WebUI: Server Connection Error"
- if r is not None:
- try:
- res = r.json()
- if "error" in res:
- error_detail = f"Ollama: {res['error']}"
- except:
- error_detail = f"Ollama: {e}"
-
- raise HTTPException(
- status_code=r.status_code if r else 500,
- detail=error_detail,
- )
diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py
index c09c030d2..31dd48741 100644
--- a/backend/apps/openai/main.py
+++ b/backend/apps/openai/main.py
@@ -16,10 +16,12 @@ from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
- get_current_user,
+ get_verified_user,
get_verified_user,
get_admin_user,
)
+from utils.task import prompt_template
+
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
@@ -294,7 +296,7 @@ async def get_all_models(raw: bool = False):
@app.get("/models")
@app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
+async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx == None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
@@ -392,22 +394,34 @@ async def generate_chat_completion(
else None
)
- if model_info.params.get("system", None):
+ system = model_info.params.get("system", None)
+ if system:
+ system = prompt_template(
+ system,
+ **(
+ {
+ "user_name": user.name,
+ "user_location": (
+ user.info.get("location") if user.info else None
+ ),
+ }
+ if user
+ else {}
+ ),
+ )
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
- message["content"] = (
- model_info.params.get("system", None) + message["content"]
- )
+ message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
- "content": model_info.params.get("system", None),
+ "content": system,
},
)
@@ -418,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
- payload["user"] = {"name": user.name, "id": user.id}
+ payload["user"] = {
+ "name": user.name,
+ "id": user.id,
+ "email": user.email,
+ "role": user.role,
+ }
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index 4bd5da86c..7c6974535 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -55,6 +55,9 @@ from apps.webui.models.documents import (
DocumentForm,
DocumentResponse,
)
+from apps.webui.models.files import (
+ Files,
+)
from apps.rag.utils import (
get_model_path,
@@ -74,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily
+from apps.rag.search.jina_search import search_jina
from utils.misc import (
calculate_sha256,
@@ -81,7 +85,7 @@ from utils.misc import (
sanitize_filename,
extract_folders_after_data_docs,
)
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
from config import (
AppConfig,
@@ -112,6 +116,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
+ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
@@ -165,6 +170,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
+app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
@@ -523,7 +529,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
@app.get("/template")
-async def get_rag_template(user=Depends(get_current_user)):
+async def get_rag_template(user=Depends(get_verified_user)):
return {
"status": True,
"template": app.state.config.RAG_TEMPLATE,
@@ -580,7 +586,7 @@ class QueryDocForm(BaseModel):
@app.post("/query/doc")
def query_doc_handler(
form_data: QueryDocForm,
- user=Depends(get_current_user),
+ user=Depends(get_verified_user),
):
try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
@@ -620,7 +626,7 @@ class QueryCollectionsForm(BaseModel):
@app.post("/query/collection")
def query_collection_handler(
form_data: QueryCollectionsForm,
- user=Depends(get_current_user),
+ user=Depends(get_verified_user),
):
try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
@@ -651,7 +657,7 @@ def query_collection_handler(
@app.post("/youtube")
-def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
+def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
try:
loader = YoutubeLoader.from_youtube_url(
form_data.url,
@@ -680,7 +686,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
@app.post("/web")
-def store_web(form_data: UrlForm, user=Depends(get_current_user)):
+def store_web(form_data: UrlForm, user=Depends(get_verified_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(
@@ -775,6 +781,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SEARXNG_QUERY_URL,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
@@ -788,6 +795,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.GOOGLE_PSE_ENGINE_ID,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception(
@@ -799,6 +807,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.BRAVE_SEARCH_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
@@ -808,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPSTACK_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
https_enabled=app.state.config.SERPSTACK_HTTPS,
)
else:
@@ -818,6 +828,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPER_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
@@ -827,11 +838,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPLY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
- return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
+ return search_duckduckgo(
+ query,
+ app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
elif engine == "tavily":
if app.state.config.TAVILY_API_KEY:
return search_tavily(
@@ -841,12 +857,14 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
+ elif engine == "jina":
+ return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else:
raise Exception("No search engine API key found in environment variables")
@app.post("/web/search")
-def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
+def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
try:
logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
@@ -1066,7 +1084,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
def store_doc(
collection_name: Optional[str] = Form(None),
file: UploadFile = File(...),
- user=Depends(get_current_user),
+ user=Depends(get_verified_user),
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
@@ -1119,6 +1137,60 @@ def store_doc(
)
+class ProcessDocForm(BaseModel):
+ file_id: str
+ collection_name: Optional[str] = None
+
+
+@app.post("/process/doc")
+def process_doc(
+ form_data: ProcessDocForm,
+ user=Depends(get_verified_user),
+):
+ try:
+ file = Files.get_file_by_id(form_data.file_id)
+ file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
+
+ f = open(file_path, "rb")
+
+ collection_name = form_data.collection_name
+ if collection_name == None:
+ collection_name = calculate_sha256(f)[:63]
+ f.close()
+
+ loader, known_type = get_loader(
+ file.filename, file.meta.get("content_type"), file_path
+ )
+ data = loader.load()
+
+ try:
+ result = store_data_in_vector_db(data, collection_name)
+
+ if result:
+ return {
+ "status": True,
+ "collection_name": collection_name,
+ "known_type": known_type,
+ }
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=e,
+ )
+ except Exception as e:
+ log.exception(e)
+ if "No pandoc was found" in str(e):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+
+
class TextRAGForm(BaseModel):
name: str
content: str
@@ -1128,7 +1200,7 @@ class TextRAGForm(BaseModel):
@app.post("/text")
def store_text(
form_data: TextRAGForm,
- user=Depends(get_current_user),
+ user=Depends(get_verified_user),
):
collection_name = form_data.collection_name
diff --git a/backend/apps/rag/search/brave.py b/backend/apps/rag/search/brave.py
index 4e0f56807..76ad1fb47 100644
--- a/backend/apps/rag/search/brave.py
+++ b/backend/apps/rag/search/brave.py
@@ -1,15 +1,17 @@
import logging
-
+from typing import List, Optional
import requests
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
-def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
+def search_brave(
+ api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
+) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response = response.json()
results = json_response.get("web", {}).get("results", [])
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
+
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
diff --git a/backend/apps/rag/search/duckduckgo.py b/backend/apps/rag/search/duckduckgo.py
index 188ae2bea..f0cc2a710 100644
--- a/backend/apps/rag/search/duckduckgo.py
+++ b/backend/apps/rag/search/duckduckgo.py
@@ -1,6 +1,6 @@
import logging
-
-from apps.rag.search.main import SearchResult
+from typing import List, Optional
+from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
@@ -8,7 +8,9 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
-def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
+def search_duckduckgo(
+ query: str, count: int, filter_list: Optional[List[str]] = None
+) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet=result.get("body"),
)
)
- print(results)
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
# Return the list of search results
return results
diff --git a/backend/apps/rag/search/google_pse.py b/backend/apps/rag/search/google_pse.py
index 7ff54c785..0c78512e7 100644
--- a/backend/apps/rag/search/google_pse.py
+++ b/backend/apps/rag/search/google_pse.py
@@ -1,9 +1,9 @@
import json
import logging
-
+from typing import List, Optional
import requests
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
- api_key: str, search_engine_id: str, query: str, count: int
+ api_key: str,
+ search_engine_id: str,
+ query: str,
+ count: int,
+ filter_list: Optional[List[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
@@ -35,6 +39,8 @@ def search_google_pse(
json_response = response.json()
results = json_response.get("items", [])
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
diff --git a/backend/apps/rag/search/jina_search.py b/backend/apps/rag/search/jina_search.py
new file mode 100644
index 000000000..65f9ad68f
--- /dev/null
+++ b/backend/apps/rag/search/jina_search.py
@@ -0,0 +1,41 @@
+import logging
+import requests
+from yarl import URL
+
+from apps.rag.search.main import SearchResult
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_jina(query: str, count: int) -> list[SearchResult]:
+ """
+ Search using Jina's Search API and return the results as a list of SearchResult objects.
+ Args:
+ query (str): The query to search for
+ count (int): The number of results to return
+
+ Returns:
+ List[SearchResult]: A list of search results
+ """
+ jina_search_endpoint = "https://s.jina.ai/"
+ headers = {
+ "Accept": "application/json",
+ }
+ url = str(URL(jina_search_endpoint + query))
+ response = requests.get(url, headers=headers)
+ response.raise_for_status()
+ data = response.json()
+
+ results = []
+ for result in data["data"][:count]:
+ results.append(
+ SearchResult(
+ link=result["url"],
+ title=result.get("title"),
+ snippet=result.get("content"),
+ )
+ )
+
+ return results
diff --git a/backend/apps/rag/search/main.py b/backend/apps/rag/search/main.py
index b5478f949..49056f1fd 100644
--- a/backend/apps/rag/search/main.py
+++ b/backend/apps/rag/search/main.py
@@ -1,8 +1,19 @@
from typing import Optional
-
+from urllib.parse import urlparse
from pydantic import BaseModel
+def get_filtered_results(results, filter_list):
+ if not filter_list:
+ return results
+ filtered_results = []
+ for result in results:
+ domain = urlparse(result["url"]).netloc
+ if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
+ filtered_results.append(result)
+ return filtered_results
+
+
class SearchResult(BaseModel):
link: str
title: Optional[str]
diff --git a/backend/apps/rag/search/searxng.py b/backend/apps/rag/search/searxng.py
index c8ad88813..6e545e994 100644
--- a/backend/apps/rag/search/searxng.py
+++ b/backend/apps/rag/search/searxng.py
@@ -1,9 +1,9 @@
import logging
import requests
-from typing import List
+from typing import List, Optional
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng(
- query_url: str, query: str, count: int, **kwargs
+ query_url: str,
+ query: str,
+ count: int,
+ filter_list: Optional[List[str]] = None,
+ **kwargs,
) -> List[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
@@ -78,6 +82,8 @@ def search_searxng(
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
+ if filter_list:
+ sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
diff --git a/backend/apps/rag/search/serper.py b/backend/apps/rag/search/serper.py
index 150da6e07..b278a4df1 100644
--- a/backend/apps/rag/search/serper.py
+++ b/backend/apps/rag/search/serper.py
@@ -1,16 +1,18 @@
import json
import logging
-
+from typing import List, Optional
import requests
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
-def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
+def search_serper(
+ api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
+) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0)
)
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
diff --git a/backend/apps/rag/search/serply.py b/backend/apps/rag/search/serply.py
index fccf70ecd..24b249b73 100644
--- a/backend/apps/rag/search/serply.py
+++ b/backend/apps/rag/search/serply.py
@@ -1,10 +1,10 @@
import json
import logging
-
+from typing import List, Optional
import requests
from urllib.parse import urlencode
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
@@ -19,6 +19,7 @@ def search_serply(
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
+ filter_list: Optional[List[str]] = None,
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
@@ -57,7 +58,8 @@ def search_serply(
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
-
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
diff --git a/backend/apps/rag/search/serpstack.py b/backend/apps/rag/search/serpstack.py
index 0d247d1ab..64b0f117d 100644
--- a/backend/apps/rag/search/serpstack.py
+++ b/backend/apps/rag/search/serpstack.py
@@ -1,9 +1,9 @@
import json
import logging
-
+from typing import List, Optional
import requests
-from apps.rag.search.main import SearchResult
+from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack(
- api_key: str, query: str, count: int, https_enabled: bool = True
+ api_key: str,
+ query: str,
+ count: int,
+ filter_list: Optional[List[str]] = None,
+ https_enabled: bool = True,
) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
@@ -35,6 +39,8 @@ def search_serpstack(
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
+ if filter_list:
+ results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py
index d0570f748..7b4324d9a 100644
--- a/backend/apps/rag/utils.py
+++ b/backend/apps/rag/utils.py
@@ -237,7 +237,7 @@ def get_embedding_function(
def get_rag_context(
- docs,
+ files,
messages,
embedding_function,
k,
@@ -245,29 +245,29 @@ def get_rag_context(
r,
hybrid_search,
):
- log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
+ log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
extracted_collections = []
relevant_contexts = []
- for doc in docs:
+ for file in files:
context = None
collection_names = (
- doc["collection_names"]
- if doc["type"] == "collection"
- else [doc["collection_name"]]
+ file["collection_names"]
+ if file["type"] == "collection"
+ else [file["collection_name"]]
)
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
- log.debug(f"skipping {doc} as it has already been extracted")
+ log.debug(f"skipping {file} as it has already been extracted")
continue
try:
- if doc["type"] == "text":
- context = doc["content"]
+ if file["type"] == "text":
+ context = file["content"]
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
@@ -290,7 +290,7 @@ def get_rag_context(
context = None
if context:
- relevant_contexts.append({**context, "source": doc})
+ relevant_contexts.append({**context, "source": file})
extracted_collections.extend(collection_names)
diff --git a/backend/apps/webui/internal/db.py b/backend/apps/webui/internal/db.py
index 0e7b1f95d..80c30d652 100644
--- a/backend/apps/webui/internal/db.py
+++ b/backend/apps/webui/internal/db.py
@@ -1,11 +1,12 @@
+import os
+import logging
import json
from peewee import *
from peewee_migrate import Router
-from playhouse.db_url import connect
+
+from apps.webui.internal.wrappers import register_connection
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
-import os
-import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else:
pass
-DB = connect(DATABASE_URL)
-log.info(f"Connected to a {DB.__class__.__name__} database.")
+
+# The `register_connection` function encapsulates the logic for setting up
+# the database connection based on the connection string, while `connect`
+# is a Peewee-specific method to manage the connection state and avoid errors
+# when a connection is already open.
+try:
+ DB = register_connection(DATABASE_URL)
+ log.info(f"Connected to a {DB.__class__.__name__} database.")
+except Exception as e:
+ log.error(f"Failed to initialize the database connection: {e}")
+ raise
+
router = Router(
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
)
router.run()
-DB.connect(reuse_if_open=True)
+try:
+ DB.connect(reuse_if_open=True)
+except OperationalError as e:
+ log.info(f"Failed to connect to database again due to: {e}")
+ pass
diff --git a/backend/apps/webui/internal/migrations/014_add_files.py b/backend/apps/webui/internal/migrations/014_add_files.py
new file mode 100644
index 000000000..5e1acf0ad
--- /dev/null
+++ b/backend/apps/webui/internal/migrations/014_add_files.py
@@ -0,0 +1,55 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+ > Model = migrator.orm['table_name'] # Return model in current state by name
+ > Model = migrator.ModelClass # Return model in current state by name
+
+ > migrator.sql(sql) # Run custom SQL
+ > migrator.run(func, *args, **kwargs) # Run python function with the given args
+ > migrator.create_model(Model) # Create a model (could be used as decorator)
+ > migrator.remove_model(model, cascade=True) # Remove a model
+ > migrator.add_fields(model, **fields) # Add fields to a model
+ > migrator.change_fields(model, **fields) # Change fields
+ > migrator.remove_fields(model, *field_names, cascade=True)
+ > migrator.rename_field(model, old_field_name, new_field_name)
+ > migrator.rename_table(model, new_table_name)
+ > migrator.add_index(model, *col_names, unique=False)
+ > migrator.add_not_null(model, *field_names)
+ > migrator.add_default(model, field_name, default)
+ > migrator.add_constraint(model, name, sql)
+ > migrator.drop_index(model, *col_names)
+ > migrator.drop_not_null(model, *field_names)
+ > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+ import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your migrations here."""
+
+ @migrator.create_model
+ class File(pw.Model):
+ id = pw.TextField(unique=True)
+ user_id = pw.TextField()
+ filename = pw.TextField()
+ meta = pw.TextField()
+ created_at = pw.BigIntegerField(null=False)
+
+ class Meta:
+ table_name = "file"
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your rollback migrations here."""
+
+ migrator.remove_model("file")
diff --git a/backend/apps/webui/internal/migrations/015_add_functions.py b/backend/apps/webui/internal/migrations/015_add_functions.py
new file mode 100644
index 000000000..8316a9333
--- /dev/null
+++ b/backend/apps/webui/internal/migrations/015_add_functions.py
@@ -0,0 +1,61 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+ > Model = migrator.orm['table_name'] # Return model in current state by name
+ > Model = migrator.ModelClass # Return model in current state by name
+
+ > migrator.sql(sql) # Run custom SQL
+ > migrator.run(func, *args, **kwargs) # Run python function with the given args
+ > migrator.create_model(Model) # Create a model (could be used as decorator)
+ > migrator.remove_model(model, cascade=True) # Remove a model
+ > migrator.add_fields(model, **fields) # Add fields to a model
+ > migrator.change_fields(model, **fields) # Change fields
+ > migrator.remove_fields(model, *field_names, cascade=True)
+ > migrator.rename_field(model, old_field_name, new_field_name)
+ > migrator.rename_table(model, new_table_name)
+ > migrator.add_index(model, *col_names, unique=False)
+ > migrator.add_not_null(model, *field_names)
+ > migrator.add_default(model, field_name, default)
+ > migrator.add_constraint(model, name, sql)
+ > migrator.drop_index(model, *col_names)
+ > migrator.drop_not_null(model, *field_names)
+ > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+ import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your migrations here."""
+
+ @migrator.create_model
+ class Function(pw.Model):
+ id = pw.TextField(unique=True)
+ user_id = pw.TextField()
+
+ name = pw.TextField()
+ type = pw.TextField()
+
+ content = pw.TextField()
+ meta = pw.TextField()
+
+ created_at = pw.BigIntegerField(null=False)
+ updated_at = pw.BigIntegerField(null=False)
+
+ class Meta:
+ table_name = "function"
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your rollback migrations here."""
+
+ migrator.remove_model("function")
diff --git a/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py
new file mode 100644
index 000000000..e3af521b7
--- /dev/null
+++ b/backend/apps/webui/internal/migrations/016_add_valves_and_is_active.py
@@ -0,0 +1,50 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+ > Model = migrator.orm['table_name'] # Return model in current state by name
+ > Model = migrator.ModelClass # Return model in current state by name
+
+ > migrator.sql(sql) # Run custom SQL
+ > migrator.run(func, *args, **kwargs) # Run python function with the given args
+ > migrator.create_model(Model) # Create a model (could be used as decorator)
+ > migrator.remove_model(model, cascade=True) # Remove a model
+ > migrator.add_fields(model, **fields) # Add fields to a model
+ > migrator.change_fields(model, **fields) # Change fields
+ > migrator.remove_fields(model, *field_names, cascade=True)
+ > migrator.rename_field(model, old_field_name, new_field_name)
+ > migrator.rename_table(model, new_table_name)
+ > migrator.add_index(model, *col_names, unique=False)
+ > migrator.add_not_null(model, *field_names)
+ > migrator.add_default(model, field_name, default)
+ > migrator.add_constraint(model, name, sql)
+ > migrator.drop_index(model, *col_names)
+ > migrator.drop_not_null(model, *field_names)
+ > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+ import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your migrations here."""
+
+ migrator.add_fields("tool", valves=pw.TextField(null=True))
+ migrator.add_fields("function", valves=pw.TextField(null=True))
+ migrator.add_fields("function", is_active=pw.BooleanField(default=False))
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your rollback migrations here."""
+
+ migrator.remove_fields("tool", "valves")
+ migrator.remove_fields("function", "valves")
+ migrator.remove_fields("function", "is_active")
diff --git a/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py b/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py
new file mode 100644
index 000000000..fd1d9b560
--- /dev/null
+++ b/backend/apps/webui/internal/migrations/017_add_user_oauth_sub.py
@@ -0,0 +1,49 @@
+"""Peewee migrations -- 017_add_user_oauth_sub.py.
+
+Some examples (model - class or model name)::
+
+ > Model = migrator.orm['table_name'] # Return model in current state by name
+ > Model = migrator.ModelClass # Return model in current state by name
+
+ > migrator.sql(sql) # Run custom SQL
+ > migrator.run(func, *args, **kwargs) # Run python function with the given args
+ > migrator.create_model(Model) # Create a model (could be used as decorator)
+ > migrator.remove_model(model, cascade=True) # Remove a model
+ > migrator.add_fields(model, **fields) # Add fields to a model
+ > migrator.change_fields(model, **fields) # Change fields
+ > migrator.remove_fields(model, *field_names, cascade=True)
+ > migrator.rename_field(model, old_field_name, new_field_name)
+ > migrator.rename_table(model, new_table_name)
+ > migrator.add_index(model, *col_names, unique=False)
+ > migrator.add_not_null(model, *field_names)
+ > migrator.add_default(model, field_name, default)
+ > migrator.add_constraint(model, name, sql)
+ > migrator.drop_index(model, *col_names)
+ > migrator.drop_not_null(model, *field_names)
+ > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+ import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your migrations here."""
+
+ migrator.add_fields(
+ "user",
+ oauth_sub=pw.TextField(null=True, unique=True),
+ )
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your rollback migrations here."""
+
+ migrator.remove_fields("user", "oauth_sub")
diff --git a/backend/apps/webui/internal/migrations/018_add_function_is_global.py b/backend/apps/webui/internal/migrations/018_add_function_is_global.py
new file mode 100644
index 000000000..04cdab705
--- /dev/null
+++ b/backend/apps/webui/internal/migrations/018_add_function_is_global.py
@@ -0,0 +1,49 @@
+"""Peewee migrations -- 017_add_user_oauth_sub.py.
+
+Some examples (model - class or model name)::
+
+ > Model = migrator.orm['table_name'] # Return model in current state by name
+ > Model = migrator.ModelClass # Return model in current state by name
+
+ > migrator.sql(sql) # Run custom SQL
+ > migrator.run(func, *args, **kwargs) # Run python function with the given args
+ > migrator.create_model(Model) # Create a model (could be used as decorator)
+ > migrator.remove_model(model, cascade=True) # Remove a model
+ > migrator.add_fields(model, **fields) # Add fields to a model
+ > migrator.change_fields(model, **fields) # Change fields
+ > migrator.remove_fields(model, *field_names, cascade=True)
+ > migrator.rename_field(model, old_field_name, new_field_name)
+ > migrator.rename_table(model, new_table_name)
+ > migrator.add_index(model, *col_names, unique=False)
+ > migrator.add_not_null(model, *field_names)
+ > migrator.add_default(model, field_name, default)
+ > migrator.add_constraint(model, name, sql)
+ > migrator.drop_index(model, *col_names)
+ > migrator.drop_not_null(model, *field_names)
+ > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+ import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your migrations here."""
+
+ migrator.add_fields(
+ "function",
+ is_global=pw.BooleanField(default=False),
+ )
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+ """Write your rollback migrations here."""
+
+ migrator.remove_fields("function", "is_global")
diff --git a/backend/apps/webui/internal/wrappers.py b/backend/apps/webui/internal/wrappers.py
new file mode 100644
index 000000000..2b5551ce2
--- /dev/null
+++ b/backend/apps/webui/internal/wrappers.py
@@ -0,0 +1,72 @@
+from contextvars import ContextVar
+from peewee import *
+from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
+
+import logging
+from playhouse.db_url import connect, parse
+from playhouse.shortcuts import ReconnectMixin
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["DB"])
+
+db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
+db_state = ContextVar("db_state", default=db_state_default.copy())
+
+
+class PeeweeConnectionState(object):
+ def __init__(self, **kwargs):
+ super().__setattr__("_state", db_state)
+ super().__init__(**kwargs)
+
+ def __setattr__(self, name, value):
+ self._state.get()[name] = value
+
+ def __getattr__(self, name):
+ value = self._state.get()[name]
+ return value
+
+
+class CustomReconnectMixin(ReconnectMixin):
+ reconnect_errors = (
+ # psycopg2
+ (OperationalError, "termin"),
+ (InterfaceError, "closed"),
+ # peewee
+ (PeeWeeInterfaceError, "closed"),
+ )
+
+
+class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
+ pass
+
+
+def register_connection(db_url):
+ db = connect(db_url)
+ if isinstance(db, PostgresqlDatabase):
+ # Enable autoconnect for SQLite databases, managed by Peewee
+ db.autoconnect = True
+ db.reuse_if_open = True
+ log.info("Connected to PostgreSQL database")
+
+ # Get the connection details
+ connection = parse(db_url)
+
+ # Use our custom database class that supports reconnection
+ db = ReconnectingPostgresqlDatabase(
+ connection["database"],
+ user=connection["user"],
+ password=connection["password"],
+ host=connection["host"],
+ port=connection["port"],
+ )
+ db.connect(reuse_if_open=True)
+ elif isinstance(db, SqliteDatabase):
+ # Enable autoconnect for SQLite databases, managed by Peewee
+ db.autoconnect = True
+ db.reuse_if_open = True
+ log.info("Connected to SQLite database")
+ else:
+ raise ValueError("Unsupported database connection")
+ return db
diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py
index 190d2d1c3..28b1b4aac 100644
--- a/backend/apps/webui/main.py
+++ b/backend/apps/webui/main.py
@@ -1,6 +1,9 @@
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
+from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
+from starlette.middleware.sessions import SessionMiddleware
+
from apps.webui.routers import (
auths,
users,
@@ -12,7 +15,13 @@ from apps.webui.routers import (
configs,
memories,
utils,
+ files,
+ functions,
)
+from apps.webui.models.functions import Functions
+from apps.webui.utils import load_function_module_by_id
+from utils.misc import stream_message_template
+
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
@@ -32,6 +41,14 @@ from config import (
AppConfig,
)
+import inspect
+import uuid
+import time
+import json
+
+from typing import Iterator, Generator
+from pydantic import BaseModel
+
app = FastAPI()
origins = ["*"]
@@ -59,7 +76,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.TOOLS = {}
-
+app.state.FUNCTIONS = {}
app.add_middleware(
CORSMiddleware,
@@ -69,17 +86,21 @@ app.add_middleware(
allow_headers=["*"],
)
+
+app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
-app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
-app.include_router(memories.router, prefix="/memories", tags=["memories"])
-app.include_router(configs.router, prefix="/configs", tags=["configs"])
+app.include_router(memories.router, prefix="/memories", tags=["memories"])
+app.include_router(files.router, prefix="/files", tags=["files"])
+app.include_router(tools.router, prefix="/tools", tags=["tools"])
+app.include_router(functions.router, prefix="/functions", tags=["functions"])
+
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@@ -91,3 +112,226 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
+
+
+async def get_pipe_models():
+ pipes = Functions.get_functions_by_type("pipe", active_only=True)
+ pipe_models = []
+
+ for pipe in pipes:
+ # Check if function is already loaded
+ if pipe.id not in app.state.FUNCTIONS:
+ function_module, function_type, frontmatter = load_function_module_by_id(
+ pipe.id
+ )
+ app.state.FUNCTIONS[pipe.id] = function_module
+ else:
+ function_module = app.state.FUNCTIONS[pipe.id]
+
+ if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+ print(f"Getting valves for {pipe.id}")
+ valves = Functions.get_function_valves_by_id(pipe.id)
+ function_module.valves = function_module.Valves(
+ **(valves if valves else {})
+ )
+
+ # Check if function is a manifold
+ if hasattr(function_module, "type"):
+ if function_module.type == "manifold":
+ manifold_pipes = []
+
+ # Check if pipes is a function or a list
+ if callable(function_module.pipes):
+ manifold_pipes = function_module.pipes()
+ else:
+ manifold_pipes = function_module.pipes
+
+ for p in manifold_pipes:
+ manifold_pipe_id = f'{pipe.id}.{p["id"]}'
+ manifold_pipe_name = p["name"]
+
+ if hasattr(function_module, "name"):
+ manifold_pipe_name = (
+ f"{function_module.name}{manifold_pipe_name}"
+ )
+
+ pipe_models.append(
+ {
+ "id": manifold_pipe_id,
+ "name": manifold_pipe_name,
+ "object": "model",
+ "created": pipe.created_at,
+ "owned_by": "openai",
+ "pipe": {"type": pipe.type},
+ }
+ )
+ else:
+ pipe_models.append(
+ {
+ "id": pipe.id,
+ "name": pipe.name,
+ "object": "model",
+ "created": pipe.created_at,
+ "owned_by": "openai",
+ "pipe": {"type": "pipe"},
+ }
+ )
+
+ return pipe_models
+
+
+async def generate_function_chat_completion(form_data, user):
+ async def job():
+ pipe_id = form_data["model"]
+ if "." in pipe_id:
+ pipe_id, sub_pipe_id = pipe_id.split(".", 1)
+ print(pipe_id)
+
+ # Check if function is already loaded
+ if pipe_id not in app.state.FUNCTIONS:
+ function_module, function_type, frontmatter = load_function_module_by_id(
+ pipe_id
+ )
+ app.state.FUNCTIONS[pipe_id] = function_module
+ else:
+ function_module = app.state.FUNCTIONS[pipe_id]
+
+ if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
+
+ valves = Functions.get_function_valves_by_id(pipe_id)
+ function_module.valves = function_module.Valves(
+ **(valves if valves else {})
+ )
+
+ pipe = function_module.pipe
+
+ # Get the signature of the function
+ sig = inspect.signature(pipe)
+ params = {"body": form_data}
+
+ if "__user__" in sig.parameters:
+ __user__ = {
+ "id": user.id,
+ "email": user.email,
+ "name": user.name,
+ "role": user.role,
+ }
+
+ try:
+ if hasattr(function_module, "UserValves"):
+ __user__["valves"] = function_module.UserValves(
+ **Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
+ )
+ except Exception as e:
+ print(e)
+
+ params = {**params, "__user__": __user__}
+
+ if form_data["stream"]:
+
+ async def stream_content():
+ try:
+ if inspect.iscoroutinefunction(pipe):
+ res = await pipe(**params)
+ else:
+ res = pipe(**params)
+
+ # Directly return if the response is a StreamingResponse
+ if isinstance(res, StreamingResponse):
+ async for data in res.body_iterator:
+ yield data
+ return
+ if isinstance(res, dict):
+ yield f"data: {json.dumps(res)}\n\n"
+ return
+
+ except Exception as e:
+ print(f"Error: {e}")
+ yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
+ return
+
+ if isinstance(res, str):
+ message = stream_message_template(form_data["model"], res)
+ yield f"data: {json.dumps(message)}\n\n"
+
+ if isinstance(res, Iterator):
+ for line in res:
+ if isinstance(line, BaseModel):
+ line = line.model_dump_json()
+ line = f"data: {line}"
+ try:
+ line = line.decode("utf-8")
+ except:
+ pass
+
+ if line.startswith("data:"):
+ yield f"{line}\n\n"
+ else:
+ line = stream_message_template(form_data["model"], line)
+ yield f"data: {json.dumps(line)}\n\n"
+
+ if isinstance(res, str) or isinstance(res, Generator):
+ finish_message = {
+ "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": form_data["model"],
+ "choices": [
+ {
+ "index": 0,
+ "delta": {},
+ "logprobs": None,
+ "finish_reason": "stop",
+ }
+ ],
+ }
+
+ yield f"data: {json.dumps(finish_message)}\n\n"
+ yield f"data: [DONE]"
+
+ return StreamingResponse(stream_content(), media_type="text/event-stream")
+ else:
+
+ try:
+ if inspect.iscoroutinefunction(pipe):
+ res = await pipe(**params)
+ else:
+ res = pipe(**params)
+
+ if isinstance(res, StreamingResponse):
+ return res
+ except Exception as e:
+ print(f"Error: {e}")
+ return {"error": {"detail": str(e)}}
+
+ if isinstance(res, dict):
+ return res
+ elif isinstance(res, BaseModel):
+ return res.model_dump()
+ else:
+ message = ""
+ if isinstance(res, str):
+ message = res
+ if isinstance(res, Generator):
+ for stream in res:
+ message = f"{message}{stream}"
+
+ return {
+ "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": form_data["model"],
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": message,
+ },
+ "logprobs": None,
+ "finish_reason": "stop",
+ }
+ ],
+ }
+
+ return await job()
diff --git a/backend/apps/webui/models/auths.py b/backend/apps/webui/models/auths.py
index e3b659e43..9ea38abcb 100644
--- a/backend/apps/webui/models/auths.py
+++ b/backend/apps/webui/models/auths.py
@@ -105,6 +105,7 @@ class AuthsTable:
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
+ oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
log.info("insert_new_auth")
@@ -115,7 +116,9 @@ class AuthsTable:
)
result = Auth.create(**auth.model_dump())
- user = Users.insert_new_user(id, name, email, profile_image_url, role)
+ user = Users.insert_new_user(
+ id, name, email, profile_image_url, role, oauth_sub
+ )
if result and user:
return user
diff --git a/backend/apps/webui/models/files.py b/backend/apps/webui/models/files.py
new file mode 100644
index 000000000..6459ad725
--- /dev/null
+++ b/backend/apps/webui/models/files.py
@@ -0,0 +1,112 @@
+from pydantic import BaseModel
+from peewee import *
+from playhouse.shortcuts import model_to_dict
+from typing import List, Union, Optional
+import time
+import logging
+from apps.webui.internal.db import DB, JSONField
+
+import json
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# Files DB Schema
+####################
+
+
+class File(Model):
+ id = CharField(unique=True)
+ user_id = CharField()
+ filename = TextField()
+ meta = JSONField()
+ created_at = BigIntegerField()
+
+ class Meta:
+ database = DB
+
+
+class FileModel(BaseModel):
+ id: str
+ user_id: str
+ filename: str
+ meta: dict
+ created_at: int # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class FileModelResponse(BaseModel):
+ id: str
+ user_id: str
+ filename: str
+ meta: dict
+ created_at: int # timestamp in epoch
+
+
+class FileForm(BaseModel):
+ id: str
+ filename: str
+ meta: dict = {}
+
+
+class FilesTable:
+ def __init__(self, db):
+ self.db = db
+ self.db.create_tables([File])
+
+ def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
+ file = FileModel(
+ **{
+ **form_data.model_dump(),
+ "user_id": user_id,
+ "created_at": int(time.time()),
+ }
+ )
+
+ try:
+ result = File.create(**file.model_dump())
+ if result:
+ return file
+ else:
+ return None
+ except Exception as e:
+ print(f"Error creating tool: {e}")
+ return None
+
+ def get_file_by_id(self, id: str) -> Optional[FileModel]:
+ try:
+ file = File.get(File.id == id)
+ return FileModel(**model_to_dict(file))
+ except:
+ return None
+
+ def get_files(self) -> List[FileModel]:
+ return [FileModel(**model_to_dict(file)) for file in File.select()]
+
+ def delete_file_by_id(self, id: str) -> bool:
+ try:
+ query = File.delete().where((File.id == id))
+ query.execute() # Remove the rows, return number of rows removed.
+
+ return True
+ except:
+ return False
+
+ def delete_all_files(self) -> bool:
+ try:
+ query = File.delete()
+ query.execute() # Remove the rows, return number of rows removed.
+
+ return True
+ except:
+ return False
+
+
+Files = FilesTable(DB)
diff --git a/backend/apps/webui/models/functions.py b/backend/apps/webui/models/functions.py
new file mode 100644
index 000000000..2cace54c4
--- /dev/null
+++ b/backend/apps/webui/models/functions.py
@@ -0,0 +1,261 @@
+from pydantic import BaseModel
+from peewee import *
+from playhouse.shortcuts import model_to_dict
+from typing import List, Union, Optional
+import time
+import logging
+from apps.webui.internal.db import DB, JSONField
+from apps.webui.models.users import Users
+
+import json
+import copy
+
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# Functions DB Schema
+####################
+
+
+class Function(Model):
+ id = CharField(unique=True)
+ user_id = CharField()
+ name = TextField()
+ type = TextField()
+ content = TextField()
+ meta = JSONField()
+ valves = JSONField()
+ is_active = BooleanField(default=False)
+ is_global = BooleanField(default=False)
+ updated_at = BigIntegerField()
+ created_at = BigIntegerField()
+
+ class Meta:
+ database = DB
+
+
+class FunctionMeta(BaseModel):
+ description: Optional[str] = None
+ manifest: Optional[dict] = {}
+
+
+class FunctionModel(BaseModel):
+ id: str
+ user_id: str
+ name: str
+ type: str
+ content: str
+ meta: FunctionMeta
+ is_active: bool = False
+ is_global: bool = False
+ updated_at: int # timestamp in epoch
+ created_at: int # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class FunctionResponse(BaseModel):
+ id: str
+ user_id: str
+ type: str
+ name: str
+ meta: FunctionMeta
+ is_active: bool
+ is_global: bool
+ updated_at: int # timestamp in epoch
+ created_at: int # timestamp in epoch
+
+
+class FunctionForm(BaseModel):
+ id: str
+ name: str
+ content: str
+ meta: FunctionMeta
+
+
+class FunctionValves(BaseModel):
+ valves: Optional[dict] = None
+
+
+class FunctionsTable:
+ def __init__(self, db):
+ self.db = db
+ self.db.create_tables([Function])
+
+ def insert_new_function(
+ self, user_id: str, type: str, form_data: FunctionForm
+ ) -> Optional[FunctionModel]:
+ function = FunctionModel(
+ **{
+ **form_data.model_dump(),
+ "user_id": user_id,
+ "type": type,
+ "updated_at": int(time.time()),
+ "created_at": int(time.time()),
+ }
+ )
+
+ try:
+ result = Function.create(**function.model_dump())
+ if result:
+ return function
+ else:
+ return None
+ except Exception as e:
+ print(f"Error creating tool: {e}")
+ return None
+
+ def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
+ try:
+ function = Function.get(Function.id == id)
+ return FunctionModel(**model_to_dict(function))
+ except:
+ return None
+
+ def get_functions(self, active_only=False) -> List[FunctionModel]:
+ if active_only:
+ return [
+ FunctionModel(**model_to_dict(function))
+ for function in Function.select().where(Function.is_active == True)
+ ]
+ else:
+ return [
+ FunctionModel(**model_to_dict(function))
+ for function in Function.select()
+ ]
+
+ def get_functions_by_type(
+ self, type: str, active_only=False
+ ) -> List[FunctionModel]:
+ if active_only:
+ return [
+ FunctionModel(**model_to_dict(function))
+ for function in Function.select().where(
+ Function.type == type, Function.is_active == True
+ )
+ ]
+ else:
+ return [
+ FunctionModel(**model_to_dict(function))
+ for function in Function.select().where(Function.type == type)
+ ]
+
+ def get_global_filter_functions(self) -> List[FunctionModel]:
+ return [
+ FunctionModel(**model_to_dict(function))
+ for function in Function.select().where(
+ Function.type == "filter",
+ Function.is_active == True,
+ Function.is_global == True,
+ )
+ ]
+
+ def get_function_valves_by_id(self, id: str) -> Optional[dict]:
+ try:
+ function = Function.get(Function.id == id)
+ return function.valves if function.valves else {}
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return None
+
+ def update_function_valves_by_id(
+ self, id: str, valves: dict
+ ) -> Optional[FunctionValves]:
+ try:
+ query = Function.update(
+ **{"valves": valves},
+ updated_at=int(time.time()),
+ ).where(Function.id == id)
+ query.execute()
+
+ function = Function.get(Function.id == id)
+ return FunctionValves(**model_to_dict(function))
+ except:
+ return None
+
+ def get_user_valves_by_id_and_user_id(
+ self, id: str, user_id: str
+ ) -> Optional[dict]:
+ try:
+ user = Users.get_user_by_id(user_id)
+ user_settings = user.settings.model_dump()
+
+ # Check if user has "functions" and "valves" settings
+ if "functions" not in user_settings:
+ user_settings["functions"] = {}
+ if "valves" not in user_settings["functions"]:
+ user_settings["functions"]["valves"] = {}
+
+ return user_settings["functions"]["valves"].get(id, {})
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return None
+
+ def update_user_valves_by_id_and_user_id(
+ self, id: str, user_id: str, valves: dict
+ ) -> Optional[dict]:
+ try:
+ user = Users.get_user_by_id(user_id)
+ user_settings = user.settings.model_dump()
+
+ # Check if user has "functions" and "valves" settings
+ if "functions" not in user_settings:
+ user_settings["functions"] = {}
+ if "valves" not in user_settings["functions"]:
+ user_settings["functions"]["valves"] = {}
+
+ user_settings["functions"]["valves"][id] = valves
+
+ # Update the user settings in the database
+ query = Users.update_user_by_id(user_id, {"settings": user_settings})
+ query.execute()
+
+ return user_settings["functions"]["valves"][id]
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return None
+
+ def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
+ try:
+ query = Function.update(
+ **updated,
+ updated_at=int(time.time()),
+ ).where(Function.id == id)
+ query.execute()
+
+ function = Function.get(Function.id == id)
+ return FunctionModel(**model_to_dict(function))
+ except:
+ return None
+
+ def deactivate_all_functions(self) -> Optional[bool]:
+ try:
+ query = Function.update(
+ **{"is_active": False},
+ updated_at=int(time.time()),
+ )
+
+ query.execute()
+
+ return True
+ except:
+ return None
+
+ def delete_function_by_id(self, id: str) -> bool:
+ try:
+ query = Function.delete().where((Function.id == id))
+ query.execute() # Remove the rows, return number of rows removed.
+
+ return True
+ except:
+ return False
+
+
+Functions = FunctionsTable(DB)
diff --git a/backend/apps/webui/models/tools.py b/backend/apps/webui/models/tools.py
index e2db1e35f..694081df9 100644
--- a/backend/apps/webui/models/tools.py
+++ b/backend/apps/webui/models/tools.py
@@ -5,8 +5,11 @@ from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
+from apps.webui.models.users import Users
import json
+import copy
+
from config import SRC_LOG_LEVELS
@@ -25,6 +28,7 @@ class Tool(Model):
content = TextField()
specs = JSONField()
meta = JSONField()
+ valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
@@ -34,6 +38,7 @@ class Tool(Model):
class ToolMeta(BaseModel):
description: Optional[str] = None
+ manifest: Optional[dict] = {}
class ToolModel(BaseModel):
@@ -68,6 +73,10 @@ class ToolForm(BaseModel):
meta: ToolMeta
+class ToolValves(BaseModel):
+ valves: Optional[dict] = None
+
+
class ToolsTable:
def __init__(self, db):
self.db = db
@@ -106,6 +115,69 @@ class ToolsTable:
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
+ def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
+ try:
+ tool = Tool.get(Tool.id == id)
+ return tool.valves if tool.valves else {}
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return None
+
+ def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
+ try:
+ query = Tool.update(
+ **{"valves": valves},
+ updated_at=int(time.time()),
+ ).where(Tool.id == id)
+ query.execute()
+
+ tool = Tool.get(Tool.id == id)
+ return ToolValves(**model_to_dict(tool))
+ except:
+ return None
+
+ def get_user_valves_by_id_and_user_id(
+ self, id: str, user_id: str
+ ) -> Optional[dict]:
+ try:
+ user = Users.get_user_by_id(user_id)
+ user_settings = user.settings.model_dump()
+
+ # Check if user has "tools" and "valves" settings
+ if "tools" not in user_settings:
+ user_settings["tools"] = {}
+ if "valves" not in user_settings["tools"]:
+ user_settings["tools"]["valves"] = {}
+
+ return user_settings["tools"]["valves"].get(id, {})
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return None
+
+ def update_user_valves_by_id_and_user_id(
+ self, id: str, user_id: str, valves: dict
+ ) -> Optional[dict]:
+ try:
+ user = Users.get_user_by_id(user_id)
+ user_settings = user.settings.model_dump()
+
+ # Check if user has "tools" and "valves" settings
+ if "tools" not in user_settings:
+ user_settings["tools"] = {}
+ if "valves" not in user_settings["tools"]:
+ user_settings["tools"]["valves"] = {}
+
+ user_settings["tools"]["valves"][id] = valves
+
+ # Update the user settings in the database
+ query = Users.update_user_by_id(user_id, {"settings": user_settings})
+ query.execute()
+
+ return user_settings["tools"]["valves"][id]
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return None
+
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
diff --git a/backend/apps/webui/models/users.py b/backend/apps/webui/models/users.py
index 485a9eea4..e3e1842b8 100644
--- a/backend/apps/webui/models/users.py
+++ b/backend/apps/webui/models/users.py
@@ -28,6 +28,8 @@ class User(Model):
settings = JSONField(null=True)
info = JSONField(null=True)
+ oauth_sub = TextField(null=True, unique=True)
+
class Meta:
database = DB
@@ -53,6 +55,8 @@ class UserModel(BaseModel):
settings: Optional[UserSettings] = None
info: Optional[dict] = None
+ oauth_sub: Optional[str] = None
+
####################
# Forms
@@ -83,6 +87,7 @@ class UsersTable:
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
+ oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
user = UserModel(
**{
@@ -94,6 +99,7 @@ class UsersTable:
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
+ "oauth_sub": oauth_sub,
}
)
result = User.create(**user.model_dump())
@@ -123,6 +129,13 @@ class UsersTable:
except:
return None
+ def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
+ try:
+ user = User.get(User.oauth_sub == sub)
+ return UserModel(**model_to_dict(user))
+ except:
+ return None
+
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))
@@ -174,6 +187,18 @@ class UsersTable:
except:
return None
+ def update_user_oauth_sub_by_id(
+ self, id: str, oauth_sub: str
+ ) -> Optional[UserModel]:
+ try:
+ query = User.update(oauth_sub=oauth_sub).where(User.id == id)
+ query.execute()
+
+ user = User.get(User.id == id)
+ return UserModel(**model_to_dict(user))
+ except:
+ return None
+
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
diff --git a/backend/apps/webui/routers/auths.py b/backend/apps/webui/routers/auths.py
index 16e395737..1be79d259 100644
--- a/backend/apps/webui/routers/auths.py
+++ b/backend/apps/webui/routers/auths.py
@@ -2,6 +2,7 @@ import logging
from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status
+from fastapi.responses import Response
from fastapi import APIRouter
from pydantic import BaseModel
@@ -9,7 +10,6 @@ import re
import uuid
import csv
-
from apps.webui.models.auths import (
SigninForm,
SignupForm,
@@ -47,7 +47,21 @@ router = APIRouter()
@router.get("/", response_model=UserResponse)
-async def get_session_user(user=Depends(get_current_user)):
+async def get_session_user(
+ request: Request, response: Response, user=Depends(get_current_user)
+):
+ token = create_token(
+ data={"id": user.id},
+ expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
+ )
+
+ # Set the cookie token
+ response.set_cookie(
+ key="token",
+ value=token,
+ httponly=True, # Ensures the cookie is not accessible via JavaScript
+ )
+
return {
"id": user.id,
"email": user.email,
@@ -108,7 +122,7 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse)
-async def signin(request: Request, form_data: SigninForm):
+async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
@@ -152,6 +166,13 @@ async def signin(request: Request, form_data: SigninForm):
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
+ # Set the cookie token
+ response.set_cookie(
+ key="token",
+ value=token,
+ httponly=True, # Ensures the cookie is not accessible via JavaScript
+ )
+
return {
"token": token,
"token_type": "Bearer",
@@ -171,7 +192,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse)
-async def signup(request: Request, form_data: SignupForm):
+async def signup(request: Request, response: Response, form_data: SignupForm):
if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
@@ -207,6 +228,13 @@ async def signup(request: Request, form_data: SignupForm):
)
# response.set_cookie(key='token', value=token, httponly=True)
+ # Set the cookie token
+ response.set_cookie(
+ key="token",
+ value=token,
+ httponly=True, # Ensures the cookie is not accessible via JavaScript
+ )
+
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.config.WEBHOOK_URL,
diff --git a/backend/apps/webui/routers/chats.py b/backend/apps/webui/routers/chats.py
index 9d1cceaa1..c4d6575c2 100644
--- a/backend/apps/webui/routers/chats.py
+++ b/backend/apps/webui/routers/chats.py
@@ -1,7 +1,7 @@
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
@@ -43,7 +43,7 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse])
@router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list(
- user=Depends(get_current_user), skip: int = 0, limit: int = 50
+ user=Depends(get_verified_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_list_by_user_id(user.id, skip, limit)
@@ -54,7 +54,7 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool)
-async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
+async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
if (
user.role == "user"
@@ -89,7 +89,7 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -106,7 +106,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
@router.get("/all", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats_by_user_id(user.id)
@@ -119,7 +119,7 @@ async def get_user_chats(user=Depends(get_current_user)):
@router.get("/all/archived", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
@@ -151,7 +151,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
- user=Depends(get_current_user), skip: int = 0, limit: int = 50
+ user=Depends(get_verified_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
@@ -162,7 +162,7 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool)
-async def archive_all_chats(user=Depends(get_current_user)):
+async def archive_all_chats(user=Depends(get_verified_user)):
return Chats.archive_all_chats_by_user_id(user.id)
@@ -172,7 +172,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
-async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
+async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -204,7 +204,7 @@ class TagNameForm(BaseModel):
@router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
- form_data: TagNameForm, user=Depends(get_current_user)
+ form_data: TagNameForm, user=Depends(get_verified_user)
):
print(form_data)
@@ -229,7 +229,7 @@ async def get_user_chat_list_by_tag_name(
@router.get("/tags/all", response_model=List[TagModel])
-async def get_all_tags(user=Depends(get_current_user)):
+async def get_all_tags(user=Depends(get_verified_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
@@ -246,7 +246,7 @@ async def get_all_tags(user=Depends(get_current_user)):
@router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
@@ -264,7 +264,7 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(
- id: str, form_data: ChatForm, user=Depends(get_current_user)
+ id: str, form_data: ChatForm, user=Depends(get_verified_user)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
@@ -285,7 +285,7 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
+async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
if user.role == "admin":
result = Chats.delete_chat_by_id(id)
@@ -307,7 +307,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
-async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
+async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
@@ -333,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
-async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
+async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
@@ -350,7 +350,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
-async def share_chat_by_id(id: str, user=Depends(get_current_user)):
+async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if chat.share_id:
@@ -382,7 +382,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
@router.delete("/{id}/share", response_model=Optional[bool])
-async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
+async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if not chat.share_id:
@@ -405,7 +405,7 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
@router.get("/{id}/tags", response_model=List[TagModel])
-async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
+async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None:
@@ -423,7 +423,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id(
- id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+ id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
@@ -450,7 +450,7 @@ async def add_chat_tag_by_id(
@router.delete("/{id}/tags", response_model=Optional[bool])
async def delete_chat_tag_by_id(
- id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+ id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
form_data.tag_name, id, user.id
@@ -470,7 +470,7 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool])
-async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
+async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
if result:
diff --git a/backend/apps/webui/routers/configs.py b/backend/apps/webui/routers/configs.py
index c127e721b..39e435013 100644
--- a/backend/apps/webui/routers/configs.py
+++ b/backend/apps/webui/routers/configs.py
@@ -14,7 +14,7 @@ from apps.webui.models.users import Users
from utils.utils import (
get_password_hash,
- get_current_user,
+ get_verified_user,
get_admin_user,
create_token,
)
@@ -84,6 +84,6 @@ async def set_banners(
@router.get("/banners", response_model=List[BannerModel])
async def get_banners(
request: Request,
- user=Depends(get_current_user),
+ user=Depends(get_verified_user),
):
return request.app.state.config.BANNERS
diff --git a/backend/apps/webui/routers/documents.py b/backend/apps/webui/routers/documents.py
index 311455390..dc53b5246 100644
--- a/backend/apps/webui/routers/documents.py
+++ b/backend/apps/webui/routers/documents.py
@@ -14,7 +14,7 @@ from apps.webui.models.documents import (
DocumentResponse,
)
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@@ -25,7 +25,7 @@ router = APIRouter()
@router.get("/", response_model=List[DocumentResponse])
-async def get_documents(user=Depends(get_current_user)):
+async def get_documents(user=Depends(get_verified_user)):
docs = [
DocumentResponse(
**{
@@ -74,7 +74,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
@router.get("/doc", response_model=Optional[DocumentResponse])
-async def get_doc_by_name(name: str, user=Depends(get_current_user)):
+async def get_doc_by_name(name: str, user=Depends(get_verified_user)):
doc = Documents.get_doc_by_name(name)
if doc:
@@ -106,7 +106,7 @@ class TagDocumentForm(BaseModel):
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
-async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
+async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
if doc:
diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py
new file mode 100644
index 000000000..3b6d44aa5
--- /dev/null
+++ b/backend/apps/webui/routers/files.py
@@ -0,0 +1,242 @@
+from fastapi import (
+ Depends,
+ FastAPI,
+ HTTPException,
+ status,
+ Request,
+ UploadFile,
+ File,
+ Form,
+)
+
+
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+from pathlib import Path
+
+from fastapi import APIRouter
+from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
+
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.files import (
+ Files,
+ FileForm,
+ FileModel,
+ FileModelResponse,
+)
+from utils.utils import get_verified_user, get_admin_user
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+import uuid
+import os, shutil, logging, re
+
+
+from config import SRC_LOG_LEVELS, UPLOAD_DIR
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+
+router = APIRouter()
+
+############################
+# Upload File
+############################
+
+
+@router.post("/")
+def upload_file(
+ file: UploadFile = File(...),
+ user=Depends(get_verified_user),
+):
+ log.info(f"file.content_type: {file.content_type}")
+ try:
+ unsanitized_filename = file.filename
+ filename = os.path.basename(unsanitized_filename)
+
+ # replace filename with uuid
+ id = str(uuid.uuid4())
+ filename = f"{id}_{filename}"
+ file_path = f"{UPLOAD_DIR}/{filename}"
+
+ contents = file.file.read()
+ with open(file_path, "wb") as f:
+ f.write(contents)
+ f.close()
+
+ file = Files.insert_new_file(
+ user.id,
+ FileForm(
+ **{
+ "id": id,
+ "filename": filename,
+ "meta": {
+ "content_type": file.content_type,
+ "size": len(contents),
+ "path": file_path,
+ },
+ }
+ ),
+ )
+
+ if file:
+ return file
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
+ )
+
+ except Exception as e:
+ log.exception(e)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+
+
+############################
+# List Files
+############################
+
+
+@router.get("/", response_model=List[FileModel])
+async def list_files(user=Depends(get_verified_user)):
+ files = Files.get_files()
+ return files
+
+
+############################
+# Delete All Files
+############################
+
+
+@router.delete("/all")
+async def delete_all_files(user=Depends(get_admin_user)):
+ result = Files.delete_all_files()
+
+ if result:
+ folder = f"{UPLOAD_DIR}"
+ try:
+ # Check if the directory exists
+ if os.path.exists(folder):
+ # Iterate over all the files and directories in the specified directory
+ for filename in os.listdir(folder):
+ file_path = os.path.join(folder, filename)
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.unlink(file_path) # Remove the file or link
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path) # Remove the directory
+ except Exception as e:
+ print(f"Failed to delete {file_path}. Reason: {e}")
+ else:
+ print(f"The directory {folder} does not exist")
+ except Exception as e:
+ print(f"Failed to process the directory {folder}. Reason: {e}")
+
+ return {"message": "All files deleted successfully"}
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
+ )
+
+
+############################
+# Get File By Id
+############################
+
+
+@router.get("/{id}", response_model=Optional[FileModel])
+async def get_file_by_id(id: str, user=Depends(get_verified_user)):
+ file = Files.get_file_by_id(id)
+
+ if file:
+ return file
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# Get File Content By Id
+############################
+
+
+@router.get("/{id}/content", response_model=Optional[FileModel])
+async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
+ file = Files.get_file_by_id(id)
+
+ if file:
+ file_path = Path(file.meta["path"])
+
+ # Check if the file already exists in the cache
+ if file_path.is_file():
+ print(f"file_path: {file_path}")
+ return FileResponse(file_path)
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
+async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
+ file = Files.get_file_by_id(id)
+
+ if file:
+ file_path = Path(file.meta["path"])
+
+ # Check if the file already exists in the cache
+ if file_path.is_file():
+ print(f"file_path: {file_path}")
+ return FileResponse(file_path)
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# Delete File By Id
+############################
+
+
+@router.delete("/{id}")
+async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
+ file = Files.get_file_by_id(id)
+
+ if file:
+ result = Files.delete_file_by_id(id)
+ if result:
+ return {"message": "File deleted successfully"}
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error deleting file"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
diff --git a/backend/apps/webui/routers/functions.py b/backend/apps/webui/routers/functions.py
new file mode 100644
index 000000000..f01133a35
--- /dev/null
+++ b/backend/apps/webui/routers/functions.py
@@ -0,0 +1,423 @@
+from fastapi import Depends, FastAPI, HTTPException, status, Request
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.functions import (
+ Functions,
+ FunctionForm,
+ FunctionModel,
+ FunctionResponse,
+)
+from apps.webui.utils import load_function_module_by_id
+from utils.utils import get_verified_user, get_admin_user
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+from pathlib import Path
+
+from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
+
+
+router = APIRouter()
+
+############################
+# GetFunctions
+############################
+
+
+@router.get("/", response_model=List[FunctionResponse])
+async def get_functions(user=Depends(get_verified_user)):
+ return Functions.get_functions()
+
+
+############################
+# ExportFunctions
+############################
+
+
+@router.get("/export", response_model=List[FunctionModel])
+async def get_functions(user=Depends(get_admin_user)):
+ return Functions.get_functions()
+
+
+############################
+# CreateNewFunction
+############################
+
+
+@router.post("/create", response_model=Optional[FunctionResponse])
+async def create_new_function(
+ request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+ if not form_data.id.isidentifier():
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Only alphanumeric characters and underscores are allowed in the id",
+ )
+
+ form_data.id = form_data.id.lower()
+
+ function = Functions.get_function_by_id(form_data.id)
+ if function == None:
+ function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
+ try:
+ with open(function_path, "w") as function_file:
+ function_file.write(form_data.content)
+
+ function_module, function_type, frontmatter = load_function_module_by_id(
+ form_data.id
+ )
+ form_data.meta.manifest = frontmatter
+
+ FUNCTIONS = request.app.state.FUNCTIONS
+ FUNCTIONS[form_data.id] = function_module
+
+ function = Functions.insert_new_function(user.id, function_type, form_data)
+
+ function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
+ function_cache_dir.mkdir(parents=True, exist_ok=True)
+
+ if function:
+ return function
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
+ )
+ except Exception as e:
+ print(e)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.ID_TAKEN,
+ )
+
+
+############################
+# GetFunctionById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[FunctionModel])
+async def get_function_by_id(id: str, user=Depends(get_admin_user)):
+ function = Functions.get_function_by_id(id)
+
+ if function:
+ return function
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# ToggleFunctionById
+############################
+
+
+@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
+async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
+ function = Functions.get_function_by_id(id)
+ if function:
+ function = Functions.update_function_by_id(
+ id, {"is_active": not function.is_active}
+ )
+
+ if function:
+ return function
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# ToggleGlobalById
+############################
+
+
+@router.post("/id/{id}/toggle/global", response_model=Optional[FunctionModel])
+async def toggle_global_by_id(id: str, user=Depends(get_admin_user)):
+ function = Functions.get_function_by_id(id)
+ if function:
+ function = Functions.update_function_by_id(
+ id, {"is_global": not function.is_global}
+ )
+
+ if function:
+ return function
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# UpdateFunctionById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
+async def update_function_by_id(
+ request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+ function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+
+ try:
+ with open(function_path, "w") as function_file:
+ function_file.write(form_data.content)
+
+ function_module, function_type, frontmatter = load_function_module_by_id(id)
+ form_data.meta.manifest = frontmatter
+
+ FUNCTIONS = request.app.state.FUNCTIONS
+ FUNCTIONS[id] = function_module
+
+ updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
+ print(updated)
+
+ function = Functions.update_function_by_id(id, updated)
+
+ if function:
+ return function
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+
+
+############################
+# DeleteFunctionById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_function_by_id(
+ request: Request, id: str, user=Depends(get_admin_user)
+):
+ result = Functions.delete_function_by_id(id)
+
+ if result:
+ FUNCTIONS = request.app.state.FUNCTIONS
+ if id in FUNCTIONS:
+ del FUNCTIONS[id]
+
+ # delete the function file
+ function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+ os.remove(function_path)
+
+ return result
+
+
+############################
+# GetFunctionValves
+############################
+
+
+@router.get("/id/{id}/valves", response_model=Optional[dict])
+async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
+ function = Functions.get_function_by_id(id)
+ if function:
+ try:
+ valves = Functions.get_function_valves_by_id(id)
+ return valves
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# GetFunctionValvesSpec
+############################
+
+
+@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
+async def get_function_valves_spec_by_id(
+ request: Request, id: str, user=Depends(get_admin_user)
+):
+ function = Functions.get_function_by_id(id)
+ if function:
+ if id in request.app.state.FUNCTIONS:
+ function_module = request.app.state.FUNCTIONS[id]
+ else:
+ function_module, function_type, frontmatter = load_function_module_by_id(id)
+ request.app.state.FUNCTIONS[id] = function_module
+
+ if hasattr(function_module, "Valves"):
+ Valves = function_module.Valves
+ return Valves.schema()
+ return None
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# UpdateFunctionValves
+############################
+
+
+@router.post("/id/{id}/valves/update", response_model=Optional[dict])
+async def update_function_valves_by_id(
+ request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
+):
+ function = Functions.get_function_by_id(id)
+ if function:
+
+ if id in request.app.state.FUNCTIONS:
+ function_module = request.app.state.FUNCTIONS[id]
+ else:
+ function_module, function_type, frontmatter = load_function_module_by_id(id)
+ request.app.state.FUNCTIONS[id] = function_module
+
+ if hasattr(function_module, "Valves"):
+ Valves = function_module.Valves
+
+ try:
+ form_data = {k: v for k, v in form_data.items() if v is not None}
+ valves = Valves(**form_data)
+ Functions.update_function_valves_by_id(id, valves.model_dump())
+ return valves.model_dump()
+ except Exception as e:
+ print(e)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# FunctionUserValves
+############################
+
+
+@router.get("/id/{id}/valves/user", response_model=Optional[dict])
+async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
+ function = Functions.get_function_by_id(id)
+ if function:
+ try:
+ user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
+ return user_valves
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
+async def get_function_user_valves_spec_by_id(
+ request: Request, id: str, user=Depends(get_verified_user)
+):
+ function = Functions.get_function_by_id(id)
+ if function:
+ if id in request.app.state.FUNCTIONS:
+ function_module = request.app.state.FUNCTIONS[id]
+ else:
+ function_module, function_type, frontmatter = load_function_module_by_id(id)
+ request.app.state.FUNCTIONS[id] = function_module
+
+ if hasattr(function_module, "UserValves"):
+ UserValves = function_module.UserValves
+ return UserValves.schema()
+ return None
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
+async def update_function_user_valves_by_id(
+ request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
+):
+ function = Functions.get_function_by_id(id)
+
+ if function:
+ if id in request.app.state.FUNCTIONS:
+ function_module = request.app.state.FUNCTIONS[id]
+ else:
+ function_module, function_type, frontmatter = load_function_module_by_id(id)
+ request.app.state.FUNCTIONS[id] = function_module
+
+ if hasattr(function_module, "UserValves"):
+ UserValves = function_module.UserValves
+
+ try:
+ form_data = {k: v for k, v in form_data.items() if v is not None}
+ user_valves = UserValves(**form_data)
+ Functions.update_user_valves_by_id_and_user_id(
+ id, user.id, user_valves.model_dump()
+ )
+ return user_valves.model_dump()
+ except Exception as e:
+ print(e)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
diff --git a/backend/apps/webui/routers/memories.py b/backend/apps/webui/routers/memories.py
index 3832fe9a1..e9ae96173 100644
--- a/backend/apps/webui/routers/memories.py
+++ b/backend/apps/webui/routers/memories.py
@@ -101,6 +101,7 @@ async def update_memory_by_id(
class QueryMemoryForm(BaseModel):
content: str
+ k: Optional[int] = 1
@router.post("/query")
@@ -112,7 +113,7 @@ async def query_memory(
results = collection.query(
query_embeddings=[query_embedding],
- n_results=1, # how many results to return
+ n_results=form_data.k, # how many results to return
)
return results
diff --git a/backend/apps/webui/routers/prompts.py b/backend/apps/webui/routers/prompts.py
index 47d8c7012..e609a0a1b 100644
--- a/backend/apps/webui/routers/prompts.py
+++ b/backend/apps/webui/routers/prompts.py
@@ -8,7 +8,7 @@ import json
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
@@ -19,7 +19,7 @@ router = APIRouter()
@router.get("/", response_model=List[PromptModel])
-async def get_prompts(user=Depends(get_current_user)):
+async def get_prompts(user=Depends(get_verified_user)):
return Prompts.get_prompts()
@@ -52,7 +52,7 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user))
@router.get("/command/{command}", response_model=Optional[PromptModel])
-async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
+async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt:
diff --git a/backend/apps/webui/routers/tools.py b/backend/apps/webui/routers/tools.py
index b68ed32ee..d20584c22 100644
--- a/backend/apps/webui/routers/tools.py
+++ b/backend/apps/webui/routers/tools.py
@@ -6,17 +6,20 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
+
+from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
+from pathlib import Path
-from config import DATA_DIR
+from config import DATA_DIR, CACHE_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
@@ -31,7 +34,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse])
-async def get_toolkits(user=Depends(get_current_user)):
+async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
@@ -71,7 +74,8 @@ async def create_new_toolkit(
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
- toolkit_module = load_toolkit_module_by_id(form_data.id)
+ toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
+ form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
@@ -79,6 +83,9 @@ async def create_new_toolkit(
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
+ tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
+ tool_cache_dir.mkdir(parents=True, exist_ok=True)
+
if toolkit:
return toolkit
else:
@@ -132,7 +139,8 @@ async def update_toolkit_by_id(
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
- toolkit_module = load_toolkit_module_by_id(id)
+ toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+ form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
@@ -181,3 +189,187 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
os.remove(toolkit_path)
return result
+
+
+############################
+# GetToolValves
+############################
+
+
+@router.get("/id/{id}/valves", response_model=Optional[dict])
+async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
+ toolkit = Tools.get_tool_by_id(id)
+ if toolkit:
+ try:
+ valves = Tools.get_tool_valves_by_id(id)
+ return valves
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# GetToolValvesSpec
+############################
+
+
+@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
+async def get_toolkit_valves_spec_by_id(
+ request: Request, id: str, user=Depends(get_admin_user)
+):
+ toolkit = Tools.get_tool_by_id(id)
+ if toolkit:
+ if id in request.app.state.TOOLS:
+ toolkit_module = request.app.state.TOOLS[id]
+ else:
+ toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+ request.app.state.TOOLS[id] = toolkit_module
+
+ if hasattr(toolkit_module, "Valves"):
+ Valves = toolkit_module.Valves
+ return Valves.schema()
+ return None
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# UpdateToolValves
+############################
+
+
+@router.post("/id/{id}/valves/update", response_model=Optional[dict])
+async def update_toolkit_valves_by_id(
+ request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
+):
+ toolkit = Tools.get_tool_by_id(id)
+ if toolkit:
+ if id in request.app.state.TOOLS:
+ toolkit_module = request.app.state.TOOLS[id]
+ else:
+ toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+ request.app.state.TOOLS[id] = toolkit_module
+
+ if hasattr(toolkit_module, "Valves"):
+ Valves = toolkit_module.Valves
+
+ try:
+ form_data = {k: v for k, v in form_data.items() if v is not None}
+ valves = Valves(**form_data)
+ Tools.update_tool_valves_by_id(id, valves.model_dump())
+ return valves.model_dump()
+ except Exception as e:
+ print(e)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# ToolUserValves
+############################
+
+
+@router.get("/id/{id}/valves/user", response_model=Optional[dict])
+async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
+ toolkit = Tools.get_tool_by_id(id)
+ if toolkit:
+ try:
+ user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
+ return user_valves
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
+async def get_toolkit_user_valves_spec_by_id(
+ request: Request, id: str, user=Depends(get_verified_user)
+):
+ toolkit = Tools.get_tool_by_id(id)
+ if toolkit:
+ if id in request.app.state.TOOLS:
+ toolkit_module = request.app.state.TOOLS[id]
+ else:
+ toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+ request.app.state.TOOLS[id] = toolkit_module
+
+ if hasattr(toolkit_module, "UserValves"):
+ UserValves = toolkit_module.UserValves
+ return UserValves.schema()
+ return None
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
+async def update_toolkit_user_valves_by_id(
+ request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
+):
+ toolkit = Tools.get_tool_by_id(id)
+
+ if toolkit:
+ if id in request.app.state.TOOLS:
+ toolkit_module = request.app.state.TOOLS[id]
+ else:
+ toolkit_module, frontmatter = load_toolkit_module_by_id(id)
+ request.app.state.TOOLS[id] = toolkit_module
+
+ if hasattr(toolkit_module, "UserValves"):
+ UserValves = toolkit_module.UserValves
+
+ try:
+ form_data = {k: v for k, v in form_data.items() if v is not None}
+ user_valves = UserValves(**form_data)
+ Tools.update_user_valves_by_id_and_user_id(
+ id, user.id, user_valves.model_dump()
+ )
+ return user_valves.model_dump()
+ except Exception as e:
+ print(e)
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(e),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
diff --git a/backend/apps/webui/utils.py b/backend/apps/webui/utils.py
index 19a8615bc..545120835 100644
--- a/backend/apps/webui/utils.py
+++ b/backend/apps/webui/utils.py
@@ -1,19 +1,61 @@
from importlib import util
import os
+import re
-from config import TOOLS_DIR
+from config import TOOLS_DIR, FUNCTIONS_DIR
+
+
+def extract_frontmatter(file_path):
+ """
+ Extract frontmatter as a dictionary from the specified file path.
+ """
+ frontmatter = {}
+ frontmatter_started = False
+ frontmatter_ended = False
+ frontmatter_pattern = re.compile(r"^\s*([a-z_]+):\s*(.*)\s*$", re.IGNORECASE)
+
+ try:
+ with open(file_path, "r", encoding="utf-8") as file:
+ first_line = file.readline()
+ if first_line.strip() != '"""':
+ # The file doesn't start with triple quotes
+ return {}
+
+ frontmatter_started = True
+
+ for line in file:
+ if '"""' in line:
+ if frontmatter_started:
+ frontmatter_ended = True
+ break
+
+ if frontmatter_started and not frontmatter_ended:
+ match = frontmatter_pattern.match(line)
+ if match:
+ key, value = match.groups()
+ frontmatter[key.strip()] = value.strip()
+
+ except FileNotFoundError:
+ print(f"Error: The file {file_path} does not exist.")
+ return {}
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return {}
+
+ return frontmatter
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
+ frontmatter = extract_frontmatter(toolkit_path)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
- return module.Tools()
+ return module.Tools(), frontmatter
else:
raise Exception("No Tools class found")
except Exception as e:
@@ -21,3 +63,26 @@ def load_toolkit_module_by_id(toolkit_id):
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e
+
+
+def load_function_module_by_id(function_id):
+ function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
+
+ spec = util.spec_from_file_location(function_id, function_path)
+ module = util.module_from_spec(spec)
+ frontmatter = extract_frontmatter(function_path)
+
+ try:
+ spec.loader.exec_module(module)
+ print(f"Loaded module: {module.__name__}")
+ if hasattr(module, "Pipe"):
+ return module.Pipe(), "pipe", frontmatter
+ elif hasattr(module, "Filter"):
+ return module.Filter(), "filter", frontmatter
+ else:
+ raise Exception("No Function class found")
+ except Exception as e:
+ print(f"Error loading module: {function_id}")
+ # Move the file to the error folder
+ os.rename(function_path, f"{function_path}.error")
+ raise e
diff --git a/backend/config.py b/backend/config.py
index 1a38a450d..3a825f53a 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -167,6 +167,12 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json
+####################################
+# SAFE_MODE
+####################################
+
+SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
+
####################################
# WEBUI_BUILD_HASH
####################################
@@ -299,6 +305,135 @@ JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
+####################################
+# OAuth config
+####################################
+
+ENABLE_OAUTH_SIGNUP = PersistentConfig(
+ "ENABLE_OAUTH_SIGNUP",
+ "oauth.enable_signup",
+ os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
+)
+
+OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig(
+ "OAUTH_MERGE_ACCOUNTS_BY_EMAIL",
+ "oauth.merge_accounts_by_email",
+ os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true",
+)
+
+OAUTH_PROVIDERS = {}
+
+GOOGLE_CLIENT_ID = PersistentConfig(
+ "GOOGLE_CLIENT_ID",
+ "oauth.google.client_id",
+ os.environ.get("GOOGLE_CLIENT_ID", ""),
+)
+
+GOOGLE_CLIENT_SECRET = PersistentConfig(
+ "GOOGLE_CLIENT_SECRET",
+ "oauth.google.client_secret",
+ os.environ.get("GOOGLE_CLIENT_SECRET", ""),
+)
+
+GOOGLE_OAUTH_SCOPE = PersistentConfig(
+ "GOOGLE_OAUTH_SCOPE",
+ "oauth.google.scope",
+ os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
+)
+
+MICROSOFT_CLIENT_ID = PersistentConfig(
+ "MICROSOFT_CLIENT_ID",
+ "oauth.microsoft.client_id",
+ os.environ.get("MICROSOFT_CLIENT_ID", ""),
+)
+
+MICROSOFT_CLIENT_SECRET = PersistentConfig(
+ "MICROSOFT_CLIENT_SECRET",
+ "oauth.microsoft.client_secret",
+ os.environ.get("MICROSOFT_CLIENT_SECRET", ""),
+)
+
+MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
+ "MICROSOFT_CLIENT_TENANT_ID",
+ "oauth.microsoft.tenant_id",
+ os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
+)
+
+MICROSOFT_OAUTH_SCOPE = PersistentConfig(
+ "MICROSOFT_OAUTH_SCOPE",
+ "oauth.microsoft.scope",
+ os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
+)
+
+OAUTH_CLIENT_ID = PersistentConfig(
+ "OAUTH_CLIENT_ID",
+ "oauth.oidc.client_id",
+ os.environ.get("OAUTH_CLIENT_ID", ""),
+)
+
+OAUTH_CLIENT_SECRET = PersistentConfig(
+ "OAUTH_CLIENT_SECRET",
+ "oauth.oidc.client_secret",
+ os.environ.get("OAUTH_CLIENT_SECRET", ""),
+)
+
+OPENID_PROVIDER_URL = PersistentConfig(
+ "OPENID_PROVIDER_URL",
+ "oauth.oidc.provider_url",
+ os.environ.get("OPENID_PROVIDER_URL", ""),
+)
+
+OAUTH_SCOPES = PersistentConfig(
+ "OAUTH_SCOPES",
+ "oauth.oidc.scopes",
+ os.environ.get("OAUTH_SCOPES", "openid email profile"),
+)
+
+OAUTH_PROVIDER_NAME = PersistentConfig(
+ "OAUTH_PROVIDER_NAME",
+ "oauth.oidc.provider_name",
+ os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
+)
+
+
+def load_oauth_providers():
+ OAUTH_PROVIDERS.clear()
+ if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
+ OAUTH_PROVIDERS["google"] = {
+ "client_id": GOOGLE_CLIENT_ID.value,
+ "client_secret": GOOGLE_CLIENT_SECRET.value,
+ "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
+ "scope": GOOGLE_OAUTH_SCOPE.value,
+ }
+
+ if (
+ MICROSOFT_CLIENT_ID.value
+ and MICROSOFT_CLIENT_SECRET.value
+ and MICROSOFT_CLIENT_TENANT_ID.value
+ ):
+ OAUTH_PROVIDERS["microsoft"] = {
+ "client_id": MICROSOFT_CLIENT_ID.value,
+ "client_secret": MICROSOFT_CLIENT_SECRET.value,
+ "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
+ "scope": MICROSOFT_OAUTH_SCOPE.value,
+ }
+
+ if (
+ OAUTH_CLIENT_ID.value
+ and OAUTH_CLIENT_SECRET.value
+ and OPENID_PROVIDER_URL.value
+ ):
+ OAUTH_PROVIDERS["oidc"] = {
+ "client_id": OAUTH_CLIENT_ID.value,
+ "client_secret": OAUTH_CLIENT_SECRET.value,
+ "server_metadata_url": OPENID_PROVIDER_URL.value,
+ "scope": OAUTH_SCOPES.value,
+ "name": OAUTH_PROVIDER_NAME.value,
+ }
+
+
+load_oauth_providers()
+
####################################
# Static DIR
####################################
@@ -377,6 +512,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
+####################################
+# Functions DIR
+####################################
+
+FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
+Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
+
+
####################################
# LITELLM_CONFIG
####################################
@@ -426,12 +569,15 @@ OLLAMA_API_BASE_URL = os.environ.get(
)
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
-AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300")
+AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
if AIOHTTP_CLIENT_TIMEOUT == "":
AIOHTTP_CLIENT_TIMEOUT = None
else:
- AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
+ try:
+ AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
+ except:
+ AIOHTTP_CLIENT_TIMEOUT = 300
K8S_FLAG = os.environ.get("K8S_FLAG", "")
@@ -719,6 +865,16 @@ WEBUI_SECRET_KEY = os.environ.get(
), # DEPRECATED: remove at next major version
)
+WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
+ "WEBUI_SESSION_COOKIE_SAME_SITE",
+ os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
+)
+
+WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
+ "WEBUI_SESSION_COOKIE_SECURE",
+ os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
+)
+
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
@@ -903,6 +1059,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
)
+# You can provide a list of your own websites to filter after performing a web search.
+# This ensures the highest level of safety and reliability of the information sources.
+RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
+ "RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
+ "rag.rag.web.search.domain.filter_list",
+ [
+ # "wikipedia.com",
+ # "wikimedia.org",
+ # "wikidata.org",
+ ],
+)
+
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
@@ -1001,6 +1169,11 @@ AUTOMATIC1111_BASE_URL = PersistentConfig(
"image_generation.automatic1111.base_url",
os.getenv("AUTOMATIC1111_BASE_URL", ""),
)
+AUTOMATIC1111_API_AUTH = PersistentConfig(
+ "AUTOMATIC1111_API_AUTH",
+ "image_generation.automatic1111.api_auth",
+ os.getenv("AUTOMATIC1111_API_AUTH", ""),
+)
COMFYUI_BASE_URL = PersistentConfig(
"COMFYUI_BASE_URL",
diff --git a/backend/main.py b/backend/main.py
index 04f886162..aae305c5e 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -1,4 +1,9 @@
+import base64
+import uuid
from contextlib import asynccontextmanager
+
+from authlib.integrations.starlette_client import OAuth
+from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json
import markdown
@@ -11,9 +16,11 @@ import requests
import mimetypes
import shutil
import os
+import uuid
import inspect
import asyncio
+from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
@@ -22,7 +29,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import StreamingResponse, Response
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app
@@ -41,29 +49,43 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
-from apps.webui.main import app as webui_app
+from apps.webui.main import (
+ app as webui_app,
+ get_pipe_models,
+ generate_function_chat_completion,
+)
from pydantic import BaseModel
-from typing import List, Optional
+from typing import List, Optional, Iterator, Generator, Union
+from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
-from apps.webui.utils import load_toolkit_module_by_id
+from apps.webui.models.functions import Functions
+from apps.webui.models.users import Users
+from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from utils.utils import (
get_admin_user,
get_verified_user,
get_current_user,
get_http_authorization_cred,
+ get_password_hash,
+ create_token,
)
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
-from utils.misc import get_last_user_message, add_or_update_system_message
+from utils.misc import (
+ get_last_user_message,
+ add_or_update_system_message,
+ stream_message_template,
+ parse_duration,
+)
from apps.rag.utils import get_rag_context, rag_template
@@ -76,6 +98,7 @@ from config import (
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
+ UPLOAD_DIR,
CACHE_DIR,
STATIC_DIR,
ENABLE_OPENAI_API,
@@ -93,9 +116,22 @@ from config import (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+ SAFE_MODE,
+ OAUTH_PROVIDERS,
+ ENABLE_OAUTH_SIGNUP,
+ OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+ WEBUI_SECRET_KEY,
+ WEBUI_SESSION_COOKIE_SAME_SITE,
+ WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
)
-from constants import ERROR_MESSAGES
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
+from utils.webhook import post_webhook
+
+if SAFE_MODE:
+ print("SAFE MODE ENABLED")
+ Functions.deactivate_all_functions()
+
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
@@ -168,7 +204,16 @@ app.state.MODELS = {}
origins = ["*"]
-async def get_function_call_response(messages, tool_id, template, task_model_id, user):
+##################################
+#
+# ChatCompletion Middleware
+#
+##################################
+
+
+async def get_function_call_response(
+ messages, files, tool_id, template, task_model_id, user
+):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
@@ -205,12 +250,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
response = None
try:
- if model["owned_by"] == "ollama":
- response = await generate_ollama_chat_completion(
- OpenAIChatCompletionForm(**payload), user=user
- )
- else:
- response = await generate_openai_chat_completion(payload, user=user)
+ response = await generate_chat_completions(form_data=payload, user=user)
content = None
@@ -231,84 +271,241 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
result = json.loads(content)
print(result)
+ citation = None
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
- toolkit_module = load_toolkit_module_by_id(tool_id)
+ toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
+ file_handler = False
+ # check if toolkit_module has file_handler self variable
+ if hasattr(toolkit_module, "file_handler"):
+ file_handler = True
+ print("file_handler: ", file_handler)
+
+ if hasattr(toolkit_module, "valves") and hasattr(
+ toolkit_module, "Valves"
+ ):
+ valves = Tools.get_tool_valves_by_id(tool_id)
+ toolkit_module.valves = toolkit_module.Valves(
+ **(valves if valves else {})
+ )
+
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
- # Check if '__user__' is a parameter of the function
+ params = result["parameters"]
+
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
- function_result = function(
- **{
- **result["parameters"],
- "__user__": {
- "id": user.id,
- "email": user.email,
- "name": user.name,
- "role": user.role,
- },
- }
- )
+ __user__ = {
+ "id": user.id,
+ "email": user.email,
+ "name": user.name,
+ "role": user.role,
+ }
+
+ try:
+ if hasattr(toolkit_module, "UserValves"):
+ __user__["valves"] = toolkit_module.UserValves(
+ **Tools.get_user_valves_by_id_and_user_id(
+ tool_id, user.id
+ )
+ )
+ except Exception as e:
+ print(e)
+
+ params = {**params, "__user__": __user__}
+ if "__messages__" in sig.parameters:
+ # Call the function with the '__messages__' parameter included
+ params = {
+ **params,
+ "__messages__": messages,
+ }
+
+ if "__files__" in sig.parameters:
+ # Call the function with the '__files__' parameter included
+ params = {
+ **params,
+ "__files__": files,
+ }
+
+ if "__model__" in sig.parameters:
+ # Call the function with the '__model__' parameter included
+ params = {
+ **params,
+ "__model__": model,
+ }
+
+ if "__id__" in sig.parameters:
+ # Call the function with the '__id__' parameter included
+ params = {
+ **params,
+ "__id__": tool_id,
+ }
+
+ if inspect.iscoroutinefunction(function):
+ function_result = await function(**params)
else:
- # Call the function without modifying the parameters
- function_result = function(**result["parameters"])
+ function_result = function(**params)
+
+ if hasattr(toolkit_module, "citation") and toolkit_module.citation:
+ citation = {
+ "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
+ "document": [function_result],
+ "metadata": [{"source": result["name"]}],
+ }
except Exception as e:
print(e)
# Add the function result to the system prompt
- if function_result:
- return function_result
+ if function_result is not None:
+ return function_result, citation, file_handler
except Exception as e:
print(f"Error: {e}")
- return None
+ return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
- return_citations = False
+ data_items = []
- if request.method == "POST" and (
- "/ollama/api/chat" in request.url.path
- or "/chat/completions" in request.url.path
+ show_citations = False
+ citations = []
+
+ if request.method == "POST" and any(
+ endpoint in request.url.path
+ for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
- # Decode body to string
body_str = body.decode("utf-8")
- # Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
- get_http_authorization_cred(request.headers.get("Authorization"))
+ request,
+ get_http_authorization_cred(request.headers.get("Authorization")),
)
-
- # Remove the citations from the body
- return_citations = data.get("citations", False)
- if "citations" in data:
+ # Flag to skip RAG completions if file_handler is present in tools/functions
+ skip_files = False
+ if data.get("citations"):
+ show_citations = True
del data["citations"]
- # Set the task model
- task_model_id = data["model"]
- if task_model_id not in app.state.MODELS:
+ model_id = data["model"]
+ if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
+ model = app.state.MODELS[model_id]
- # Check if the user has a custom task model
- # If the user has a custom task model, use that model
+ def get_priority(function_id):
+ function = Functions.get_function_by_id(function_id)
+ if function is not None and hasattr(function, "valves"):
+ return (function.valves if function.valves else {}).get(
+ "priority", 0
+ )
+ return 0
+
+ filter_ids = [
+ function.id for function in Functions.get_global_filter_functions()
+ ]
+ if "info" in model and "meta" in model["info"]:
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+ filter_ids = list(set(filter_ids))
+
+ enabled_filter_ids = [
+ function.id
+ for function in Functions.get_functions_by_type(
+ "filter", active_only=True
+ )
+ ]
+ filter_ids = [
+ filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+ ]
+
+ filter_ids.sort(key=get_priority)
+ for filter_id in filter_ids:
+ filter = Functions.get_function_by_id(filter_id)
+ if filter:
+ if filter_id in webui_app.state.FUNCTIONS:
+ function_module = webui_app.state.FUNCTIONS[filter_id]
+ else:
+ function_module, function_type, frontmatter = (
+ load_function_module_by_id(filter_id)
+ )
+ webui_app.state.FUNCTIONS[filter_id] = function_module
+
+ # Check if the function has a file_handler variable
+ if hasattr(function_module, "file_handler"):
+ skip_files = function_module.file_handler
+
+ if hasattr(function_module, "valves") and hasattr(
+ function_module, "Valves"
+ ):
+ valves = Functions.get_function_valves_by_id(filter_id)
+ function_module.valves = function_module.Valves(
+ **(valves if valves else {})
+ )
+
+ try:
+ if hasattr(function_module, "inlet"):
+ inlet = function_module.inlet
+
+ # Get the signature of the function
+ sig = inspect.signature(inlet)
+ params = {"body": data}
+
+ if "__user__" in sig.parameters:
+ __user__ = {
+ "id": user.id,
+ "email": user.email,
+ "name": user.name,
+ "role": user.role,
+ }
+
+ try:
+ if hasattr(function_module, "UserValves"):
+ __user__["valves"] = function_module.UserValves(
+ **Functions.get_user_valves_by_id_and_user_id(
+ filter_id, user.id
+ )
+ )
+ except Exception as e:
+ print(e)
+
+ params = {**params, "__user__": __user__}
+
+ if "__id__" in sig.parameters:
+ params = {
+ **params,
+ "__id__": filter_id,
+ }
+
+ if inspect.iscoroutinefunction(inlet):
+ data = await inlet(**params)
+ else:
+ data = inlet(**params)
+
+ except Exception as e:
+ print(f"Error: {e}")
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": str(e)},
+ )
+
+ # Set the task model
+ task_model_id = data["model"]
+ # Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
@@ -331,55 +528,71 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
for tool_id in data["tool_ids"]:
print(tool_id)
try:
- response = await get_function_call_response(
- messages=data["messages"],
- tool_id=tool_id,
- template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
- task_model_id=task_model_id,
- user=user,
+ response, citation, file_handler = (
+ await get_function_call_response(
+ messages=data["messages"],
+ files=data.get("files", []),
+ tool_id=tool_id,
+ template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+ task_model_id=task_model_id,
+ user=user,
+ )
)
- if response:
+ print(file_handler)
+ if isinstance(response, str):
context += ("\n" if context != "" else "") + response
+
+ if citation:
+ citations.append(citation)
+ show_citations = True
+
+ if file_handler:
+ skip_files = True
+
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
- # If docs field is present, generate RAG completions
- if "docs" in data:
- data = {**data}
- rag_context, citations = get_rag_context(
- docs=data["docs"],
- messages=data["messages"],
- embedding_function=rag_app.state.EMBEDDING_FUNCTION,
- k=rag_app.state.config.TOP_K,
- reranking_function=rag_app.state.sentence_transformer_rf,
- r=rag_app.state.config.RELEVANCE_THRESHOLD,
- hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
- )
+ # If files field is present, generate RAG completions
+ # If skip_files is True, skip the RAG completions
+ if "files" in data:
+ if not skip_files:
+ data = {**data}
+ rag_context, rag_citations = get_rag_context(
+ files=data["files"],
+ messages=data["messages"],
+ embedding_function=rag_app.state.EMBEDDING_FUNCTION,
+ k=rag_app.state.config.TOP_K,
+ reranking_function=rag_app.state.sentence_transformer_rf,
+ r=rag_app.state.config.RELEVANCE_THRESHOLD,
+ hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
+ )
+ if rag_context:
+ context += ("\n" if context != "" else "") + rag_context
- if rag_context:
- context += ("\n" if context != "" else "") + rag_context
+ log.debug(f"rag_context: {rag_context}, citations: {citations}")
- del data["docs"]
+ if rag_citations:
+ citations.extend(rag_citations)
- log.debug(f"rag_context: {rag_context}, citations: {citations}")
+ del data["files"]
+
+ if show_citations and len(citations) > 0:
+ data_items.append({"citations": citations})
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
-
print(system_prompt)
-
data["messages"] = add_or_update_system_message(
- f"\n{system_prompt}", data["messages"]
+ system_prompt, data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
-
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
@@ -392,43 +605,54 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
],
]
- response = await call_next(request)
-
- if return_citations:
- # Inject the citations into the response
+ response = await call_next(request)
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
- self.openai_stream_wrapper(response.body_iterator, citations),
+ self.openai_stream_wrapper(response.body_iterator, data_items),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
- self.ollama_stream_wrapper(response.body_iterator, citations),
+ self.ollama_stream_wrapper(response.body_iterator, data_items),
)
+ else:
+ return response
+ # If it's not a chat completion request, just pass it through
+ response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
- async def openai_stream_wrapper(self, original_generator, citations):
- yield f"data: {json.dumps({'citations': citations})}\n\n"
+ async def openai_stream_wrapper(self, original_generator, data_items):
+ for item in data_items:
+ yield f"data: {json.dumps(item)}\n\n"
+
async for data in original_generator:
yield data
- async def ollama_stream_wrapper(self, original_generator, citations):
- yield f"{json.dumps({'citations': citations})}\n"
+ async def ollama_stream_wrapper(self, original_generator, data_items):
+ for item in data_items:
+ yield f"{json.dumps(item)}\n"
+
async for data in original_generator:
yield data
app.add_middleware(ChatCompletionMiddleware)
+##################################
+#
+# Pipeline Middleware
+#
+##################################
+
def filter_pipeline(payload, user):
- user = {"id": user.id, "name": user.name, "role": user.role}
+ user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [
model
@@ -516,7 +740,8 @@ class PipelineMiddleware(BaseHTTPMiddleware):
data = json.loads(body_str) if body_str else {}
user = get_current_user(
- get_http_authorization_cred(request.headers.get("Authorization"))
+ request,
+ get_http_authorization_cred(request.headers.get("Authorization")),
)
try:
@@ -584,7 +809,6 @@ async def update_embedding_function(request: Request, call_next):
app.mount("/ws", socket_app)
-
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
@@ -598,17 +822,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models():
+ pipe_models = []
openai_models = []
ollama_models = []
+ pipe_models = await get_pipe_models()
+
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
-
openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models()
-
ollama_models = [
{
"id": model["model"],
@@ -621,9 +846,9 @@ async def get_all_models():
for model in ollama_models["models"]
]
- models = openai_models + ollama_models
- custom_models = Models.get_all_models()
+ models = pipe_models + openai_models + ollama_models
+ custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
for model in models:
@@ -686,6 +911,200 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
+@app.post("/api/chat/completions")
+async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
+ model_id = form_data["model"]
+ if model_id not in app.state.MODELS:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ model = app.state.MODELS[model_id]
+
+ pipe = model.get("pipe")
+ if pipe:
+ return await generate_function_chat_completion(form_data, user=user)
+ if model["owned_by"] == "ollama":
+ return await generate_ollama_chat_completion(form_data, user=user)
+ else:
+ return await generate_openai_chat_completion(form_data, user=user)
+
+
+@app.post("/api/chat/completed")
+async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
+ data = form_data
+ model_id = data["model"]
+ if model_id not in app.state.MODELS:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+ model = app.state.MODELS[model_id]
+
+ filters = [
+ model
+ for model in app.state.MODELS.values()
+ if "pipeline" in model
+ and "type" in model["pipeline"]
+ and model["pipeline"]["type"] == "filter"
+ and (
+ model["pipeline"]["pipelines"] == ["*"]
+ or any(
+ model_id == target_model_id
+ for target_model_id in model["pipeline"]["pipelines"]
+ )
+ )
+ ]
+
+ sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+ if "pipeline" in model:
+ sorted_filters = [model] + sorted_filters
+
+ for filter in sorted_filters:
+ r = None
+ try:
+ urlIdx = filter["urlIdx"]
+
+ url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+ key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+
+ if key != "":
+ headers = {"Authorization": f"Bearer {key}"}
+ r = requests.post(
+ f"{url}/{filter['id']}/filter/outlet",
+ headers=headers,
+ json={
+ "user": {
+ "id": user.id,
+ "name": user.name,
+ "email": user.email,
+ "role": user.role,
+ },
+ "body": data,
+ },
+ )
+
+ r.raise_for_status()
+ data = r.json()
+ except Exception as e:
+ # Handle connection error here
+ print(f"Connection error: {e}")
+
+ if r is not None:
+ try:
+ res = r.json()
+ if "detail" in res:
+ return JSONResponse(
+ status_code=r.status_code,
+ content=res,
+ )
+ except:
+ pass
+
+ else:
+ pass
+
+ def get_priority(function_id):
+ function = Functions.get_function_by_id(function_id)
+ if function is not None and hasattr(function, "valves"):
+ return (function.valves if function.valves else {}).get("priority", 0)
+ return 0
+
+ filter_ids = [function.id for function in Functions.get_global_filter_functions()]
+ if "info" in model and "meta" in model["info"]:
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+ filter_ids = list(set(filter_ids))
+
+ enabled_filter_ids = [
+ function.id
+ for function in Functions.get_functions_by_type("filter", active_only=True)
+ ]
+ filter_ids = [
+ filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+ ]
+
+ # Sort filter_ids by priority, using the get_priority function
+ filter_ids.sort(key=get_priority)
+
+ for filter_id in filter_ids:
+ filter = Functions.get_function_by_id(filter_id)
+ if filter:
+ if filter_id in webui_app.state.FUNCTIONS:
+ function_module = webui_app.state.FUNCTIONS[filter_id]
+ else:
+ function_module, function_type, frontmatter = (
+ load_function_module_by_id(filter_id)
+ )
+ webui_app.state.FUNCTIONS[filter_id] = function_module
+
+ if hasattr(function_module, "valves") and hasattr(
+ function_module, "Valves"
+ ):
+ valves = Functions.get_function_valves_by_id(filter_id)
+ function_module.valves = function_module.Valves(
+ **(valves if valves else {})
+ )
+
+ try:
+ if hasattr(function_module, "outlet"):
+ outlet = function_module.outlet
+
+ # Get the signature of the function
+ sig = inspect.signature(outlet)
+ params = {"body": data}
+
+ if "__user__" in sig.parameters:
+ __user__ = {
+ "id": user.id,
+ "email": user.email,
+ "name": user.name,
+ "role": user.role,
+ }
+
+ try:
+ if hasattr(function_module, "UserValves"):
+ __user__["valves"] = function_module.UserValves(
+ **Functions.get_user_valves_by_id_and_user_id(
+ filter_id, user.id
+ )
+ )
+ except Exception as e:
+ print(e)
+
+ params = {**params, "__user__": __user__}
+
+ if "__id__" in sig.parameters:
+ params = {
+ **params,
+ "__id__": filter_id,
+ }
+
+ if inspect.iscoroutinefunction(outlet):
+ data = await outlet(**params)
+ else:
+ data = outlet(**params)
+
+ except Exception as e:
+ print(f"Error: {e}")
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": str(e)},
+ )
+
+ return data
+
+
+##################################
+#
+# Task Endpoints
+#
+##################################
+
+
+# TODO: Refactor task API endpoints below into a separate file
+
+
@app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)):
return {
@@ -791,12 +1210,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]},
)
- if model["owned_by"] == "ollama":
- return await generate_ollama_chat_completion(
- OpenAIChatCompletionForm(**payload), user=user
- )
- else:
- return await generate_openai_chat_completion(payload, user=user)
+ return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions")
@@ -856,12 +1270,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]},
)
- if model["owned_by"] == "ollama":
- return await generate_ollama_chat_completion(
- OpenAIChatCompletionForm(**payload), user=user
- )
- else:
- return await generate_openai_chat_completion(payload, user=user)
+ return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/emoji/completions")
@@ -925,12 +1334,7 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]},
)
- if model["owned_by"] == "ollama":
- return await generate_ollama_chat_completion(
- OpenAIChatCompletionForm(**payload), user=user
- )
- else:
- return await generate_openai_chat_completion(payload, user=user)
+ return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tools/completions")
@@ -961,8 +1365,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
- context = await get_function_call_response(
- form_data["messages"], form_data["tool_id"], template, model_id, user
+ context, citation, file_handler = await get_function_call_response(
+ form_data["messages"],
+ form_data.get("files", []),
+ form_data["tool_id"],
+ template,
+ model_id,
+ user,
)
return context
except Exception as e:
@@ -972,94 +1381,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
)
-@app.post("/api/chat/completions")
-async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
- model_id = form_data["model"]
- if model_id not in app.state.MODELS:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Model not found",
- )
-
- model = app.state.MODELS[model_id]
- print(model)
-
- if model["owned_by"] == "ollama":
- return await generate_ollama_chat_completion(
- OpenAIChatCompletionForm(**form_data), user=user
- )
- else:
- return await generate_openai_chat_completion(form_data, user=user)
+##################################
+#
+# Pipelines Endpoints
+#
+##################################
-@app.post("/api/chat/completed")
-async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
- data = form_data
- model_id = data["model"]
-
- filters = [
- model
- for model in app.state.MODELS.values()
- if "pipeline" in model
- and "type" in model["pipeline"]
- and model["pipeline"]["type"] == "filter"
- and (
- model["pipeline"]["pipelines"] == ["*"]
- or any(
- model_id == target_model_id
- for target_model_id in model["pipeline"]["pipelines"]
- )
- )
- ]
- sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
-
- print(model_id)
-
- if model_id in app.state.MODELS:
- model = app.state.MODELS[model_id]
- if "pipeline" in model:
- sorted_filters = [model] + sorted_filters
-
- for filter in sorted_filters:
- r = None
- try:
- urlIdx = filter["urlIdx"]
-
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
-
- if key != "":
- headers = {"Authorization": f"Bearer {key}"}
- r = requests.post(
- f"{url}/{filter['id']}/filter/outlet",
- headers=headers,
- json={
- "user": {"id": user.id, "name": user.name, "role": user.role},
- "body": data,
- },
- )
-
- r.raise_for_status()
- data = r.json()
- except Exception as e:
- # Handle connection error here
- print(f"Connection error: {e}")
-
- if r is not None:
- try:
- res = r.json()
- if "detail" in res:
- return JSONResponse(
- status_code=r.status_code,
- content=res,
- )
- except:
- pass
-
- else:
- pass
-
- return data
+# TODO: Refactor pipelines API endpoints below into a separate file
@app.get("/api/pipelines/list")
@@ -1382,6 +1711,13 @@ async def update_pipeline_valves(
)
+##################################
+#
+# Config Endpoints
+#
+##################################
+
+
@app.get("/api/config")
async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
@@ -1416,6 +1752,12 @@ async def get_app_config():
"engine": audio_app.state.config.STT_ENGINE,
},
},
+ "oauth": {
+ "providers": {
+ name: config.get("name", name)
+ for name, config in OAUTH_PROVIDERS.items()
+ }
+ },
}
@@ -1445,6 +1787,9 @@ async def update_model_filter_config(
}
+# TODO: webhook endpoint should be under config endpoints
+
+
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {
@@ -1494,6 +1839,154 @@ async def get_app_latest_release_version():
)
+############################
+# OAuth Login & Callback
+############################
+
+oauth = OAuth()
+
+for provider_name, provider_config in OAUTH_PROVIDERS.items():
+ oauth.register(
+ name=provider_name,
+ client_id=provider_config["client_id"],
+ client_secret=provider_config["client_secret"],
+ server_metadata_url=provider_config["server_metadata_url"],
+ client_kwargs={
+ "scope": provider_config["scope"],
+ },
+ )
+
+# SessionMiddleware is used by authlib for oauth
+if len(OAUTH_PROVIDERS) > 0:
+ app.add_middleware(
+ SessionMiddleware,
+ secret_key=WEBUI_SECRET_KEY,
+ session_cookie="oui-session",
+ same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
+ https_only=WEBUI_SESSION_COOKIE_SECURE,
+ )
+
+
+@app.get("/oauth/{provider}/login")
+async def oauth_login(provider: str, request: Request):
+ if provider not in OAUTH_PROVIDERS:
+ raise HTTPException(404)
+ redirect_uri = request.url_for("oauth_callback", provider=provider)
+ return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
+
+
+# OAuth login logic is as follows:
+# 1. Attempt to find a user with matching subject ID, tied to the provider
+# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
+# - This is considered insecure in general, as OAuth providers do not always verify email addresses
+# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
+# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
+@app.get("/oauth/{provider}/callback")
+async def oauth_callback(provider: str, request: Request, response: Response):
+ if provider not in OAUTH_PROVIDERS:
+ raise HTTPException(404)
+ client = oauth.create_client(provider)
+ try:
+ token = await client.authorize_access_token(request)
+ except Exception as e:
+ log.warning(f"OAuth callback error: {e}")
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+ user_data: UserInfo = token["userinfo"]
+
+ sub = user_data.get("sub")
+ if not sub:
+ log.warning(f"OAuth callback failed, sub is missing: {user_data}")
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+ provider_sub = f"{provider}@{sub}"
+ email = user_data.get("email", "").lower()
+ # We currently mandate that email addresses are provided
+ if not email:
+ log.warning(f"OAuth callback failed, email is missing: {user_data}")
+ raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
+
+ # Check if the user exists
+ user = Users.get_user_by_oauth_sub(provider_sub)
+
+ if not user:
+ # If the user does not exist, check if merging is enabled
+ if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
+ # Check if the user exists by email
+ user = Users.get_user_by_email(email)
+ if user:
+ # Update the user with the new oauth sub
+ Users.update_user_oauth_sub_by_id(user.id, provider_sub)
+
+ if not user:
+ # If the user does not exist, check if signups are enabled
+ if ENABLE_OAUTH_SIGNUP.value:
+ # Check if an existing user with the same email already exists
+ existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
+ if existing_user:
+ raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
+
+ picture_url = user_data.get("picture", "")
+ if picture_url:
+ # Download the profile image into a base64 string
+ try:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(picture_url) as resp:
+ picture = await resp.read()
+ base64_encoded_picture = base64.b64encode(picture).decode(
+ "utf-8"
+ )
+ guessed_mime_type = mimetypes.guess_type(picture_url)[0]
+ if guessed_mime_type is None:
+ # assume JPG, browsers are tolerant enough of image formats
+ guessed_mime_type = "image/jpeg"
+ picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
+ except Exception as e:
+ log.error(f"Error downloading profile image '{picture_url}': {e}")
+ picture_url = ""
+ if not picture_url:
+ picture_url = "/user.png"
+ user = Auths.insert_new_auth(
+ email=email,
+ password=get_password_hash(
+ str(uuid.uuid4())
+ ), # Random password, not used
+ name=user_data.get("name", "User"),
+ profile_image_url=picture_url,
+ role=webui_app.state.config.DEFAULT_USER_ROLE,
+ oauth_sub=provider_sub,
+ )
+
+ if webui_app.state.config.WEBHOOK_URL:
+ post_webhook(
+ webui_app.state.config.WEBHOOK_URL,
+ WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+ {
+ "action": "signup",
+ "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
+ "user": user.model_dump_json(exclude_none=True),
+ },
+ )
+ else:
+ raise HTTPException(
+ status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
+ )
+
+ jwt_token = create_token(
+ data={"id": user.id},
+ expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
+ )
+
+ # Set the cookie token
+ response.set_cookie(
+ key="token",
+ value=token,
+ httponly=True, # Ensures the cookie is not accessible via JavaScript
+ )
+
+ # Redirect back to the frontend with the JWT token
+ redirect_url = f"{request.base_url}auth#token={jwt_token}"
+ return RedirectResponse(url=redirect_url)
+
+
@app.get("/manifest.json")
async def get_manifest_json():
return {
@@ -1502,7 +1995,6 @@ async def get_manifest_json():
"start_url": "/",
"display": "standalone",
"background_color": "#343541",
- "theme_color": "#343541",
"orientation": "portrait-primary",
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
}
diff --git a/backend/requirements.txt b/backend/requirements.txt
index be0e32c7d..a36af5497 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -17,11 +17,17 @@ peewee-migrate==1.12.2
psycopg2-binary==2.9.9
PyMySQL==1.1.1
bcrypt==4.1.3
-
+SQLAlchemy
+pymongo
+redis
boto3==1.34.110
argon2-cffi==23.1.0
APScheduler==3.10.4
+
+# AI libraries
+openai
+anthropic
google-generativeai==0.5.4
langchain==0.2.0
@@ -52,6 +58,7 @@ rank-bm25==0.2.2
faster-whisper==1.0.2
PyJWT[crypto]==2.8.0
+authlib==1.3.0
black==24.4.2
langfuse==2.33.0
diff --git a/backend/utils/misc.py b/backend/utils/misc.py
index c3c65d3f5..b4e499df8 100644
--- a/backend/utils/misc.py
+++ b/backend/utils/misc.py
@@ -3,7 +3,9 @@ import hashlib
import json
import re
from datetime import timedelta
-from typing import Optional, List
+from typing import Optional, List, Tuple
+import uuid
+import time
def get_last_user_message(messages: List[dict]) -> str:
@@ -28,6 +30,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None
+def get_system_message(messages: List[dict]) -> dict:
+ for message in messages:
+ if message["role"] == "system":
+ return message
+ return None
+
+
+def remove_system_message(messages: List[dict]) -> List[dict]:
+ return [message for message in messages if message["role"] != "system"]
+
+
+def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
+ return get_system_message(messages), remove_system_message(messages)
+
+
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
@@ -47,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages
+def stream_message_template(model: str, message: str):
+ return {
+ "id": f"{model}-{str(uuid.uuid4())}",
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"content": message},
+ "logprobs": None,
+ "finish_reason": None,
+ }
+ ],
+ }
+
+
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters
diff --git a/backend/utils/task.py b/backend/utils/task.py
index ea277eb0b..053a526a8 100644
--- a/backend/utils/task.py
+++ b/backend/utils/task.py
@@ -24,10 +24,16 @@ def prompt_template(
if user_name:
# Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name)
+ else:
+ # Replace {{USER_NAME}} in the template with "Unknown"
+ template = template.replace("{{USER_NAME}}", "Unknown")
if user_location:
# Replace {{USER_LOCATION}} in the template with the current location
template = template.replace("{{USER_LOCATION}}", user_location)
+ else:
+ # Replace {{USER_LOCATION}} in the template with "Unknown"
+ template = template.replace("{{USER_LOCATION}}", "Unknown")
return template
diff --git a/backend/utils/tools.py b/backend/utils/tools.py
index 5fef2a2b6..c1c41ed37 100644
--- a/backend/utils/tools.py
+++ b/backend/utils/tools.py
@@ -20,7 +20,9 @@ def get_tools_specs(tools) -> List[dict]:
function_list = [
{"name": func, "function": getattr(tools, func)}
for func in dir(tools)
- if callable(getattr(tools, func)) and not func.startswith("__")
+ if callable(getattr(tools, func))
+ and not func.startswith("__")
+ and not inspect.isclass(getattr(tools, func))
]
specs = []
@@ -65,6 +67,7 @@ def get_tools_specs(tools) -> List[dict]:
function
).parameters.items()
if param.default is param.empty
+ and not (name.startswith("__") and name.endswith("__"))
],
},
}
diff --git a/backend/utils/utils.py b/backend/utils/utils.py
index cc6bb06b8..8c3c899bd 100644
--- a/backend/utils/utils.py
+++ b/backend/utils/utils.py
@@ -1,5 +1,5 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
-from fastapi import HTTPException, status, Depends
+from fastapi import HTTPException, status, Depends, Request
from apps.webui.models.users import Users
@@ -24,7 +24,7 @@ ALGORITHM = "HS256"
# Auth Utils
##############
-bearer_security = HTTPBearer()
+bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -75,13 +75,26 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user(
+ request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
+ token = None
+
+ if auth_token is not None:
+ token = auth_token.credentials
+
+ if token is None and "token" in request.cookies:
+ token = request.cookies.get("token")
+
+ if token is None:
+ raise HTTPException(status_code=403, detail="Not authenticated")
+
# auth by api key
- if auth_token.credentials.startswith("sk-"):
- return get_current_user_by_api_key(auth_token.credentials)
+ if token.startswith("sk-"):
+ return get_current_user_by_api_key(token)
+
# auth by jwt token
- data = decode_token(auth_token.credentials)
+ data = decode_token(token)
if data != None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user is None:
diff --git a/package-lock.json b/package-lock.json
index 513993c74..bd4bc6892 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "open-webui",
- "version": "0.3.5",
+ "version": "0.3.6",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "open-webui",
- "version": "0.3.5",
+ "version": "0.3.6",
"dependencies": {
"@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6",
@@ -16,6 +16,7 @@
"async": "^3.2.5",
"bits-ui": "^0.19.7",
"codemirror": "^6.0.1",
+ "crc-32": "^1.2.2",
"dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5",
@@ -28,11 +29,12 @@
"katex": "^0.16.9",
"marked": "^9.1.0",
"mermaid": "^10.9.1",
- "pyodide": "^0.26.0-alpha.4",
- "socket.io-client": "^4.7.5",
+ "pyodide": "^0.26.1",
+ "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19",
"tippy.js": "^6.3.7",
+ "turndown": "^7.2.0",
"uuid": "^9.0.1"
},
"devDependencies": {
@@ -999,6 +1001,11 @@
"svelte": ">=3 <5"
}
},
+ "node_modules/@mixmark-io/domino": {
+ "version": "2.2.0",
+ "resolved": "https://registry.npmjs.org/@mixmark-io/domino/-/domino-2.2.0.tgz",
+ "integrity": "sha512-Y28PR25bHXUg88kCV7nivXrP2Nj2RueZ3/l/jdx6J9f8J4nsEGcgX0Qe6lt7Pa+J79+kPiJU3LguR6O/6zrLOw=="
+ },
"node_modules/@nodelib/fs.scandir": {
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
@@ -2266,11 +2273,6 @@
"dev": true,
"optional": true
},
- "node_modules/base-64": {
- "version": "1.0.0",
- "resolved": "https://registry.npmjs.org/base-64/-/base-64-1.0.0.tgz",
- "integrity": "sha512-kwDPIFCGx0NZHog36dj+tHiwP4QMzsZ3AgMViUBKI0+V5n4U0ufTCUMhnQ04diaRI8EX/QcPfql7zlhZ7j4zgg=="
- },
"node_modules/base64-js": {
"version": "1.5.1",
"resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz",
@@ -3063,6 +3065,17 @@
"layout-base": "^1.0.0"
}
},
+ "node_modules/crc-32": {
+ "version": "1.2.2",
+ "resolved": "https://registry.npmjs.org/crc-32/-/crc-32-1.2.2.tgz",
+ "integrity": "sha512-ROmzCKrTnOwybPcJApAA6WBWij23HVfGVNKqqrZpuyZOHqK2CwHSvpGuyt/UNNvaIjEd8X5IFGp4Mh+Ie1IHJQ==",
+ "bin": {
+ "crc32": "bin/crc32.njs"
+ },
+ "engines": {
+ "node": ">=0.8"
+ }
+ },
"node_modules/crelt": {
"version": "1.0.6",
"resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz",
@@ -3984,37 +3997,17 @@
}
},
"node_modules/engine.io-client": {
- "version": "6.5.3",
- "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz",
- "integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==",
+ "version": "6.5.4",
+ "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.4.tgz",
+ "integrity": "sha512-GeZeeRjpD2qf49cZQ0Wvh/8NJNfeXkXXcoGh+F77oEAgo9gUHwT1fCRxSNU+YEEaysOJTnsFHmM5oAcPy4ntvQ==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1",
"engine.io-parser": "~5.2.1",
- "ws": "~8.11.0",
+ "ws": "~8.17.1",
"xmlhttprequest-ssl": "~2.0.0"
}
},
- "node_modules/engine.io-client/node_modules/ws": {
- "version": "8.11.0",
- "resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
- "integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
- "engines": {
- "node": ">=10.0.0"
- },
- "peerDependencies": {
- "bufferutil": "^4.0.1",
- "utf-8-validate": "^5.0.2"
- },
- "peerDependenciesMeta": {
- "bufferutil": {
- "optional": true
- },
- "utf-8-validate": {
- "optional": true
- }
- }
- },
"node_modules/engine.io-parser": {
"version": "5.2.2",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.2.tgz",
@@ -7551,11 +7544,10 @@
}
},
"node_modules/pyodide": {
- "version": "0.26.0-alpha.4",
- "resolved": "https://registry.npmjs.org/pyodide/-/pyodide-0.26.0-alpha.4.tgz",
- "integrity": "sha512-Ixuczq99DwhQlE+Bt0RaS6Ln9MHSZOkbU6iN8azwaeorjHtr7ukaxh+FeTxViFrp2y+ITyKgmcobY+JnBPcULw==",
+ "version": "0.26.1",
+ "resolved": "https://registry.npmjs.org/pyodide/-/pyodide-0.26.1.tgz",
+ "integrity": "sha512-P+Gm88nwZqY7uBgjbQH8CqqU6Ei/rDn7pS1t02sNZsbyLJMyE2OVXjgNuqVT3KqYWnyGREUN0DbBUCJqk8R0ew==",
"dependencies": {
- "base-64": "^1.0.0",
"ws": "^8.5.0"
},
"engines": {
@@ -9065,6 +9057,14 @@
"node": "*"
}
},
+ "node_modules/turndown": {
+ "version": "7.2.0",
+ "resolved": "https://registry.npmjs.org/turndown/-/turndown-7.2.0.tgz",
+ "integrity": "sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==",
+ "dependencies": {
+ "@mixmark-io/domino": "^2.2.0"
+ }
+ },
"node_modules/tweetnacl": {
"version": "0.14.5",
"resolved": "https://registry.npmjs.org/tweetnacl/-/tweetnacl-0.14.5.tgz",
@@ -10382,9 +10382,9 @@
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="
},
"node_modules/ws": {
- "version": "8.17.0",
- "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.0.tgz",
- "integrity": "sha512-uJq6108EgZMAl20KagGkzCKfMEjxmKvZHG7Tlq0Z6nOky7YF7aq4mOx6xK8TJ/i1LeK4Qus7INktacctDgY8Ow==",
+ "version": "8.17.1",
+ "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz",
+ "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==",
"engines": {
"node": ">=10.0.0"
},
diff --git a/package.json b/package.json
index 46aeb14f7..bb17cd4c8 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "open-webui",
- "version": "0.3.5",
+ "version": "0.3.6",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",
@@ -56,6 +56,7 @@
"async": "^3.2.5",
"bits-ui": "^0.19.7",
"codemirror": "^6.0.1",
+ "crc-32": "^1.2.2",
"dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5",
@@ -68,11 +69,12 @@
"katex": "^0.16.9",
"marked": "^9.1.0",
"mermaid": "^10.9.1",
- "pyodide": "^0.26.0-alpha.4",
- "socket.io-client": "^4.7.5",
+ "pyodide": "^0.26.1",
+ "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19",
"tippy.js": "^6.3.7",
+ "turndown": "^7.2.0",
"uuid": "^9.0.1"
}
}
diff --git a/pyproject.toml b/pyproject.toml
index 4571e5b61..80893b15b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -59,6 +59,7 @@ dependencies = [
"faster-whisper==1.0.2",
"PyJWT[crypto]==2.8.0",
+ "authlib==1.3.0",
"black==24.4.2",
"langfuse==2.33.0",
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 6aa26dad4..f7660eae3 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -31,6 +31,8 @@ asgiref==3.8.1
# via opentelemetry-instrumentation-asgi
attrs==23.2.0
# via aiohttp
+authlib==1.3.0
+ # via open-webui
av==11.0.0
# via faster-whisper
backoff==2.2.1
@@ -93,6 +95,7 @@ coloredlogs==15.0.1
compressed-rtf==1.0.6
# via extract-msg
cryptography==42.0.7
+ # via authlib
# via msoffcrypto-tool
# via pyjwt
ctranslate2==4.2.1
@@ -395,6 +398,7 @@ pandas==2.2.2
# via open-webui
passlib==1.7.4
# via open-webui
+ # via passlib
pathspec==0.12.1
# via black
pcodedmp==1.2.6
@@ -453,6 +457,7 @@ pygments==2.18.0
# via rich
pyjwt==2.8.0
# via open-webui
+ # via pyjwt
pymysql==1.1.0
# via open-webui
pypandoc==1.13
@@ -554,9 +559,6 @@ scipy==1.13.0
# via sentence-transformers
sentence-transformers==2.7.0
# via open-webui
-setuptools==69.5.1
- # via ctranslate2
- # via opentelemetry-instrumentation
shapely==2.0.4
# via rapidocr-onnxruntime
shellingham==1.5.4
@@ -651,6 +653,7 @@ uvicorn==0.22.0
# via chromadb
# via fastapi
# via open-webui
+ # via uvicorn
uvloop==0.19.0
# via uvicorn
validators==0.28.1
@@ -678,3 +681,6 @@ youtube-transcript-api==0.6.2
# via open-webui
zipp==3.18.1
# via importlib-metadata
+setuptools==69.5.1
+ # via ctranslate2
+ # via opentelemetry-instrumentation
diff --git a/requirements.lock b/requirements.lock
index 6aa26dad4..f7660eae3 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -31,6 +31,8 @@ asgiref==3.8.1
# via opentelemetry-instrumentation-asgi
attrs==23.2.0
# via aiohttp
+authlib==1.3.0
+ # via open-webui
av==11.0.0
# via faster-whisper
backoff==2.2.1
@@ -93,6 +95,7 @@ coloredlogs==15.0.1
compressed-rtf==1.0.6
# via extract-msg
cryptography==42.0.7
+ # via authlib
# via msoffcrypto-tool
# via pyjwt
ctranslate2==4.2.1
@@ -395,6 +398,7 @@ pandas==2.2.2
# via open-webui
passlib==1.7.4
# via open-webui
+ # via passlib
pathspec==0.12.1
# via black
pcodedmp==1.2.6
@@ -453,6 +457,7 @@ pygments==2.18.0
# via rich
pyjwt==2.8.0
# via open-webui
+ # via pyjwt
pymysql==1.1.0
# via open-webui
pypandoc==1.13
@@ -554,9 +559,6 @@ scipy==1.13.0
# via sentence-transformers
sentence-transformers==2.7.0
# via open-webui
-setuptools==69.5.1
- # via ctranslate2
- # via opentelemetry-instrumentation
shapely==2.0.4
# via rapidocr-onnxruntime
shellingham==1.5.4
@@ -651,6 +653,7 @@ uvicorn==0.22.0
# via chromadb
# via fastapi
# via open-webui
+ # via uvicorn
uvloop==0.19.0
# via uvicorn
validators==0.28.1
@@ -678,3 +681,6 @@ youtube-transcript-api==0.6.2
# via open-webui
zipp==3.18.1
# via importlib-metadata
+setuptools==69.5.1
+ # via ctranslate2
+ # via opentelemetry-instrumentation
diff --git a/scripts/prepare-pyodide.js b/scripts/prepare-pyodide.js
index c14a5bf1b..5aaac5927 100644
--- a/scripts/prepare-pyodide.js
+++ b/scripts/prepare-pyodide.js
@@ -1,4 +1,6 @@
const packages = [
+ 'micropip',
+ 'packaging',
'requests',
'beautifulsoup4',
'numpy',
@@ -11,20 +13,64 @@ const packages = [
];
import { loadPyodide } from 'pyodide';
-import { writeFile, copyFile, readdir } from 'fs/promises';
+import { writeFile, readFile, copyFile, readdir, rmdir } from 'fs/promises';
async function downloadPackages() {
console.log('Setting up pyodide + micropip');
- const pyodide = await loadPyodide({
- packageCacheDir: 'static/pyodide'
- });
- await pyodide.loadPackage('micropip');
- const micropip = pyodide.pyimport('micropip');
- console.log('Downloading Pyodide packages:', packages);
- await micropip.install(packages);
- console.log('Pyodide packages downloaded, freezing into lock file');
- const lockFile = await micropip.freeze();
- await writeFile('static/pyodide/pyodide-lock.json', lockFile);
+
+ let pyodide;
+ try {
+ pyodide = await loadPyodide({
+ packageCacheDir: 'static/pyodide'
+ });
+ } catch (err) {
+ console.error('Failed to load Pyodide:', err);
+ return;
+ }
+
+ const packageJson = JSON.parse(await readFile('package.json'));
+ const pyodideVersion = packageJson.dependencies.pyodide.replace('^', '');
+
+ try {
+ const pyodidePackageJson = JSON.parse(await readFile('static/pyodide/package.json'));
+ const pyodidePackageVersion = pyodidePackageJson.version.replace('^', '');
+
+ if (pyodideVersion !== pyodidePackageVersion) {
+ console.log('Pyodide version mismatch, removing static/pyodide directory');
+ await rmdir('static/pyodide', { recursive: true });
+ }
+ } catch (e) {
+ console.log('Pyodide package not found, proceeding with download.');
+ }
+
+ try {
+ console.log('Loading micropip package');
+ await pyodide.loadPackage('micropip');
+
+ const micropip = pyodide.pyimport('micropip');
+ console.log('Downloading Pyodide packages:', packages);
+
+ try {
+ for (const pkg of packages) {
+ console.log(`Installing package: ${pkg}`);
+ await micropip.install(pkg);
+ }
+ } catch (err) {
+ console.error('Package installation failed:', err);
+ return;
+ }
+
+ console.log('Pyodide packages downloaded, freezing into lock file');
+
+ try {
+ const lockFile = await micropip.freeze();
+ await writeFile('static/pyodide/pyodide-lock.json', lockFile);
+ } catch (err) {
+ console.error('Failed to write lock file:', err);
+ }
+ } catch (err) {
+ console.error('Failed to load or install micropip:', err);
+ }
}
async function copyPyodide() {
diff --git a/src/app.css b/src/app.css
index baf620845..2dacf5d72 100644
--- a/src/app.css
+++ b/src/app.css
@@ -32,6 +32,10 @@ math {
@apply underline;
}
+iframe {
+ @apply rounded-lg;
+}
+
ol > li {
counter-increment: list-number;
display: block;
diff --git a/src/app.html b/src/app.html
index a79343df5..da6af2cc4 100644
--- a/src/app.html
+++ b/src/app.html
@@ -13,6 +13,12 @@
href="/opensearch.xml"
/>
+
+
-
-