From bf4b6e5f8018d2d1f1e64f492d5e15212901dffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Wed, 20 Nov 2024 13:26:42 +0800 Subject: [PATCH] feat: support custom tool upload file (#10796) --- api/core/tools/tool/api_tool.py | 13 ++++++++++--- api/core/tools/utils/parser.py | 3 +++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index c779d704c3..0b4c5bd2c6 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -5,6 +5,7 @@ from urllib.parse import urlencode import httpx +from core.file.file_manager import download from core.helper import ssrf_proxy from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType @@ -138,6 +139,7 @@ class ApiTool(Tool): path_params = {} body = {} cookies = {} + files = [] # check parameters for parameter in self.api_bundle.openapi.get("parameters", []): @@ -166,8 +168,12 @@ class ApiTool(Tool): properties = body_schema.get("properties", {}) for name, property in properties.items(): if name in parameters: - # convert type - body[name] = self._convert_body_property_type(property, parameters[name]) + if property.get("format") == "binary": + f = parameters[name] + files.append((name, (f.filename, download(f), f.mime_type))) + else: + # convert type + body[name] = self._convert_body_property_type(property, parameters[name]) elif name in required: raise ToolParameterValidationError( f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" @@ -182,7 +188,7 @@ class ApiTool(Tool): for name, value in path_params.items(): url = url.replace(f"{{{name}}}", f"{value}") - # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored + # parse http body data if needed if "Content-Type" in headers: if headers["Content-Type"] == "application/json": body = json.dumps(body) @@ -198,6 +204,7 @@ class ApiTool(Tool): headers=headers, cookies=cookies, data=body, + files=files, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True, ) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 5867a11bb3..ae44b1b99d 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -161,6 +161,9 @@ class ApiBasedToolSchemaParser: def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: parameter = parameter or {} typ = None + if parameter.get("format") == "binary": + return ToolParameter.ToolParameterType.FILE + if "type" in parameter: typ = parameter["type"] elif "schema" in parameter and "type" in parameter["schema"]: