mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:48:58 +08:00
fix(api): fix fail branch functionality for WorkflowTool
(#15966)
This commit is contained in:
parent
fe76dfe1f8
commit
2b4d1cf1db
@ -7,6 +7,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from models.account import Account
|
||||
@ -96,11 +97,8 @@ class WorkflowTool(Tool):
|
||||
assert isinstance(result, dict)
|
||||
data = result.get("data", {})
|
||||
|
||||
if data.get("error"):
|
||||
raise Exception(data.get("error"))
|
||||
|
||||
if data.get("error"):
|
||||
raise Exception(data.get("error"))
|
||||
if err := data.get("error"):
|
||||
raise ToolInvokeError(err)
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if outputs is None:
|
||||
|
@ -9,6 +9,7 @@ from core.file import File, FileTransferMethod
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
@ -119,13 +120,14 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||
except PluginDaemonClientSideError as e:
|
||||
except (PluginDaemonClientSideError, ToolInvokeError) as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to transform tool message: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
|
||||
|
||||
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
|
||||
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||
the `data` element.
|
||||
"""
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
output_schema=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# needs to patch those methods to avoid database access.
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
|
||||
|
||||
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||
)
|
||||
|
||||
with pytest.raises(ToolInvokeError) as exc_info:
|
||||
# WorkflowTool always returns a generator, so we need to iterate to
|
||||
# actually `run` the tool.
|
||||
list(tool.invoke("test_user", {}))
|
||||
assert exc_info.value.args == ("oops",)
|
@ -0,0 +1 @@
|
||||
|
110
api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
Normal file
110
api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
Normal file
@ -0,0 +1,110 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.tool import ToolNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
def _create_tool_node():
|
||||
data = ToolNodeData(
|
||||
title="Test Tool",
|
||||
tool_parameters={},
|
||||
provider_id="test_tool",
|
||||
provider_type=ToolProviderType.WORKFLOW,
|
||||
provider_name="test tool",
|
||||
tool_name="test tool",
|
||||
tool_label="test tool",
|
||||
tool_configurations={},
|
||||
plugin_unique_identifier=None,
|
||||
desc="Exception handling test tool",
|
||||
error_strategy=ErrorStrategy.FAIL_BRANCH,
|
||||
version="1",
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
node = ToolNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
class MockToolRuntime:
|
||||
def get_merged_runtime_parameters(self):
|
||||
pass
|
||||
|
||||
|
||||
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
|
||||
yield from []
|
||||
raise ToolInvokeError("oops")
|
||||
|
||||
|
||||
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure that ToolNode can handle ToolInvokeError when transforming
|
||||
messages generated by ToolEngine.generic_invoke.
|
||||
"""
|
||||
tool_node = _create_tool_node()
|
||||
|
||||
# Need to patch ToolManager and ToolEngine so that we don't
|
||||
# have to set up a database.
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_engine.ToolEngine.generic_invoke",
|
||||
lambda *args, **kwargs: mock_message_stream(),
|
||||
)
|
||||
|
||||
streams = list(tool_node._run())
|
||||
assert len(streams) == 1
|
||||
stream = streams[0]
|
||||
assert isinstance(stream, RunCompletedEvent)
|
||||
result = stream.run_result
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "oops" in result.error
|
||||
assert "Failed to transform tool message:" in result.error
|
||||
assert result.error_type == "ToolInvokeError"
|
Loading…
x
Reference in New Issue
Block a user