This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 02:41:25 -08:00
parent d3d161f723
commit 4819199650
2 changed files with 334 additions and 139 deletions

View File

@ -46,6 +46,21 @@ from open_webui.routers import (
retrieval,
pipelines,
tasks,
auths,
chats,
folders,
configs,
groups,
files,
functions,
memories,
models,
knowledge,
prompts,
evaluations,
tools,
users,
utils,
)
from open_webui.retrieval.utils import get_sources_from_files
@ -117,6 +132,60 @@ from open_webui.config import (
WHISPER_MODEL,
WHISPER_MODEL_AUTO_UPDATE,
WHISPER_MODEL_DIR,
# Retrieval
RAG_TEMPLATE,
DEFAULT_RAG_TEMPLATE,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_BATCH_SIZE,
RAG_RELEVANCE_THRESHOLD,
RAG_FILE_MAX_COUNT,
RAG_FILE_MAX_SIZE,
RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY,
RAG_OLLAMA_BASE_URL,
RAG_OLLAMA_API_KEY,
CHUNK_OVERLAP,
CHUNK_SIZE,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
RAG_TOP_K,
RAG_TEXT_SPLITTER,
TIKTOKEN_ENCODING_NAME,
PDF_EXTRACT_IMAGES,
YOUTUBE_LOADER_LANGUAGE,
YOUTUBE_LOADER_PROXY_URL,
# Retrieval (Web Search)
RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
JINA_API_KEY,
SEARCHAPI_API_KEY,
SEARCHAPI_ENGINE,
SEARXNG_QUERY_URL,
SERPER_API_KEY,
SERPLY_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
TAVILY_API_KEY,
BING_SEARCH_V7_ENDPOINT,
BING_SEARCH_V7_SUBSCRIPTION_KEY,
BRAVE_SEARCH_API_KEY,
KAGI_SEARCH_API_KEY,
MOJEEK_SEARCH_API_KEY,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
ENABLE_RAG_WEB_SEARCH,
UPLOAD_DIR,
# WebUI
WEBUI_AUTH,
WEBUI_NAME,
@ -383,6 +452,72 @@ app.state.FUNCTIONS = {}
#
########################################
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
app.state.config.JINA_API_KEY = JINA_API_KEY
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.EMBEDDING_FUNCTION = None
########################################
#
# IMAGES
@ -1083,8 +1218,8 @@ def filter_pipeline(payload, user, models):
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "":
continue
@ -1230,14 +1365,6 @@ async def check_url(request: Request, call_next):
return response
# @app.middleware("http")
# async def update_embedding_function(request: Request, call_next):
# response = await call_next(request)
# if "/embedding/update" in request.url.path:
# webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
# return response
@app.middleware("http")
async def inspect_websocket(request: Request, call_next):
if (
@ -1268,18 +1395,36 @@ app.add_middleware(
app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.include_router(ollama.router, prefix="/ollama")
app.include_router(openai.router, prefix="/openai")
app.mount("/retrieval/api/v1", retrieval_app)
app.include_router(images.router, prefix="/api/v1/images")
app.include_router(audio.router, prefix="/api/v1/audio")
app.include_router(retrieval.router, prefix="/api/v1/retrieval")
app.mount("/api/v1", webui_app)
app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"])
app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"])
app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"])
app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"])
app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"])
app.include_router(files.router, prefix="/api/v1/files", tags=["files"])
app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"])
app.include_router(
evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"]
)
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
async def get_all_base_models():

View File

@ -13,7 +13,15 @@ from aiocache import cached
import requests
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
from fastapi import (
Depends,
FastAPI,
File,
HTTPException,
Request,
UploadFile,
APIRouter,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict
@ -26,18 +34,15 @@ from open_webui.models.models import Models
from open_webui.config import (
UPLOAD_DIR,
)
from open_webui.env import (
ENV,
SRC_LOG_LEVELS,
AIOHTTP_CLIENT_TIMEOUT,
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
BYPASS_MODEL_ACCESS_CONTROL,
)
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
from open_webui.utils.misc import (
calculate_sha256,
@ -54,13 +59,15 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
router = APIRouter()
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization.
@app.head("/")
@app.get("/")
@router.head("/")
@router.get("/")
async def get_status():
return {"status": True}
@ -70,7 +77,7 @@ class ConnectionVerificationForm(BaseModel):
key: Optional[str] = None
@app.post("/verify")
@router.post("/verify")
async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
):
@ -110,12 +117,12 @@ async def verify_connection(
raise HTTPException(status_code=500, detail=error_detail)
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
"ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
}
@ -125,23 +132,25 @@ class OllamaConfigForm(BaseModel):
OLLAMA_API_CONFIGS: dict
@app.post("/config/update")
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
@router.post("/config/update")
async def update_config(
request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
# Remove any extra configs
config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
for url in list(app.state.config.OLLAMA_BASE_URLS):
config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys()
for url in list(request.app.state.config.OLLAMA_BASE_URLS):
if url not in config_urls:
app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
return {
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
"ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
"OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
"OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
}
@ -158,6 +167,12 @@ async def aiohttp_get(url, key=None):
return None
def get_api_key(url, configs):
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
return configs.get(base_url, {}).get("key", None)
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
@ -169,7 +184,11 @@ async def cleanup_response(
async def post_streaming_url(
url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
url: str,
payload: Union[str, bytes],
stream: bool = True,
key: Optional[str] = None,
content_type=None,
):
r = None
try:
@ -177,12 +196,6 @@ async def post_streaming_url(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
@ -246,13 +259,13 @@ def merge_models_lists(model_lists):
@cached(ttl=3)
async def get_all_models():
log.info("get_all_models()")
if app.state.config.ENABLE_OLLAMA_API:
if request.app.state.config.ENABLE_OLLAMA_API:
tasks = []
for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
if url not in app.state.config.OLLAMA_API_CONFIGS:
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
if url not in request.app.state.config.OLLAMA_API_CONFIGS:
tasks.append(aiohttp_get(f"{url}/api/tags"))
else:
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True)
key = api_config.get("key", None)
@ -265,8 +278,8 @@ async def get_all_models():
for idx, response in enumerate(responses):
if response:
url = app.state.config.OLLAMA_BASE_URLS[idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
model_ids = api_config.get("model_ids", [])
@ -298,21 +311,21 @@ async def get_all_models():
return models
@app.get("/api/tags")
@app.get("/api/tags/{url_idx}")
@router.get("/api/tags")
@router.get("/api/tags/{url_idx}")
async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user)
request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
models = []
if url_idx is None:
models = await get_all_models()
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {}
@ -356,18 +369,20 @@ async def get_ollama_tags(
return models
@app.get("/api/version")
@app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API:
@router.get("/api/version")
@router.get("/api/version/{url_idx}")
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
if request.app.state.config.ENABLE_OLLAMA_API:
if url_idx is None:
# returns lowest version
tasks = [
aiohttp_get(
f"{url}/api/version",
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
"key", None
),
)
for url in app.state.config.OLLAMA_BASE_URLS
for url in request.app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
@ -387,7 +402,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None
try:
@ -414,22 +429,24 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return {"version": False}
@app.get("/api/ps")
async def get_ollama_loaded_models(user=Depends(get_verified_user)):
@router.get("/api/ps")
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
"""
List models that are currently loaded into Ollama memory, and which node they are loaded on.
"""
if app.state.config.ENABLE_OLLAMA_API:
if request.app.state.config.ENABLE_OLLAMA_API:
tasks = [
aiohttp_get(
f"{url}/api/ps",
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
"key", None
),
)
for url in app.state.config.OLLAMA_BASE_URLS
for url in request.app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses))
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
else:
return {}
@ -438,18 +455,25 @@ class ModelNameForm(BaseModel):
name: str
@app.post("/api/pull")
@app.post("/api/pull/{url_idx}")
@router.post("/api/pull")
@router.post("/api/pull/{url_idx}")
async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
request: Request,
form_data: ModelNameForm,
url_idx: int = 0,
user=Depends(get_admin_user),
):
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
return await post_streaming_url(
url=f"{url}/api/pull",
payload=json.dumps(payload),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
)
class PushModelForm(BaseModel):
@ -458,9 +482,10 @@ class PushModelForm(BaseModel):
stream: Optional[bool] = None
@app.delete("/api/push")
@app.delete("/api/push/{url_idx}")
@router.delete("/api/push")
@router.delete("/api/push/{url_idx}")
async def push_model(
request: Request,
form_data: PushModelForm,
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
@ -477,11 +502,13 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}")
return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
url=f"{url}/api/push",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
)
@ -492,17 +519,22 @@ class CreateModelForm(BaseModel):
path: Optional[str] = None
@app.post("/api/create")
@app.post("/api/create/{url_idx}")
@router.post("/api/create")
@router.post("/api/create/{url_idx}")
async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
request: Request,
form_data: CreateModelForm,
url_idx: int = 0,
user=Depends(get_admin_user),
):
log.debug(f"form_data: {form_data}")
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
url=f"{url}/api/create",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
)
@ -511,9 +543,10 @@ class CopyModelForm(BaseModel):
destination: str
@app.post("/api/copy")
@app.post("/api/copy/{url_idx}")
@router.post("/api/copy")
@router.post("/api/copy/{url_idx}")
async def copy_model(
request: Request,
form_data: CopyModelForm,
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
@ -530,13 +563,13 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -573,9 +606,10 @@ async def copy_model(
)
@app.delete("/api/delete")
@app.delete("/api/delete/{url_idx}")
@router.delete("/api/delete")
@router.delete("/api/delete/{url_idx}")
async def delete_model(
request: Request,
form_data: ModelNameForm,
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
@ -592,13 +626,13 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -634,8 +668,10 @@ async def delete_model(
)
@app.post("/api/show")
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
@router.post("/api/show")
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
@ -646,13 +682,13 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
)
url_idx = random.choice(models[form_data.name]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -701,8 +737,8 @@ class GenerateEmbedForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None
@app.post("/api/embed")
@app.post("/api/embed/{url_idx}")
@router.post("/api/embed")
@router.post("/api/embed/{url_idx}")
async def generate_embeddings(
form_data: GenerateEmbedForm,
url_idx: Optional[int] = None,
@ -711,8 +747,8 @@ async def generate_embeddings(
return await generate_ollama_batch_embeddings(form_data, url_idx)
@app.post("/api/embeddings")
@app.post("/api/embeddings/{url_idx}")
@router.post("/api/embeddings")
@router.post("/api/embeddings/{url_idx}")
async def generate_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
@ -744,13 +780,13 @@ async def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -814,13 +850,13 @@ async def generate_ollama_batch_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
key = api_config.get("key", None)
headers = {"Content-Type": "application/json"}
@ -873,9 +909,10 @@ class GenerateCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None
@app.post("/api/generate")
@app.post("/api/generate/{url_idx}")
@router.post("/api/generate")
@router.post("/api/generate/{url_idx}")
async def generate_completion(
request: Request,
form_data: GenerateCompletionForm,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
@ -897,15 +934,17 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
log.info(f"url: {url}")
return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
url=f"{url}/api/generate",
payload=form_data.model_dump_json(exclude_none=True).encode(),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
)
@ -936,13 +975,14 @@ async def get_ollama_url(url_idx: Optional[int], model: str):
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url_idx = random.choice(models[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
return url
@app.post("/api/chat")
@app.post("/api/chat/{url_idx}")
@router.post("/api/chat")
@router.post("/api/chat/{url_idx}")
async def generate_chat_completion(
request: Request,
form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
@ -1003,15 +1043,16 @@ async def generate_chat_completion(
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url(
f"{url}/api/chat",
json.dumps(payload),
url=f"{url}/api/chat",
payload=json.dumps(payload),
stream=form_data.stream,
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
content_type="application/x-ndjson",
)
@ -1043,10 +1084,13 @@ class OpenAICompletionForm(BaseModel):
model_config = ConfigDict(extra="allow")
@app.post("/v1/completions")
@app.post("/v1/completions/{url_idx}")
@router.post("/v1/completions")
@router.post("/v1/completions/{url_idx}")
async def generate_openai_completion(
form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
request: Request,
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
try:
form_data = OpenAICompletionForm(**form_data)
@ -1099,22 +1143,24 @@ async def generate_openai_completion(
url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url(
f"{url}/v1/completions",
json.dumps(payload),
url=f"{url}/v1/completions",
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
)
@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}")
@router.post("/v1/chat/completions")
@router.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion(
request: Request,
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
@ -1172,21 +1218,23 @@ async def generate_openai_chat_completion(
url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
return await post_streaming_url(
f"{url}/v1/chat/completions",
json.dumps(payload),
url=f"{url}/v1/chat/completions",
payload=json.dumps(payload),
stream=payload.get("stream", False),
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
)
@app.get("/v1/models")
@app.get("/v1/models/{url_idx}")
@router.get("/v1/models")
@router.get("/v1/models/{url_idx}")
async def get_openai_models(
request: Request,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
@ -1205,7 +1253,7 @@ async def get_openai_models(
]
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
@ -1329,9 +1377,10 @@ async def download_file_stream(
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download")
@app.post("/models/download/{url_idx}")
@router.post("/models/download")
@router.post("/models/download/{url_idx}")
async def download_model(
request: Request,
form_data: UrlForm,
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
@ -1346,7 +1395,7 @@ async def download_model(
if url_idx is None:
url_idx = 0
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url)
@ -1360,16 +1409,17 @@ async def download_model(
return None
@app.post("/models/upload")
@app.post("/models/upload/{url_idx}")
@router.post("/models/upload")
@router.post("/models/upload/{url_idx}")
def upload_model(
request: Request,
file: UploadFile = File(...),
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx is None:
url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}"