diff --git a/backend/open_webui/__init__.py b/backend/open_webui/__init__.py
index d85be48da..0c70cb63a 100644
--- a/backend/open_webui/__init__.py
+++ b/backend/open_webui/__init__.py
@@ -73,8 +73,15 @@ def serve(
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main
+ from open_webui.env import UVICORN_WORKERS # Import the workers setting
- uvicorn.run(open_webui.main.app, host=host, port=port, forwarded_allow_ips="*")
+ uvicorn.run(
+ open_webui.main.app,
+ host=host,
+ port=port,
+ forwarded_allow_ips="*",
+ workers=UVICORN_WORKERS
+ )
@app.command()
diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py
index 8238f8a87..9f5395154 100644
--- a/backend/open_webui/config.py
+++ b/backend/open_webui/config.py
@@ -201,6 +201,7 @@ def save_config(config):
T = TypeVar("T")
+ENABLE_PERSISTENT_CONFIG = os.environ.get("ENABLE_PERSISTENT_CONFIG", "True").lower() == "true"
class PersistentConfig(Generic[T]):
def __init__(self, env_name: str, config_path: str, env_value: T):
@@ -208,7 +209,7 @@ class PersistentConfig(Generic[T]):
self.config_path = config_path
self.env_value = env_value
self.config_value = get_config_value(config_path)
- if self.config_value is not None:
+ if self.config_value is not None and ENABLE_PERSISTENT_CONFIG:
log.info(f"'{env_name}' loaded from the latest database entry")
self.value = self.config_value
else:
@@ -456,6 +457,12 @@ OAUTH_SCOPES = PersistentConfig(
os.environ.get("OAUTH_SCOPES", "openid email profile"),
)
+OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig(
+ "OAUTH_CODE_CHALLENGE_METHOD",
+ "oauth.oidc.code_challenge_method",
+ os.environ.get("OAUTH_CODE_CHALLENGE_METHOD", None),
+)
+
OAUTH_PROVIDER_NAME = PersistentConfig(
"OAUTH_PROVIDER_NAME",
"oauth.oidc.provider_name",
@@ -560,7 +567,7 @@ def load_oauth_providers():
name="microsoft",
client_id=MICROSOFT_CLIENT_ID.value,
client_secret=MICROSOFT_CLIENT_SECRET.value,
- server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
+ server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}",
client_kwargs={
"scope": MICROSOFT_OAUTH_SCOPE.value,
},
@@ -601,14 +608,21 @@ def load_oauth_providers():
):
def oidc_oauth_register(client):
+ client_kwargs = {
+ "scope": OAUTH_SCOPES.value,
+ }
+
+ if OAUTH_CODE_CHALLENGE_METHOD.value and OAUTH_CODE_CHALLENGE_METHOD.value == "S256":
+ client_kwargs["code_challenge_method"] = "S256"
+ elif OAUTH_CODE_CHALLENGE_METHOD.value:
+ raise Exception('Code challenge methods other than "%s" not supported. Given: "%s"' % ("S256", OAUTH_CODE_CHALLENGE_METHOD.value))
+
client.register(
name="oidc",
client_id=OAUTH_CLIENT_ID.value,
client_secret=OAUTH_CLIENT_SECRET.value,
server_metadata_url=OPENID_PROVIDER_URL.value,
- client_kwargs={
- "scope": OAUTH_SCOPES.value,
- },
+ client_kwargs=client_kwargs,
redirect_uri=OPENID_REDIRECT_URI.value,
)
@@ -2141,6 +2155,18 @@ PERPLEXITY_API_KEY = PersistentConfig(
os.getenv("PERPLEXITY_API_KEY", ""),
)
+SOUGOU_API_SID = PersistentConfig(
+ "SOUGOU_API_SID",
+ "rag.web.search.sougou_api_sid",
+ os.getenv("SOUGOU_API_SID", ""),
+)
+
+SOUGOU_API_SK = PersistentConfig(
+ "SOUGOU_API_SK",
+ "rag.web.search.sougou_api_sk",
+ os.getenv("SOUGOU_API_SK", ""),
+)
+
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
@@ -2472,6 +2498,24 @@ AUDIO_STT_MODEL = PersistentConfig(
os.getenv("AUDIO_STT_MODEL", ""),
)
+AUDIO_STT_AZURE_API_KEY = PersistentConfig(
+ "AUDIO_STT_AZURE_API_KEY",
+ "audio.stt.azure.api_key",
+ os.getenv("AUDIO_STT_AZURE_API_KEY", ""),
+)
+
+AUDIO_STT_AZURE_REGION = PersistentConfig(
+ "AUDIO_STT_AZURE_REGION",
+ "audio.stt.azure.region",
+ os.getenv("AUDIO_STT_AZURE_REGION", ""),
+)
+
+AUDIO_STT_AZURE_LOCALES = PersistentConfig(
+ "AUDIO_STT_AZURE_LOCALES",
+ "audio.stt.azure.locales",
+ os.getenv("AUDIO_STT_AZURE_LOCALES", ""),
+)
+
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_TTS_OPENAI_API_BASE_URL",
"audio.tts.openai.api_base_url",
diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py
index 86d87a2c3..95c54a0d2 100644
--- a/backend/open_webui/constants.py
+++ b/backend/open_webui/constants.py
@@ -31,6 +31,7 @@ class ERROR_MESSAGES(str, Enum):
USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username."
)
+ PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py
index bcc8a02c3..dda239dce 100644
--- a/backend/open_webui/env.py
+++ b/backend/open_webui/env.py
@@ -326,6 +326,20 @@ REDIS_URL = os.environ.get("REDIS_URL", "")
REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "")
REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379")
+####################################
+# UVICORN WORKERS
+####################################
+
+# Number of uvicorn worker processes for handling requests
+UVICORN_WORKERS = os.environ.get("UVICORN_WORKERS", "1")
+try:
+ UVICORN_WORKERS = int(UVICORN_WORKERS)
+ if UVICORN_WORKERS < 1:
+ UVICORN_WORKERS = 1
+except ValueError:
+ UVICORN_WORKERS = 1
+ log.info(f"Invalid UVICORN_WORKERS value, defaulting to {UVICORN_WORKERS}")
+
####################################
# WEBUI_AUTH (Required for security)
####################################
@@ -411,6 +425,21 @@ else:
except Exception:
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST = 10
+
+AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = os.environ.get(
+ "AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA", "10"
+)
+
+if AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA == "":
+ AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = None
+else:
+ try:
+ AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = int(
+ AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA
+ )
+ except Exception:
+ AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
+
####################################
# OFFLINE_MODE
####################################
@@ -463,3 +492,10 @@ OTEL_TRACES_SAMPLER = os.environ.get(
PIP_OPTIONS = os.getenv("PIP_OPTIONS", "").split()
PIP_PACKAGE_INDEX_OPTIONS = os.getenv("PIP_PACKAGE_INDEX_OPTIONS", "").split()
+
+
+####################################
+# PROGRESSIVE WEB APP OPTIONS
+####################################
+
+EXTERNAL_PWA_MANIFEST_URL = os.environ.get("EXTERNAL_PWA_MANIFEST_URL")
\ No newline at end of file
diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py
index 07efa1e5f..01af45358 100644
--- a/backend/open_webui/main.py
+++ b/backend/open_webui/main.py
@@ -148,6 +148,9 @@ from open_webui.config import (
AUDIO_STT_MODEL,
AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY,
+ AUDIO_STT_AZURE_API_KEY,
+ AUDIO_STT_AZURE_REGION,
+ AUDIO_STT_AZURE_LOCALES,
AUDIO_TTS_API_KEY,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
@@ -225,6 +228,8 @@ from open_webui.config import (
BRAVE_SEARCH_API_KEY,
EXA_API_KEY,
PERPLEXITY_API_KEY,
+ SOUGOU_API_SID,
+ SOUGOU_API_SK,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
BOCHA_SEARCH_API_KEY,
@@ -341,6 +346,7 @@ from open_webui.env import (
RESET_CONFIG_ON_START,
OFFLINE_MODE,
ENABLE_OTEL,
+ EXTERNAL_PWA_MANIFEST_URL,
)
@@ -427,6 +433,7 @@ async def lifespan(app: FastAPI):
app = FastAPI(
+ title="Open WebUI",
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
@@ -566,6 +573,7 @@ app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.SIGNOUT_REDIRECT_URI = SIGNOUT_REDIRECT_URI
+app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL
app.state.USER_COUNT = None
app.state.TOOLS = {}
@@ -653,6 +661,8 @@ app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.EXA_API_KEY = EXA_API_KEY
app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY
+app.state.config.SOUGOU_API_SID = SOUGOU_API_SID
+app.state.config.SOUGOU_API_SK = SOUGOU_API_SK
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
@@ -780,6 +790,10 @@ app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.WHISPER_MODEL = WHISPER_MODEL
app.state.config.DEEPGRAM_API_KEY = DEEPGRAM_API_KEY
+app.state.config.AUDIO_STT_AZURE_API_KEY = AUDIO_STT_AZURE_API_KEY
+app.state.config.AUDIO_STT_AZURE_REGION = AUDIO_STT_AZURE_REGION
+app.state.config.AUDIO_STT_AZURE_LOCALES = AUDIO_STT_AZURE_LOCALES
+
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
@@ -1055,6 +1069,7 @@ async def chat_completion(
model_item = form_data.pop("model_item", {})
tasks = form_data.pop("background_tasks", None)
+ metadata = {}
try:
if not model_item.get("direct", False):
model_id = form_data.get("model", None)
@@ -1110,13 +1125,15 @@ async def chat_completion(
except Exception as e:
log.debug(f"Error processing chat payload: {e}")
- Chats.upsert_message_to_chat_by_id_and_message_id(
- metadata["chat_id"],
- metadata["message_id"],
- {
- "error": {"content": str(e)},
- },
- )
+ if metadata.get("chat_id") and metadata.get("message_id"):
+ # Update the chat message with the error
+ Chats.upsert_message_to_chat_by_id_and_message_id(
+ metadata["chat_id"],
+ metadata["message_id"],
+ {
+ "error": {"content": str(e)},
+ },
+ )
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -1392,29 +1409,32 @@ async def oauth_callback(provider: str, request: Request, response: Response):
@app.get("/manifest.json")
async def get_manifest_json():
- return {
- "name": app.state.WEBUI_NAME,
- "short_name": app.state.WEBUI_NAME,
- "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
- "start_url": "/",
- "display": "standalone",
- "background_color": "#343541",
- "orientation": "natural",
- "icons": [
- {
- "src": "/static/logo.png",
- "type": "image/png",
- "sizes": "500x500",
- "purpose": "any",
- },
- {
- "src": "/static/logo.png",
- "type": "image/png",
- "sizes": "500x500",
- "purpose": "maskable",
- },
- ],
- }
+ if app.state.EXTERNAL_PWA_MANIFEST_URL:
+ return requests.get(app.state.EXTERNAL_PWA_MANIFEST_URL).json()
+ else:
+ return {
+ "name": app.state.WEBUI_NAME,
+ "short_name": app.state.WEBUI_NAME,
+ "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.",
+ "start_url": "/",
+ "display": "standalone",
+ "background_color": "#343541",
+ "orientation": "natural",
+ "icons": [
+ {
+ "src": "/static/logo.png",
+ "type": "image/png",
+ "sizes": "500x500",
+ "purpose": "any",
+ },
+ {
+ "src": "/static/logo.png",
+ "type": "image/png",
+ "sizes": "500x500",
+ "purpose": "maskable",
+ },
+ ],
+ }
@app.get("/opensearch.xml")
diff --git a/backend/open_webui/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py
index 8eb48488b..f59dd7df5 100644
--- a/backend/open_webui/retrieval/loaders/youtube.py
+++ b/backend/open_webui/retrieval/loaders/youtube.py
@@ -110,7 +110,7 @@ class YoutubeLoader:
transcript = " ".join(
map(
- lambda transcript_piece: transcript_piece["text"].strip(" "),
+ lambda transcript_piece: transcript_piece.text.strip(" "),
transcript_pieces,
)
)
diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py
index 12d48f869..a00e6982c 100644
--- a/backend/open_webui/retrieval/utils.py
+++ b/backend/open_webui/retrieval/utils.py
@@ -77,6 +77,7 @@ def query_doc(
collection_name: str, query_embedding: list[float], k: int, user: UserModel = None
):
try:
+ log.debug(f"query_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.search(
collection_name=collection_name,
vectors=[query_embedding],
@@ -94,6 +95,7 @@ def query_doc(
def get_doc(collection_name: str, user: UserModel = None):
try:
+ log.debug(f"get_doc:doc {collection_name}")
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
@@ -116,6 +118,7 @@ def query_doc_with_hybrid_search(
r: float,
) -> dict:
try:
+ log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
bm25_retriever = BM25Retriever.from_texts(
texts=collection_result.documents[0],
metadatas=collection_result.metadatas[0],
@@ -168,6 +171,7 @@ def query_doc_with_hybrid_search(
)
return result
except Exception as e:
+ log.exception(f"Error querying doc {collection_name} with hybrid search: {e}")
raise e
@@ -257,6 +261,7 @@ def query_collection(
) -> dict:
results = []
for query in queries:
+ log.debug(f"query_collection:query {query}")
query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
for collection_name in collection_names:
if collection_name:
@@ -292,6 +297,7 @@ def query_collection_with_hybrid_search(
collection_results = {}
for collection_name in collection_names:
try:
+ log.debug(f"query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}")
collection_results[collection_name] = VECTOR_DB_CLIENT.get(
collection_name=collection_name
)
@@ -613,6 +619,7 @@ def generate_openai_batch_embeddings(
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
+ log.debug(f"generate_openai_batch_embeddings:model {model} batch size: {len(texts)}")
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
@@ -655,6 +662,7 @@ def generate_ollama_batch_embeddings(
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
+ log.debug(f"generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}")
json_data = {"input": texts, "model": model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py
index d95086671..bf8ae6880 100644
--- a/backend/open_webui/retrieval/web/duckduckgo.py
+++ b/backend/open_webui/retrieval/web/duckduckgo.py
@@ -3,6 +3,7 @@ from typing import Optional
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
+from duckduckgo_search.exceptions import RatelimitException
from open_webui.env import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
@@ -22,16 +23,15 @@ def search_duckduckgo(
list[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
+ search_results = []
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
- ddgs_gen = ddgs.text(
- query, safesearch="moderate", max_results=count, backend="api"
- )
- # Check if there are search results
- if ddgs_gen:
- # Convert the search results into a list
- search_results = [r for r in ddgs_gen]
-
+ try:
+ search_results = ddgs.text(
+ query, safesearch="moderate", max_results=count, backend="lite"
+ )
+ except RatelimitException as e:
+ log.error(f"RatelimitException: {e}")
if filter_list:
search_results = get_filtered_results(search_results, filter_list)
diff --git a/backend/open_webui/retrieval/web/sougou.py b/backend/open_webui/retrieval/web/sougou.py
new file mode 100644
index 000000000..af7957c4f
--- /dev/null
+++ b/backend/open_webui/retrieval/web/sougou.py
@@ -0,0 +1,60 @@
+import logging
+import json
+from typing import Optional, List
+
+
+from open_webui.retrieval.web.main import SearchResult, get_filtered_results
+from open_webui.env import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["RAG"])
+
+
+def search_sougou(
+ sougou_api_sid: str,
+ sougou_api_sk: str,
+ query: str,
+ count: int,
+ filter_list: Optional[List[str]] = None,
+) -> List[SearchResult]:
+ from tencentcloud.common.common_client import CommonClient
+ from tencentcloud.common import credential
+ from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
+ TencentCloudSDKException,
+ )
+ from tencentcloud.common.profile.client_profile import ClientProfile
+ from tencentcloud.common.profile.http_profile import HttpProfile
+
+ try:
+ cred = credential.Credential(sougou_api_sid, sougou_api_sk)
+ http_profile = HttpProfile()
+ http_profile.endpoint = "tms.tencentcloudapi.com"
+ client_profile = ClientProfile()
+ client_profile.http_profile = http_profile
+ params = json.dumps({"Query": query, "Cnt": 20})
+ common_client = CommonClient(
+ "tms", "2020-12-29", cred, "", profile=client_profile
+ )
+ results = [
+ json.loads(page)
+ for page in common_client.call_json("SearchPro", json.loads(params))[
+ "Response"
+ ]["Pages"]
+ ]
+ sorted_results = sorted(
+ results, key=lambda x: x.get("scour", 0.0), reverse=True
+ )
+ if filter_list:
+ sorted_results = get_filtered_results(sorted_results, filter_list)
+
+ return [
+ SearchResult(
+ link=result.get("url"),
+ title=result.get("title"),
+ snippet=result.get("passage"),
+ )
+ for result in sorted_results[:count]
+ ]
+ except TencentCloudSDKException as err:
+ log.error(f"Error in Sougou search: {err}")
+ return []
diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py
index ea1372623..3fa066acf 100644
--- a/backend/open_webui/routers/audio.py
+++ b/backend/open_webui/routers/audio.py
@@ -50,6 +50,8 @@ router = APIRouter()
# Constants
MAX_FILE_SIZE_MB = 25
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
+AZURE_MAX_FILE_SIZE_MB = 200
+AZURE_MAX_FILE_SIZE = AZURE_MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@@ -68,8 +70,8 @@ from pydub import AudioSegment
from pydub.utils import mediainfo
-def is_mp4_audio(file_path):
- """Check if the given file is an MP4 audio file."""
+def get_audio_format(file_path):
+ """Check if the given file needs to be converted to a different format."""
if not os.path.isfile(file_path):
log.error(f"File not found: {file_path}")
return False
@@ -80,13 +82,17 @@ def is_mp4_audio(file_path):
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
- return True
- return False
+ return "mp4"
+ elif info.get("format_name") == "ogg":
+ return "ogg"
+ elif info.get("format_name") == "matroska,webm":
+ return "webm"
+ return None
-def convert_mp4_to_wav(file_path, output_path):
- """Convert MP4 audio file to WAV format."""
- audio = AudioSegment.from_file(file_path, format="mp4")
+def convert_audio_to_wav(file_path, output_path, conversion_type):
+ """Convert MP4/OGG audio file to WAV format."""
+ audio = AudioSegment.from_file(file_path, format=conversion_type)
audio.export(output_path, format="wav")
log.info(f"Converted {file_path} to {output_path}")
@@ -141,6 +147,9 @@ class STTConfigForm(BaseModel):
MODEL: str
WHISPER_MODEL: str
DEEPGRAM_API_KEY: str
+ AZURE_API_KEY: str
+ AZURE_REGION: str
+ AZURE_LOCALES: str
class AudioConfigUpdateForm(BaseModel):
@@ -169,6 +178,9 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)):
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
+ "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
+ "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
+ "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
},
}
@@ -195,6 +207,9 @@ async def update_audio_config(
request.app.state.config.STT_MODEL = form_data.stt.MODEL
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY
+ request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY
+ request.app.state.config.AUDIO_STT_AZURE_REGION = form_data.stt.AZURE_REGION
+ request.app.state.config.AUDIO_STT_AZURE_LOCALES = form_data.stt.AZURE_LOCALES
if request.app.state.config.STT_ENGINE == "":
request.app.state.faster_whisper_model = set_faster_whisper_model(
@@ -220,6 +235,9 @@ async def update_audio_config(
"MODEL": request.app.state.config.STT_MODEL,
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
"DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY,
+ "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY,
+ "AZURE_REGION": request.app.state.config.AUDIO_STT_AZURE_REGION,
+ "AZURE_LOCALES": request.app.state.config.AUDIO_STT_AZURE_LOCALES,
},
}
@@ -496,10 +514,15 @@ def transcribe(request: Request, file_path):
log.debug(data)
return data
elif request.app.state.config.STT_ENGINE == "openai":
- if is_mp4_audio(file_path):
- os.rename(file_path, file_path.replace(".wav", ".mp4"))
- # Convert MP4 audio file to WAV format
- convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
+ audio_format = get_audio_format(file_path)
+ if audio_format:
+ os.rename(file_path, file_path.replace(".wav", f".{audio_format}"))
+ # Convert unsupported audio file to WAV format
+ convert_audio_to_wav(
+ file_path.replace(".wav", f".{audio_format}"),
+ file_path,
+ audio_format,
+ )
r = None
try:
@@ -598,6 +621,107 @@ def transcribe(request: Request, file_path):
detail = f"External: {e}"
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
+ elif request.app.state.config.STT_ENGINE == "azure":
+ # Check file exists and size
+ if not os.path.exists(file_path):
+ raise HTTPException(
+ status_code=400,
+ detail="Audio file not found"
+ )
+
+ # Check file size (Azure has a larger limit of 200MB)
+ file_size = os.path.getsize(file_path)
+ if file_size > AZURE_MAX_FILE_SIZE:
+ raise HTTPException(
+ status_code=400,
+ detail=f"File size exceeds Azure's limit of {AZURE_MAX_FILE_SIZE_MB}MB",
+ )
+
+ api_key = request.app.state.config.AUDIO_STT_AZURE_API_KEY
+ region = request.app.state.config.AUDIO_STT_AZURE_REGION
+ locales = request.app.state.config.AUDIO_STT_AZURE_LOCALES
+
+ # IF NO LOCALES, USE DEFAULTS
+ if len(locales) < 2:
+ locales = ['en-US', 'es-ES', 'es-MX', 'fr-FR', 'hi-IN',
+ 'it-IT','de-DE', 'en-GB', 'en-IN', 'ja-JP',
+ 'ko-KR', 'pt-BR', 'zh-CN']
+ locales = ','.join(locales)
+
+
+ if not api_key or not region:
+ raise HTTPException(
+ status_code=400,
+ detail="Azure API key and region are required for Azure STT",
+ )
+
+ r = None
+ try:
+ # Prepare the request
+ data = {'definition': json.dumps({
+ "locales": locales.split(','),
+ "diarization": {"maxSpeakers": 3,"enabled": True}
+ } if locales else {}
+ )
+ }
+ url = f"https://{region}.api.cognitive.microsoft.com/speechtotext/transcriptions:transcribe?api-version=2024-11-15"
+
+ # Use context manager to ensure file is properly closed
+ with open(file_path, 'rb') as audio_file:
+ r = requests.post(
+ url=url,
+ files={'audio': audio_file},
+ data=data,
+ headers={
+ 'Ocp-Apim-Subscription-Key': api_key,
+ },
+ )
+
+ r.raise_for_status()
+ response = r.json()
+
+ # Extract transcript from response
+ if not response.get('combinedPhrases'):
+ raise ValueError("No transcription found in response")
+
+ # Get the full transcript from combinedPhrases
+ transcript = response['combinedPhrases'][0].get('text', '').strip()
+ if not transcript:
+ raise ValueError("Empty transcript in response")
+
+ data = {"text": transcript}
+
+ # Save transcript to json file (consistent with other providers)
+ transcript_file = f"{file_dir}/{id}.json"
+ with open(transcript_file, "w") as f:
+ json.dump(data, f)
+
+ log.debug(data)
+ return data
+
+ except (KeyError, IndexError, ValueError) as e:
+ log.exception("Error parsing Azure response")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to parse Azure response: {str(e)}",
+ )
+ except requests.exceptions.RequestException as e:
+ log.exception(e)
+ detail = None
+
+ try:
+ if r is not None and r.status_code != 200:
+ res = r.json()
+ if "error" in res:
+ detail = f"External: {res['error'].get('message', '')}"
+ except Exception:
+ detail = f"External: {e}"
+
+ raise HTTPException(
+ status_code=getattr(r, 'status_code', 500) if r else 500,
+ detail=detail if detail else "Open WebUI: Server Connection Error",
+ )
+
def compress_audio(file_path):
if os.path.getsize(file_path) > MAX_FILE_SIZE:
diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py
index 04e04afbf..5c91fce29 100644
--- a/backend/open_webui/routers/auths.py
+++ b/backend/open_webui/routers/auths.py
@@ -231,11 +231,13 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
entry = connection_app.entries[0]
username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower()
- email = str(entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"])
- if not email or email == "" or email == "[]":
+ email = entry[f"{LDAP_ATTRIBUTE_FOR_MAIL}"]
+ if not email:
raise HTTPException(400, "User does not have a valid email address.")
- else:
+ elif isinstance(email, str):
email = email.lower()
+ elif isinstance(email, list):
+ email = email[0].lower()
cn = str(entry["cn"])
user_dn = entry.entry_dn
@@ -455,6 +457,13 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
# Disable signup after the first user is created
request.app.state.config.ENABLE_SIGNUP = False
+ # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
+ if len(form_data.password.encode("utf-8")) > 72:
+ raise HTTPException(
+ status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.PASSWORD_TOO_LONG,
+ )
+
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(),
diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py
index c30366545..8a2888d86 100644
--- a/backend/open_webui/routers/files.py
+++ b/backend/open_webui/routers/files.py
@@ -1,6 +1,7 @@
import logging
import os
import uuid
+from fnmatch import fnmatch
from pathlib import Path
from typing import Optional
from urllib.parse import quote
@@ -177,6 +178,47 @@ async def list_files(user=Depends(get_verified_user), content: bool = Query(True
return files
+############################
+# Search Files
+############################
+
+
+@router.get("/search", response_model=list[FileModelResponse])
+async def search_files(
+ filename: str = Query(
+ ...,
+ description="Filename pattern to search for. Supports wildcards such as '*.txt'",
+ ),
+ content: bool = Query(True),
+ user=Depends(get_verified_user),
+):
+ """
+ Search for files by filename with support for wildcard patterns.
+ """
+ # Get files according to user role
+ if user.role == "admin":
+ files = Files.get_files()
+ else:
+ files = Files.get_files_by_user_id(user.id)
+
+ # Get matching files
+ matching_files = [
+ file for file in files if fnmatch(file.filename.lower(), filename.lower())
+ ]
+
+ if not matching_files:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="No files found matching the pattern.",
+ )
+
+ if not content:
+ for file in matching_files:
+ del file.data["content"]
+
+ return matching_files
+
+
############################
# Delete All Files
############################
diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py
index bc1e2429e..ab745cf84 100644
--- a/backend/open_webui/routers/knowledge.py
+++ b/backend/open_webui/routers/knowledge.py
@@ -159,6 +159,72 @@ async def create_new_knowledge(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_EXISTS,
)
+
+
+
+############################
+# ReindexKnowledgeFiles
+############################
+
+
+@router.post("/reindex", response_model=bool)
+async def reindex_knowledge_files(
+ request: Request,
+ user=Depends(get_verified_user)
+):
+ if user.role != "admin":
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
+ )
+
+ knowledge_bases = Knowledges.get_knowledge_bases()
+
+ log.info(f"Starting reindexing for {len(knowledge_bases)} knowledge bases")
+
+ for knowledge_base in knowledge_bases:
+ try:
+ files = Files.get_files_by_ids(knowledge_base.data.get("file_ids", []))
+
+ try:
+ if VECTOR_DB_CLIENT.has_collection(collection_name=knowledge_base.id):
+ VECTOR_DB_CLIENT.delete_collection(
+ collection_name=knowledge_base.id
+ )
+ except Exception as e:
+ log.error(f"Error deleting collection {knowledge_base.id}: {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Error deleting vector DB collection"
+ )
+
+ failed_files = []
+ for file in files:
+ try:
+ process_file(
+ request,
+ ProcessFileForm(file_id=file.id, collection_name=knowledge_base.id),
+ user=user,
+ )
+ except Exception as e:
+ log.error(f"Error processing file {file.filename} (ID: {file.id}): {str(e)}")
+ failed_files.append({"file_id": file.id, "error": str(e)})
+ continue
+
+ except Exception as e:
+ log.error(f"Error processing knowledge base {knowledge_base.id}: {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Error processing knowledge base"
+ )
+
+ if failed_files:
+ log.warning(f"Failed to process {len(failed_files)} files in knowledge base {knowledge_base.id}")
+ for failed in failed_files:
+ log.warning(f"File ID: {failed['file_id']}, Error: {failed['error']}")
+
+ log.info("Reindexing completed successfully")
+ return True
############################
@@ -676,3 +742,6 @@ def add_files_to_knowledge_batch(
return KnowledgeFilesResponse(
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
)
+
+
+
diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py
index f31abd9ff..8e1708c65 100644
--- a/backend/open_webui/routers/retrieval.py
+++ b/backend/open_webui/routers/retrieval.py
@@ -60,6 +60,7 @@ from open_webui.retrieval.web.tavily import search_tavily
from open_webui.retrieval.web.bing import search_bing
from open_webui.retrieval.web.exa import search_exa
from open_webui.retrieval.web.perplexity import search_perplexity
+from open_webui.retrieval.web.sougou import search_sougou
from open_webui.retrieval.utils import (
get_embedding_function,
@@ -411,6 +412,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"exa_api_key": request.app.state.config.EXA_API_KEY,
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
+ "sougou_api_sid": request.app.state.config.SOUGOU_API_SID,
+ "sougou_api_sk": request.app.state.config.SOUGOU_API_SK,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@@ -478,6 +481,8 @@ class WebSearchConfig(BaseModel):
bing_search_v7_subscription_key: Optional[str] = None
exa_api_key: Optional[str] = None
perplexity_api_key: Optional[str] = None
+ sougou_api_sid: Optional[str] = None
+ sougou_api_sk: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
trust_env: Optional[bool] = None
@@ -640,6 +645,12 @@ async def update_rag_config(
request.app.state.config.PERPLEXITY_API_KEY = (
form_data.web.search.perplexity_api_key
)
+ request.app.state.config.SOUGOU_API_SID = (
+ form_data.web.search.sougou_api_sid
+ )
+ request.app.state.config.SOUGOU_API_SK = (
+ form_data.web.search.sougou_api_sk
+ )
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = (
form_data.web.search.result_count
@@ -712,6 +723,8 @@ async def update_rag_config(
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
"exa_api_key": request.app.state.config.EXA_API_KEY,
"perplexity_api_key": request.app.state.config.PERPLEXITY_API_KEY,
+ "sougou_api_sid": request.app.state.config.SOUGOU_API_SID,
+ "sougou_api_sk": request.app.state.config.SOUGOU_API_SK,
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
@@ -1267,6 +1280,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
- TAVILY_API_KEY
- EXA_API_KEY
- PERPLEXITY_API_KEY
+ - SOUGOU_API_SID + SOUGOU_API_SK
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
Args:
@@ -1438,6 +1452,17 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]:
request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
+ elif engine == 'sougou':
+ if request.app.state.config.SOUGOU_API_SID and request.app.state.config.SOUGOU_API_SK:
+ return search_sougou(
+ request.app.state.config.SOUGOU_API_SID,
+ request.app.state.config.SOUGOU_API_SK,
+ query,
+ request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
+ request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
+ )
+ else:
+ raise Exception("No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables")
else:
raise Exception("No search engine API key found in environment variables")
diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py
index 8a98b4e20..318f61398 100644
--- a/backend/open_webui/routers/tools.py
+++ b/backend/open_webui/routers/tools.py
@@ -10,11 +10,11 @@ from open_webui.models.tools import (
ToolUserResponse,
Tools,
)
-from open_webui.utils.plugin import load_tools_module_by_id, replace_imports
+from open_webui.utils.plugin import load_tool_module_by_id, replace_imports
from open_webui.config import CACHE_DIR
from open_webui.constants import ERROR_MESSAGES
from fastapi import APIRouter, Depends, HTTPException, Request, status
-from open_webui.utils.tools import get_tools_specs
+from open_webui.utils.tools import get_tool_specs
from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.utils.access_control import has_access, has_permission
from open_webui.env import SRC_LOG_LEVELS
@@ -45,7 +45,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
)
tools = Tools.get_tools()
- for idx, server in enumerate(request.app.state.TOOL_SERVERS):
+ for server in request.app.state.TOOL_SERVERS:
tools.append(
ToolUserResponse(
**{
@@ -60,7 +60,7 @@ async def get_tools(request: Request, user=Depends(get_verified_user)):
.get("description", ""),
},
"access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
- idx
+ server["idx"]
]
.get("config", {})
.get("access_control", None),
@@ -137,15 +137,15 @@ async def create_new_tools(
if tools is None:
try:
form_data.content = replace_imports(form_data.content)
- tools_module, frontmatter = load_tools_module_by_id(
+ tool_module, frontmatter = load_tool_module_by_id(
form_data.id, content=form_data.content
)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
- TOOLS[form_data.id] = tools_module
+ TOOLS[form_data.id] = tool_module
- specs = get_tools_specs(TOOLS[form_data.id])
+ specs = get_tool_specs(TOOLS[form_data.id])
tools = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = CACHE_DIR / "tools" / form_data.id
@@ -226,15 +226,13 @@ async def update_tools_by_id(
try:
form_data.content = replace_imports(form_data.content)
- tools_module, frontmatter = load_tools_module_by_id(
- id, content=form_data.content
- )
+ tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
- TOOLS[id] = tools_module
+ TOOLS[id] = tool_module
- specs = get_tools_specs(TOOLS[id])
+ specs = get_tool_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
@@ -332,7 +330,7 @@ async def get_tools_valves_spec_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
- tools_module, _ = load_tools_module_by_id(id)
+ tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "Valves"):
@@ -375,7 +373,7 @@ async def update_tools_valves_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
- tools_module, _ = load_tools_module_by_id(id)
+ tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if not hasattr(tools_module, "Valves"):
@@ -431,7 +429,7 @@ async def get_tools_user_valves_spec_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
- tools_module, _ = load_tools_module_by_id(id)
+ tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
@@ -455,7 +453,7 @@ async def update_tools_user_valves_by_id(
if id in request.app.state.TOOLS:
tools_module = request.app.state.TOOLS[id]
else:
- tools_module, _ = load_tools_module_by_id(id)
+ tools_module, _ = load_tool_module_by_id(id)
request.app.state.TOOLS[id] = tools_module
if hasattr(tools_module, "UserValves"):
diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py
index badae9906..f246bc84b 100644
--- a/backend/open_webui/utils/middleware.py
+++ b/backend/open_webui/utils/middleware.py
@@ -897,12 +897,16 @@ async def process_chat_payload(request, form_data, user, metadata, model):
# If context is not empty, insert it into the messages
if len(sources) > 0:
context_string = ""
- for source_idx, source in enumerate(sources):
+ citated_file_idx = {}
+ for _, source in enumerate(sources, 1):
if "document" in source:
- for doc_idx, doc_context in enumerate(source["document"]):
- context_string += (
- f'