mirror of
https://git.mirrors.martin98.com/https://github.com/open-webui/open-webui
synced 2025-08-20 05:19:10 +08:00
feat: preset backend logic
This commit is contained in:
parent
7d2ab168f1
commit
88d053833d
@ -875,15 +875,88 @@ async def generate_chat_completion(
|
|||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
model_id = get_model_id_from_custom_model_id(form_data.model)
|
|
||||||
model = model_id
|
log.debug(
|
||||||
|
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
||||||
|
form_data.model_dump_json(exclude_none=True).encode()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
**form_data.model_dump(exclude_none=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
model_id = form_data.model
|
||||||
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
print(model_info)
|
||||||
|
if model_info.base_model_id:
|
||||||
|
payload["model"] = model_info.base_model_id
|
||||||
|
|
||||||
|
model_info.params = model_info.params.model_dump()
|
||||||
|
|
||||||
|
if model_info.params:
|
||||||
|
payload["options"] = {}
|
||||||
|
|
||||||
|
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
||||||
|
payload["options"]["mirostat_eta"] = model_info.params.get(
|
||||||
|
"mirostat_eta", None
|
||||||
|
)
|
||||||
|
payload["options"]["mirostat_tau"] = model_info.params.get(
|
||||||
|
"mirostat_tau", None
|
||||||
|
)
|
||||||
|
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||||
|
|
||||||
|
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||||
|
"repeat_last_n", None
|
||||||
|
)
|
||||||
|
payload["options"]["repeat_penalty"] = model_info.params.get(
|
||||||
|
"frequency_penalty", None
|
||||||
|
)
|
||||||
|
|
||||||
|
payload["options"]["temperature"] = model_info.params.get(
|
||||||
|
"temperature", None
|
||||||
|
)
|
||||||
|
payload["options"]["seed"] = model_info.params.get("seed", None)
|
||||||
|
|
||||||
|
# TODO: add "stop" back in
|
||||||
|
# payload["stop"] = model_info.params.get("stop", None)
|
||||||
|
|
||||||
|
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
||||||
|
|
||||||
|
payload["options"]["num_predict"] = model_info.params.get(
|
||||||
|
"max_tokens", None
|
||||||
|
)
|
||||||
|
payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
||||||
|
|
||||||
|
payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
||||||
|
|
||||||
|
if model_info.params.get("system", None):
|
||||||
|
# Check if the payload already has a system message
|
||||||
|
# If not, add a system message to the payload
|
||||||
|
if payload.get("messages"):
|
||||||
|
for message in payload["messages"]:
|
||||||
|
if message.get("role") == "system":
|
||||||
|
message["content"] = (
|
||||||
|
model_info.params.get("system", None) + message["content"]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
payload["messages"].insert(
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": model_info.params.get("system", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
if ":" not in model:
|
if ":" not in payload["model"]:
|
||||||
model = f"{model}:latest"
|
payload["model"] = f"{payload['model']}:latest"
|
||||||
|
|
||||||
if model in app.state.MODELS:
|
if payload["model"] in app.state.MODELS:
|
||||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@ -893,23 +966,12 @@ async def generate_chat_completion(
|
|||||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
log.info(f"url: {url}")
|
log.info(f"url: {url}")
|
||||||
|
|
||||||
|
print(payload)
|
||||||
|
|
||||||
r = None
|
r = None
|
||||||
|
|
||||||
# payload = {
|
|
||||||
# **form_data.model_dump_json(exclude_none=True).encode(),
|
|
||||||
# "model": model,
|
|
||||||
# "messages": form_data.messages,
|
|
||||||
|
|
||||||
# }
|
|
||||||
|
|
||||||
log.debug(
|
|
||||||
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
|
||||||
form_data.model_dump_json(exclude_none=True).encode()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_request():
|
def get_request():
|
||||||
nonlocal form_data
|
nonlocal payload
|
||||||
nonlocal r
|
nonlocal r
|
||||||
|
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
@ -918,7 +980,7 @@ async def generate_chat_completion(
|
|||||||
|
|
||||||
def stream_content():
|
def stream_content():
|
||||||
try:
|
try:
|
||||||
if form_data.stream:
|
if payload.get("stream", None):
|
||||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||||
|
|
||||||
for chunk in r.iter_content(chunk_size=8192):
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
@ -936,7 +998,7 @@ async def generate_chat_completion(
|
|||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=f"{url}/api/chat",
|
url=f"{url}/api/chat",
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=json.dumps(payload),
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -992,14 +1054,56 @@ async def generate_openai_chat_completion(
|
|||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
**form_data.model_dump(exclude_none=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
model_id = form_data.model
|
||||||
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
print(model_info)
|
||||||
|
if model_info.base_model_id:
|
||||||
|
payload["model"] = model_info.base_model_id
|
||||||
|
|
||||||
|
model_info.params = model_info.params.model_dump()
|
||||||
|
|
||||||
|
if model_info.params:
|
||||||
|
payload["temperature"] = model_info.params.get("temperature", None)
|
||||||
|
payload["top_p"] = model_info.params.get("top_p", None)
|
||||||
|
payload["max_tokens"] = model_info.params.get("max_tokens", None)
|
||||||
|
payload["frequency_penalty"] = model_info.params.get(
|
||||||
|
"frequency_penalty", None
|
||||||
|
)
|
||||||
|
payload["seed"] = model_info.params.get("seed", None)
|
||||||
|
# TODO: add "stop" back in
|
||||||
|
# payload["stop"] = model_info.params.get("stop", None)
|
||||||
|
|
||||||
|
if model_info.params.get("system", None):
|
||||||
|
# Check if the payload already has a system message
|
||||||
|
# If not, add a system message to the payload
|
||||||
|
if payload.get("messages"):
|
||||||
|
for message in payload["messages"]:
|
||||||
|
if message.get("role") == "system":
|
||||||
|
message["content"] = (
|
||||||
|
model_info.params.get("system", None) + message["content"]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
payload["messages"].insert(
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": model_info.params.get("system", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if url_idx == None:
|
if url_idx == None:
|
||||||
model = form_data.model
|
if ":" not in payload["model"]:
|
||||||
|
payload["model"] = f"{payload['model']}:latest"
|
||||||
|
|
||||||
if ":" not in model:
|
if payload["model"] in app.state.MODELS:
|
||||||
model = f"{model}:latest"
|
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
||||||
|
|
||||||
if model in app.state.MODELS:
|
|
||||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@ -1012,7 +1116,7 @@ async def generate_openai_chat_completion(
|
|||||||
r = None
|
r = None
|
||||||
|
|
||||||
def get_request():
|
def get_request():
|
||||||
nonlocal form_data
|
nonlocal payload
|
||||||
nonlocal r
|
nonlocal r
|
||||||
|
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
@ -1021,7 +1125,7 @@ async def generate_openai_chat_completion(
|
|||||||
|
|
||||||
def stream_content():
|
def stream_content():
|
||||||
try:
|
try:
|
||||||
if form_data.stream:
|
if payload.get("stream"):
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{"request_id": request_id, "done": False}
|
{"request_id": request_id, "done": False}
|
||||||
) + "\n"
|
) + "\n"
|
||||||
@ -1041,7 +1145,7 @@ async def generate_openai_chat_completion(
|
|||||||
r = requests.request(
|
r = requests.request(
|
||||||
method="POST",
|
method="POST",
|
||||||
url=f"{url}/v1/chat/completions",
|
url=f"{url}/v1/chat/completions",
|
||||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
data=json.dumps(payload),
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -315,41 +315,87 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||||||
body = await request.body()
|
body = await request.body()
|
||||||
# TODO: Remove below after gpt-4-vision fix from Open AI
|
# TODO: Remove below after gpt-4-vision fix from Open AI
|
||||||
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
|
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
|
||||||
|
|
||||||
|
payload = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = body.decode("utf-8")
|
if "chat/completions" in path:
|
||||||
body = json.loads(body)
|
body = body.decode("utf-8")
|
||||||
|
body = json.loads(body)
|
||||||
|
|
||||||
print(app.state.MODELS)
|
payload = {**body}
|
||||||
|
|
||||||
model = app.state.MODELS[body.get("model")]
|
model_id = body.get("model")
|
||||||
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
idx = model["urlIdx"]
|
if model_info:
|
||||||
|
print(model_info)
|
||||||
|
if model_info.base_model_id:
|
||||||
|
payload["model"] = model_info.base_model_id
|
||||||
|
|
||||||
if "pipeline" in model and model.get("pipeline"):
|
model_info.params = model_info.params.model_dump()
|
||||||
body["user"] = {"name": user.name, "id": user.id}
|
|
||||||
body["title"] = (
|
|
||||||
True if body["stream"] == False and body["max_tokens"] == 50 else False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
if model_info.params:
|
||||||
# This is a workaround until OpenAI fixes the issue with this model
|
payload["temperature"] = model_info.params.get("temperature", None)
|
||||||
if body.get("model") == "gpt-4-vision-preview":
|
payload["top_p"] = model_info.params.get("top_p", None)
|
||||||
if "max_tokens" not in body:
|
payload["max_tokens"] = model_info.params.get("max_tokens", None)
|
||||||
body["max_tokens"] = 4000
|
payload["frequency_penalty"] = model_info.params.get(
|
||||||
log.debug("Modified body_dict:", body)
|
"frequency_penalty", None
|
||||||
|
)
|
||||||
|
payload["seed"] = model_info.params.get("seed", None)
|
||||||
|
# TODO: add "stop" back in
|
||||||
|
# payload["stop"] = model_info.params.get("stop", None)
|
||||||
|
|
||||||
# Fix for ChatGPT calls failing because the num_ctx key is in body
|
if model_info.params.get("system", None):
|
||||||
if "num_ctx" in body:
|
# Check if the payload already has a system message
|
||||||
# If 'num_ctx' is in the dictionary, delete it
|
# If not, add a system message to the payload
|
||||||
# Leaving it there generates an error with the
|
if payload.get("messages"):
|
||||||
# OpenAI API (Feb 2024)
|
for message in payload["messages"]:
|
||||||
del body["num_ctx"]
|
if message.get("role") == "system":
|
||||||
|
message["content"] = (
|
||||||
|
model_info.params.get("system", None)
|
||||||
|
+ message["content"]
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
payload["messages"].insert(
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": model_info.params.get("system", None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print(app.state.MODELS)
|
||||||
|
model = app.state.MODELS[payload.get("model")]
|
||||||
|
|
||||||
|
idx = model["urlIdx"]
|
||||||
|
|
||||||
|
if "pipeline" in model and model.get("pipeline"):
|
||||||
|
payload["user"] = {"name": user.name, "id": user.id}
|
||||||
|
payload["title"] = (
|
||||||
|
True
|
||||||
|
if payload["stream"] == False and payload["max_tokens"] == 50
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||||
|
# This is a workaround until OpenAI fixes the issue with this model
|
||||||
|
if payload.get("model") == "gpt-4-vision-preview":
|
||||||
|
if "max_tokens" not in payload:
|
||||||
|
payload["max_tokens"] = 4000
|
||||||
|
log.debug("Modified payload:", payload)
|
||||||
|
|
||||||
|
# Convert the modified body back to JSON
|
||||||
|
payload = json.dumps(payload)
|
||||||
|
|
||||||
# Convert the modified body back to JSON
|
|
||||||
body = json.dumps(body)
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
log.error("Error loading request body into a dictionary:", e)
|
log.error("Error loading request body into a dictionary:", e)
|
||||||
|
|
||||||
|
print(payload)
|
||||||
|
|
||||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||||
|
|
||||||
@ -368,7 +414,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|||||||
r = requests.request(
|
r = requests.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=target_url,
|
url=target_url,
|
||||||
data=body,
|
data=payload if payload else body,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user