chore: refactor the OpenAICompatible and improve thinking display (#13299)

This commit is contained in:
非法操作 2025-02-07 13:28:46 +08:00 committed by GitHub
parent be46f32056
commit 3eb3db0663
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 110 deletions

View File

@ -30,6 +30,11 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HTML_THINKING_TAG = (
'<details style="color:gray;background-color: #f5f5f5;padding: 8px;border-radius: 4px;" open> '
"<summary> Thinking... </summary>"
)
class LargeLanguageModel(AIModel): class LargeLanguageModel(AIModel):
""" """
@ -400,6 +405,40 @@ if you are not sure about the structure.
), ),
) )
def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]:
"""
If the reasoning response is from delta.get("reasoning_content"), we wrap
it with HTML details tag.
:param delta: delta dictionary from LLM streaming response
:param is_reasoning: is reasoning
:return: tuple of (processed_content, is_reasoning)
"""
content = delta.get("content") or ""
reasoning_content = delta.get("reasoning_content")
if reasoning_content:
if not is_reasoning:
content = HTML_THINKING_TAG + reasoning_content
is_reasoning = True
else:
content = reasoning_content
elif is_reasoning:
content = "</details>" + content
is_reasoning = False
return content, is_reasoning
def _wrap_thinking_by_tag(self, content: str) -> str:
"""
if the reasoning response is a <think>...</think> block from delta.get("content"),
we replace <think> to <detail>.
:param content: delta.get("content")
:return: processed_content
"""
return content.replace("<think>", HTML_THINKING_TAG).replace("</think>", "</details>")
def _invoke_result_generator( def _invoke_result_generator(
self, self,
model: str, model: str,

View File

@ -1,6 +1,5 @@
import codecs
import json import json
import logging
import re
from collections.abc import Generator from collections.abc import Generator
from decimal import Decimal from decimal import Decimal
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -39,8 +38,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
""" """
@ -100,7 +97,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
:param tools: tools for tool calling :param tools: tools for tool calling
:return: :return:
""" """
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials) return self._num_tokens_from_messages(prompt_messages, tools, credentials)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """
@ -399,6 +396,73 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages)
def _create_final_llm_result_chunk(
self,
index: int,
message: AssistantPromptMessage,
finish_reason: str,
usage: dict,
model: str,
prompt_messages: list[PromptMessage],
credentials: dict,
full_content: str,
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = usage and usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = self._num_tokens_from_string(text=prompt_messages[0].content)
completion_tokens = usage and usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = self._num_tokens_from_string(text=full_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
def _get_tool_call(self, tool_call_id: str, tools_calls: list[AssistantPromptMessage.ToolCall]):
"""
Get or create a tool call by ID
:param tool_call_id: tool call ID
:param tools_calls: list of existing tool calls
:return: existing or new tool call, updated tools_calls
"""
if not tool_call_id:
return tools_calls[-1], tools_calls
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
tools_calls.append(tool_call)
return tool_call, tools_calls
def _increase_tool_call(
self, new_tool_calls: list[AssistantPromptMessage.ToolCall], tools_calls: list[AssistantPromptMessage.ToolCall]
) -> list[AssistantPromptMessage.ToolCall]:
for new_tool_call in new_tool_calls:
# get tool call
tool_call, tools_calls = self._get_tool_call(new_tool_call.function.name, tools_calls)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
return tools_calls
def _handle_generate_stream_response( def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator: ) -> Generator:
@ -411,71 +475,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
:param prompt_messages: prompt messages :param prompt_messages: prompt messages
:return: llm response chunk generator :return: llm response chunk generator
""" """
full_assistant_content = ""
chunk_index = 0 chunk_index = 0
full_assistant_content = ""
def create_final_llm_result_chunk( tools_calls: list[AssistantPromptMessage.ToolCall] = []
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict finish_reason = None
) -> LLMResultChunk: usage = None
# calculate num tokens is_reasoning_started = False
prompt_tokens = usage and usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = usage and usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
id=id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
# delimiter for stream response, need unicode_escape # delimiter for stream response, need unicode_escape
import codecs
delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = credentials.get("stream_mode_delimiter", "\n\n")
delimiter = codecs.decode(delimiter, "unicode_escape") delimiter = codecs.decode(delimiter, "unicode_escape")
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_call_id: str):
if not tool_call_id:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = None # The default value of finish_reason is None
message_id, usage = None, None
is_reasoning_started = False
is_reasoning_started_tag = False
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip() chunk = chunk.strip()
if chunk: if chunk:
@ -490,12 +498,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
chunk_json: dict = json.loads(decoded_chunk) chunk_json: dict = json.loads(decoded_chunk)
# stream ended # stream ended
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
yield create_final_llm_result_chunk( yield self._create_final_llm_result_chunk(
id=message_id,
index=chunk_index + 1, index=chunk_index + 1,
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.", finish_reason="Non-JSON encountered.",
usage=usage, usage=usage,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
full_content=full_assistant_content,
) )
break break
# handle the error here. for issue #11629 # handle the error here. for issue #11629
@ -510,42 +521,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
choice = chunk_json["choices"][0] choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason") finish_reason = chunk_json["choices"][0].get("finish_reason")
message_id = chunk_json.get("id")
chunk_index += 1 chunk_index += 1
if "delta" in choice: if "delta" in choice:
delta = choice["delta"] delta = choice["delta"]
delta_content = delta.get("content") delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content(
if not delta_content: delta, is_reasoning_started
delta_content = "" )
delta_content = self._wrap_thinking_by_tag(delta_content)
if not is_reasoning_started_tag and "<think>" in delta_content:
is_reasoning_started_tag = True
delta_content = "> 💭 " + delta_content.replace("<think>", "")
elif is_reasoning_started_tag and "</think>" in delta_content:
delta_content = delta_content.replace("</think>", "") + "\n\n"
is_reasoning_started_tag = False
elif is_reasoning_started_tag:
if "\n" in delta_content:
delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
reasoning_content = delta.get("reasoning_content")
if is_reasoning_started and not reasoning_content and not delta_content:
delta_content = ""
elif reasoning_content:
if not is_reasoning_started:
delta_content = "> 💭 " + reasoning_content
is_reasoning_started = True
else:
delta_content = reasoning_content
if "\n" in delta_content:
delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
elif is_reasoning_started:
# If we were in reasoning mode but now getting regular content,
# add \n\n to close the reasoning block
delta_content = "\n\n" + delta_content
is_reasoning_started = False
assistant_message_tool_calls = None assistant_message_tool_calls = None
@ -559,12 +542,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
] ]
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response # extract tool calls from response
if assistant_message_tool_calls: if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls) tools_calls = self._increase_tool_call(tool_calls, tools_calls)
if delta_content is None or delta_content == "": if delta_content is None or delta_content == "":
continue continue
@ -589,7 +570,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
continue continue
yield LLMResultChunk( yield LLMResultChunk(
id=message_id,
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
@ -602,7 +582,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
if tools_calls: if tools_calls:
yield LLMResultChunk( yield LLMResultChunk(
id=message_id,
model=model, model=model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
@ -611,12 +590,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
), ),
) )
yield create_final_llm_result_chunk( yield self._create_final_llm_result_chunk(
id=message_id,
index=chunk_index, index=chunk_index,
message=AssistantPromptMessage(content=""), message=AssistantPromptMessage(content=""),
finish_reason=finish_reason, finish_reason=finish_reason,
usage=usage, usage=usage,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
full_content=full_assistant_content,
) )
def _handle_generate_response( def _handle_generate_response(
@ -730,12 +712,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
return message_dict return message_dict
def _num_tokens_from_string( def _num_tokens_from_string(
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None self, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
) -> int: ) -> int:
""" """
Approximate num tokens for model with gpt2 tokenizer. Approximate num tokens for model with gpt2 tokenizer.
:param model: model name
:param text: prompt text :param text: prompt text
:param tools: tools for tool calling :param tools: tools for tool calling
:return: number of tokens :return: number of tokens
@ -758,7 +739,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
def _num_tokens_from_messages( def _num_tokens_from_messages(
self, self,
model: str,
messages: list[PromptMessage], messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
credentials: Optional[dict] = None, credentials: Optional[dict] = None,