From 88d053833d5287b924a419fb3f926b41c348a3b4 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 25 May 2024 02:05:05 -0700 Subject: [PATCH] feat: preset backend logic --- backend/apps/ollama/main.py | 166 +++++++++++++++++++++++++++++------- backend/apps/openai/main.py | 96 +++++++++++++++------ 2 files changed, 206 insertions(+), 56 deletions(-) diff --git a/backend/apps/ollama/main.py b/backend/apps/ollama/main.py index f9028d667..5568f359d 100644 --- a/backend/apps/ollama/main.py +++ b/backend/apps/ollama/main.py @@ -875,15 +875,88 @@ async def generate_chat_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): - model_id = get_model_id_from_custom_model_id(form_data.model) - model = model_id + + log.debug( + "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( + form_data.model_dump_json(exclude_none=True).encode() + ) + ) + + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["options"] = {} + + payload["options"]["mirostat"] = model_info.params.get("mirostat", None) + payload["options"]["mirostat_eta"] = model_info.params.get( + "mirostat_eta", None + ) + payload["options"]["mirostat_tau"] = model_info.params.get( + "mirostat_tau", None + ) + payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) + + payload["options"]["repeat_last_n"] = model_info.params.get( + "repeat_last_n", None + ) + payload["options"]["repeat_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + + payload["options"]["temperature"] = model_info.params.get( + "temperature", None + ) + payload["options"]["seed"] = model_info.params.get("seed", None) + + # TODO: add "stop" back in + # payload["stop"] = model_info.params.get("stop", None) + + payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) + + payload["options"]["num_predict"] = model_info.params.get( + "max_tokens", None + ) + payload["options"]["top_k"] = model_info.params.get("top_k", None) + + payload["options"]["top_p"] = model_info.params.get("top_p", None) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) if url_idx == None: - if ":" not in model: - model = f"{model}:latest" + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -893,23 +966,12 @@ async def generate_chat_completion( url = app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") + print(payload) + r = None - # payload = { - # **form_data.model_dump_json(exclude_none=True).encode(), - # "model": model, - # "messages": form_data.messages, - - # } - - log.debug( - "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( - form_data.model_dump_json(exclude_none=True).encode() - ) - ) - def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -918,7 +980,7 @@ async def generate_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream", None): yield json.dumps({"id": request_id, "done": False}) + "\n" for chunk in r.iter_content(chunk_size=8192): @@ -936,7 +998,7 @@ async def generate_chat_completion( r = requests.request( method="POST", url=f"{url}/api/chat", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) @@ -992,14 +1054,56 @@ async def generate_openai_chat_completion( user=Depends(get_verified_user), ): + payload = { + **form_data.model_dump(exclude_none=True), + } + + model_id = form_data.model + model_info = Models.get_model_by_id(model_id) + + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id + + model_info.params = model_info.params.model_dump() + + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + # TODO: add "stop" back in + # payload["stop"] = model_info.params.get("stop", None) + + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + if url_idx == None: - model = form_data.model + if ":" not in payload["model"]: + payload["model"] = f"{payload['model']}:latest" - if ":" not in model: - model = f"{model}:latest" - - if model in app.state.MODELS: - url_idx = random.choice(app.state.MODELS[model]["urls"]) + if payload["model"] in app.state.MODELS: + url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"]) else: raise HTTPException( status_code=400, @@ -1012,7 +1116,7 @@ async def generate_openai_chat_completion( r = None def get_request(): - nonlocal form_data + nonlocal payload nonlocal r request_id = str(uuid.uuid4()) @@ -1021,7 +1125,7 @@ async def generate_openai_chat_completion( def stream_content(): try: - if form_data.stream: + if payload.get("stream"): yield json.dumps( {"request_id": request_id, "done": False} ) + "\n" @@ -1041,7 +1145,7 @@ async def generate_openai_chat_completion( r = requests.request( method="POST", url=f"{url}/v1/chat/completions", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(payload), stream=True, ) diff --git a/backend/apps/openai/main.py b/backend/apps/openai/main.py index 0b9735238..df1b28638 100644 --- a/backend/apps/openai/main.py +++ b/backend/apps/openai/main.py @@ -315,41 +315,87 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() # TODO: Remove below after gpt-4-vision fix from Open AI # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) + + payload = None + try: - body = body.decode("utf-8") - body = json.loads(body) + if "chat/completions" in path: + body = body.decode("utf-8") + body = json.loads(body) - print(app.state.MODELS) + payload = {**body} - model = app.state.MODELS[body.get("model")] + model_id = body.get("model") + model_info = Models.get_model_by_id(model_id) - idx = model["urlIdx"] + if model_info: + print(model_info) + if model_info.base_model_id: + payload["model"] = model_info.base_model_id - if "pipeline" in model and model.get("pipeline"): - body["user"] = {"name": user.name, "id": user.id} - body["title"] = ( - True if body["stream"] == False and body["max_tokens"] == 50 else False - ) + model_info.params = model_info.params.model_dump() - # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 - # This is a workaround until OpenAI fixes the issue with this model - if body.get("model") == "gpt-4-vision-preview": - if "max_tokens" not in body: - body["max_tokens"] = 4000 - log.debug("Modified body_dict:", body) + if model_info.params: + payload["temperature"] = model_info.params.get("temperature", None) + payload["top_p"] = model_info.params.get("top_p", None) + payload["max_tokens"] = model_info.params.get("max_tokens", None) + payload["frequency_penalty"] = model_info.params.get( + "frequency_penalty", None + ) + payload["seed"] = model_info.params.get("seed", None) + # TODO: add "stop" back in + # payload["stop"] = model_info.params.get("stop", None) - # Fix for ChatGPT calls failing because the num_ctx key is in body - if "num_ctx" in body: - # If 'num_ctx' is in the dictionary, delete it - # Leaving it there generates an error with the - # OpenAI API (Feb 2024) - del body["num_ctx"] + if model_info.params.get("system", None): + # Check if the payload already has a system message + # If not, add a system message to the payload + if payload.get("messages"): + for message in payload["messages"]: + if message.get("role") == "system": + message["content"] = ( + model_info.params.get("system", None) + + message["content"] + ) + break + else: + payload["messages"].insert( + 0, + { + "role": "system", + "content": model_info.params.get("system", None), + }, + ) + else: + pass + + print(app.state.MODELS) + model = app.state.MODELS[payload.get("model")] + + idx = model["urlIdx"] + + if "pipeline" in model and model.get("pipeline"): + payload["user"] = {"name": user.name, "id": user.id} + payload["title"] = ( + True + if payload["stream"] == False and payload["max_tokens"] == 50 + else False + ) + + # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 + # This is a workaround until OpenAI fixes the issue with this model + if payload.get("model") == "gpt-4-vision-preview": + if "max_tokens" not in payload: + payload["max_tokens"] = 4000 + log.debug("Modified payload:", payload) + + # Convert the modified body back to JSON + payload = json.dumps(payload) - # Convert the modified body back to JSON - body = json.dumps(body) except json.JSONDecodeError as e: log.error("Error loading request body into a dictionary:", e) + print(payload) + url = app.state.config.OPENAI_API_BASE_URLS[idx] key = app.state.config.OPENAI_API_KEYS[idx] @@ -368,7 +414,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): r = requests.request( method=request.method, url=target_url, - data=body, + data=payload if payload else body, headers=headers, stream=True, )