mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 05:45:59 +08:00
chore: refactor the OpenAICompatible and improve thinking display (#13299)
This commit is contained in:
parent
be46f32056
commit
3eb3db0663
@ -30,6 +30,11 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HTML_THINKING_TAG = (
|
||||||
|
'<details style="color:gray;background-color: #f5f5f5;padding: 8px;border-radius: 4px;" open> '
|
||||||
|
"<summary> Thinking... </summary>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LargeLanguageModel(AIModel):
|
class LargeLanguageModel(AIModel):
|
||||||
"""
|
"""
|
||||||
@ -400,6 +405,40 @@ if you are not sure about the structure.
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]:
|
||||||
|
"""
|
||||||
|
If the reasoning response is from delta.get("reasoning_content"), we wrap
|
||||||
|
it with HTML details tag.
|
||||||
|
|
||||||
|
:param delta: delta dictionary from LLM streaming response
|
||||||
|
:param is_reasoning: is reasoning
|
||||||
|
:return: tuple of (processed_content, is_reasoning)
|
||||||
|
"""
|
||||||
|
|
||||||
|
content = delta.get("content") or ""
|
||||||
|
reasoning_content = delta.get("reasoning_content")
|
||||||
|
|
||||||
|
if reasoning_content:
|
||||||
|
if not is_reasoning:
|
||||||
|
content = HTML_THINKING_TAG + reasoning_content
|
||||||
|
is_reasoning = True
|
||||||
|
else:
|
||||||
|
content = reasoning_content
|
||||||
|
elif is_reasoning:
|
||||||
|
content = "</details>" + content
|
||||||
|
is_reasoning = False
|
||||||
|
return content, is_reasoning
|
||||||
|
|
||||||
|
def _wrap_thinking_by_tag(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
if the reasoning response is a <think>...</think> block from delta.get("content"),
|
||||||
|
we replace <think> to <detail>.
|
||||||
|
|
||||||
|
:param content: delta.get("content")
|
||||||
|
:return: processed_content
|
||||||
|
"""
|
||||||
|
return content.replace("<think>", HTML_THINKING_TAG).replace("</think>", "</details>")
|
||||||
|
|
||||||
def _invoke_result_generator(
|
def _invoke_result_generator(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
|
import codecs
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union, cast
|
||||||
@ -39,8 +38,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||||
from core.model_runtime.utils import helper
|
from core.model_runtime.utils import helper
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
@ -100,7 +97,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
|
return self._num_tokens_from_messages(prompt_messages, tools, credentials)
|
||||||
|
|
||||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
"""
|
"""
|
||||||
@ -399,6 +396,73 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
|
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
def _create_final_llm_result_chunk(
|
||||||
|
self,
|
||||||
|
index: int,
|
||||||
|
message: AssistantPromptMessage,
|
||||||
|
finish_reason: str,
|
||||||
|
usage: dict,
|
||||||
|
model: str,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
credentials: dict,
|
||||||
|
full_content: str,
|
||||||
|
) -> LLMResultChunk:
|
||||||
|
# calculate num tokens
|
||||||
|
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||||
|
if prompt_tokens is None:
|
||||||
|
prompt_tokens = self._num_tokens_from_string(text=prompt_messages[0].content)
|
||||||
|
completion_tokens = usage and usage.get("completion_tokens")
|
||||||
|
if completion_tokens is None:
|
||||||
|
completion_tokens = self._num_tokens_from_string(text=full_content)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
return LLMResultChunk(
|
||||||
|
model=model,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_tool_call(self, tool_call_id: str, tools_calls: list[AssistantPromptMessage.ToolCall]):
|
||||||
|
"""
|
||||||
|
Get or create a tool call by ID
|
||||||
|
|
||||||
|
:param tool_call_id: tool call ID
|
||||||
|
:param tools_calls: list of existing tool calls
|
||||||
|
:return: existing or new tool call, updated tools_calls
|
||||||
|
"""
|
||||||
|
if not tool_call_id:
|
||||||
|
return tools_calls[-1], tools_calls
|
||||||
|
|
||||||
|
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||||
|
if tool_call is None:
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
type="function",
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||||
|
)
|
||||||
|
tools_calls.append(tool_call)
|
||||||
|
|
||||||
|
return tool_call, tools_calls
|
||||||
|
|
||||||
|
def _increase_tool_call(
|
||||||
|
self, new_tool_calls: list[AssistantPromptMessage.ToolCall], tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||||
|
) -> list[AssistantPromptMessage.ToolCall]:
|
||||||
|
for new_tool_call in new_tool_calls:
|
||||||
|
# get tool call
|
||||||
|
tool_call, tools_calls = self._get_tool_call(new_tool_call.function.name, tools_calls)
|
||||||
|
# update tool call
|
||||||
|
if new_tool_call.id:
|
||||||
|
tool_call.id = new_tool_call.id
|
||||||
|
if new_tool_call.type:
|
||||||
|
tool_call.type = new_tool_call.type
|
||||||
|
if new_tool_call.function.name:
|
||||||
|
tool_call.function.name = new_tool_call.function.name
|
||||||
|
if new_tool_call.function.arguments:
|
||||||
|
tool_call.function.arguments += new_tool_call.function.arguments
|
||||||
|
return tools_calls
|
||||||
|
|
||||||
def _handle_generate_stream_response(
|
def _handle_generate_stream_response(
|
||||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
@ -411,71 +475,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator
|
:return: llm response chunk generator
|
||||||
"""
|
"""
|
||||||
full_assistant_content = ""
|
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
|
full_assistant_content = ""
|
||||||
def create_final_llm_result_chunk(
|
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
finish_reason = None
|
||||||
) -> LLMResultChunk:
|
usage = None
|
||||||
# calculate num tokens
|
is_reasoning_started = False
|
||||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
|
||||||
if prompt_tokens is None:
|
|
||||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
|
||||||
completion_tokens = usage and usage.get("completion_tokens")
|
|
||||||
if completion_tokens is None:
|
|
||||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
|
||||||
|
|
||||||
# transform usage
|
|
||||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
||||||
|
|
||||||
return LLMResultChunk(
|
|
||||||
id=id,
|
|
||||||
model=model,
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
|
||||||
)
|
|
||||||
|
|
||||||
# delimiter for stream response, need unicode_escape
|
# delimiter for stream response, need unicode_escape
|
||||||
import codecs
|
|
||||||
|
|
||||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||||
|
|
||||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
|
||||||
|
|
||||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
|
||||||
def get_tool_call(tool_call_id: str):
|
|
||||||
if not tool_call_id:
|
|
||||||
return tools_calls[-1]
|
|
||||||
|
|
||||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
|
||||||
if tool_call is None:
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
type="function",
|
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
|
||||||
)
|
|
||||||
tools_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_call
|
|
||||||
|
|
||||||
for new_tool_call in new_tool_calls:
|
|
||||||
# get tool call
|
|
||||||
tool_call = get_tool_call(new_tool_call.function.name)
|
|
||||||
# update tool call
|
|
||||||
if new_tool_call.id:
|
|
||||||
tool_call.id = new_tool_call.id
|
|
||||||
if new_tool_call.type:
|
|
||||||
tool_call.type = new_tool_call.type
|
|
||||||
if new_tool_call.function.name:
|
|
||||||
tool_call.function.name = new_tool_call.function.name
|
|
||||||
if new_tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
|
||||||
|
|
||||||
finish_reason = None # The default value of finish_reason is None
|
|
||||||
message_id, usage = None, None
|
|
||||||
is_reasoning_started = False
|
|
||||||
is_reasoning_started_tag = False
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||||
chunk = chunk.strip()
|
chunk = chunk.strip()
|
||||||
if chunk:
|
if chunk:
|
||||||
@ -490,12 +498,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
chunk_json: dict = json.loads(decoded_chunk)
|
chunk_json: dict = json.loads(decoded_chunk)
|
||||||
# stream ended
|
# stream ended
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
yield create_final_llm_result_chunk(
|
yield self._create_final_llm_result_chunk(
|
||||||
id=message_id,
|
|
||||||
index=chunk_index + 1,
|
index=chunk_index + 1,
|
||||||
message=AssistantPromptMessage(content=""),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason="Non-JSON encountered.",
|
finish_reason="Non-JSON encountered.",
|
||||||
usage=usage,
|
usage=usage,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
full_content=full_assistant_content,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
# handle the error here. for issue #11629
|
# handle the error here. for issue #11629
|
||||||
@ -510,42 +521,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
|
|
||||||
choice = chunk_json["choices"][0]
|
choice = chunk_json["choices"][0]
|
||||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||||
message_id = chunk_json.get("id")
|
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
|
||||||
if "delta" in choice:
|
if "delta" in choice:
|
||||||
delta = choice["delta"]
|
delta = choice["delta"]
|
||||||
delta_content = delta.get("content")
|
delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content(
|
||||||
if not delta_content:
|
delta, is_reasoning_started
|
||||||
delta_content = ""
|
)
|
||||||
|
delta_content = self._wrap_thinking_by_tag(delta_content)
|
||||||
if not is_reasoning_started_tag and "<think>" in delta_content:
|
|
||||||
is_reasoning_started_tag = True
|
|
||||||
delta_content = "> 💭 " + delta_content.replace("<think>", "")
|
|
||||||
elif is_reasoning_started_tag and "</think>" in delta_content:
|
|
||||||
delta_content = delta_content.replace("</think>", "") + "\n\n"
|
|
||||||
is_reasoning_started_tag = False
|
|
||||||
elif is_reasoning_started_tag:
|
|
||||||
if "\n" in delta_content:
|
|
||||||
delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
|
|
||||||
|
|
||||||
reasoning_content = delta.get("reasoning_content")
|
|
||||||
if is_reasoning_started and not reasoning_content and not delta_content:
|
|
||||||
delta_content = ""
|
|
||||||
elif reasoning_content:
|
|
||||||
if not is_reasoning_started:
|
|
||||||
delta_content = "> 💭 " + reasoning_content
|
|
||||||
is_reasoning_started = True
|
|
||||||
else:
|
|
||||||
delta_content = reasoning_content
|
|
||||||
|
|
||||||
if "\n" in delta_content:
|
|
||||||
delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
|
|
||||||
elif is_reasoning_started:
|
|
||||||
# If we were in reasoning mode but now getting regular content,
|
|
||||||
# add \n\n to close the reasoning block
|
|
||||||
delta_content = "\n\n" + delta_content
|
|
||||||
is_reasoning_started = False
|
|
||||||
|
|
||||||
assistant_message_tool_calls = None
|
assistant_message_tool_calls = None
|
||||||
|
|
||||||
@ -559,12 +542,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
||||||
]
|
]
|
||||||
|
|
||||||
# assistant_message_function_call = delta.delta.function_call
|
|
||||||
|
|
||||||
# extract tool calls from response
|
# extract tool calls from response
|
||||||
if assistant_message_tool_calls:
|
if assistant_message_tool_calls:
|
||||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||||
increase_tool_call(tool_calls)
|
tools_calls = self._increase_tool_call(tool_calls, tools_calls)
|
||||||
|
|
||||||
if delta_content is None or delta_content == "":
|
if delta_content is None or delta_content == "":
|
||||||
continue
|
continue
|
||||||
@ -589,7 +570,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
id=message_id,
|
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
@ -602,7 +582,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
|
|
||||||
if tools_calls:
|
if tools_calls:
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
id=message_id,
|
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
@ -611,12 +590,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
yield create_final_llm_result_chunk(
|
yield self._create_final_llm_result_chunk(
|
||||||
id=message_id,
|
|
||||||
index=chunk_index,
|
index=chunk_index,
|
||||||
message=AssistantPromptMessage(content=""),
|
message=AssistantPromptMessage(content=""),
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
full_content=full_assistant_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_generate_response(
|
def _handle_generate_response(
|
||||||
@ -730,12 +712,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def _num_tokens_from_string(
|
def _num_tokens_from_string(
|
||||||
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
|
self, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Approximate num tokens for model with gpt2 tokenizer.
|
Approximate num tokens for model with gpt2 tokenizer.
|
||||||
|
|
||||||
:param model: model name
|
|
||||||
:param text: prompt text
|
:param text: prompt text
|
||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return: number of tokens
|
:return: number of tokens
|
||||||
@ -758,7 +739,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
|||||||
|
|
||||||
def _num_tokens_from_messages(
|
def _num_tokens_from_messages(
|
||||||
self,
|
self,
|
||||||
model: str,
|
|
||||||
messages: list[PromptMessage],
|
messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
credentials: Optional[dict] = None,
|
credentials: Optional[dict] = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user