diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 53de16d621..1b799131e7 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,5 +1,6 @@ import logging import time +import uuid from collections.abc import Generator, Sequence from typing import Optional, Union @@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager logger = logging.getLogger(__name__) +def _gen_tool_call_id() -> str: + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" + + +def _increase_tool_call( + new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] +): + """ + Merge incremental tool call updates into existing tool calls. + + :param new_tool_calls: List of new tool call deltas to be merged. + :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. + """ + + def get_tool_call(tool_call_id: str): + """ + Get or create a tool call by ID + + :param tool_call_id: tool call ID + :return: existing or new tool call + """ + if not tool_call_id: + return existing_tools_calls[-1] + + _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None) + if _tool_call is None: + _tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), + ) + existing_tools_calls.append(_tool_call) + + return _tool_call + + for new_tool_call in new_tool_calls: + # generate ID for tool calls with function name but no ID to track them + if new_tool_call.function.name and not new_tool_call.id: + new_tool_call.id = _gen_tool_call_id() + # get tool call + tool_call = get_tool_call(new_tool_call.id) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + class LargeLanguageModel(AIModel): """ Model class for large language model. @@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel): system_fingerprint = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] - def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): - def get_tool_call(tool_name: str): - if not tool_name: - return tools_calls[-1] - - tool_call = next( - (tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None - ) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id="", - type="", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), - ) - tools_calls.append(tool_call) - - return tool_call - - for new_tool_call in new_tool_calls: - # get tool call - tool_call = get_tool_call(new_tool_call.function.name) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments - for chunk in result: if isinstance(chunk.delta.message.content, str): content += chunk.delta.message.content elif isinstance(chunk.delta.message.content, list): content_list.extend(chunk.delta.message.content) if chunk.delta.message.tool_calls: - increase_tool_call(chunk.delta.message.tool_calls) + _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) usage = chunk.delta.usage or LLMUsage.empty_usage() system_fingerprint = chunk.system_fingerprint diff --git a/api/tests/unit_tests/core/model_runtime/__base/__init__.py b/api/tests/unit_tests/core/model_runtime/__base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py new file mode 100644 index 0000000000..93d8a20cac --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py @@ -0,0 +1,99 @@ +from unittest.mock import MagicMock, patch + +from core.model_runtime.entities.message_entities import AssistantPromptMessage +from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call + +ToolCall = AssistantPromptMessage.ToolCall + +# CASE 1: Single tool call +INPUTS_CASE_1 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_1 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), +] + +# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...) +INPUTS_CASE_2 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_2 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), + ToolCall( + id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') + ), +] + +# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...) +INPUTS_CASE_3 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_3 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), + ToolCall( + id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') + ), +] + +# CASE 4: Tool call sequences with no IDs +INPUTS_CASE_4 = [ + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_4 = [ + ToolCall( + id="RANDOM_ID_1", + type="function", + function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), + ), + ToolCall( + id="RANDOM_ID_2", + type="function", + function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'), + ), +] + + +def _run_case(inputs: list[ToolCall], expected: list[ToolCall]): + actual = [] + _increase_tool_call(inputs, actual) + assert actual == expected + + +def test__increase_tool_call(): + # case 1: + _run_case(INPUTS_CASE_1, EXPECTED_CASE_1) + + # case 2: + _run_case(INPUTS_CASE_2, EXPECTED_CASE_2) + + # case 3: + _run_case(INPUTS_CASE_3, EXPECTED_CASE_3) + + # case 4: + mock_id_generator = MagicMock() + mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] + with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): + _run_case(INPUTS_CASE_4, EXPECTED_CASE_4)