diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 13761f8cb..51fa711ca 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -19,7 +19,7 @@ from apps.webui.models.functions import Functions from apps.webui.models.models import Models from apps.webui.utils import load_function_module_by_id -from utils.misc import stream_message_template +from utils.misc import stream_message_template, whole_message_template from utils.task import prompt_template @@ -203,7 +203,7 @@ async def execute_pipe(pipe, params): return pipe(**params) -async def get_message(res: str | Generator | AsyncGenerator) -> str: +async def get_message_content(res: str | Generator | AsyncGenerator) -> str: if isinstance(res, str): return res if isinstance(res, Generator): @@ -212,28 +212,6 @@ async def get_message(res: str | Generator | AsyncGenerator) -> str: return "".join([str(stream) async for stream in res]) -def get_final_message(form_data: dict, message: str | None = None) -> dict: - choice = { - "index": 0, - "logprobs": None, - "finish_reason": "stop", - } - - # If message is None, we're dealing with a chunk - if not message: - choice["delta"] = {} - else: - choice["message"] = {"role": "assistant", "content": message} - - return { - "id": f"{form_data['model']}-{str(uuid.uuid4())}", - "created": int(time.time()), - "model": form_data["model"], - "object": "chat.completion" if message is not None else "chat.completion.chunk", - "choices": [choice], - } - - def process_line(form_data: dict, line): if isinstance(line, BaseModel): line = line.model_dump_json() @@ -292,7 +270,9 @@ def get_params_dict(pipe, form_data, user, extra_params, function_module): def get_extra_params(metadata: dict): - __event_emitter__ = __event_call__ = __task__ = None + __event_emitter__ = None + __event_call__ = None + __task__ = None if metadata: if all(k in metadata for k in ("session_id", "chat_id", "message_id")): @@ -401,7 +381,8 @@ async def generate_function_chat_completion(form_data, user): yield process_line(form_data, line) if isinstance(res, str) or isinstance(res, Generator): - finish_message = get_final_message(form_data) + finish_message = stream_message_template(form_data, "") + finish_message["choices"][0]["finish_reason"] = "stop" yield f"data: {json.dumps(finish_message)}\n\n" yield "data: [DONE]" @@ -419,7 +400,7 @@ async def generate_function_chat_completion(form_data, user): if isinstance(res, BaseModel): return res.model_dump() - message = await get_message(res) - return get_final_message(form_data, message) + message = await get_message_content(res) + return whole_message_template(form_data["model"], message) return await job() diff --git a/backend/utils/misc.py b/backend/utils/misc.py index f44a7ce7a..a1e0a8e80 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -87,23 +87,29 @@ def add_or_update_system_message(content: str, messages: List[dict]): return messages -def stream_message_template(model: str, message: str): +def message_template(model: str): return { "id": f"{model}-{str(uuid.uuid4())}", - "object": "chat.completion.chunk", "created": int(time.time()), "model": model, - "choices": [ - { - "index": 0, - "delta": {"content": message}, - "logprobs": None, - "finish_reason": None, - } - ], + "choices": [{"index": 0, "logprobs": None, "finish_reason": None}], } +def stream_message_template(model: str, message: str): + template = message_template(model) + template["object"] = "chat.completion.chunk" + template["choices"][0]["delta"] = {"content": message} + return template + + +def whole_message_template(model: str, message: str): + template = message_template(model) + template["object"] = "chat.completion" + template["choices"][0]["message"] = {"content": message, "role": "assistant"} + template["choices"][0]["finish_reason"] = "stop" + + def get_gravatar_url(email): # Trim leading and trailing whitespace from # an email address and force all characters