Merge pull request #2731 from cheahjs/fix/ollama-cancellation

fix: ollama and openai stream cancellation
This commit is contained in:
Timothy Jaeryang Baek 2024-06-02 16:27:11 -07:00 committed by GitHub
commit d9a3e4db39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 184 additions and 501 deletions

View File

@ -29,6 +29,8 @@ import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
REQUEST_POOL = []
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization. # least connections, or least response time for better resource utilization and performance optimization.
@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
async def fetch_url(url): async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5) timeout = aiohttp.ClientTimeout(total=5)
try: try:
@ -154,6 +143,45 @@ async def fetch_url(url):
return None return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def post_streaming_url(url: str, payload: str):
r = None
try:
session = aiohttp.ClientSession()
r = await session.post(url, data=payload)
r.raise_for_status()
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(cleanup_response, response=r, session=session),
)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=error_detail,
)
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
merged_models = {} merged_models = {}
@ -313,65 +341,7 @@ async def pull_model(
# Admin should be able to pull models from any source # Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request(): return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
nonlocal url
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class PushModelForm(BaseModel): class PushModelForm(BaseModel):
@ -399,50 +369,9 @@ async def push_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
def get_request(): )
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/push",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class CreateModelForm(BaseModel): class CreateModelForm(BaseModel):
@ -461,53 +390,9 @@ async def create_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
def get_request(): )
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/create",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
log.debug(f"r: {r}")
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class CopyModelForm(BaseModel): class CopyModelForm(BaseModel):
@ -797,66 +682,9 @@ async def generate_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
def get_request(): )
nonlocal form_data
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if form_data.stream:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/generate",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@ -1014,67 +842,7 @@ async def generate_chat_completion(
print(payload) print(payload)
r = None return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
log.exception(e)
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
# TODO: we should update this part once Ollama supports other types # TODO: we should update this part once Ollama supports other types
@ -1165,68 +933,7 @@ async def generate_openai_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
@app.get("/v1/models") @app.get("/v1/models")
@ -1555,7 +1262,7 @@ async def deprecated_proxy(
if path == "generate": if path == "generate":
data = json.loads(body.decode("utf-8")) data = json.loads(body.decode("utf-8"))
if not ("stream" in data and data["stream"] == False): if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n" yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat": elif path == "chat":

View File

@ -9,6 +9,7 @@ import json
import logging import logging
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
@ -194,6 +195,16 @@ async def fetch_url(url, key):
return None return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}") log.debug(f"merge_models_lists {model_lists}")
merged_list = [] merged_list = []
@ -447,40 +458,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
session = None
streaming = False
try: try:
r = requests.request( session = aiohttp.ClientSession()
r = await session.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
data=payload if payload else body, data=payload if payload else body,
headers=headers, headers=headers,
stream=True,
) )
r.raise_for_status() r.raise_for_status()
# Check if response is SSE # Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""): if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse( return StreamingResponse(
r.iter_content(chunk_size=8192), r.content,
status_code=r.status_code, status_code=r.status,
headers=dict(r.headers), headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
) )
else: else:
response_data = r.json() response_data = await r.json()
return response_data return response_data
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
res = r.json() res = await r.json()
print(res) print(res)
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except: except:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
raise HTTPException( finally:
status_code=r.status_code if r else 500, detail=error_detail if not streaming and session:
) if r:
r.close()
await session.close()

View File

