diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 4cc0d4ae6e..7ec13ef048 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -157,6 +157,7 @@ class ToolInvokeMessage(BaseModel): BLOB = "blob" JSON = "json" IMAGE_LINK = "image_link" + BINARY_LINK = "binary_link" VARIABLE = "variable" FILE = "file" diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 09a7ef9d46..3d19c14cd8 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -85,7 +85,7 @@ class ToolFileMessageTransformer: ) else: yield ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, + type=ToolInvokeMessage.MessageType.BINARY_LINK, message=ToolInvokeMessage.TextMessage(text=url), meta=message.meta.copy() if message.meta is not None else {}, ) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 63858323bd..8681f5fbee 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -190,7 +190,11 @@ class ToolNode(BaseNode[ToolNodeData]): variables: dict[str, Any] = {} for message in message_stream: - if message.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: assert isinstance(message.message, ToolInvokeMessage.TextMessage) url = message.message.text @@ -209,7 +213,7 @@ class ToolNode(BaseNode[ToolNodeData]): mapping = { "tool_file_id": tool_file_id, - "type": FileType.IMAGE, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), "transfer_method": transfer_method, "url": url, } diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 8538775a67..dccae186f0 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -289,3 +289,7 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: else: file_type = FileType.CUSTOM return file_type + + +def get_file_type_by_mime_type(mime_type: str) -> FileType: + return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM