diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index cb9b804c77..cf840880bf 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -7,6 +7,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.errors import ToolInvokeError from extensions.ext_database import db from factories.file_factory import build_from_mapping from models.account import Account @@ -96,11 +97,8 @@ class WorkflowTool(Tool): assert isinstance(result, dict) data = result.get("data", {}) - if data.get("error"): - raise Exception(data.get("error")) - - if data.get("error"): - raise Exception(data.get("error")) + if err := data.get("error"): + raise ToolInvokeError(err) outputs = data.get("outputs") if outputs is None: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 7ec092cfdd..6f0cc3f6d2 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -9,6 +9,7 @@ from core.file import File, FileTransferMethod from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.plugin import PluginInstallationManager from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.variables.segments import ArrayAnySegment @@ -119,13 +120,14 @@ class ToolNode(BaseNode[ToolNodeData]): try: # convert tool messages yield from self._transform_message(message_stream, tool_info, parameters_for_log) - except PluginDaemonClientSideError as e: + except (PluginDaemonClientSideError, ToolInvokeError) as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to transform tool message: {str(e)}", + error_type=type(e).__name__, ) ) diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/__init__.py b/api/tests/unit_tests/core/tools/workflow_as_tool/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py new file mode 100644 index 0000000000..15a9e8e9f4 --- /dev/null +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -0,0 +1,49 @@ +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity +from core.tools.errors import ToolInvokeError +from core.tools.workflow_as_tool.tool import WorkflowTool + + +def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch): + """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when + `WorkflowAppGenerator.generate` returns a result with `error` key inside + the `data` element. + """ + entity = ToolEntity( + identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), + parameters=[], + description=None, + output_schema=None, + has_runtime_parameters=False, + ) + runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) + tool = WorkflowTool( + workflow_app_id="", + workflow_as_tool_id="", + version="1", + workflow_entities={}, + workflow_call_depth=1, + entity=entity, + runtime=runtime, + ) + + # needs to patch those methods to avoid database access. + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None) + + # replace `WorkflowAppGenerator.generate` 's return value. + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", + lambda *args, **kwargs: {"data": {"error": "oops"}}, + ) + + with pytest.raises(ToolInvokeError) as exc_info: + # WorkflowTool always returns a generator, so we need to iterate to + # actually `run` the tool. + list(tool.invoke("test_user", {})) + assert exc_info.value.args == ("oops",) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/__init__.py b/api/tests/unit_tests/core/workflow/nodes/tool/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py new file mode 100644 index 0000000000..f593510830 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -0,0 +1,110 @@ +from collections.abc import Generator + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType +from core.tools.errors import ToolInvokeError +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.enums import ErrorStrategy +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.tool import ToolNode +from core.workflow.nodes.tool.entities import ToolNodeData +from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType + + +def _create_tool_node(): + data = ToolNodeData( + title="Test Tool", + tool_parameters={}, + provider_id="test_tool", + provider_type=ToolProviderType.WORKFLOW, + provider_name="test tool", + tool_name="test tool", + tool_label="test tool", + tool_configurations={}, + plugin_unique_identifier=None, + desc="Exception handling test tool", + error_strategy=ErrorStrategy.FAIL_BRANCH, + version="1", + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = ToolNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + +class MockToolRuntime: + def get_merged_runtime_parameters(self): + pass + + +def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]: + yield from [] + raise ToolInvokeError("oops") + + +def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch): + """Ensure that ToolNode can handle ToolInvokeError when transforming + messages generated by ToolEngine.generic_invoke. + """ + tool_node = _create_tool_node() + + # Need to patch ToolManager and ToolEngine so that we don't + # have to set up a database. + monkeypatch.setattr( + "core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime() + ) + monkeypatch.setattr( + "core.tools.tool_engine.ToolEngine.generic_invoke", + lambda *args, **kwargs: mock_message_stream(), + ) + + streams = list(tool_node._run()) + assert len(streams) == 1 + stream = streams[0] + assert isinstance(stream, RunCompletedEvent) + result = stream.run_result + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "oops" in result.error + assert "Failed to transform tool message:" in result.error + assert result.error_type == "ToolInvokeError"