diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py
index 402a30376b..780a55681a 100644
--- a/api/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/api/core/model_runtime/model_providers/__base/large_language_model.py
@@ -30,6 +30,11 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
logger = logging.getLogger(__name__)
+HTML_THINKING_TAG = (
+ ' '
+ " Thinking...
"
+)
+
class LargeLanguageModel(AIModel):
"""
@@ -400,6 +405,40 @@ if you are not sure about the structure.
),
)
+ def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]:
+ """
+ If the reasoning response is from delta.get("reasoning_content"), we wrap
+ it with HTML details tag.
+
+ :param delta: delta dictionary from LLM streaming response
+ :param is_reasoning: is reasoning
+ :return: tuple of (processed_content, is_reasoning)
+ """
+
+ content = delta.get("content") or ""
+ reasoning_content = delta.get("reasoning_content")
+
+ if reasoning_content:
+ if not is_reasoning:
+ content = HTML_THINKING_TAG + reasoning_content
+ is_reasoning = True
+ else:
+ content = reasoning_content
+ elif is_reasoning:
+ content = " " + content
+ is_reasoning = False
+ return content, is_reasoning
+
+ def _wrap_thinking_by_tag(self, content: str) -> str:
+ """
+ if the reasoning response is a ... block from delta.get("content"),
+ we replace to .
+
+ :param content: delta.get("content")
+ :return: processed_content
+ """
+ return content.replace("", HTML_THINKING_TAG).replace("", "")
+
def _invoke_result_generator(
self,
model: str,
diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
index 17aefc7efc..7f79da267f 100644
--- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
+++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
@@ -1,6 +1,5 @@
+import codecs
import json
-import logging
-import re
from collections.abc import Generator
from decimal import Decimal
from typing import Optional, Union, cast
@@ -39,8 +38,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
from core.model_runtime.utils import helper
-logger = logging.getLogger(__name__)
-
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
"""
@@ -100,7 +97,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
:param tools: tools for tool calling
:return:
"""
- return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
+ return self._num_tokens_from_messages(prompt_messages, tools, credentials)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@@ -399,6 +396,73 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
+ def _create_final_llm_result_chunk(
+ self,
+ index: int,
+ message: AssistantPromptMessage,
+ finish_reason: str,
+ usage: dict,
+ model: str,
+ prompt_messages: list[PromptMessage],
+ credentials: dict,
+ full_content: str,
+ ) -> LLMResultChunk:
+ # calculate num tokens
+ prompt_tokens = usage and usage.get("prompt_tokens")
+ if prompt_tokens is None:
+ prompt_tokens = self._num_tokens_from_string(text=prompt_messages[0].content)
+ completion_tokens = usage and usage.get("completion_tokens")
+ if completion_tokens is None:
+ completion_tokens = self._num_tokens_from_string(text=full_content)
+
+ # transform usage
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+ return LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
+ )
+
+ def _get_tool_call(self, tool_call_id: str, tools_calls: list[AssistantPromptMessage.ToolCall]):
+ """
+ Get or create a tool call by ID
+
+ :param tool_call_id: tool call ID
+ :param tools_calls: list of existing tool calls
+ :return: existing or new tool call, updated tools_calls
+ """
+ if not tool_call_id:
+ return tools_calls[-1], tools_calls
+
+ tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
+ if tool_call is None:
+ tool_call = AssistantPromptMessage.ToolCall(
+ id=tool_call_id,
+ type="function",
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
+ )
+ tools_calls.append(tool_call)
+
+ return tool_call, tools_calls
+
+ def _increase_tool_call(
+ self, new_tool_calls: list[AssistantPromptMessage.ToolCall], tools_calls: list[AssistantPromptMessage.ToolCall]
+ ) -> list[AssistantPromptMessage.ToolCall]:
+ for new_tool_call in new_tool_calls:
+ # get tool call
+ tool_call, tools_calls = self._get_tool_call(new_tool_call.function.name, tools_calls)
+ # update tool call
+ if new_tool_call.id:
+ tool_call.id = new_tool_call.id
+ if new_tool_call.type:
+ tool_call.type = new_tool_call.type
+ if new_tool_call.function.name:
+ tool_call.function.name = new_tool_call.function.name
+ if new_tool_call.function.arguments:
+ tool_call.function.arguments += new_tool_call.function.arguments
+ return tools_calls
+
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator:
@@ -411,71 +475,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
- full_assistant_content = ""
chunk_index = 0
-
- def create_final_llm_result_chunk(
- id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
- ) -> LLMResultChunk:
- # calculate num tokens
- prompt_tokens = usage and usage.get("prompt_tokens")
- if prompt_tokens is None:
- prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
- completion_tokens = usage and usage.get("completion_tokens")
- if completion_tokens is None:
- completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
-
- # transform usage
- usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
- return LLMResultChunk(
- id=id,
- model=model,
- prompt_messages=prompt_messages,
- delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
- )
-
+ full_assistant_content = ""
+ tools_calls: list[AssistantPromptMessage.ToolCall] = []
+ finish_reason = None
+ usage = None
+ is_reasoning_started = False
# delimiter for stream response, need unicode_escape
- import codecs
-
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
delimiter = codecs.decode(delimiter, "unicode_escape")
-
- tools_calls: list[AssistantPromptMessage.ToolCall] = []
-
- def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
- def get_tool_call(tool_call_id: str):
- if not tool_call_id:
- return tools_calls[-1]
-
- tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
- if tool_call is None:
- tool_call = AssistantPromptMessage.ToolCall(
- id=tool_call_id,
- type="function",
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
- )
- tools_calls.append(tool_call)
-
- return tool_call
-
- for new_tool_call in new_tool_calls:
- # get tool call
- tool_call = get_tool_call(new_tool_call.function.name)
- # update tool call
- if new_tool_call.id:
- tool_call.id = new_tool_call.id
- if new_tool_call.type:
- tool_call.type = new_tool_call.type
- if new_tool_call.function.name:
- tool_call.function.name = new_tool_call.function.name
- if new_tool_call.function.arguments:
- tool_call.function.arguments += new_tool_call.function.arguments
-
- finish_reason = None # The default value of finish_reason is None
- message_id, usage = None, None
- is_reasoning_started = False
- is_reasoning_started_tag = False
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip()
if chunk:
@@ -490,12 +498,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
chunk_json: dict = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
- yield create_final_llm_result_chunk(
- id=message_id,
+ yield self._create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.",
usage=usage,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ full_content=full_assistant_content,
)
break
# handle the error here. for issue #11629
@@ -510,42 +521,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason")
- message_id = chunk_json.get("id")
chunk_index += 1
if "delta" in choice:
delta = choice["delta"]
- delta_content = delta.get("content")
- if not delta_content:
- delta_content = ""
-
- if not is_reasoning_started_tag and "" in delta_content:
- is_reasoning_started_tag = True
- delta_content = "> 💠" + delta_content.replace("", "")
- elif is_reasoning_started_tag and "" in delta_content:
- delta_content = delta_content.replace("", "") + "\n\n"
- is_reasoning_started_tag = False
- elif is_reasoning_started_tag:
- if "\n" in delta_content:
- delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
-
- reasoning_content = delta.get("reasoning_content")
- if is_reasoning_started and not reasoning_content and not delta_content:
- delta_content = ""
- elif reasoning_content:
- if not is_reasoning_started:
- delta_content = "> 💠" + reasoning_content
- is_reasoning_started = True
- else:
- delta_content = reasoning_content
-
- if "\n" in delta_content:
- delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
- elif is_reasoning_started:
- # If we were in reasoning mode but now getting regular content,
- # add \n\n to close the reasoning block
- delta_content = "\n\n" + delta_content
- is_reasoning_started = False
+ delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content(
+ delta, is_reasoning_started
+ )
+ delta_content = self._wrap_thinking_by_tag(delta_content)
assistant_message_tool_calls = None
@@ -559,12 +542,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
]
- # assistant_message_function_call = delta.delta.function_call
-
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
- increase_tool_call(tool_calls)
+ tools_calls = self._increase_tool_call(tool_calls, tools_calls)
if delta_content is None or delta_content == "":
continue
@@ -589,7 +570,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
continue
yield LLMResultChunk(
- id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
@@ -602,7 +582,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
if tools_calls:
yield LLMResultChunk(
- id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
@@ -611,12 +590,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
),
)
- yield create_final_llm_result_chunk(
- id=message_id,
+ yield self._create_final_llm_result_chunk(
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason,
usage=usage,
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ full_content=full_assistant_content,
)
def _handle_generate_response(
@@ -730,12 +712,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
return message_dict
def _num_tokens_from_string(
- self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
+ self, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""
Approximate num tokens for model with gpt2 tokenizer.
- :param model: model name
:param text: prompt text
:param tools: tools for tool calling
:return: number of tokens
@@ -758,7 +739,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
def _num_tokens_from_messages(
self,
- model: str,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
credentials: Optional[dict] = None,