diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ab43ef8b4..308489ee6 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -46,6 +46,21 @@ from open_webui.routers import ( retrieval, pipelines, tasks, + auths, + chats, + folders, + configs, + groups, + files, + functions, + memories, + models, + knowledge, + prompts, + evaluations, + tools, + users, + utils, ) from open_webui.retrieval.utils import get_sources_from_files @@ -117,6 +132,60 @@ from open_webui.config import ( WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_DIR, + # Retrieval + RAG_TEMPLATE, + DEFAULT_RAG_TEMPLATE, + RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + RAG_EMBEDDING_ENGINE, + RAG_EMBEDDING_BATCH_SIZE, + RAG_RELEVANCE_THRESHOLD, + RAG_FILE_MAX_COUNT, + RAG_FILE_MAX_SIZE, + RAG_OPENAI_API_BASE_URL, + RAG_OPENAI_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OLLAMA_API_KEY, + CHUNK_OVERLAP, + CHUNK_SIZE, + CONTENT_EXTRACTION_ENGINE, + TIKA_SERVER_URL, + RAG_TOP_K, + RAG_TEXT_SPLITTER, + TIKTOKEN_ENCODING_NAME, + PDF_EXTRACT_IMAGES, + YOUTUBE_LOADER_LANGUAGE, + YOUTUBE_LOADER_PROXY_URL, + # Retrieval (Web Search) + RAG_WEB_SEARCH_ENGINE, + RAG_WEB_SEARCH_RESULT_COUNT, + RAG_WEB_SEARCH_CONCURRENT_REQUESTS, + RAG_WEB_SEARCH_DOMAIN_FILTER_LIST, + JINA_API_KEY, + SEARCHAPI_API_KEY, + SEARCHAPI_ENGINE, + SEARXNG_QUERY_URL, + SERPER_API_KEY, + SERPLY_API_KEY, + SERPSTACK_API_KEY, + SERPSTACK_HTTPS, + TAVILY_API_KEY, + BING_SEARCH_V7_ENDPOINT, + BING_SEARCH_V7_SUBSCRIPTION_KEY, + BRAVE_SEARCH_API_KEY, + KAGI_SEARCH_API_KEY, + MOJEEK_SEARCH_API_KEY, + GOOGLE_PSE_API_KEY, + GOOGLE_PSE_ENGINE_ID, + ENABLE_RAG_HYBRID_SEARCH, + ENABLE_RAG_LOCAL_WEB_FETCH, + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, + ENABLE_RAG_WEB_SEARCH, + UPLOAD_DIR, # WebUI WEBUI_AUTH, WEBUI_NAME, @@ -383,6 +452,72 @@ app.state.FUNCTIONS = {} # ######################################## + +app.state.config.TOP_K = RAG_TOP_K +app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE +app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT + +app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH +app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( + ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION +) + +app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE +app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL + +app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER +app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME + +app.state.config.CHUNK_SIZE = CHUNK_SIZE +app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP + +app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE +app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL +app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE +app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL +app.state.config.RAG_TEMPLATE = RAG_TEMPLATE + +app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY + +app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL +app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY + +app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES + +app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE +app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL + + +app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH +app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE +app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST + +app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL +app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY +app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID +app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY +app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY +app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY +app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY +app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS +app.state.config.SERPER_API_KEY = SERPER_API_KEY +app.state.config.SERPLY_API_KEY = SERPLY_API_KEY +app.state.config.TAVILY_API_KEY = TAVILY_API_KEY +app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY +app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE +app.state.config.JINA_API_KEY = JINA_API_KEY +app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT +app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY + +app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT +app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS + + +app.state.YOUTUBE_LOADER_TRANSLATION = None +app.state.EMBEDDING_FUNCTION = None + ######################################## # # IMAGES @@ -1083,8 +1218,8 @@ def filter_pipeline(payload, user, models): try: urlIdx = filter["urlIdx"] - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + url = app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = app.state.config.OPENAI_API_KEYS[urlIdx] if key == "": continue @@ -1230,14 +1365,6 @@ async def check_url(request: Request, call_next): return response -# @app.middleware("http") -# async def update_embedding_function(request: Request, call_next): -# response = await call_next(request) -# if "/embedding/update" in request.url.path: -# webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION -# return response - - @app.middleware("http") async def inspect_websocket(request: Request, call_next): if ( @@ -1268,18 +1395,36 @@ app.add_middleware( app.mount("/ws", socket_app) -app.mount("/ollama", ollama_app) -app.mount("/openai", openai_app) - -app.mount("/images/api/v1", images_app) -app.mount("/audio/api/v1", audio_app) +app.include_router(ollama.router, prefix="/ollama") +app.include_router(openai.router, prefix="/openai") -app.mount("/retrieval/api/v1", retrieval_app) +app.include_router(images.router, prefix="/api/v1/images") +app.include_router(audio.router, prefix="/api/v1/audio") +app.include_router(retrieval.router, prefix="/api/v1/retrieval") -app.mount("/api/v1", webui_app) -app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION +app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) + +app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"]) +app.include_router(users.router, prefix="/api/v1/users", tags=["users"]) + +app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"]) + +app.include_router(models.router, prefix="/api/v1/models", tags=["models"]) +app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"]) +app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"]) +app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"]) + +app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"]) +app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"]) +app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"]) +app.include_router(files.router, prefix="/api/v1/files", tags=["files"]) +app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"]) +app.include_router( + evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"] +) +app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) async def get_all_base_models(): diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 581a881b7..8a43d5c52 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -13,7 +13,15 @@ from aiocache import cached import requests -from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile +from fastapi import ( + Depends, + FastAPI, + File, + HTTPException, + Request, + UploadFile, + APIRouter, +) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict @@ -26,18 +34,15 @@ from open_webui.models.models import Models from open_webui.config import ( UPLOAD_DIR, ) - - from open_webui.env import ( + ENV, + SRC_LOG_LEVELS, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, BYPASS_MODEL_ACCESS_CONTROL, ) - from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS - from open_webui.utils.misc import ( calculate_sha256, @@ -54,13 +59,15 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) +router = APIRouter() + # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. -@app.head("/") -@app.get("/") +@router.head("/") +@router.get("/") async def get_status(): return {"status": True} @@ -70,7 +77,7 @@ class ConnectionVerificationForm(BaseModel): key: Optional[str] = None -@app.post("/verify") +@router.post("/verify") async def verify_connection( form_data: ConnectionVerificationForm, user=Depends(get_admin_user) ): @@ -110,12 +117,12 @@ async def verify_connection( raise HTTPException(status_code=500, detail=error_detail) -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, } @@ -125,23 +132,25 @@ class OllamaConfigForm(BaseModel): OLLAMA_API_CONFIGS: dict -@app.post("/config/update") -async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API - app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS +@router.post("/config/update") +async def update_config( + request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API + request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS - app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS + request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS # Remove any extra configs - config_urls = app.state.config.OLLAMA_API_CONFIGS.keys() - for url in list(app.state.config.OLLAMA_BASE_URLS): + config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() + for url in list(request.app.state.config.OLLAMA_BASE_URLS): if url not in config_urls: - app.state.config.OLLAMA_API_CONFIGS.pop(url, None) + request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) return { - "ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, - "OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, - "OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, + "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, + "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, + "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, } @@ -158,6 +167,12 @@ async def aiohttp_get(url, key=None): return None +def get_api_key(url, configs): + parsed_url = urlparse(url) + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return configs.get(base_url, {}).get("key", None) + + async def cleanup_response( response: Optional[aiohttp.ClientResponse], session: Optional[aiohttp.ClientSession], @@ -169,7 +184,11 @@ async def cleanup_response( async def post_streaming_url( - url: str, payload: Union[str, bytes], stream: bool = True, content_type=None + url: str, + payload: Union[str, bytes], + stream: bool = True, + key: Optional[str] = None, + content_type=None, ): r = None try: @@ -177,12 +196,6 @@ async def post_streaming_url( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) - parsed_url = urlparse(url) - base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) - key = api_config.get("key", None) - headers = {"Content-Type": "application/json"} if key: headers["Authorization"] = f"Bearer {key}" @@ -246,13 +259,13 @@ def merge_models_lists(model_lists): @cached(ttl=3) async def get_all_models(): log.info("get_all_models()") - if app.state.config.ENABLE_OLLAMA_API: + if request.app.state.config.ENABLE_OLLAMA_API: tasks = [] - for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS): - if url not in app.state.config.OLLAMA_API_CONFIGS: + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): + if url not in request.app.state.config.OLLAMA_API_CONFIGS: tasks.append(aiohttp_get(f"{url}/api/tags")) else: - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) key = api_config.get("key", None) @@ -265,8 +278,8 @@ async def get_all_models(): for idx, response in enumerate(responses): if response: - url = app.state.config.OLLAMA_BASE_URLS[idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) model_ids = api_config.get("model_ids", []) @@ -298,21 +311,21 @@ async def get_all_models(): return models -@app.get("/api/tags") -@app.get("/api/tags/{url_idx}") +@router.get("/api/tags") +@router.get("/api/tags/{url_idx}") async def get_ollama_tags( - url_idx: Optional[int] = None, user=Depends(get_verified_user) + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) ): models = [] if url_idx is None: models = await get_all_models() else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {} @@ -356,18 +369,20 @@ async def get_ollama_tags( return models -@app.get("/api/version") -@app.get("/api/version/{url_idx}") -async def get_ollama_versions(url_idx: Optional[int] = None): - if app.state.config.ENABLE_OLLAMA_API: +@router.get("/api/version") +@router.get("/api/version/{url_idx}") +async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): + if request.app.state.config.ENABLE_OLLAMA_API: if url_idx is None: # returns lowest version tasks = [ aiohttp_get( f"{url}/api/version", - app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), ) - for url in app.state.config.OLLAMA_BASE_URLS + for url in request.app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) responses = list(filter(lambda x: x is not None, responses)) @@ -387,7 +402,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, ) else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] r = None try: @@ -414,22 +429,24 @@ async def get_ollama_versions(url_idx: Optional[int] = None): return {"version": False} -@app.get("/api/ps") -async def get_ollama_loaded_models(user=Depends(get_verified_user)): +@router.get("/api/ps") +async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): """ List models that are currently loaded into Ollama memory, and which node they are loaded on. """ - if app.state.config.ENABLE_OLLAMA_API: + if request.app.state.config.ENABLE_OLLAMA_API: tasks = [ aiohttp_get( f"{url}/api/ps", - app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), + request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( + "key", None + ), ) - for url in app.state.config.OLLAMA_BASE_URLS + for url in request.app.state.config.OLLAMA_BASE_URLS ] responses = await asyncio.gather(*tasks) - return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses)) + return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) else: return {} @@ -438,18 +455,25 @@ class ModelNameForm(BaseModel): name: str -@app.post("/api/pull") -@app.post("/api/pull/{url_idx}") +@router.post("/api/pull") +@router.post("/api/pull/{url_idx}") async def pull_model( - form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) + request: Request, + form_data: ModelNameForm, + url_idx: int = 0, + user=Depends(get_admin_user), ): - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") # Admin should be able to pull models from any source payload = {**form_data.model_dump(exclude_none=True), "insecure": True} - return await post_streaming_url(f"{url}/api/pull", json.dumps(payload)) + return await post_streaming_url( + url=f"{url}/api/pull", + payload=json.dumps(payload), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), + ) class PushModelForm(BaseModel): @@ -458,9 +482,10 @@ class PushModelForm(BaseModel): stream: Optional[bool] = None -@app.delete("/api/push") -@app.delete("/api/push/{url_idx}") +@router.delete("/api/push") +@router.delete("/api/push/{url_idx}") async def push_model( + request: Request, form_data: PushModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -477,11 +502,13 @@ async def push_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.debug(f"url: {url}") return await post_streaming_url( - f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode() + url=f"{url}/api/push", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -492,17 +519,22 @@ class CreateModelForm(BaseModel): path: Optional[str] = None -@app.post("/api/create") -@app.post("/api/create/{url_idx}") +@router.post("/api/create") +@router.post("/api/create/{url_idx}") async def create_model( - form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) + request: Request, + form_data: CreateModelForm, + url_idx: int = 0, + user=Depends(get_admin_user), ): log.debug(f"form_data: {form_data}") - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") return await post_streaming_url( - f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode() + url=f"{url}/api/create", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -511,9 +543,10 @@ class CopyModelForm(BaseModel): destination: str -@app.post("/api/copy") -@app.post("/api/copy/{url_idx}") +@router.post("/api/copy") +@router.post("/api/copy/{url_idx}") async def copy_model( + request: Request, form_data: CopyModelForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -530,13 +563,13 @@ async def copy_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -573,9 +606,10 @@ async def copy_model( ) -@app.delete("/api/delete") -@app.delete("/api/delete/{url_idx}") +@router.delete("/api/delete") +@router.delete("/api/delete/{url_idx}") async def delete_model( + request: Request, form_data: ModelNameForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -592,13 +626,13 @@ async def delete_model( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -634,8 +668,10 @@ async def delete_model( ) -@app.post("/api/show") -async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): +@router.post("/api/show") +async def show_model_info( + request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) +): model_list = await get_all_models() models = {model["model"]: model for model in model_list["models"]} @@ -646,13 +682,13 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ) url_idx = random.choice(models[form_data.name]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -701,8 +737,8 @@ class GenerateEmbedForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@app.post("/api/embed") -@app.post("/api/embed/{url_idx}") +@router.post("/api/embed") +@router.post("/api/embed/{url_idx}") async def generate_embeddings( form_data: GenerateEmbedForm, url_idx: Optional[int] = None, @@ -711,8 +747,8 @@ async def generate_embeddings( return await generate_ollama_batch_embeddings(form_data, url_idx) -@app.post("/api/embeddings") -@app.post("/api/embeddings/{url_idx}") +@router.post("/api/embeddings") +@router.post("/api/embeddings/{url_idx}") async def generate_embeddings( form_data: GenerateEmbeddingsForm, url_idx: Optional[int] = None, @@ -744,13 +780,13 @@ async def generate_ollama_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -814,13 +850,13 @@ async def generate_ollama_batch_embeddings( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) key = api_config.get("key", None) headers = {"Content-Type": "application/json"} @@ -873,9 +909,10 @@ class GenerateCompletionForm(BaseModel): keep_alive: Optional[Union[int, str]] = None -@app.post("/api/generate") -@app.post("/api/generate/{url_idx}") +@router.post("/api/generate") +@router.post("/api/generate/{url_idx}") async def generate_completion( + request: Request, form_data: GenerateCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), @@ -897,15 +934,17 @@ async def generate_completion( detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: form_data.model = form_data.model.replace(f"{prefix_id}.", "") log.info(f"url: {url}") return await post_streaming_url( - f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode() + url=f"{url}/api/generate", + payload=form_data.model_dump_json(exclude_none=True).encode(), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) @@ -936,13 +975,14 @@ async def get_ollama_url(url_idx: Optional[int], model: str): detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) url_idx = random.choice(models[model]["urls"]) - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] return url -@app.post("/api/chat") -@app.post("/api/chat/{url_idx}") +@router.post("/api/chat") +@router.post("/api/chat/{url_idx}") async def generate_chat_completion( + request: Request, form_data: GenerateChatCompletionForm, url_idx: Optional[int] = None, user=Depends(get_verified_user), @@ -1003,15 +1043,16 @@ async def generate_chat_completion( parsed_url = urlparse(url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( - f"{url}/api/chat", - json.dumps(payload), + url=f"{url}/api/chat", + payload=json.dumps(payload), stream=form_data.stream, + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", ) @@ -1043,10 +1084,13 @@ class OpenAICompletionForm(BaseModel): model_config = ConfigDict(extra="allow") -@app.post("/v1/completions") -@app.post("/v1/completions/{url_idx}") +@router.post("/v1/completions") +@router.post("/v1/completions/{url_idx}") async def generate_openai_completion( - form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user) + request: Request, + form_data: dict, + url_idx: Optional[int] = None, + user=Depends(get_verified_user), ): try: form_data = OpenAICompletionForm(**form_data) @@ -1099,22 +1143,24 @@ async def generate_openai_completion( url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( - f"{url}/v1/completions", - json.dumps(payload), + url=f"{url}/v1/completions", + payload=json.dumps(payload), stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) -@app.post("/v1/chat/completions") -@app.post("/v1/chat/completions/{url_idx}") +@router.post("/v1/chat/completions") +@router.post("/v1/chat/completions/{url_idx}") async def generate_openai_chat_completion( + request: Request, form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user), @@ -1172,21 +1218,23 @@ async def generate_openai_chat_completion( url = await get_ollama_url(url_idx, payload["model"]) log.info(f"url: {url}") - api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") return await post_streaming_url( - f"{url}/v1/chat/completions", - json.dumps(payload), + url=f"{url}/v1/chat/completions", + payload=json.dumps(payload), stream=payload.get("stream", False), + key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), ) -@app.get("/v1/models") -@app.get("/v1/models/{url_idx}") +@router.get("/v1/models") +@router.get("/v1/models/{url_idx}") async def get_openai_models( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user), ): @@ -1205,7 +1253,7 @@ async def get_openai_models( ] else: - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] try: r = requests.request(method="GET", url=f"{url}/api/tags") r.raise_for_status() @@ -1329,9 +1377,10 @@ async def download_file_stream( # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" -@app.post("/models/download") -@app.post("/models/download/{url_idx}") +@router.post("/models/download") +@router.post("/models/download/{url_idx}") async def download_model( + request: Request, form_data: UrlForm, url_idx: Optional[int] = None, user=Depends(get_admin_user), @@ -1346,7 +1395,7 @@ async def download_model( if url_idx is None: url_idx = 0 - url = app.state.config.OLLAMA_BASE_URLS[url_idx] + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] file_name = parse_huggingface_url(form_data.url) @@ -1360,16 +1409,17 @@ async def download_model( return None -@app.post("/models/upload") -@app.post("/models/upload/{url_idx}") +@router.post("/models/upload") +@router.post("/models/upload/{url_idx}") def upload_model( + request: Request, file: UploadFile = File(...), url_idx: Optional[int] = None, user=Depends(get_admin_user), ): if url_idx is None: url_idx = 0 - ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] + ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] file_path = f"{UPLOAD_DIR}/{file.filename}"