feat: add agent thinking content display of deepseek R1 (#12949)

This commit is contained in:
非法操作 2025-01-24 20:13:42 +08:00 committed by GitHub
parent a2f8bce8f5
commit 59b3e672aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,10 +1,13 @@
import json
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union
import requests
from yarl import URL from yarl import URL
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageTool, PromptMessageTool,
) )
@ -36,3 +39,208 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
credentials["mode"] = LLMMode.CHAT.value credentials["mode"] = LLMMode.CHAT.value
credentials["function_calling_type"] = "tool_call" credentials["function_calling_type"] = "tool_call"
credentials["stream_function_calling"] = "support" credentials["stream_function_calling"] = "support"
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ""
chunk_index = 0
is_reasoning_started = False # Add flag to track reasoning state
def create_final_llm_result_chunk(
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = usage and usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = usage and usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
id=id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
# delimiter for stream response, need unicode_escape
import codecs
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
delimiter = codecs.decode(delimiter, "unicode_escape")
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_call_id: str):
if not tool_call_id:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = None # The default value of finish_reason is None
message_id, usage = None, None
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip()
if chunk:
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
continue
try:
chunk_json: dict = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
id=message_id,
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.",
usage=usage,
)
break
# handle the error here. for issue #11629
if chunk_json.get("error") and chunk_json.get("choices") is None:
raise ValueError(chunk_json.get("error"))
if chunk_json:
if u := chunk_json.get("usage"):
usage = u
if not chunk_json or len(chunk_json["choices"]) == 0:
continue
choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason")
message_id = chunk_json.get("id")
chunk_index += 1
if "delta" in choice:
delta = choice["delta"]
is_reasoning = delta.get("reasoning_content")
delta_content = delta.get("content") or delta.get("reasoning_content")
assistant_message_tool_calls = None
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
assistant_message_tool_calls = delta.get("tool_calls", None)
elif (
"function_call" in delta
and credentials.get("function_calling_type", "no_call") == "function_call"
):
assistant_message_tool_calls = [
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
]
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == "":
continue
# Add markdown quote markers for reasoning content
if is_reasoning:
if not is_reasoning_started:
delta_content = "> 💭 " + delta_content
is_reasoning_started = True
elif "\n\n" in delta_content:
delta_content = delta_content.replace("\n\n", "\n> ")
elif "\n" in delta_content:
delta_content = delta_content.replace("\n", "\n> ")
elif is_reasoning_started:
# If we were in reasoning mode but now getting regular content,
# add \n\n to close the reasoning block
delta_content = "\n\n" + delta_content
is_reasoning_started = False
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
)
# reset tool calls
tool_calls = []
full_assistant_content += delta_content
elif "text" in choice:
choice_text = choice.get("text", "")
if choice_text == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
yield LLMResultChunk(
id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
),
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
),
)
yield create_final_llm_result_chunk(
id=message_id,
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason,
usage=usage,
)