@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
return [res, controller]; return [res, controller];
}; };
export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
let error = null;
const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
method: 'GET',
headers: {
'Content-Type': 'text/event-stream',
Authorization: `Bearer ${token}`
}
}).catch((err) => {
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const createModel = async (token: string, tagName: string, content: string) => { export const createModel = async (token: string, tagName: string, content: string) => {
let error = null; let error = null;
@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => { export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => {
let error = null; let error = null;
const controller = new AbortController();
const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, { const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, {
signal: controller.signal,
method: 'POST', method: 'POST',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
if (error) { if (error) {
throw error; throw error;
} }
return res; return [res, controller];
}; };
export const downloadModel = async ( export const downloadModel = async (

View File

@ -26,7 +26,7 @@
splitStream splitStream
} from '$lib/utils'; } from '$lib/utils';
import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion } from '$lib/apis/ollama';
import { import {
addTagById, addTagById,
createNewChat, createNewChat,
@ -65,7 +65,6 @@
let autoScroll = true; let autoScroll = true;
let processing = ''; let processing = '';
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
let currentRequestId = null;
let showModelSelector = true; let showModelSelector = true;
@ -130,10 +129,6 @@
////////////////////////// //////////////////////////
const initNewChat = async () => { const initNewChat = async () => {
if (currentRequestId !== null) {
await cancelOllamaRequest(localStorage.token, currentRequestId);
currentRequestId = null;
}
window.history.replaceState(history.state, '', `/`); window.history.replaceState(history.state, '', `/`);
await chatId.set(''); await chatId.set('');
@ -616,7 +611,6 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
await cancelOllamaRequest(localStorage.token, currentRequestId);
} else { } else {
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
const res = await chatCompleted(localStorage.token, { const res = await chatCompleted(localStorage.token, {
@ -647,8 +641,6 @@
} }
} }
currentRequestId = null;
break; break;
} }
@ -669,63 +661,58 @@
throw data; throw data;
} }
if ('id' in data) { if (data.done == false) {
console.log(data); if (responseMessage.content == '' && data.message.content == '\n') {
currentRequestId = data.id; continue;
} else {
if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '\n') {
continue;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
} else { } else {
responseMessage.done = true; responseMessage.content += data.message.content;
if (responseMessage.content == '') {
responseMessage.error = {
code: 400,
content: `Oops! No text generated from Ollama, Please try again.`
};
}
responseMessage.context = data.context ?? null;
responseMessage.info = {
total_duration: data.total_duration,
load_duration: data.load_duration,
sample_count: data.sample_count,
sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count,
eval_duration: data.eval_duration
};
messages = messages; messages = messages;
}
} else {
responseMessage.done = true;
if ($settings.notificationEnabled && !document.hasFocus()) { if (responseMessage.content == '') {
const notification = new Notification( responseMessage.error = {
selectedModelfile code: 400,
? `${ content: `Oops! No text generated from Ollama, Please try again.`
selectedModelfile.title.charAt(0).toUpperCase() + };
selectedModelfile.title.slice(1) }
}`
: `${model}`,
{
body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
}
);
}
if ($settings.responseAutoCopy) { responseMessage.context = data.context ?? null;
copyToClipboard(responseMessage.content); responseMessage.info = {
} total_duration: data.total_duration,
load_duration: data.load_duration,
sample_count: data.sample_count,
sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count,
eval_duration: data.eval_duration
};
messages = messages;
if ($settings.responseAutoPlayback) { if ($settings.notificationEnabled && !document.hasFocus()) {
await tick(); const notification = new Notification(
document.getElementById(`speak-button-${responseMessage.id}`)?.click(); selectedModelfile
} ? `${
selectedModelfile.title.charAt(0).toUpperCase() +
selectedModelfile.title.slice(1)
}`
: `${model}`,
{
body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
}
);
}
if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content);
}
if ($settings.responseAutoPlayback) {
await tick();
document.getElementById(`speak-button-${responseMessage.id}`)?.click();
} }
} }
} }

View File

@ -8,7 +8,7 @@
import Check from '$lib/components/icons/Check.svelte'; import Check from '$lib/components/icons/Check.svelte';
import Search from '$lib/components/icons/Search.svelte'; import Search from '$lib/components/icons/Search.svelte';
import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama'; import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores'; import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
import { toast } from 'svelte-sonner'; import { toast } from 'svelte-sonner';
@ -72,10 +72,12 @@
return; return;
} }
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
);
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -83,6 +85,16 @@
.pipeThrough(splitStream('\n')) .pipeThrough(splitStream('\n'))
.getReader(); .getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) { while (true) {
try { try {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
@ -101,19 +113,6 @@
throw data.detail; throw data.detail;
} }
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
@ -181,11 +180,12 @@
}); });
const cancelModelPullHandler = async (model: string) => { const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) { if (reader) {
await reader.cancel(); await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model]; delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({ MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL ...$MODEL_DOWNLOAD_POOL

View File

@ -8,7 +8,6 @@
getOllamaUrls, getOllamaUrls,
getOllamaVersion, getOllamaVersion,
pullModel, pullModel,
cancelOllamaRequest,
uploadModel, uploadModel,
getOllamaConfig getOllamaConfig
} from '$lib/apis/ollama'; } from '$lib/apis/ollama';
@ -70,12 +69,14 @@
console.log(model); console.log(model);
updateModelId = model.id; updateModelId = model.id;
const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch( const [res, controller] = await pullModel(
(error) => { localStorage.token,
toast.error(error); model.id,
return null; selectedOllamaUrlIdx
} ).catch((error) => {
); toast.error(error);
return null;
});
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -144,10 +145,12 @@
return; return;
} }
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => { const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
toast.error(error); (error) => {
return null; toast.error(error);
}); return null;
}
);
if (res) { if (res) {
const reader = res.body const reader = res.body
@ -155,6 +158,16 @@
.pipeThrough(splitStream('\n')) .pipeThrough(splitStream('\n'))
.getReader(); .getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) { while (true) {
try { try {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
@ -173,19 +186,6 @@
throw data.detail; throw data.detail;
} }
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) { if (data.status) {
if (data.digest) { if (data.digest) {
let downloadProgress = 0; let downloadProgress = 0;
@ -419,11 +419,12 @@
}; };
const cancelModelPullHandler = async (model: string) => { const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model]; const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) { if (reader) {
await reader.cancel(); await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model]; delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({ MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL ...$MODEL_DOWNLOAD_POOL

View File

@ -8,7 +8,7 @@
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants'; import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, config, user, models, settings } from '$lib/stores'; import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama'; import { generateChatCompletion } from '$lib/apis/ollama';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { splitStream } from '$lib/utils'; import { splitStream } from '$lib/utils';
@ -24,7 +24,6 @@
let selectedModelId = ''; let selectedModelId = '';
let loading = false; let loading = false;
let currentRequestId = null;
let stopResponseFlag = false; let stopResponseFlag = false;
let messagesContainerElement: HTMLDivElement; let messagesContainerElement: HTMLDivElement;
@ -46,14 +45,6 @@
} }
}; };
// const cancelHandler = async () => {
// if (currentRequestId) {
// const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
// currentRequestId = null;
// loading = false;
// }
// };
const stopResponse = () => { const stopResponse = () => {
stopResponseFlag = true; stopResponseFlag = true;
console.log('stopResponse'); console.log('stopResponse');
@ -171,8 +162,6 @@
if (stopResponseFlag) { if (stopResponseFlag) {
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
} }
currentRequestId = null;
break; break;
} }
@ -229,7 +218,6 @@
loading = false; loading = false;
stopResponseFlag = false; stopResponseFlag = false;
currentRequestId = null;
} }
}; };