From d24c21b40f27a7605485233b764c49c89bd848c2 Mon Sep 17 00:00:00 2001 From: DmitriyAlergant-T1A Date: Thu, 21 Nov 2024 23:14:05 -0500 Subject: [PATCH 01/30] Fix Logging cleanup: removed some extraneous hard prints (including some that revealed message content!); improved debug logging a bit. + added chat_id to task metadata (helpful for logging/tracking in some pipe functions) --- backend/open_webui/apps/openai/main.py | 2 - backend/open_webui/apps/webui/main.py | 13 ++++-- backend/open_webui/apps/webui/utils.py | 17 +++++--- backend/open_webui/main.py | 56 +++++++++++++------------- 4 files changed, 47 insertions(+), 41 deletions(-) diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 6d6ac50c6..31c36a8a1 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -585,8 +585,6 @@ async def generate_chat_completion( # Convert the modified body back to JSON payload = json.dumps(payload) - log.debug(payload) - headers = {} headers["Authorization"] = f"Bearer {key}" headers["Content-Type"] = "application/json" diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index ce4945b69..8995dcdff 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -68,6 +68,7 @@ from open_webui.config import ( ) from open_webui.env import ( ENV, + SRC_LOG_LEVELS, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, ) @@ -94,6 +95,7 @@ app = FastAPI( ) log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) app.state.config = AppConfig() @@ -270,7 +272,7 @@ async def get_pipe_models(): log.exception(e) sub_pipes = [] - print(sub_pipes) + log.debug(f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}") for p in sub_pipes: sub_pipe_id = f'{pipe.id}.{p["id"]}' @@ -280,6 +282,7 @@ async def get_pipe_models(): sub_pipe_name = f"{function_module.name}{sub_pipe_name}" pipe_flag = {"type": pipe.type} + pipe_models.append( { "id": sub_pipe_id, @@ -293,6 +296,8 @@ async def get_pipe_models(): else: pipe_flag = {"type": "pipe"} + log.debug(f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}") + pipe_models.append( { "id": pipe.id, @@ -346,7 +351,7 @@ def get_pipe_id(form_data: dict) -> str: pipe_id = form_data["model"] if "." in pipe_id: pipe_id, _ = pipe_id.split(".", 1) - print(pipe_id) + return pipe_id @@ -453,7 +458,7 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}): return except Exception as e: - print(f"Error: {e}") + log.error(f"Error: {e}") yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" return @@ -483,7 +488,7 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}): res = await execute_pipe(pipe, params) except Exception as e: - print(f"Error: {e}") + log.error(f"Error: {e}") return {"error": {"detail": str(e)}} if isinstance(res, StreamingResponse) or isinstance(res, dict): diff --git a/backend/open_webui/apps/webui/utils.py b/backend/open_webui/apps/webui/utils.py index 6bfddd072..054158b3e 100644 --- a/backend/open_webui/apps/webui/utils.py +++ b/backend/open_webui/apps/webui/utils.py @@ -5,10 +5,15 @@ import sys from importlib import util import types import tempfile +import logging +from open_webui.env import SRC_LOG_LEVELS from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.tools import Tools +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + def extract_frontmatter(content): """ @@ -95,7 +100,7 @@ def load_tools_module_by_id(toolkit_id, content=None): # Executing the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) - print(f"Loaded module: {module.__name__}") + log.info(f"Loaded module: {module.__name__}") # Create and return the object if the class 'Tools' is found in the module if hasattr(module, "Tools"): @@ -103,7 +108,7 @@ def load_tools_module_by_id(toolkit_id, content=None): else: raise Exception("No Tools class found in the module") except Exception as e: - print(f"Error loading module: {toolkit_id}: {e}") + log.error(f"Error loading module: {toolkit_id}: {e}") del sys.modules[module_name] # Clean up raise e finally: @@ -139,7 +144,7 @@ def load_function_module_by_id(function_id, content=None): # Execute the modified content in the created module's namespace exec(content, module.__dict__) frontmatter = extract_frontmatter(content) - print(f"Loaded module: {module.__name__}") + log.info(f"Loaded module: {module.__name__}") # Create appropriate object based on available class type in the module if hasattr(module, "Pipe"): @@ -151,7 +156,7 @@ def load_function_module_by_id(function_id, content=None): else: raise Exception("No Function class found in the module") except Exception as e: - print(f"Error loading module: {function_id}: {e}") + log.error(f"Error loading module: {function_id}: {e}") del sys.modules[module_name] # Cleanup by removing the module in case of error Functions.update_function_by_id(function_id, {"is_active": False}) @@ -164,7 +169,7 @@ def install_frontmatter_requirements(requirements): if requirements: req_list = [req.strip() for req in requirements.split(",")] for req in req_list: - print(f"Installing requirement: {req}") + log.info(f"Installing requirement: {req}") subprocess.check_call([sys.executable, "-m", "pip", "install", req]) else: - print("No requirements found in frontmatter.") + log.info("No requirements found in frontmatter.") diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index cfa13e0a5..9aeffae2e 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -539,7 +539,6 @@ async def chat_completion_files_handler( if len(queries) == 0: queries = [get_last_user_message(body["messages"])] - print(f"{queries=}") sources = get_sources_from_files( files=files, @@ -970,7 +969,7 @@ app.add_middleware(SecurityHeadersMiddleware) @app.middleware("http") async def commit_session_after_request(request: Request, call_next): response = await call_next(request) - log.debug("Commit session after request") + #log.debug("Commit session after request") Session.commit() return response @@ -1177,6 +1176,8 @@ async def get_all_models(): model["actions"].extend( get_action_items_from_module(action_function, function_module) ) + log.debug(f"get_all_models() returned {len(models)} models") + return models @@ -1214,6 +1215,8 @@ async def get_models(user=Depends(get_verified_user)): filtered_models.append(model) models = filtered_models + log.debug(f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}") + return {"data": models} @@ -1704,7 +1707,6 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u @app.post("/api/task/title/completions") async def generate_title(form_data: dict, user=Depends(get_verified_user)): - print("generate_title") model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -1725,9 +1727,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): models, ) - print(task_model_id) - - model = models[task_model_id] + log.debug(f"generating chat title using model {task_model_id} for user {user.email} ") if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE @@ -1766,10 +1766,12 @@ Artificial Intelligence in Healthcare "max_completion_tokens": 50, } ), - "chat_id": form_data.get("chat_id", None), - "metadata": {"task": str(TASKS.TITLE_GENERATION), "task_body": form_data}, + "metadata": { + "task": str(TASKS.TITLE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None) + }, } - log.debug(payload) # Handle pipeline filters try: @@ -1793,7 +1795,7 @@ Artificial Intelligence in Healthcare @app.post("/api/task/tags/completions") async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): - print("generate_chat_tags") + if not app.state.config.ENABLE_TAGS_GENERATION: return JSONResponse( status_code=status.HTTP_200_OK, @@ -1818,7 +1820,8 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): app.state.config.TASK_MODEL_EXTERNAL, models, ) - print(task_model_id) + + log.debug(f"generating chat tags using model {task_model_id} for user {user.email} ") if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE @@ -1849,9 +1852,12 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - "metadata": {"task": str(TASKS.TAGS_GENERATION), "task_body": form_data}, + "metadata": { + "task": str(TASKS.TAGS_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None) + } } - log.debug(payload) # Handle pipeline filters try: @@ -1875,7 +1881,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } @app.post("/api/task/queries/completions") async def generate_queries(form_data: dict, user=Depends(get_verified_user)): - print("generate_queries") + type = form_data.get("type") if type == "web_search": if not app.state.config.ENABLE_SEARCH_QUERY_GENERATION: @@ -1908,9 +1914,8 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): app.state.config.TASK_MODEL_EXTERNAL, models, ) - print(task_model_id) - - model = models[task_model_id] + + log.debug(f"generating {type} queries using model {task_model_id} for user {user.email}") if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE @@ -1925,9 +1930,8 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data}, + "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None)}, } - log.debug(payload) # Handle pipeline filters try: @@ -1951,7 +1955,6 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): @app.post("/api/task/emoji/completions") async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): - print("generate_emoji") model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -1971,9 +1974,8 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): app.state.config.TASK_MODEL_EXTERNAL, models, ) - print(task_model_id) - model = models[task_model_id] + log.debug(f"generating emoji using model {task_model_id} for user {user.email} ") template = ''' Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). @@ -2003,7 +2005,6 @@ Message: """{{prompt}}""" "chat_id": form_data.get("chat_id", None), "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data}, } - log.debug(payload) # Handle pipeline filters try: @@ -2027,7 +2028,6 @@ Message: """{{prompt}}""" @app.post("/api/task/moa/completions") async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): - print("generate_moa_response") model_list = await get_all_models() models = {model["id"]: model for model in model_list} @@ -2047,9 +2047,8 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user) app.state.config.TASK_MODEL_EXTERNAL, models, ) - print(task_model_id) - - model = models[task_model_id] + + log.debug(f"generating MOA model {task_model_id} for user {user.email} ") template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" @@ -2073,7 +2072,6 @@ Responses from models: {{responses}}""" "task_body": form_data, }, } - log.debug(payload) try: payload = filter_pipeline(payload, user, models) @@ -2108,7 +2106,7 @@ Responses from models: {{responses}}""" async def get_pipelines_list(user=Depends(get_admin_user)): responses = await get_openai_models_responses() - print(responses) + log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") urlIdxs = [ idx for idx, response in enumerate(responses) From 374d6cad18eab9e36b01f35bc1fa495986086608 Mon Sep 17 00:00:00 2001 From: DmitriyAlergant-T1A Date: Fri, 22 Nov 2024 23:11:46 -0500 Subject: [PATCH 02/30] Python Formatting (Failed CI - fixed) --- backend/open_webui/apps/webui/main.py | 12 ++++--- backend/open_webui/main.py | 45 +++++++++++++++++---------- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 8995dcdff..bedd49ae3 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -272,7 +272,9 @@ async def get_pipe_models(): log.exception(e) sub_pipes = [] - log.debug(f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}") + log.debug( + f"get_pipe_models: function '{pipe.id}' is a manifold of {sub_pipes}" + ) for p in sub_pipes: sub_pipe_id = f'{pipe.id}.{p["id"]}' @@ -282,7 +284,7 @@ async def get_pipe_models(): sub_pipe_name = f"{function_module.name}{sub_pipe_name}" pipe_flag = {"type": pipe.type} - + pipe_models.append( { "id": sub_pipe_id, @@ -296,8 +298,10 @@ async def get_pipe_models(): else: pipe_flag = {"type": "pipe"} - log.debug(f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}") - + log.debug( + f"get_pipe_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" + ) + pipe_models.append( { "id": pipe.id, diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 9aeffae2e..3761a0e39 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -539,7 +539,6 @@ async def chat_completion_files_handler( if len(queries) == 0: queries = [get_last_user_message(body["messages"])] - sources = get_sources_from_files( files=files, queries=queries, @@ -969,7 +968,7 @@ app.add_middleware(SecurityHeadersMiddleware) @app.middleware("http") async def commit_session_after_request(request: Request, call_next): response = await call_next(request) - #log.debug("Commit session after request") + # log.debug("Commit session after request") Session.commit() return response @@ -1215,7 +1214,9 @@ async def get_models(user=Depends(get_verified_user)): filtered_models.append(model) models = filtered_models - log.debug(f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}") + log.debug( + f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}" + ) return {"data": models} @@ -1727,7 +1728,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): models, ) - log.debug(f"generating chat title using model {task_model_id} for user {user.email} ") + log.debug( + f"generating chat title using model {task_model_id} for user {user.email} " + ) if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE @@ -1767,10 +1770,10 @@ Artificial Intelligence in Healthcare } ), "metadata": { - "task": str(TASKS.TITLE_GENERATION), - "task_body": form_data, - "chat_id": form_data.get("chat_id", None) - }, + "task": str(TASKS.TITLE_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, } # Handle pipeline filters @@ -1820,8 +1823,10 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)): app.state.config.TASK_MODEL_EXTERNAL, models, ) - - log.debug(f"generating chat tags using model {task_model_id} for user {user.email} ") + + log.debug( + f"generating chat tags using model {task_model_id} for user {user.email} " + ) if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE @@ -1853,10 +1858,10 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] } "messages": [{"role": "user", "content": content}], "stream": False, "metadata": { - "task": str(TASKS.TAGS_GENERATION), + "task": str(TASKS.TAGS_GENERATION), "task_body": form_data, - "chat_id": form_data.get("chat_id", None) - } + "chat_id": form_data.get("chat_id", None), + }, } # Handle pipeline filters @@ -1914,8 +1919,10 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): app.state.config.TASK_MODEL_EXTERNAL, models, ) - - log.debug(f"generating {type} queries using model {task_model_id} for user {user.email}") + + log.debug( + f"generating {type} queries using model {task_model_id} for user {user.email}" + ) if app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE != "": template = app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE @@ -1930,7 +1937,11 @@ async def generate_queries(form_data: dict, user=Depends(get_verified_user)): "model": task_model_id, "messages": [{"role": "user", "content": content}], "stream": False, - "metadata": {"task": str(TASKS.QUERY_GENERATION), "task_body": form_data, "chat_id": form_data.get("chat_id", None)}, + "metadata": { + "task": str(TASKS.QUERY_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, } # Handle pipeline filters @@ -2047,7 +2058,7 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user) app.state.config.TASK_MODEL_EXTERNAL, models, ) - + log.debug(f"generating MOA model {task_model_id} for user {user.email} ") template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" From b2d3bfa3a8ac330e3ed789289cf3df1768c55d07 Mon Sep 17 00:00:00 2001 From: Pierre Glandon Date: Fri, 22 Nov 2024 17:25:52 +0100 Subject: [PATCH 03/30] feat: add description in Tool --- backend/open_webui/utils/tools.py | 33 ++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 3cdcf15bf..60a9f942f 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -90,6 +90,32 @@ def get_tools( return tools_dict +def parse_description(docstring: str | None) -> str: + """ + Parse a function's docstring to extract the description. + + Args: + docstring (str): The docstring to parse. + + Returns: + str: The description. + """ + + if not docstring: + return "" + + lines = [line.strip() for line in docstring.strip().split("\n")] + description_lines: list[str] = [] + + for line in lines: + if re.match(r":param", line) or re.match(r":return", line): + break + + description_lines.append(line) + + return "\n".join(description_lines) + + def parse_docstring(docstring): """ Parse a function's docstring to extract parameter descriptions in reST format. @@ -138,6 +164,8 @@ def function_to_pydantic_model(func: Callable) -> type[BaseModel]: docstring = func.__doc__ descriptions = parse_docstring(docstring) + tool_description = parse_description(docstring) + field_defs = {} for name, param in parameters.items(): type_hint = type_hints.get(name, Any) @@ -148,7 +176,10 @@ def function_to_pydantic_model(func: Callable) -> type[BaseModel]: continue field_defs[name] = type_hint, Field(default_value, description=description) - return create_model(func.__name__, **field_defs) + model = create_model(func.__name__, **field_defs) + model.__doc__ = tool_description + + return model def get_callable_attributes(tool: object) -> list[Callable]: From a83f89d4305556427f922add4dc858632dcf4d99 Mon Sep 17 00:00:00 2001 From: houcheng Date: Sun, 24 Nov 2024 00:28:14 +0800 Subject: [PATCH 04/30] fix: prevent TTS blocking using aiohttp and aiofiles --- backend/open_webui/apps/audio/main.py | 106 ++++++++++++-------------- backend/requirements.txt | 1 + 2 files changed, 49 insertions(+), 58 deletions(-) diff --git a/backend/open_webui/apps/audio/main.py b/backend/open_webui/apps/audio/main.py index 384bb3cd4..83ea23e2c 100644 --- a/backend/open_webui/apps/audio/main.py +++ b/backend/open_webui/apps/audio/main.py @@ -8,6 +8,8 @@ from pathlib import Path from pydub import AudioSegment from pydub.silence import split_on_silence +import aiohttp +import aiofiles import requests from open_webui.config import ( AUDIO_STT_ENGINE, @@ -292,46 +294,39 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception: pass - r = None try: - r = requests.post( - url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", - data=body, - headers=headers, - stream=True, - ) + async with aiohttp.ClientSession() as session: + async with session.post( + url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", + data=body, + headers=headers + ) as r: + r.raise_for_status() + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(json.loads(body.decode("utf-8")))) - r.raise_for_status() - - # Save the streaming content to a file - with open(file_path, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) - - # Return the saved file return FileResponse(file_path) except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() + try: + if r.status != 200: + res = await r.json() if "error" in res: error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" + except Exception: + error_detail = f"External: {e}" raise HTTPException( - status_code=r.status_code if r != None else 500, + status_code=getattr(r, 'status', 500), detail=error_detail, ) elif app.state.config.TTS_ENGINE == "elevenlabs": - payload = None try: payload = json.loads(body.decode("utf-8")) except Exception as e: @@ -339,7 +334,6 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=400, detail="Invalid JSON payload") voice_id = payload.get("voice", "") - if voice_id not in get_available_voices(): raise HTTPException( status_code=400, @@ -347,13 +341,11 @@ async def speech(request: Request, user=Depends(get_verified_user)): ) url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" - headers = { "Accept": "audio/mpeg", "Content-Type": "application/json", "xi-api-key": app.state.config.TTS_API_KEY, } - data = { "text": payload["input"], "model_id": app.state.config.TTS_MODEL, @@ -361,39 +353,34 @@ async def speech(request: Request, user=Depends(get_verified_user)): } try: - r = requests.post(url, json=data, headers=headers) + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data, headers=headers) as r: + r.raise_for_status() + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(json.loads(body.decode("utf-8")))) - r.raise_for_status() - - # Save the streaming content to a file - with open(file_path, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - with open(file_body_path, "w") as f: - json.dump(json.loads(body.decode("utf-8")), f) - - # Return the saved file return FileResponse(file_path) except Exception as e: log.exception(e) error_detail = "Open WebUI: Server Connection Error" - if r is not None: - try: - res = r.json() + try: + if r.status != 200: + res = await r.json() if "error" in res: error_detail = f"External: {res['error']['message']}" - except Exception: - error_detail = f"External: {e}" + except Exception: + error_detail = f"External: {e}" raise HTTPException( - status_code=r.status_code if r != None else 500, + status_code=getattr(r, 'status', 500), detail=error_detail, ) elif app.state.config.TTS_ENGINE == "azure": - payload = None try: payload = json.loads(body.decode("utf-8")) except Exception as e: @@ -416,17 +403,20 @@ async def speech(request: Request, user=Depends(get_verified_user)): {payload["input"]} """ - response = requests.post(url, headers=headers, data=data) - - if response.status_code == 200: - with open(file_path, "wb") as f: - f.write(response.content) - return FileResponse(file_path) - else: - log.error(f"Error synthesizing speech - {response.reason}") - raise HTTPException( - status_code=500, detail=f"Error synthesizing speech - {response.reason}" - ) + try: + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, data=data) as response: + if response.status == 200: + async with aiofiles.open(file_path, "wb") as f: + await f.write(await response.read()) + return FileResponse(file_path) + else: + error_msg = f"Error synthesizing speech - {response.reason}" + log.error(error_msg) + raise HTTPException(status_code=500, detail=error_msg) + except Exception as e: + log.exception(e) + raise HTTPException(status_code=500, detail=str(e)) elif app.state.config.TTS_ENGINE == "transformers": payload = None try: diff --git a/backend/requirements.txt b/backend/requirements.txt index 258f69e25..c83e6b3b7 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -14,6 +14,7 @@ requests==2.32.3 aiohttp==3.10.8 async-timeout aiocache +aiofiles sqlalchemy==2.0.32 alembic==1.13.2 From c567185cb168723049e20dbc7f89d0e86856f6ac Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sat, 23 Nov 2024 20:31:33 -0800 Subject: [PATCH 05/30] refac: rich text input behaviour --- src/app.css | 10 ++++++++++ src/lib/components/common/RichTextInput.svelte | 7 +++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/app.css b/src/app.css index 3974974bc..ea38b2b2b 100644 --- a/src/app.css +++ b/src/app.css @@ -231,6 +231,16 @@ input[type='number'] { @apply dark:bg-gray-800 bg-gray-100; } + +.tiptap p code { + color: #eb5757; + border-width: 0px; + padding: 3px 8px; + font-size: 0.8em; + font-weight: 600; + @apply rounded-md dark:bg-gray-800 bg-gray-100 mx-0.5; +} + /* Code styling */ .hljs-comment, .hljs-quote { diff --git a/src/lib/components/common/RichTextInput.svelte b/src/lib/components/common/RichTextInput.svelte index 39db40c27..3e3f79a55 100644 --- a/src/lib/components/common/RichTextInput.svelte +++ b/src/lib/components/common/RichTextInput.svelte @@ -1,7 +1,10 @@ From bd28e1ed7d42529ed4ac561c63d09a4664af147a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 24 Nov 2024 18:49:56 -0800 Subject: [PATCH 15/30] refac: rag prompt template --- backend/open_webui/apps/retrieval/utils.py | 40 ------------------ backend/open_webui/config.py | 16 ++++---- backend/open_webui/main.py | 13 +++--- backend/open_webui/utils/task.py | 47 ++++++++++++++++++++++ 4 files changed, 63 insertions(+), 53 deletions(-) diff --git a/backend/open_webui/apps/retrieval/utils.py b/backend/open_webui/apps/retrieval/utils.py index 420ed7bf4..35159f80d 100644 --- a/backend/open_webui/apps/retrieval/utils.py +++ b/backend/open_webui/apps/retrieval/utils.py @@ -15,8 +15,6 @@ from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT from open_webui.utils.misc import get_last_user_message from open_webui.env import SRC_LOG_LEVELS -from open_webui.config import DEFAULT_RAG_TEMPLATE - log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -238,44 +236,6 @@ def query_collection_with_hybrid_search( return merge_and_sort_query_results(results, k=k, reverse=True) -def rag_template(template: str, context: str, query: str): - if template == "": - template = DEFAULT_RAG_TEMPLATE - - if "[context]" not in template and "{{CONTEXT}}" not in template: - log.debug( - "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder." - ) - - if "" in context and "" in context: - log.debug( - "WARNING: Potential prompt injection attack: the RAG " - "context contains '' and ''. This might be " - "nothing, or the user might be trying to hack something." - ) - - query_placeholders = [] - if "[query]" in context: - query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" - template = template.replace("[query]", query_placeholder) - query_placeholders.append(query_placeholder) - - if "{{QUERY}}" in context: - query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" - template = template.replace("{{QUERY}}", query_placeholder) - query_placeholders.append(query_placeholder) - - template = template.replace("[context]", context) - template = template.replace("{{CONTEXT}}", context) - template = template.replace("[query]", query) - template = template.replace("{{QUERY}}", query) - - for query_placeholder in query_placeholders: - template = template.replace(query_placeholder, query) - - return template - - def get_embedding_function( embedding_engine, embedding_model, diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 8c185d146..db24403e5 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -969,16 +969,16 @@ QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( ) DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task: -Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list. +Analyze the chat history to determine the necessity of generating search queries. By default, **prioritize generating 1-3 broad and relevant search queries** unless it is absolutely certain that no additional information is required. The aim is to retrieve comprehensive, updated, and valuable information even with minimal uncertainty. If no search is unequivocally needed, return an empty list. ### Guidelines: -- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is prohibited. -- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise. -- If no search query is necessary, output should be: { "queries": [] } -- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required. -- Be concise, focusing strictly on composing search queries with no additional commentary or text. -- When in doubt, prefer to suggest a search for comprehensiveness. -- Today's date is: {{CURRENT_DATE}} +- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is strictly prohibited. +- When generating search queries, respond in the format: { "queries": ["query1", "query2"] }, ensuring each query is distinct, concise, and relevant to the topic. +- If and only if it is entirely certain that no useful results can be retrieved by a search, return: { "queries": [] }. +- Err on the side of suggesting search queries if there is **any chance** they might provide useful or updated information. +- Be concise and focused on composing high-quality search queries, avoiding unnecessary elaboration, commentary, or assumptions. +- Assume today's date is: {{CURRENT_DATE}}. +- Always prioritize providing actionable and broad queries that maximize informational coverage. ### Output: Strictly return in JSON format: diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 3761a0e39..4c52ceb0f 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -49,7 +49,9 @@ from open_webui.apps.openai.main import ( get_all_models_responses as get_openai_models_responses, ) from open_webui.apps.retrieval.main import app as retrieval_app -from open_webui.apps.retrieval.utils import get_sources_from_files, rag_template +from open_webui.apps.retrieval.utils import get_sources_from_files + + from open_webui.apps.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, @@ -122,11 +124,12 @@ from open_webui.utils.response import ( ) from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.task import ( - moa_response_generation_template, - tags_generation_template, - query_generation_template, - emoji_generation_template, + rag_template, title_generation_template, + query_generation_template, + tags_generation_template, + emoji_generation_template, + moa_response_generation_template, tools_function_calling_generation_template, ) from open_webui.utils.tools import get_tools diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index 28b07da37..b6d0d3bce 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -1,11 +1,20 @@ +import logging import math import re from datetime import datetime from typing import Optional +import uuid from open_webui.utils.misc import get_last_user_message, get_messages_content +from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import DEFAULT_RAG_TEMPLATE + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + def prompt_template( template: str, user_name: Optional[str] = None, user_location: Optional[str] = None @@ -110,6 +119,44 @@ def replace_messages_variable(template: str, messages: list[str]) -> str: # {{prompt:middletruncate:8000}} +def rag_template(template: str, context: str, query: str): + if template == "": + template = DEFAULT_RAG_TEMPLATE + + if "[context]" not in template and "{{CONTEXT}}" not in template: + log.debug( + "WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder." + ) + + if "" in context and "" in context: + log.debug( + "WARNING: Potential prompt injection attack: the RAG " + "context contains '' and ''. This might be " + "nothing, or the user might be trying to hack something." + ) + + query_placeholders = [] + if "[query]" in context: + query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" + template = template.replace("[query]", query_placeholder) + query_placeholders.append(query_placeholder) + + if "{{QUERY}}" in context: + query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}" + template = template.replace("{{QUERY}}", query_placeholder) + query_placeholders.append(query_placeholder) + + template = template.replace("[context]", context) + template = template.replace("{{CONTEXT}}", context) + template = template.replace("[query]", query) + template = template.replace("{{QUERY}}", query) + + for query_placeholder in query_placeholders: + template = template.replace(query_placeholder, query) + + return template + + def title_generation_template( template: str, messages: list[dict], user: Optional[dict] = None ) -> str: From 840437e58fd7b3991293323f3ff6af9cb4ce7334 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 24 Nov 2024 19:07:51 -0800 Subject: [PATCH 16/30] refac: o1 title generation issue --- src/lib/components/chat/Chat.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 61a8a54cf..03080d9b5 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -1976,7 +1976,7 @@ } ); - return title; + return title ? title : (lastUserMessage?.content ?? 'New Chat'); } else { return lastUserMessage?.content ?? 'New Chat'; } From 8dc73e87440c92fd8ab41052a3b3cbfd46593911 Mon Sep 17 00:00:00 2001 From: bnodnarb <97063458+bnodnarb@users.noreply.github.com> Date: Sun, 24 Nov 2024 20:29:54 -1000 Subject: [PATCH 17/30] Fix: Add authorization header with bearer token for remote Ollama server endpoints --- backend/open_webui/apps/ollama/main.py | 46 ++++++++++++++++++++------ 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index b44f68017..0ac1f0401 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -195,7 +195,10 @@ async def post_streaming_url( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -210,13 +213,13 @@ async def post_streaming_url( r.raise_for_status() if stream: - headers = dict(r.headers) + response_headers = dict(r.headers) if content_type: - headers["Content-Type"] = content_type + response_headers["Content-Type"] = content_type return StreamingResponse( r.content, status_code=r.status, - headers=headers, + headers=response_headers, background=BackgroundTask( cleanup_response, response=r, session=session ), @@ -324,7 +327,10 @@ async def get_ollama_tags( else: url = app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {} @@ -525,7 +531,10 @@ async def copy_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -584,7 +593,10 @@ async def delete_model( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -635,7 +647,10 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -730,7 +745,10 @@ async def generate_ollama_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -797,7 +815,10 @@ async def generate_ollama_batch_embeddings( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -974,7 +995,10 @@ async def generate_chat_completion( log.info(f"url: {url}") log.debug(f"generate_chat_completion() - 2.payload = {payload}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + + api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") From c4f82309dc6283b8737975f5d3cc5019e19eb0f3 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 25 Nov 2024 16:11:49 -0800 Subject: [PATCH 18/30] fix: min_p save issue --- src/lib/components/chat/Settings/General.svelte | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib/components/chat/Settings/General.svelte b/src/lib/components/chat/Settings/General.svelte index 9c1ccb7f7..694648206 100644 --- a/src/lib/components/chat/Settings/General.svelte +++ b/src/lib/components/chat/Settings/General.svelte @@ -55,6 +55,7 @@ mirostat_tau: null, top_k: null, top_p: null, + min_p: null, stop: null, tfs_z: null, num_ctx: null, @@ -340,6 +341,7 @@ mirostat_tau: params.mirostat_tau !== null ? params.mirostat_tau : undefined, top_k: params.top_k !== null ? params.top_k : undefined, top_p: params.top_p !== null ? params.top_p : undefined, + min_p: params.min_p !== null ? params.min_p : undefined, tfs_z: params.tfs_z !== null ? params.tfs_z : undefined, num_ctx: params.num_ctx !== null ? params.num_ctx : undefined, num_batch: params.num_batch !== null ? params.num_batch : undefined, From f9e24968e3d421e87d850166d7a528e3bd6772fd Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 25 Nov 2024 22:43:34 -0800 Subject: [PATCH 19/30] fix: input issue --- src/lib/components/chat/MessageInput.svelte | 102 ++++++++++-------- .../components/common/RichTextInput.svelte | 17 +-- 2 files changed, 60 insertions(+), 59 deletions(-) diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 9016e979d..042159bb4 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -592,29 +592,6 @@ placeholder={placeholder ? placeholder : $i18n.t('Send a Message')} largeTextAsFile={$settings?.largeTextAsFile ?? false} bind:value={prompt} - on:enter={async (e) => { - const commandsContainerElement = - document.getElementById('commands-container'); - if (commandsContainerElement) { - e.preventDefault(); - - const commandOptionButton = [ - ...document.getElementsByClassName('selected-command-option-button') - ]?.at(-1); - - if (commandOptionButton) { - commandOptionButton?.click(); - return; - } - } - - if (prompt !== '') { - dispatch('submit', prompt); - } - }} - on:keypress={(e) => { - e = e.detail.event; - }} on:keydown={async (e) => { e = e.detail.event; @@ -657,34 +634,69 @@ editButton?.click(); } - if (commandsContainerElement && e.key === 'ArrowUp') { - e.preventDefault(); - commandsElement.selectUp(); + if (commandsContainerElement) { + if (commandsContainerElement && e.key === 'ArrowUp') { + e.preventDefault(); + commandsElement.selectUp(); - const commandOptionButton = [ - ...document.getElementsByClassName('selected-command-option-button') - ]?.at(-1); - commandOptionButton.scrollIntoView({ block: 'center' }); - } + const commandOptionButton = [ + ...document.getElementsByClassName('selected-command-option-button') + ]?.at(-1); + commandOptionButton.scrollIntoView({ block: 'center' }); + } - if (commandsContainerElement && e.key === 'ArrowDown') { - e.preventDefault(); - commandsElement.selectDown(); + if (commandsContainerElement && e.key === 'ArrowDown') { + e.preventDefault(); + commandsElement.selectDown(); - const commandOptionButton = [ - ...document.getElementsByClassName('selected-command-option-button') - ]?.at(-1); - commandOptionButton.scrollIntoView({ block: 'center' }); - } + const commandOptionButton = [ + ...document.getElementsByClassName('selected-command-option-button') + ]?.at(-1); + commandOptionButton.scrollIntoView({ block: 'center' }); + } - if (commandsContainerElement && e.key === 'Tab') { - e.preventDefault(); + if (commandsContainerElement && e.key === 'Tab') { + e.preventDefault(); - const commandOptionButton = [ - ...document.getElementsByClassName('selected-command-option-button') - ]?.at(-1); + const commandOptionButton = [ + ...document.getElementsByClassName('selected-command-option-button') + ]?.at(-1); - commandOptionButton?.click(); + commandOptionButton?.click(); + } + + if (commandsContainerElement && e.key === 'Enter') { + e.preventDefault(); + + const commandOptionButton = [ + ...document.getElementsByClassName('selected-command-option-button') + ]?.at(-1); + + if (commandOptionButton) { + commandOptionButton?.click(); + } else { + document.getElementById('send-message-button')?.click(); + } + } + } else { + if ( + !$mobile || + !( + 'ontouchstart' in window || + navigator.maxTouchPoints > 0 || + navigator.msMaxTouchPoints > 0 + ) + ) { + // Prevent Enter key from creating a new line + if (e.keyCode === 13 && !e.shiftKey) { + e.preventDefault(); + } + + // Submit the prompt when Enter key is pressed + if (prompt !== '' && e.keyCode === 13 && !e.shiftKey) { + dispatch('submit', prompt); + } + } } if (e.key === 'Escape') { diff --git a/src/lib/components/common/RichTextInput.svelte b/src/lib/components/common/RichTextInput.svelte index 8200154fa..fc24ab063 100644 --- a/src/lib/components/common/RichTextInput.svelte +++ b/src/lib/components/common/RichTextInput.svelte @@ -171,11 +171,10 @@ eventDispatch('focus', { event }); return false; }, - keypress: (view, event) => { - eventDispatch('keypress', { event }); + keyup: (view, event) => { + eventDispatch('keyup', { event }); return false; }, - keydown: (view, event) => { // Handle Tab Key if (event.key === 'Tab') { @@ -217,22 +216,12 @@ // Handle shift + Enter for a line break if (shiftEnter) { - if (event.key === 'Enter' && event.shiftKey) { + if (event.key === 'Enter' && event.shiftKey && !event.ctrlKey && !event.metaKey) { editor.commands.setHardBreak(); // Insert a hard break view.dispatch(view.state.tr.scrollIntoView()); // Move viewport to the cursor event.preventDefault(); return true; } - if (event.key === 'Enter') { - eventDispatch('enter', { event }); - event.preventDefault(); - return true; - } - } - if (event.key === 'Enter') { - eventDispatch('enter', { event }); - event.preventDefault(); - return true; } } eventDispatch('keydown', { event }); From 29fac5eccaacdf061ebfa1b0064f7c3efb3c5e30 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Mon, 25 Nov 2024 22:57:54 -0800 Subject: [PATCH 20/30] refac: admin models settings --- src/lib/components/admin/Settings/Models.svelte | 13 ++++++++++--- src/lib/components/chat/MessageInput.svelte | 1 + 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index b7084d8ce..920d280f5 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -38,9 +38,16 @@ let showResetModal = false; $: if (models) { - filteredModels = models.filter( - (m) => searchValue === '' || m.name.toLowerCase().includes(searchValue.toLowerCase()) - ); + filteredModels = models + .filter((m) => searchValue === '' || m.name.toLowerCase().includes(searchValue.toLowerCase())) + .sort((a, b) => { + // Check if either model is inactive and push them to the bottom + if ((a.is_active ?? true) !== (b.is_active ?? true)) { + return (b.is_active ?? true) - (a.is_active ?? true); + } + // If both models' active states are the same, sort alphabetically + return a.name.localeCompare(b.name); + }); } let searchValue = ''; diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 042159bb4..16e3cdb91 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -688,6 +688,7 @@ ) ) { // Prevent Enter key from creating a new line + // Uses keyCode '13' for Enter key for chinese/japanese keyboards if (e.keyCode === 13 && !e.shiftKey) { e.preventDefault(); } From 2b4e8f6ceac2611b5d0cb1522fa315dc0ee026a2 Mon Sep 17 00:00:00 2001 From: Pieter Becking Date: Tue, 26 Nov 2024 09:42:13 +0100 Subject: [PATCH 21/30] fix(i18n): Correct capitalization typo in Dutch localization (GRoepen -> Groepen) --- src/lib/i18n/locales/nl-NL/translation.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/i18n/locales/nl-NL/translation.json b/src/lib/i18n/locales/nl-NL/translation.json index 1d5dc32b9..b6cc5d051 100644 --- a/src/lib/i18n/locales/nl-NL/translation.json +++ b/src/lib/i18n/locales/nl-NL/translation.json @@ -441,7 +441,7 @@ "Group Description": "Groepsbeschrijving", "Group Name": "Groepsnaam", "Group updated successfully": "Groep succesvol bijgewerkt", - "Groups": "GRoepen", + "Groups": "Groepen", "h:mm a": "h:mm a", "Haptic Feedback": "Haptische feedback", "has no conversations.": "heeft geen gesprekken.", From 5fac25a0028f4cb625a78de3926b2258209b7b00 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 26 Nov 2024 00:55:58 -0800 Subject: [PATCH 22/30] enh: reintroduce model order/default models --- backend/open_webui/apps/webui/main.py | 2 + .../open_webui/apps/webui/routers/configs.py | 45 +-- backend/open_webui/config.py | 6 + backend/open_webui/main.py | 8 + src/lib/apis/configs/index.ts | 38 ++- src/lib/apis/index.ts | 20 -- .../Evaluations/ArenaModelModal.svelte | 4 +- .../components/admin/Settings/Models.svelte | 26 +- .../Models/ConfigureModelsModal.svelte | 258 ++++++++++++++++++ .../admin/Settings/Models/ModelList.svelte | 58 ++++ src/lib/components/chat/ModelSelector.svelte | 2 - src/lib/components/common/Modal.svelte | 6 +- 12 files changed, 407 insertions(+), 66 deletions(-) create mode 100644 src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte create mode 100644 src/lib/components/admin/Settings/Models/ModelList.svelte diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index bedd49ae3..054c6280e 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -31,6 +31,7 @@ from open_webui.config import ( DEFAULT_MODELS, DEFAULT_PROMPT_SUGGESTIONS, DEFAULT_USER_ROLE, + MODEL_ORDER_LIST, ENABLE_COMMUNITY_SHARING, ENABLE_LOGIN_FORM, ENABLE_MESSAGE_RATING, @@ -120,6 +121,7 @@ app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.BANNERS = WEBUI_BANNERS +app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING diff --git a/backend/open_webui/apps/webui/routers/configs.py b/backend/open_webui/apps/webui/routers/configs.py index 1c30b0b3b..b19fc1745 100644 --- a/backend/open_webui/apps/webui/routers/configs.py +++ b/backend/open_webui/apps/webui/routers/configs.py @@ -34,8 +34,32 @@ async def export_config(user=Depends(get_admin_user)): return get_config() -class SetDefaultModelsForm(BaseModel): - models: str +############################ +# SetDefaultModels +############################ +class ModelsConfigForm(BaseModel): + DEFAULT_MODELS: str + MODEL_ORDER_LIST: list[str] + + +@router.get("/models", response_model=ModelsConfigForm) +async def get_models_config(request: Request, user=Depends(get_admin_user)): + return { + "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, + "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, + } + + +@router.post("/models", response_model=ModelsConfigForm) +async def set_models_config( + request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS + request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST + return { + "DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS, + "MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST, + } class PromptSuggestion(BaseModel): @@ -47,21 +71,8 @@ class SetDefaultSuggestionsForm(BaseModel): suggestions: list[PromptSuggestion] -############################ -# SetDefaultModels -############################ - - -@router.post("/default/models", response_model=str) -async def set_global_default_models( - request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) -): - request.app.state.config.DEFAULT_MODELS = form_data.models - return request.app.state.config.DEFAULT_MODELS - - -@router.post("/default/suggestions", response_model=list[PromptSuggestion]) -async def set_global_default_suggestions( +@router.post("/suggestions", response_model=list[PromptSuggestion]) +async def set_default_suggestions( request: Request, form_data: SetDefaultSuggestionsForm, user=Depends(get_admin_user), diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index db24403e5..3c1ee798d 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -740,6 +740,12 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( ], ) +MODEL_ORDER_LIST = PersistentConfig( + "MODEL_ORDER_LIST", + "ui.model_order_list", + [], +) + DEFAULT_USER_ROLE = PersistentConfig( "DEFAULT_USER_ROLE", "ui.default_user_role", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 4c52ceb0f..0dca21b08 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1194,6 +1194,14 @@ async def get_models(user=Depends(get_verified_user)): if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" ] + model_order_list = webui_app.state.config.MODEL_ORDER_LIST + if model_order_list: + model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} + # Sort models by order list priority, with fallback for those not in the list + models.sort( + key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"]) + ) + # Filter out models that the user does not have access to if user.role == "user": filtered_models = [] diff --git a/src/lib/apis/configs/index.ts b/src/lib/apis/configs/index.ts index 0c4de6ad6..b3b002557 100644 --- a/src/lib/apis/configs/index.ts +++ b/src/lib/apis/configs/index.ts @@ -58,17 +58,46 @@ export const exportConfig = async (token: string) => { return res; }; -export const setDefaultModels = async (token: string, models: string) => { + +export const getModelsConfig = async (token: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/default/models`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/models`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}` + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + console.log(err); + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + + +export const setModelsConfig = async (token: string, config: object) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/models`, { method: 'POST', headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, body: JSON.stringify({ - models: models + ...config }) }) .then(async (res) => { @@ -88,10 +117,11 @@ export const setDefaultModels = async (token: string, models: string) => { return res; }; + export const setDefaultPromptSuggestions = async (token: string, promptSuggestions: string) => { let error = null; - const res = await fetch(`${WEBUI_API_BASE_URL}/configs/default/suggestions`, { + const res = await fetch(`${WEBUI_API_BASE_URL}/configs/suggestions`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 9c726e4d0..a33610c31 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -25,26 +25,6 @@ export const getModels = async (token: string = '', base: boolean = false) => { } let models = res?.data ?? []; - models = models - .filter((models) => models) - // Sort the models - .sort((a, b) => { - // Compare case-insensitively by name for models without position property - const lowerA = a.name.toLowerCase(); - const lowerB = b.name.toLowerCase(); - - if (lowerA < lowerB) return -1; - if (lowerA > lowerB) return 1; - - // If same case-insensitively, sort by original strings, - // lowercase will come before uppercase due to ASCII values - if (a.name < b.name) return -1; - if (a.name > b.name) return 1; - - return 0; // They are equal - }); - - console.log(models); return models; }; diff --git a/src/lib/components/admin/Settings/Evaluations/ArenaModelModal.svelte b/src/lib/components/admin/Settings/Evaluations/ArenaModelModal.svelte index 5f64137e4..ac45e7f2c 100644 --- a/src/lib/components/admin/Settings/Evaluations/ArenaModelModal.svelte +++ b/src/lib/components/admin/Settings/Evaluations/ArenaModelModal.svelte @@ -375,7 +375,7 @@
{#if edit}
- +
diff --git a/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte b/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte new file mode 100644 index 000000000..be7fc8a45 --- /dev/null +++ b/src/lib/components/admin/Settings/Models/ConfigureModelsModal.svelte @@ -0,0 +1,258 @@ + + + { + const res = deleteAllModels(localStorage.token); + if (res) { + toast.success($i18n.t('All models deleted successfully')); + init(); + } + }} +/> + + +
+
+
+ {$i18n.t('Configure Models')} +
+ +
+ +
+
+ {#if config} +
{ + submitHandler(); + }} + > +
+
+
+
{$i18n.t('Reorder Models')}
+
+ + +
+
+ +
+ +
+
+
+
{$i18n.t('Default Models')}
+
+ + {#if defaultModelIds.length > 0} +
+ {#each defaultModelIds as modelId, modelIdx} +
+
+ {$models.find((model) => model.id === modelId)?.name} +
+
+ +
+
+ {/each} +
+ {:else} +
+ {$i18n.t('No models selected')} +
+ {/if} + +
+ +
+ + +
+ +
+
+
+
+ +
+ + + + + +
+
+ {:else} +
+ +
+ {/if} +
+
+
+
diff --git a/src/lib/components/admin/Settings/Models/ModelList.svelte b/src/lib/components/admin/Settings/Models/ModelList.svelte new file mode 100644 index 000000000..c54d19ee4 --- /dev/null +++ b/src/lib/components/admin/Settings/Models/ModelList.svelte @@ -0,0 +1,58 @@ + + +{#if modelIds.length > 0} +
+ {#each modelIds as modelId, modelIdx (modelId)} +
+ +
+ + +
+ {#if $models.find((model) => model.id === modelId)} + {$models.find((model) => model.id === modelId).name} + {:else} + {modelId} + {/if} +
+
+
+
+ {/each} +
+{:else} +
+ {$i18n.t('No models found')} +
+{/if} diff --git a/src/lib/components/chat/ModelSelector.svelte b/src/lib/components/chat/ModelSelector.svelte index 9b16a6500..9b77cd8ce 100644 --- a/src/lib/components/chat/ModelSelector.svelte +++ b/src/lib/components/chat/ModelSelector.svelte @@ -5,9 +5,7 @@ import Selector from './ModelSelector/Selector.svelte'; import Tooltip from '../common/Tooltip.svelte'; - import { setDefaultModels } from '$lib/apis/configs'; import { updateUserSettings } from '$lib/apis/users'; - const i18n = getContext('i18n'); export let selectedModels = ['']; diff --git a/src/lib/components/common/Modal.svelte b/src/lib/components/common/Modal.svelte index 795d3d0f1..86c741e23 100644 --- a/src/lib/components/common/Modal.svelte +++ b/src/lib/components/common/Modal.svelte @@ -6,7 +6,7 @@ export let show = true; export let size = 'md'; - export let className = 'bg-gray-50 dark:bg-gray-900 rounded-2xl'; + export let className = 'bg-gray-50 dark:bg-gray-900 rounded-2xl'; let modalElement = null; let mounted = false; @@ -65,7 +65,7 @@