mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 21:48:58 +08:00
fix: invalid new tool call creation logic during response handling in OAI-Compat model (#17781)
This commit is contained in:
parent
b6b608219a
commit
defd5520ea
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from collections.abc import Generator, Sequence
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_tool_call_id() -> str:
|
||||||
|
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _increase_tool_call(
|
||||||
|
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Merge incremental tool call updates into existing tool calls.
|
||||||
|
|
||||||
|
:param new_tool_calls: List of new tool call deltas to be merged.
|
||||||
|
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_tool_call(tool_call_id: str):
|
||||||
|
"""
|
||||||
|
Get or create a tool call by ID
|
||||||
|
|
||||||
|
:param tool_call_id: tool call ID
|
||||||
|
:return: existing or new tool call
|
||||||
|
"""
|
||||||
|
if not tool_call_id:
|
||||||
|
return existing_tools_calls[-1]
|
||||||
|
|
||||||
|
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
|
||||||
|
if _tool_call is None:
|
||||||
|
_tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call_id,
|
||||||
|
type="function",
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||||
|
)
|
||||||
|
existing_tools_calls.append(_tool_call)
|
||||||
|
|
||||||
|
return _tool_call
|
||||||
|
|
||||||
|
for new_tool_call in new_tool_calls:
|
||||||
|
# generate ID for tool calls with function name but no ID to track them
|
||||||
|
if new_tool_call.function.name and not new_tool_call.id:
|
||||||
|
new_tool_call.id = _gen_tool_call_id()
|
||||||
|
# get tool call
|
||||||
|
tool_call = get_tool_call(new_tool_call.id)
|
||||||
|
# update tool call
|
||||||
|
if new_tool_call.id:
|
||||||
|
tool_call.id = new_tool_call.id
|
||||||
|
if new_tool_call.type:
|
||||||
|
tool_call.type = new_tool_call.type
|
||||||
|
if new_tool_call.function.name:
|
||||||
|
tool_call.function.name = new_tool_call.function.name
|
||||||
|
if new_tool_call.function.arguments:
|
||||||
|
tool_call.function.arguments += new_tool_call.function.arguments
|
||||||
|
|
||||||
|
|
||||||
class LargeLanguageModel(AIModel):
|
class LargeLanguageModel(AIModel):
|
||||||
"""
|
"""
|
||||||
Model class for large language model.
|
Model class for large language model.
|
||||||
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
|
|||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||||
|
|
||||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
|
||||||
def get_tool_call(tool_name: str):
|
|
||||||
if not tool_name:
|
|
||||||
return tools_calls[-1]
|
|
||||||
|
|
||||||
tool_call = next(
|
|
||||||
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
|
|
||||||
)
|
|
||||||
if tool_call is None:
|
|
||||||
tool_call = AssistantPromptMessage.ToolCall(
|
|
||||||
id="",
|
|
||||||
type="",
|
|
||||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
|
|
||||||
)
|
|
||||||
tools_calls.append(tool_call)
|
|
||||||
|
|
||||||
return tool_call
|
|
||||||
|
|
||||||
for new_tool_call in new_tool_calls:
|
|
||||||
# get tool call
|
|
||||||
tool_call = get_tool_call(new_tool_call.function.name)
|
|
||||||
# update tool call
|
|
||||||
if new_tool_call.id:
|
|
||||||
tool_call.id = new_tool_call.id
|
|
||||||
if new_tool_call.type:
|
|
||||||
tool_call.type = new_tool_call.type
|
|
||||||
if new_tool_call.function.name:
|
|
||||||
tool_call.function.name = new_tool_call.function.name
|
|
||||||
if new_tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
|
||||||
|
|
||||||
for chunk in result:
|
for chunk in result:
|
||||||
if isinstance(chunk.delta.message.content, str):
|
if isinstance(chunk.delta.message.content, str):
|
||||||
content += chunk.delta.message.content
|
content += chunk.delta.message.content
|
||||||
elif isinstance(chunk.delta.message.content, list):
|
elif isinstance(chunk.delta.message.content, list):
|
||||||
content_list.extend(chunk.delta.message.content)
|
content_list.extend(chunk.delta.message.content)
|
||||||
if chunk.delta.message.tool_calls:
|
if chunk.delta.message.tool_calls:
|
||||||
increase_tool_call(chunk.delta.message.tool_calls)
|
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
|
||||||
|
|
||||||
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
usage = chunk.delta.usage or LLMUsage.empty_usage()
|
||||||
system_fingerprint = chunk.system_fingerprint
|
system_fingerprint = chunk.system_fingerprint
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||||
|
|
||||||
|
ToolCall = AssistantPromptMessage.ToolCall
|
||||||
|
|
||||||
|
# CASE 1: Single tool call
|
||||||
|
INPUTS_CASE_1 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_1 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||||
|
INPUTS_CASE_2 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_2 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||||
|
INPUTS_CASE_3 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_3 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 4: Tool call sequences with no IDs
|
||||||
|
INPUTS_CASE_4 = [
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_4 = [
|
||||||
|
ToolCall(
|
||||||
|
id="RANDOM_ID_1",
|
||||||
|
type="function",
|
||||||
|
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="RANDOM_ID_2",
|
||||||
|
type="function",
|
||||||
|
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||||
|
actual = []
|
||||||
|
_increase_tool_call(inputs, actual)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test__increase_tool_call():
|
||||||
|
# case 1:
|
||||||
|
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||||
|
|
||||||
|
# case 2:
|
||||||
|
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||||
|
|
||||||
|
# case 3:
|
||||||
|
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||||
|
|
||||||
|
# case 4:
|
||||||
|
mock_id_generator = MagicMock()
|
||||||
|
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||||
|
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||||
|
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
Loading…
x
Reference in New Issue
Block a user