This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 02:41:25 -08:00
parent d3d161f723
commit 4819199650
2 changed files with 334 additions and 139 deletions

View File

@ -46,6 +46,21 @@ from open_webui.routers import (
retrieval, retrieval,
pipelines, pipelines,
tasks, tasks,
auths,
chats,
folders,
configs,
groups,
files,
functions,
memories,
models,
knowledge,
prompts,
evaluations,
tools,
users,
utils,
) )
from open_webui.retrieval.utils import get_sources_from_files from open_webui.retrieval.utils import get_sources_from_files
@ -117,6 +132,60 @@ from open_webui.config import (
WHISPER_MODEL, WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
# Retrieval
RAG_TEMPLATE,
DEFAULT_RAG_TEMPLATE,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_RELEVANCE_THRESHOLD,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
RAG_TOP_K,
RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME,
PDF_EXTRACT_IMAGES,
YOUTUBE_LOADER_LANGUAGE,
YOUTUBE_LOADER_PROXY_URL,
# Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
JINA_API_KEY,
SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE,
SEARXNG_QUERY_URL,
SERPER_API_KEY,
SERPLY_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
TAVILY_API_KEY,
BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY,
BRAVE_SEARCH_API_KEY,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
UPLOAD_DIR,
# WebUI # WebUI
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_NAME, WEBUI_NAME,
@ -383,6 +452,72 @@ app.state.FUNCTIONS = {}
# #
######################################## ########################################
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app.state.config.JINA_API_KEY = JINA_API_KEY
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.EMBEDDING_FUNCTION = None
######################################## ########################################
# #
# IMAGES # IMAGES
@ -1083,8 +1218,8 @@ def filter_pipeline(payload, user, models):
try: try:
urlIdx = filter["urlIdx"] urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "": if key == "":
continue continue
@ -1230,14 +1365,6 @@ async def check_url(request: Request, call_next):
return response return response
# @app.middleware("http")
# async def update_embedding_function(request: Request, call_next):
# response = await call_next(request)
# if "/embedding/update" in request.url.path:
# webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
# return response
@app.middleware("http") @app.middleware("http")
async def inspect_websocket(request: Request, call_next): async def inspect_websocket(request: Request, call_next):
if ( if (
@ -1268,18 +1395,36 @@ app.add_middleware(
app.mount("/ws", socket_app) app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app) app.include_router(ollama.router, prefix="/ollama")
app.mount("/openai", openai_app) app.include_router(openai.router, prefix="/openai")
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/retrieval/api/v1", retrieval_app) app.include_router(images.router, prefix="/api/v1/images")
app.include_router(audio.router, prefix="/api/v1/audio")
app.include_router(retrieval.router, prefix="/api/v1/retrieval")
app.mount("/api/v1", webui_app)
app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"])
app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"])
app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"])
app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"])
app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"])
app.include_router(files.router, prefix="/api/v1/files", tags=["files"])
app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"])
app.include_router(
evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"]
)
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
async def get_all_base_models(): async def get_all_base_models():

View File

@ -13,7 +13,15 @@ from aiocache import cached
import requests import requests
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile from fastapi import (
Depends,
FastAPI,
File,
HTTPException,
Request,
UploadFile,
APIRouter,
)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -26,18 +34,15 @@ from open_webui.models.models import Models
from open_webui.config import ( from open_webui.config import (
UPLOAD_DIR, UPLOAD_DIR,
) )
from open_webui.env import ( from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
BYPASS_MODEL_ACCESS_CONTROL, BYPASS_MODEL_ACCESS_CONTROL,
) )
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
from open_webui.utils.misc import ( from open_webui.utils.misc import (
calculate_sha256, calculate_sha256,
@ -54,13 +59,15 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
router = APIRouter()
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization. # least connections, or least response time for better resource utilization and performance optimization.
@app.head("/") @router.head("/")
@app.get("/") @router.get("/")
async def get_status(): async def get_status():
return {"status": True} return {"status": True}
@ -70,7 +77,7 @@ class ConnectionVerificationForm(BaseModel):
key: Optional[str] = None key: Optional[str] = None
@app.post("/verify") @router.post("/verify")
async def verify_connection( async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user) form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
): ):
@ -110,12 +117,12 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail) raise HTTPException(status_code=500, detail=error_detail)
@app.get("/config") @router.get("/config")
async def get_config(user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return { return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
} }
@ -125,23 +132,25 @@ class OllamaConfigForm(BaseModel):
OLLAMA_API_CONFIGS: dict OLLAMA_API_CONFIGS: dict
@app.post("/config/update") @router.post("/config/update")
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)): async def update_config(
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS ):
request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
# Remove any extra configs # Remove any extra configs
config_urls = app.state.config.OLLAMA_API_CONFIGS.keys() config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys()
for url in list(app.state.config.OLLAMA_BASE_URLS): for url in list(request.app.state.config.OLLAMA_BASE_URLS):
if url not in config_urls: if url not in config_urls:
app.state.config.OLLAMA_API_CONFIGS.pop(url, None) request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
return { return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API, "ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS, "OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS, "OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
} }
@ -158,6 +167,12 @@ async def aiohttp_get(url, key=None):
return None return None
def get_api_key(url, configs):
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
return configs.get(base_url, {}).get("key", None)
async def cleanup_response( async def cleanup_response(
response: Optional[aiohttp.ClientResponse], response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession], session: Optional[aiohttp.ClientSession],
@ -169,7 +184,11 @@ async def cleanup_response(
async def post_streaming_url( async def post_streaming_url(
url: str, payload: Union[str, bytes], stream: bool = True, content_type=None url: str,
payload: Union[str, bytes],
stream: bool = True,
key: Optional[str] = None,
content_type=None,
): ):
r = None r = None
try: try:
@ -177,12 +196,6 @@ async def post_streaming_url(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) )
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
if key: if key:
headers["Authorization"] = f"Bearer {key}" headers["Authorization"] = f"Bearer {key}"
@ -246,13 +259,13 @@ def merge_models_lists(model_lists):
@cached(ttl=3) @cached(ttl=3)
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
if app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
tasks = [] tasks = []
for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS): for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if url not in app.state.config.OLLAMA_API_CONFIGS: if url not in request.app.state.config.OLLAMA_API_CONFIGS:
tasks.append(aiohttp_get(f"{url}/api/tags")) tasks.append(aiohttp_get(f"{url}/api/tags"))
else: else:
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True) enable = api_config.get("enable", True)
key = api_config.get("key", None) key = api_config.get("key", None)
@ -265,8 +278,8 @@ async def get_all_models():
for idx, response in enumerate(responses): for idx, response in enumerate(responses):
if response: if response:
url = app.state.config.OLLAMA_BASE_URLS[idx] url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
model_ids = api_config.get("model_ids", []) model_ids = api_config.get("model_ids", [])
@ -298,21 +311,21 @@ async def get_all_models():
return models return models
@app.get("/api/tags") @router.get("/api/tags")
@app.get("/api/tags/{url_idx}") @router.get("/api/tags/{url_idx}")
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user) request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
): ):
models = [] models = []
if url_idx is None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {} headers = {}
@ -356,18 +369,20 @@ async def get_ollama_tags(
return models return models
@app.get("/api/version") @router.get("/api/version")
@app.get("/api/version/{url_idx}") @router.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None): async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
if url_idx is None: if url_idx is None:
# returns lowest version # returns lowest version
tasks = [ tasks = [
aiohttp_get( aiohttp_get(
f"{url}/api/version", f"{url}/api/version",
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
"key", None
),
) )
for url in app.state.config.OLLAMA_BASE_URLS for url in request.app.state.config.OLLAMA_BASE_URLS
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
@ -387,7 +402,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
) )
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None r = None
try: try:
@ -414,22 +429,24 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return {"version": False} return {"version": False}
@app.get("/api/ps") @router.get("/api/ps")
async def get_ollama_loaded_models(user=Depends(get_verified_user)): async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
""" """
List models that are currently loaded into Ollama memory, and which node they are loaded on. List models that are currently loaded into Ollama memory, and which node they are loaded on.
""" """
if app.state.config.ENABLE_OLLAMA_API: if request.app.state.config.ENABLE_OLLAMA_API:
tasks = [ tasks = [
aiohttp_get( aiohttp_get(
f"{url}/api/ps", f"{url}/api/ps",
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
"key", None
),
) )
for url in app.state.config.OLLAMA_BASE_URLS for url in request.app.state.config.OLLAMA_BASE_URLS
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses)) return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else: else:
return {} return {}
@ -438,18 +455,25 @@ class ModelNameForm(BaseModel):
name: str name: str
@app.post("/api/pull") @router.post("/api/pull")
@app.post("/api/pull/{url_idx}") @router.post("/api/pull/{url_idx}")
async def pull_model( async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) request: Request,
form_data: ModelNameForm,
url_idx: int = 0,
user=Depends(get_admin_user),
): ):
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
# Admin should be able to pull models from any source # Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload)) return await post_streaming_url(
url=f"{url}/api/pull",
payload=json.dumps(payload),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
)
class PushModelForm(BaseModel): class PushModelForm(BaseModel):
@ -458,9 +482,10 @@ class PushModelForm(BaseModel):
stream: Optional[bool] = None stream: Optional[bool] = None
@app.delete("/api/push") @router.delete("/api/push")
@app.delete("/api/push/{url_idx}") @router.delete("/api/push/{url_idx}")
async def push_model( async def push_model(
request: Request,
form_data: PushModelForm, form_data: PushModelForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
@ -477,11 +502,13 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
return await post_streaming_url( return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode() url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -492,17 +519,22 @@ class CreateModelForm(BaseModel):
path: Optional[str] = None path: Optional[str] = None
@app.post("/api/create") @router.post("/api/create")
@app.post("/api/create/{url_idx}") @router.post("/api/create/{url_idx}")
async def create_model( async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) request: Request,
form_data: CreateModelForm,
url_idx: int = 0,
user=Depends(get_admin_user),
): ):
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
return await post_streaming_url( return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode() url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -511,9 +543,10 @@ class CopyModelForm(BaseModel):
destination: str destination: str
@app.post("/api/copy") @router.post("/api/copy")
@app.post("/api/copy/{url_idx}") @router.post("/api/copy/{url_idx}")
async def copy_model( async def copy_model(
request: Request,
form_data: CopyModelForm, form_data: CopyModelForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
@ -530,13 +563,13 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
) )
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -573,9 +606,10 @@ async def copy_model(
) )
@app.delete("/api/delete") @router.delete("/api/delete")
@app.delete("/api/delete/{url_idx}") @router.delete("/api/delete/{url_idx}")
async def delete_model( async def delete_model(
request: Request,
form_data: ModelNameForm, form_data: ModelNameForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
@ -592,13 +626,13 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -634,8 +668,10 @@ async def delete_model(
) )
@app.post("/api/show") @router.post("/api/show")
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
model_list = await get_all_models() model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]} models = {model["model"]: model for model in model_list["models"]}
@ -646,13 +682,13 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
) )
url_idx = random.choice(models[form_data.name]["urls"]) url_idx = random.choice(models[form_data.name]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -701,8 +737,8 @@ class GenerateEmbedForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
@app.post("/api/embed") @router.post("/api/embed")
@app.post("/api/embed/{url_idx}") @router.post("/api/embed/{url_idx}")
async def generate_embeddings( async def generate_embeddings(
form_data: GenerateEmbedForm, form_data: GenerateEmbedForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
@ -711,8 +747,8 @@ async def generate_embeddings(
return await generate_ollama_batch_embeddings(form_data, url_idx) return await generate_ollama_batch_embeddings(form_data, url_idx)
@app.post("/api/embeddings") @router.post("/api/embeddings")
@app.post("/api/embeddings/{url_idx}") @router.post("/api/embeddings/{url_idx}")
async def generate_embeddings( async def generate_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
@ -744,13 +780,13 @@ async def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -814,13 +850,13 @@ async def generate_ollama_batch_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None) key = api_config.get("key", None)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -873,9 +909,10 @@ class GenerateCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
@app.post("/api/generate") @router.post("/api/generate")
@app.post("/api/generate/{url_idx}") @router.post("/api/generate/{url_idx}")
async def generate_completion( async def generate_completion(
request: Request,
form_data: GenerateCompletionForm, form_data: GenerateCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
@ -897,15 +934,17 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "") form_data.model = form_data.model.replace(f"{prefix_id}.", "")
log.info(f"url: {url}") log.info(f"url: {url}")
return await post_streaming_url( return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode() url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@ -936,13 +975,14 @@ async def get_ollama_url(url_idx: Optional[int], model: str):
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
) )
url_idx = random.choice(models[model]["urls"]) url_idx = random.choice(models[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
return url return url
@app.post("/api/chat") @router.post("/api/chat")
@app.post("/api/chat/{url_idx}") @router.post("/api/chat/{url_idx}")
async def generate_chat_completion( async def generate_chat_completion(
request: Request,
form_data: GenerateChatCompletionForm, form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
@ -1003,15 +1043,16 @@ async def generate_chat_completion(
parsed_url = urlparse(url) parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url( return await post_streaming_url(
f"{url}/api/chat", url=f"{url}/api/chat",
json.dumps(payload), payload=json.dumps(payload),
stream=form_data.stream, stream=form_data.stream,
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson", content_type="application/x-ndjson",
) )
@ -1043,10 +1084,13 @@ class OpenAICompletionForm(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@app.post("/v1/completions") @router.post("/v1/completions")
@app.post("/v1/completions/{url_idx}") @router.post("/v1/completions/{url_idx}")
async def generate_openai_completion( async def generate_openai_completion(
form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user) request: Request,
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
): ):
try: try:
form_data = OpenAICompletionForm(**form_data) form_data = OpenAICompletionForm(**form_data)
@ -1099,22 +1143,24 @@ async def generate_openai_completion(
url = await get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url( return await post_streaming_url(
f"{url}/v1/completions", url=f"{url}/v1/completions",
json.dumps(payload), payload=json.dumps(payload),
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@app.post("/v1/chat/completions") @router.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}") @router.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion( async def generate_openai_chat_completion(
request: Request,
form_data: dict, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
@ -1172,21 +1218,23 @@ async def generate_openai_chat_completion(
url = await get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url( return await post_streaming_url(
f"{url}/v1/chat/completions", url=f"{url}/v1/chat/completions",
json.dumps(payload), payload=json.dumps(payload),
stream=payload.get("stream", False), stream=payload.get("stream", False),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
) )
@app.get("/v1/models") @router.get("/v1/models")
@app.get("/v1/models/{url_idx}") @router.get("/v1/models/{url_idx}")
async def get_openai_models( async def get_openai_models(
request: Request,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
@ -1205,7 +1253,7 @@ async def get_openai_models(
] ]
else: else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
@ -1329,9 +1377,10 @@ async def download_file_stream(
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download") @router.post("/models/download")
@app.post("/models/download/{url_idx}") @router.post("/models/download/{url_idx}")
async def download_model( async def download_model(
request: Request,
form_data: UrlForm, form_data: UrlForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
@ -1346,7 +1395,7 @@ async def download_model(
if url_idx is None: if url_idx is None:
url_idx = 0 url_idx = 0
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url) file_name = parse_huggingface_url(form_data.url)
@ -1360,16 +1409,17 @@ async def download_model(
return None return None
@app.post("/models/upload") @router.post("/models/upload")
@app.post("/models/upload/{url_idx}") @router.post("/models/upload/{url_idx}")
def upload_model( def upload_model(
request: Request,
file: UploadFile = File(...), file: UploadFile = File(...),
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
url_idx = 0 url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}" file_path = f"{UPLOAD_DIR}/{file.filename}"