mirror of
https://git.mirrors.martin98.com/https://github.com/open-webui/open-webui
synced 2025-08-16 21:26:00 +08:00
wip
This commit is contained in:
parent
d3d161f723
commit
4819199650
@ -46,6 +46,21 @@ from open_webui.routers import (
|
||||
retrieval,
|
||||
pipelines,
|
||||
tasks,
|
||||
auths,
|
||||
chats,
|
||||
folders,
|
||||
configs,
|
||||
groups,
|
||||
files,
|
||||
functions,
|
||||
memories,
|
||||
models,
|
||||
knowledge,
|
||||
prompts,
|
||||
evaluations,
|
||||
tools,
|
||||
users,
|
||||
utils,
|
||||
)
|
||||
|
||||
from open_webui.retrieval.utils import get_sources_from_files
|
||||
@ -117,6 +132,60 @@ from open_webui.config import (
|
||||
WHISPER_MODEL,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
# Retrieval
|
||||
RAG_TEMPLATE,
|
||||
DEFAULT_RAG_TEMPLATE,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_RERANKING_MODEL,
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_BATCH_SIZE,
|
||||
RAG_RELEVANCE_THRESHOLD,
|
||||
RAG_FILE_MAX_COUNT,
|
||||
RAG_FILE_MAX_SIZE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
RAG_OPENAI_API_KEY,
|
||||
RAG_OLLAMA_BASE_URL,
|
||||
RAG_OLLAMA_API_KEY,
|
||||
CHUNK_OVERLAP,
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL,
|
||||
RAG_TOP_K,
|
||||
RAG_TEXT_SPLITTER,
|
||||
TIKTOKEN_ENCODING_NAME,
|
||||
PDF_EXTRACT_IMAGES,
|
||||
YOUTUBE_LOADER_LANGUAGE,
|
||||
YOUTUBE_LOADER_PROXY_URL,
|
||||
# Retrieval (Web Search)
|
||||
RAG_WEB_SEARCH_ENGINE,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
JINA_API_KEY,
|
||||
SEARCHAPI_API_KEY,
|
||||
SEARCHAPI_ENGINE,
|
||||
SEARXNG_QUERY_URL,
|
||||
SERPER_API_KEY,
|
||||
SERPLY_API_KEY,
|
||||
SERPSTACK_API_KEY,
|
||||
SERPSTACK_HTTPS,
|
||||
TAVILY_API_KEY,
|
||||
BING_SEARCH_V7_ENDPOINT,
|
||||
BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
KAGI_SEARCH_API_KEY,
|
||||
MOJEEK_SEARCH_API_KEY,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
ENABLE_RAG_HYBRID_SEARCH,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
ENABLE_RAG_WEB_SEARCH,
|
||||
UPLOAD_DIR,
|
||||
# WebUI
|
||||
WEBUI_AUTH,
|
||||
WEBUI_NAME,
|
||||
@ -383,6 +452,72 @@ app.state.FUNCTIONS = {}
|
||||
#
|
||||
########################################
|
||||
|
||||
|
||||
app.state.config.TOP_K = RAG_TOP_K
|
||||
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||
app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE
|
||||
app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT
|
||||
|
||||
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
|
||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
|
||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
||||
|
||||
app.state.config.CHUNK_SIZE = CHUNK_SIZE
|
||||
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
|
||||
app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||
|
||||
app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
|
||||
app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
|
||||
|
||||
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
|
||||
|
||||
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
||||
app.state.config.YOUTUBE_LOADER_PROXY_URL = YOUTUBE_LOADER_PROXY_URL
|
||||
|
||||
|
||||
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
|
||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
||||
app.state.config.KAGI_SEARCH_API_KEY = KAGI_SEARCH_API_KEY
|
||||
app.state.config.MOJEEK_SEARCH_API_KEY = MOJEEK_SEARCH_API_KEY
|
||||
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
||||
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
||||
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
||||
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
||||
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
||||
app.state.config.SEARCHAPI_API_KEY = SEARCHAPI_API_KEY
|
||||
app.state.config.SEARCHAPI_ENGINE = SEARCHAPI_ENGINE
|
||||
app.state.config.JINA_API_KEY = JINA_API_KEY
|
||||
app.state.config.BING_SEARCH_V7_ENDPOINT = BING_SEARCH_V7_ENDPOINT
|
||||
app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY
|
||||
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
|
||||
|
||||
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||
app.state.EMBEDDING_FUNCTION = None
|
||||
|
||||
########################################
|
||||
#
|
||||
# IMAGES
|
||||
@ -1083,8 +1218,8 @@ def filter_pipeline(payload, user, models):
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key == "":
|
||||
continue
|
||||
@ -1230,14 +1365,6 @@ async def check_url(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
# @app.middleware("http")
|
||||
# async def update_embedding_function(request: Request, call_next):
|
||||
# response = await call_next(request)
|
||||
# if "/embedding/update" in request.url.path:
|
||||
# webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
|
||||
# return response
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def inspect_websocket(request: Request, call_next):
|
||||
if (
|
||||
@ -1268,18 +1395,36 @@ app.add_middleware(
|
||||
app.mount("/ws", socket_app)
|
||||
|
||||
|
||||
app.mount("/ollama", ollama_app)
|
||||
app.mount("/openai", openai_app)
|
||||
|
||||
app.mount("/images/api/v1", images_app)
|
||||
app.mount("/audio/api/v1", audio_app)
|
||||
app.include_router(ollama.router, prefix="/ollama")
|
||||
app.include_router(openai.router, prefix="/openai")
|
||||
|
||||
|
||||
app.mount("/retrieval/api/v1", retrieval_app)
|
||||
app.include_router(images.router, prefix="/api/v1/images")
|
||||
app.include_router(audio.router, prefix="/api/v1/audio")
|
||||
app.include_router(retrieval.router, prefix="/api/v1/retrieval")
|
||||
|
||||
app.mount("/api/v1", webui_app)
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
|
||||
app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])
|
||||
|
||||
app.include_router(auths.router, prefix="/api/v1/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/api/v1/users", tags=["users"])
|
||||
|
||||
app.include_router(chats.router, prefix="/api/v1/chats", tags=["chats"])
|
||||
|
||||
app.include_router(models.router, prefix="/api/v1/models", tags=["models"])
|
||||
app.include_router(knowledge.router, prefix="/api/v1/knowledge", tags=["knowledge"])
|
||||
app.include_router(prompts.router, prefix="/api/v1/prompts", tags=["prompts"])
|
||||
app.include_router(tools.router, prefix="/api/v1/tools", tags=["tools"])
|
||||
|
||||
app.include_router(memories.router, prefix="/api/v1/memories", tags=["memories"])
|
||||
app.include_router(folders.router, prefix="/api/v1/folders", tags=["folders"])
|
||||
app.include_router(groups.router, prefix="/api/v1/groups", tags=["groups"])
|
||||
app.include_router(files.router, prefix="/api/v1/files", tags=["files"])
|
||||
app.include_router(functions.router, prefix="/api/v1/functions", tags=["functions"])
|
||||
app.include_router(
|
||||
evaluations.router, prefix="/api/v1/evaluations", tags=["evaluations"]
|
||||
)
|
||||
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
|
||||
|
||||
|
||||
async def get_all_base_models():
|
||||
|
@ -13,7 +13,15 @@ from aiocache import cached
|
||||
|
||||
import requests
|
||||
|
||||
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
APIRouter,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@ -26,18 +34,15 @@ from open_webui.models.models import Models
|
||||
from open_webui.config import (
|
||||
UPLOAD_DIR,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.env import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST,
|
||||
BYPASS_MODEL_ACCESS_CONTROL,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS
|
||||
|
||||
|
||||
from open_webui.utils.misc import (
|
||||
calculate_sha256,
|
||||
@ -54,13 +59,15 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
||||
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
||||
# least connections, or least response time for better resource utilization and performance optimization.
|
||||
|
||||
|
||||
@app.head("/")
|
||||
@app.get("/")
|
||||
@router.head("/")
|
||||
@router.get("/")
|
||||
async def get_status():
|
||||
return {"status": True}
|
||||
|
||||
@ -70,7 +77,7 @@ class ConnectionVerificationForm(BaseModel):
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/verify")
|
||||
@router.post("/verify")
|
||||
async def verify_connection(
|
||||
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
|
||||
):
|
||||
@ -110,12 +117,12 @@ async def verify_connection(
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(user=Depends(get_admin_user)):
|
||||
@router.get("/config")
|
||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
||||
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
||||
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
||||
"ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
|
||||
"OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
|
||||
"OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
|
||||
}
|
||||
|
||||
|
||||
@ -125,23 +132,25 @@ class OllamaConfigForm(BaseModel):
|
||||
OLLAMA_API_CONFIGS: dict
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
||||
app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
||||
@router.post("/config/update")
|
||||
async def update_config(
|
||||
request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API
|
||||
request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS
|
||||
|
||||
app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
||||
request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS
|
||||
|
||||
# Remove any extra configs
|
||||
config_urls = app.state.config.OLLAMA_API_CONFIGS.keys()
|
||||
for url in list(app.state.config.OLLAMA_BASE_URLS):
|
||||
config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys()
|
||||
for url in list(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
if url not in config_urls:
|
||||
app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None)
|
||||
|
||||
return {
|
||||
"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API,
|
||||
"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS,
|
||||
"OLLAMA_API_CONFIGS": app.state.config.OLLAMA_API_CONFIGS,
|
||||
"ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API,
|
||||
"OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS,
|
||||
"OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS,
|
||||
}
|
||||
|
||||
|
||||
@ -158,6 +167,12 @@ async def aiohttp_get(url, key=None):
|
||||
return None
|
||||
|
||||
|
||||
def get_api_key(url, configs):
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
return configs.get(base_url, {}).get("key", None)
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
@ -169,7 +184,11 @@ async def cleanup_response(
|
||||
|
||||
|
||||
async def post_streaming_url(
|
||||
url: str, payload: Union[str, bytes], stream: bool = True, content_type=None
|
||||
url: str,
|
||||
payload: Union[str, bytes],
|
||||
stream: bool = True,
|
||||
key: Optional[str] = None,
|
||||
content_type=None,
|
||||
):
|
||||
r = None
|
||||
try:
|
||||
@ -177,12 +196,6 @@ async def post_streaming_url(
|
||||
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
key = api_config.get("key", None)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if key:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
@ -246,13 +259,13 @@ def merge_models_lists(model_lists):
|
||||
@cached(ttl=3)
|
||||
async def get_all_models():
|
||||
log.info("get_all_models()")
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
tasks = []
|
||||
for idx, url in enumerate(app.state.config.OLLAMA_BASE_URLS):
|
||||
if url not in app.state.config.OLLAMA_API_CONFIGS:
|
||||
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS):
|
||||
if url not in request.app.state.config.OLLAMA_API_CONFIGS:
|
||||
tasks.append(aiohttp_get(f"{url}/api/tags"))
|
||||
else:
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
enable = api_config.get("enable", True)
|
||||
key = api_config.get("key", None)
|
||||
|
||||
@ -265,8 +278,8 @@ async def get_all_models():
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
if response:
|
||||
url = app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
model_ids = api_config.get("model_ids", [])
|
||||
@ -298,21 +311,21 @@ async def get_all_models():
|
||||
return models
|
||||
|
||||
|
||||
@app.get("/api/tags")
|
||||
@app.get("/api/tags/{url_idx}")
|
||||
@router.get("/api/tags")
|
||||
@router.get("/api/tags/{url_idx}")
|
||||
async def get_ollama_tags(
|
||||
url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
||||
request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
||||
):
|
||||
models = []
|
||||
if url_idx is None:
|
||||
models = await get_all_models()
|
||||
else:
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
key = api_config.get("key", None)
|
||||
|
||||
headers = {}
|
||||
@ -356,18 +369,20 @@ async def get_ollama_tags(
|
||||
return models
|
||||
|
||||
|
||||
@app.get("/api/version")
|
||||
@app.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(url_idx: Optional[int] = None):
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
@router.get("/api/version")
|
||||
@router.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None):
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
if url_idx is None:
|
||||
# returns lowest version
|
||||
tasks = [
|
||||
aiohttp_get(
|
||||
f"{url}/api/version",
|
||||
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
|
||||
"key", None
|
||||
),
|
||||
)
|
||||
for url in app.state.config.OLLAMA_BASE_URLS
|
||||
for url in request.app.state.config.OLLAMA_BASE_URLS
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
responses = list(filter(lambda x: x is not None, responses))
|
||||
@ -387,7 +402,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
||||
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
r = None
|
||||
try:
|
||||
@ -414,22 +429,24 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
|
||||
return {"version": False}
|
||||
|
||||
|
||||
@app.get("/api/ps")
|
||||
async def get_ollama_loaded_models(user=Depends(get_verified_user)):
|
||||
@router.get("/api/ps")
|
||||
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
List models that are currently loaded into Ollama memory, and which node they are loaded on.
|
||||
"""
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
if request.app.state.config.ENABLE_OLLAMA_API:
|
||||
tasks = [
|
||||
aiohttp_get(
|
||||
f"{url}/api/ps",
|
||||
app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
|
||||
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get(
|
||||
"key", None
|
||||
),
|
||||
)
|
||||
for url in app.state.config.OLLAMA_BASE_URLS
|
||||
for url in request.app.state.config.OLLAMA_BASE_URLS
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
return dict(zip(app.state.config.OLLAMA_BASE_URLS, responses))
|
||||
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses))
|
||||
else:
|
||||
return {}
|
||||
|
||||
@ -438,18 +455,25 @@ class ModelNameForm(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@app.post("/api/pull")
|
||||
@app.post("/api/pull/{url_idx}")
|
||||
@router.post("/api/pull")
|
||||
@router.post("/api/pull/{url_idx}")
|
||||
async def pull_model(
|
||||
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
|
||||
request: Request,
|
||||
form_data: ModelNameForm,
|
||||
url_idx: int = 0,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
# Admin should be able to pull models from any source
|
||||
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
||||
|
||||
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
|
||||
return await post_streaming_url(
|
||||
url=f"{url}/api/pull",
|
||||
payload=json.dumps(payload),
|
||||
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
)
|
||||
|
||||
|
||||
class PushModelForm(BaseModel):
|
||||
@ -458,9 +482,10 @@ class PushModelForm(BaseModel):
|
||||
stream: Optional[bool] = None
|
||||
|
||||
|
||||
@app.delete("/api/push")
|
||||
@app.delete("/api/push/{url_idx}")
|
||||
@router.delete("/api/push")
|
||||
@router.delete("/api/push/{url_idx}")
|
||||
async def push_model(
|
||||
request: Request,
|
||||
form_data: PushModelForm,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
@ -477,11 +502,13 @@ async def push_model(
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.debug(f"url: {url}")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
|
||||
url=f"{url}/api/push",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
)
|
||||
|
||||
|
||||
@ -492,17 +519,22 @@ class CreateModelForm(BaseModel):
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/api/create")
|
||||
@app.post("/api/create/{url_idx}")
|
||||
@router.post("/api/create")
|
||||
@router.post("/api/create/{url_idx}")
|
||||
async def create_model(
|
||||
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
|
||||
request: Request,
|
||||
form_data: CreateModelForm,
|
||||
url_idx: int = 0,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
log.debug(f"form_data: {form_data}")
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
|
||||
url=f"{url}/api/create",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
)
|
||||
|
||||
|
||||
@ -511,9 +543,10 @@ class CopyModelForm(BaseModel):
|
||||
destination: str
|
||||
|
||||
|
||||
@app.post("/api/copy")
|
||||
@app.post("/api/copy/{url_idx}")
|
||||
@router.post("/api/copy")
|
||||
@router.post("/api/copy/{url_idx}")
|
||||
async def copy_model(
|
||||
request: Request,
|
||||
form_data: CopyModelForm,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
@ -530,13 +563,13 @@ async def copy_model(
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
key = api_config.get("key", None)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@ -573,9 +606,10 @@ async def copy_model(
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/api/delete")
|
||||
@app.delete("/api/delete/{url_idx}")
|
||||
@router.delete("/api/delete")
|
||||
@router.delete("/api/delete/{url_idx}")
|
||||
async def delete_model(
|
||||
request: Request,
|
||||
form_data: ModelNameForm,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
@ -592,13 +626,13 @@ async def delete_model(
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
key = api_config.get("key", None)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@ -634,8 +668,10 @@ async def delete_model(
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/show")
|
||||
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
|
||||
@router.post("/api/show")
|
||||
async def show_model_info(
|
||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||
):
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
@ -646,13 +682,13 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
|
||||
)
|
||||
|
||||
url_idx = random.choice(models[form_data.name]["urls"])
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
key = api_config.get("key", None)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@ -701,8 +737,8 @@ class GenerateEmbedForm(BaseModel):
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
@app.post("/api/embed")
|
||||
@app.post("/api/embed/{url_idx}")
|
||||
@router.post("/api/embed")
|
||||
@router.post("/api/embed/{url_idx}")
|
||||
async def generate_embeddings(
|
||||
form_data: GenerateEmbedForm,
|
||||
url_idx: Optional[int] = None,
|
||||
@ -711,8 +747,8 @@ async def generate_embeddings(
|
||||
return await generate_ollama_batch_embeddings(form_data, url_idx)
|
||||
|
||||
|
||||
@app.post("/api/embeddings")
|
||||
@app.post("/api/embeddings/{url_idx}")
|
||||
@router.post("/api/embeddings")
|
||||
@router.post("/api/embeddings/{url_idx}")
|
||||
async def generate_embeddings(
|
||||
form_data: GenerateEmbeddingsForm,
|
||||
url_idx: Optional[int] = None,
|
||||
@ -744,13 +780,13 @@ async def generate_ollama_embeddings(
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
key = api_config.get("key", None)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@ -814,13 +850,13 @@ async def generate_ollama_batch_embeddings(
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
key = api_config.get("key", None)
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
@ -873,9 +909,10 @@ class GenerateCompletionForm(BaseModel):
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
@app.post("/api/generate")
|
||||
@app.post("/api/generate/{url_idx}")
|
||||
@router.post("/api/generate")
|
||||
@router.post("/api/generate/{url_idx}")
|
||||
async def generate_completion(
|
||||
request: Request,
|
||||
form_data: GenerateCompletionForm,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
@ -897,15 +934,17 @@ async def generate_completion(
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
|
||||
)
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
form_data.model = form_data.model.replace(f"{prefix_id}.", "")
|
||||
log.info(f"url: {url}")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
|
||||
url=f"{url}/api/generate",
|
||||
payload=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
)
|
||||
|
||||
|
||||
@ -936,13 +975,14 @@ async def get_ollama_url(url_idx: Optional[int], model: str):
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
||||
)
|
||||
url_idx = random.choice(models[model]["urls"])
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
return url
|
||||
|
||||
|
||||
@app.post("/api/chat")
|
||||
@app.post("/api/chat/{url_idx}")
|
||||
@router.post("/api/chat")
|
||||
@router.post("/api/chat/{url_idx}")
|
||||
async def generate_chat_completion(
|
||||
request: Request,
|
||||
form_data: GenerateChatCompletionForm,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
@ -1003,15 +1043,16 @@ async def generate_chat_completion(
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/chat",
|
||||
json.dumps(payload),
|
||||
url=f"{url}/api/chat",
|
||||
payload=json.dumps(payload),
|
||||
stream=form_data.stream,
|
||||
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
content_type="application/x-ndjson",
|
||||
)
|
||||
|
||||
@ -1043,10 +1084,13 @@ class OpenAICompletionForm(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@app.post("/v1/completions/{url_idx}")
|
||||
@router.post("/v1/completions")
|
||||
@router.post("/v1/completions/{url_idx}")
|
||||
async def generate_openai_completion(
|
||||
form_data: dict, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
try:
|
||||
form_data = OpenAICompletionForm(**form_data)
|
||||
@ -1099,22 +1143,24 @@ async def generate_openai_completion(
|
||||
url = await get_ollama_url(url_idx, payload["model"])
|
||||
log.info(f"url: {url}")
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/v1/completions",
|
||||
json.dumps(payload),
|
||||
url=f"{url}/v1/completions",
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@app.post("/v1/chat/completions/{url_idx}")
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/v1/chat/completions/{url_idx}")
|
||||
async def generate_openai_chat_completion(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
@ -1172,21 +1218,23 @@ async def generate_openai_chat_completion(
|
||||
url = await get_ollama_url(url_idx, payload["model"])
|
||||
log.info(f"url: {url}")
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
prefix_id = api_config.get("prefix_id", None)
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
return await post_streaming_url(
|
||||
f"{url}/v1/chat/completions",
|
||||
json.dumps(payload),
|
||||
url=f"{url}/v1/chat/completions",
|
||||
payload=json.dumps(payload),
|
||||
stream=payload.get("stream", False),
|
||||
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
@app.get("/v1/models/{url_idx}")
|
||||
@router.get("/v1/models")
|
||||
@router.get("/v1/models/{url_idx}")
|
||||
async def get_openai_models(
|
||||
request: Request,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
@ -1205,7 +1253,7 @@ async def get_openai_models(
|
||||
]
|
||||
|
||||
else:
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
try:
|
||||
r = requests.request(method="GET", url=f"{url}/api/tags")
|
||||
r.raise_for_status()
|
||||
@ -1329,9 +1377,10 @@ async def download_file_stream(
|
||||
|
||||
|
||||
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
||||
@app.post("/models/download")
|
||||
@app.post("/models/download/{url_idx}")
|
||||
@router.post("/models/download")
|
||||
@router.post("/models/download/{url_idx}")
|
||||
async def download_model(
|
||||
request: Request,
|
||||
form_data: UrlForm,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
@ -1346,7 +1395,7 @@ async def download_model(
|
||||
|
||||
if url_idx is None:
|
||||
url_idx = 0
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
file_name = parse_huggingface_url(form_data.url)
|
||||
|
||||
@ -1360,16 +1409,17 @@ async def download_model(
|
||||
return None
|
||||
|
||||
|
||||
@app.post("/models/upload")
|
||||
@app.post("/models/upload/{url_idx}")
|
||||
@router.post("/models/upload")
|
||||
@router.post("/models/upload/{url_idx}")
|
||||
def upload_model(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
url_idx = 0
|
||||
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
file_path = f"{UPLOAD_DIR}/{file.filename}"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user