From 3eb3db0663ab54ec6ef743d5352bd2c049c13771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Fri, 7 Feb 2025 13:28:46 +0800 Subject: [PATCH] chore: refactor the OpenAICompatible and improve thinking display (#13299) --- .../__base/large_language_model.py | 39 ++++ .../openai_api_compatible/llm/llm.py | 200 ++++++++---------- 2 files changed, 129 insertions(+), 110 deletions(-) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 402a30376b..780a55681a 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -30,6 +30,11 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) +HTML_THINKING_TAG = ( + '
' + " Thinking... " +) + class LargeLanguageModel(AIModel): """ @@ -400,6 +405,40 @@ if you are not sure about the structure. ), ) + def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]: + """ + If the reasoning response is from delta.get("reasoning_content"), we wrap + it with HTML details tag. + + :param delta: delta dictionary from LLM streaming response + :param is_reasoning: is reasoning + :return: tuple of (processed_content, is_reasoning) + """ + + content = delta.get("content") or "" + reasoning_content = delta.get("reasoning_content") + + if reasoning_content: + if not is_reasoning: + content = HTML_THINKING_TAG + reasoning_content + is_reasoning = True + else: + content = reasoning_content + elif is_reasoning: + content = "
" + content + is_reasoning = False + return content, is_reasoning + + def _wrap_thinking_by_tag(self, content: str) -> str: + """ + if the reasoning response is a ... block from delta.get("content"), + we replace to . + + :param content: delta.get("content") + :return: processed_content + """ + return content.replace("", HTML_THINKING_TAG).replace("", "") + def _invoke_result_generator( self, model: str, diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 17aefc7efc..7f79da267f 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -1,6 +1,5 @@ +import codecs import json -import logging -import re from collections.abc import Generator from decimal import Decimal from typing import Optional, Union, cast @@ -39,8 +38,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat from core.model_runtime.utils import helper -logger = logging.getLogger(__name__) - class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): """ @@ -100,7 +97,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): :param tools: tools for tool calling :return: """ - return self._num_tokens_from_messages(model, prompt_messages, tools, credentials) + return self._num_tokens_from_messages(prompt_messages, tools, credentials) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -399,6 +396,73 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): return self._handle_generate_response(model, credentials, response, prompt_messages) + def _create_final_llm_result_chunk( + self, + index: int, + message: AssistantPromptMessage, + finish_reason: str, + usage: dict, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + full_content: str, + ) -> LLMResultChunk: + # calculate num tokens + prompt_tokens = usage and usage.get("prompt_tokens") + if prompt_tokens is None: + prompt_tokens = self._num_tokens_from_string(text=prompt_messages[0].content) + completion_tokens = usage and usage.get("completion_tokens") + if completion_tokens is None: + completion_tokens = self._num_tokens_from_string(text=full_content) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + return LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), + ) + + def _get_tool_call(self, tool_call_id: str, tools_calls: list[AssistantPromptMessage.ToolCall]): + """ + Get or create a tool call by ID + + :param tool_call_id: tool call ID + :param tools_calls: list of existing tool calls + :return: existing or new tool call, updated tools_calls + """ + if not tool_call_id: + return tools_calls[-1], tools_calls + + tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) + if tool_call is None: + tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), + ) + tools_calls.append(tool_call) + + return tool_call, tools_calls + + def _increase_tool_call( + self, new_tool_calls: list[AssistantPromptMessage.ToolCall], tools_calls: list[AssistantPromptMessage.ToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: + for new_tool_call in new_tool_calls: + # get tool call + tool_call, tools_calls = self._get_tool_call(new_tool_call.function.name, tools_calls) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + return tools_calls + def _handle_generate_stream_response( self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] ) -> Generator: @@ -411,71 +475,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = "" chunk_index = 0 - - def create_final_llm_result_chunk( - id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict - ) -> LLMResultChunk: - # calculate num tokens - prompt_tokens = usage and usage.get("prompt_tokens") - if prompt_tokens is None: - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = usage and usage.get("completion_tokens") - if completion_tokens is None: - completion_tokens = self._num_tokens_from_string(model, full_assistant_content) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - return LLMResultChunk( - id=id, - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), - ) - + full_assistant_content = "" + tools_calls: list[AssistantPromptMessage.ToolCall] = [] + finish_reason = None + usage = None + is_reasoning_started = False # delimiter for stream response, need unicode_escape - import codecs - delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") - - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): - def get_tool_call(tool_call_id: str): - if not tool_call_id: - return tools_calls[-1] - - tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - tools_calls.append(tool_call) - - return tool_call - - for new_tool_call in new_tool_calls: - # get tool call - tool_call = get_tool_call(new_tool_call.function.name) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments - - finish_reason = None # The default value of finish_reason is None - message_id, usage = None, None - is_reasoning_started = False - is_reasoning_started_tag = False for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): chunk = chunk.strip() if chunk: @@ -490,12 +498,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): chunk_json: dict = json.loads(decoded_chunk) # stream ended except json.JSONDecodeError as e: - yield create_final_llm_result_chunk( - id=message_id, + yield self._create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), finish_reason="Non-JSON encountered.", usage=usage, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + full_content=full_assistant_content, ) break # handle the error here. for issue #11629 @@ -510,42 +521,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): choice = chunk_json["choices"][0] finish_reason = chunk_json["choices"][0].get("finish_reason") - message_id = chunk_json.get("id") chunk_index += 1 if "delta" in choice: delta = choice["delta"] - delta_content = delta.get("content") - if not delta_content: - delta_content = "" - - if not is_reasoning_started_tag and "" in delta_content: - is_reasoning_started_tag = True - delta_content = "> 💭 " + delta_content.replace("", "") - elif is_reasoning_started_tag and "" in delta_content: - delta_content = delta_content.replace("", "") + "\n\n" - is_reasoning_started_tag = False - elif is_reasoning_started_tag: - if "\n" in delta_content: - delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content) - - reasoning_content = delta.get("reasoning_content") - if is_reasoning_started and not reasoning_content and not delta_content: - delta_content = "" - elif reasoning_content: - if not is_reasoning_started: - delta_content = "> 💭 " + reasoning_content - is_reasoning_started = True - else: - delta_content = reasoning_content - - if "\n" in delta_content: - delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content) - elif is_reasoning_started: - # If we were in reasoning mode but now getting regular content, - # add \n\n to close the reasoning block - delta_content = "\n\n" + delta_content - is_reasoning_started = False + delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content( + delta, is_reasoning_started + ) + delta_content = self._wrap_thinking_by_tag(delta_content) assistant_message_tool_calls = None @@ -559,12 +542,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} ] - # assistant_message_function_call = delta.delta.function_call - # extract tool calls from response if assistant_message_tool_calls: tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) - increase_tool_call(tool_calls) + tools_calls = self._increase_tool_call(tool_calls, tools_calls) if delta_content is None or delta_content == "": continue @@ -589,7 +570,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): continue yield LLMResultChunk( - id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -602,7 +582,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): if tools_calls: yield LLMResultChunk( - id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( @@ -611,12 +590,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): ), ) - yield create_final_llm_result_chunk( - id=message_id, + yield self._create_final_llm_result_chunk( index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason, usage=usage, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + full_content=full_assistant_content, ) def _handle_generate_response( @@ -730,12 +712,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): return message_dict def _num_tokens_from_string( - self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None + self, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None ) -> int: """ Approximate num tokens for model with gpt2 tokenizer. - :param model: model name :param text: prompt text :param tools: tools for tool calling :return: number of tokens @@ -758,7 +739,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): def _num_tokens_from_messages( self, - model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None, credentials: Optional[dict] = None,