mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 16:45:57 +08:00
fix(api): fix fail branch functionality for WorkflowTool
(#15966)
This commit is contained in:
parent
fe76dfe1f8
commit
2b4d1cf1db
@ -7,6 +7,7 @@ from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
|||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||||
|
from core.tools.errors import ToolInvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping
|
from factories.file_factory import build_from_mapping
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
@ -96,11 +97,8 @@ class WorkflowTool(Tool):
|
|||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
data = result.get("data", {})
|
data = result.get("data", {})
|
||||||
|
|
||||||
if data.get("error"):
|
if err := data.get("error"):
|
||||||
raise Exception(data.get("error"))
|
raise ToolInvokeError(err)
|
||||||
|
|
||||||
if data.get("error"):
|
|
||||||
raise Exception(data.get("error"))
|
|
||||||
|
|
||||||
outputs = data.get("outputs")
|
outputs = data.get("outputs")
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
|
@ -9,6 +9,7 @@ from core.file import File, FileTransferMethod
|
|||||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||||
from core.plugin.manager.plugin import PluginInstallationManager
|
from core.plugin.manager.plugin import PluginInstallationManager
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
|
from core.tools.errors import ToolInvokeError
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
from core.variables.segments import ArrayAnySegment
|
from core.variables.segments import ArrayAnySegment
|
||||||
@ -119,13 +120,14 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
try:
|
try:
|
||||||
# convert tool messages
|
# convert tool messages
|
||||||
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
yield from self._transform_message(message_stream, tool_info, parameters_for_log)
|
||||||
except PluginDaemonClientSideError as e:
|
except (PluginDaemonClientSideError, ToolInvokeError) as e:
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
inputs=parameters_for_log,
|
inputs=parameters_for_log,
|
||||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||||
error=f"Failed to transform tool message: {str(e)}",
|
error=f"Failed to transform tool message: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -0,0 +1,49 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity
|
||||||
|
from core.tools.errors import ToolInvokeError
|
||||||
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch):
|
||||||
|
"""Ensure that WorkflowTool will throw a `ToolInvokeError` exception when
|
||||||
|
`WorkflowAppGenerator.generate` returns a result with `error` key inside
|
||||||
|
the `data` element.
|
||||||
|
"""
|
||||||
|
entity = ToolEntity(
|
||||||
|
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||||
|
parameters=[],
|
||||||
|
description=None,
|
||||||
|
output_schema=None,
|
||||||
|
has_runtime_parameters=False,
|
||||||
|
)
|
||||||
|
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||||
|
tool = WorkflowTool(
|
||||||
|
workflow_app_id="",
|
||||||
|
workflow_as_tool_id="",
|
||||||
|
version="1",
|
||||||
|
workflow_entities={},
|
||||||
|
workflow_call_depth=1,
|
||||||
|
entity=entity,
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
# needs to patch those methods to avoid database access.
|
||||||
|
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||||
|
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||||
|
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
# replace `WorkflowAppGenerator.generate` 's return value.
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
|
||||||
|
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ToolInvokeError) as exc_info:
|
||||||
|
# WorkflowTool always returns a generator, so we need to iterate to
|
||||||
|
# actually `run` the tool.
|
||||||
|
list(tool.invoke("test_user", {}))
|
||||||
|
assert exc_info.value.args == ("oops",)
|
@ -0,0 +1 @@
|
|||||||
|
|
110
api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
Normal file
110
api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||||
|
from core.tools.errors import ToolInvokeError
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||||
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||||
|
from core.workflow.nodes.end import EndStreamParam
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy
|
||||||
|
from core.workflow.nodes.event import RunCompletedEvent
|
||||||
|
from core.workflow.nodes.tool import ToolNode
|
||||||
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
|
from models import UserFrom, WorkflowNodeExecutionStatus, WorkflowType
|
||||||
|
|
||||||
|
|
||||||
|
def _create_tool_node():
|
||||||
|
data = ToolNodeData(
|
||||||
|
title="Test Tool",
|
||||||
|
tool_parameters={},
|
||||||
|
provider_id="test_tool",
|
||||||
|
provider_type=ToolProviderType.WORKFLOW,
|
||||||
|
provider_name="test tool",
|
||||||
|
tool_name="test tool",
|
||||||
|
tool_label="test tool",
|
||||||
|
tool_configurations={},
|
||||||
|
plugin_unique_identifier=None,
|
||||||
|
desc="Exception handling test tool",
|
||||||
|
error_strategy=ErrorStrategy.FAIL_BRANCH,
|
||||||
|
version="1",
|
||||||
|
)
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={},
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
node = ToolNode(
|
||||||
|
id="1",
|
||||||
|
config={
|
||||||
|
"id": "1",
|
||||||
|
"data": data.model_dump(),
|
||||||
|
},
|
||||||
|
graph_init_params=GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config={},
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
),
|
||||||
|
graph=Graph(
|
||||||
|
root_node_id="1",
|
||||||
|
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||||
|
answer_dependencies={},
|
||||||
|
answer_generate_route={},
|
||||||
|
),
|
||||||
|
end_stream_param=EndStreamParam(
|
||||||
|
end_dependencies={},
|
||||||
|
end_stream_variable_selector_mapping={},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
class MockToolRuntime:
|
||||||
|
def get_merged_runtime_parameters(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def mock_message_stream() -> Generator[ToolInvokeMessage, None, None]:
|
||||||
|
yield from []
|
||||||
|
raise ToolInvokeError("oops")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_node_on_tool_invoke_error(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Ensure that ToolNode can handle ToolInvokeError when transforming
|
||||||
|
messages generated by ToolEngine.generic_invoke.
|
||||||
|
"""
|
||||||
|
tool_node = _create_tool_node()
|
||||||
|
|
||||||
|
# Need to patch ToolManager and ToolEngine so that we don't
|
||||||
|
# have to set up a database.
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.tools.tool_manager.ToolManager.get_workflow_tool_runtime", lambda *args, **kwargs: MockToolRuntime()
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"core.tools.tool_engine.ToolEngine.generic_invoke",
|
||||||
|
lambda *args, **kwargs: mock_message_stream(),
|
||||||
|
)
|
||||||
|
|
||||||
|
streams = list(tool_node._run())
|
||||||
|
assert len(streams) == 1
|
||||||
|
stream = streams[0]
|
||||||
|
assert isinstance(stream, RunCompletedEvent)
|
||||||
|
result = stream.run_result
|
||||||
|
assert isinstance(result, NodeRunResult)
|
||||||
|
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||||
|
assert "oops" in result.error
|
||||||
|
assert "Failed to transform tool message:" in result.error
|
||||||
|
assert result.error_type == "ToolInvokeError"
|
Loading…
x
Reference in New Issue
Block a user