fix(api): fix fail branch functionality for WorkflowTool (#15966)

This commit is contained in:
QuantumGhost 2025-03-17 11:53:32 +08:00 committed by GitHub
parent fe76dfe1f8
commit 2b4d1cf1db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 166 additions and 6 deletions

View File

@ -7,6 +7,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from models.account import Account
@ -96,11 +97,8 @@ class WorkflowTool(Tool):
assert isinstance(result, dict)
data = result.get("data", {})
if data.get("error"):
raise Exception(data.get("error"))
if data.get("error"):
raise Exception(data.get("error"))
if err := data.get("error"):
raise ToolInvokeError(err)
outputs = data.get("outputs")
if outputs is None:

View File

@ -9,6 +9,7 @@ from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError
from core.plugin.manager.plugin import PluginInstallationManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment
@ -119,13 +120,14 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
# convert tool messages
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
except PluginDaemonClientSideError as e:
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to transform tool message: {str(e)}",
error_type=type(e).__name__,
)
)

View File

@ -0,0 +1,49 @@
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
from core.tools.errors import ToolInvokeError
from core.tools.workflow_as_tool.tool import WorkflowTool
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
`WorkflowAppGenerator.generate` returns a result with `error` key inside
the `data` element.
"""
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
output_schema=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
# needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
# replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}},
)
with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to
# actually `run` the tool.
list(tool.invoke("test_user", {}))
assert exc_info.value.args == ("oops",)

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,110 @@
from collections.abc import Generator
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.tool import ToolNode
from core.workflow.nodes.tool.entities import ToolNodeData
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
def _create_tool_node():
data = ToolNodeData(
title="Test Tool",
tool_parameters={},
provider_id="test_tool",
provider_type=ToolProviderType.WORKFLOW,
provider_name="test tool",
tool_name="test tool",
tool_label="test tool",
tool_configurations={},
plugin_unique_identifier=None,
desc="Exception handling test tool",
error_strategy=ErrorStrategy.FAIL_BRANCH,
version="1",
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = ToolNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config={},
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
answer_generate_route={},
),
end_stream_param=EndStreamParam(
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
return node
class MockToolRuntime:
def get_merged_runtime_parameters(self):
pass
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
yield from []
raise ToolInvokeError("oops")
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
"""Ensure that ToolNode can handle ToolInvokeError when transforming
messages generated by ToolEngine.generic_invoke.
"""
tool_node = _create_tool_node()
# Need to patch ToolManager and ToolEngine so that we don't
# have to set up a database.
monkeypatch.setattr(
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
)
monkeypatch.setattr(
"core.tools.tool_engine.ToolEngine.generic_invoke",
lambda *args, **kwargs: mock_message_stream(),
)
streams = list(tool_node._run())
assert len(streams) == 1
stream = streams[0]
assert isinstance(stream, RunCompletedEvent)
result = stream.run_result
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "oops" in result.error
assert "Failed to transform tool message:" in result.error
assert result.error_type == "ToolInvokeError"