fix: invalid new tool call creation logic during response handling in OAI-Compat model (#17781)

This commit is contained in:
Vitor 2025-04-17 16:52:49 +08:00 committed by GitHub
parent b6b608219a
commit defd5520ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 153 additions and 32 deletions

View File

@ -1,5 +1,6 @@
import logging
import time
import uuid
from collections.abc import Generator, Sequence
from typing import Optional, Union
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__)
def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
"""
Merge incremental tool call updates into existing tool calls.
:param new_tool_calls: List of new tool call deltas to be merged.
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
def get_tool_call(tool_call_id: str):
"""
Get or create a tool call by ID
:param tool_call_id: tool call ID
:return: existing or new tool call
"""
if not tool_call_id:
return existing_tools_calls[-1]
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
if _tool_call is None:
_tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
existing_tools_calls.append(_tool_call)
return _tool_call
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
# get tool call
tool_call = get_tool_call(new_tool_call.id)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
class LargeLanguageModel(AIModel):
"""
Model class for large language model.
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in result:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls)
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint

View File

@ -0,0 +1,99 @@
from unittest.mock import MagicMock, patch
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
ToolCall = AssistantPromptMessage.ToolCall
# CASE 1: Single tool call
INPUTS_CASE_1 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_1 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
]
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
INPUTS_CASE_2 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_2 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
INPUTS_CASE_3 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_3 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 4: Tool call sequences with no IDs
INPUTS_CASE_4 = [
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_4 = [
ToolCall(
id="RANDOM_ID_1",
type="function",
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
),
ToolCall(
id="RANDOM_ID_2",
type="function",
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
),
]
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
actual = []
_increase_tool_call(inputs, actual)
assert actual == expected
def test__increase_tool_call():
# case 1:
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
# case 2:
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
# case 3:
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
# case 4:
mock_id_generator = MagicMock()
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)