chore: refactor the http executor node (#5212)

This commit is contained in:
非法操作 2024-06-24 16:14:59 +08:00 committed by GitHub
parent 1e28a8c033
commit f7900f298f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 249 additions and 230 deletions

View File

@ -1,65 +1,48 @@
""" """
Proxy requests to avoid SSRF Proxy requests to avoid SSRF
""" """
import os import os
from httpx import get as _get import httpx
from httpx import head as _head
from httpx import options as _options
from httpx import patch as _patch
from httpx import post as _post
from httpx import put as _put
from requests import delete as _delete
SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')
requests_proxies = { proxies = {
'http': SSRF_PROXY_HTTP_URL,
'https': SSRF_PROXY_HTTPS_URL
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
httpx_proxies = {
'http://': SSRF_PROXY_HTTP_URL, 'http://': SSRF_PROXY_HTTP_URL,
'https://': SSRF_PROXY_HTTPS_URL 'https://': SSRF_PROXY_HTTPS_URL
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
def get(url, *args, **kwargs):
return _get(url=url, *args, proxies=httpx_proxies, **kwargs)
def post(url, *args, **kwargs): def make_request(method, url, **kwargs):
return _post(url=url, *args, proxies=httpx_proxies, **kwargs) if SSRF_PROXY_ALL_URL:
return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
elif proxies:
return httpx.request(method=method, url=url, proxies=proxies, **kwargs)
else:
return httpx.request(method=method, url=url, **kwargs)
def put(url, *args, **kwargs):
return _put(url=url, *args, proxies=httpx_proxies, **kwargs)
def patch(url, *args, **kwargs): def get(url, **kwargs):
return _patch(url=url, *args, proxies=httpx_proxies, **kwargs) return make_request('GET', url, **kwargs)
def delete(url, *args, **kwargs):
if 'follow_redirects' in kwargs:
if kwargs['follow_redirects']:
kwargs['allow_redirects'] = kwargs['follow_redirects']
kwargs.pop('follow_redirects')
if 'timeout' in kwargs:
timeout = kwargs['timeout']
if timeout is None:
kwargs.pop('timeout')
elif isinstance(timeout, tuple):
# check length of tuple
if len(timeout) == 2:
kwargs['timeout'] = timeout
elif len(timeout) == 1:
kwargs['timeout'] = timeout[0]
elif len(timeout) > 2:
kwargs['timeout'] = (timeout[0], timeout[1])
else:
kwargs['timeout'] = (timeout, timeout)
return _delete(url=url, *args, proxies=requests_proxies, **kwargs)
def head(url, *args, **kwargs): def post(url, **kwargs):
return _head(url=url, *args, proxies=httpx_proxies, **kwargs) return make_request('POST', url, **kwargs)
def options(url, *args, **kwargs):
return _options(url=url, *args, proxies=httpx_proxies, **kwargs) def put(url, **kwargs):
return make_request('PUT', url, **kwargs)
def patch(url, **kwargs):
return make_request('PATCH', url, **kwargs)
def delete(url, **kwargs):
return make_request('DELETE', url, **kwargs)
def head(url, **kwargs):
return make_request('HEAD', url, **kwargs)

View File

@ -1,11 +1,9 @@
import json import json
from json import dumps
from os import getenv from os import getenv
from typing import Any, Union from typing import Any
from urllib.parse import urlencode from urllib.parse import urlencode
import httpx import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy import core.helper.ssrf_proxy as ssrf_proxy
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
@ -18,12 +16,14 @@ API_TOOL_DEFAULT_TIMEOUT = (
int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60'))
) )
class ApiTool(Tool): class ApiTool(Tool):
api_bundle: ApiToolBundle api_bundle: ApiToolBundle
""" """
Api tool Api tool
""" """
def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool':
""" """
fork a new tool with meta data fork a new tool with meta data
@ -38,8 +38,9 @@ class ApiTool(Tool):
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
runtime=Tool.Runtime(**runtime) runtime=Tool.Runtime(**runtime)
) )
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str: def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any],
format_only: bool = False) -> str:
""" """
validate the credentials for Api tool validate the credentials for Api tool
""" """
@ -47,7 +48,7 @@ class ApiTool(Tool):
headers = self.assembling_request(parameters) headers = self.assembling_request(parameters)
if format_only: if format_only:
return return ''
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
# validate response # validate response
@ -68,12 +69,12 @@ class ApiTool(Tool):
if 'api_key_header' in credentials: if 'api_key_header' in credentials:
api_key_header = credentials['api_key_header'] api_key_header = credentials['api_key_header']
if 'api_key_value' not in credentials: if 'api_key_value' not in credentials:
raise ToolProviderCredentialValidationError('Missing api_key_value') raise ToolProviderCredentialValidationError('Missing api_key_value')
elif not isinstance(credentials['api_key_value'], str): elif not isinstance(credentials['api_key_value'], str):
raise ToolProviderCredentialValidationError('api_key_value must be a string') raise ToolProviderCredentialValidationError('api_key_value must be a string')
if 'api_key_header_prefix' in credentials: if 'api_key_header_prefix' in credentials:
api_key_header_prefix = credentials['api_key_header_prefix'] api_key_header_prefix = credentials['api_key_header_prefix']
if api_key_header_prefix == 'basic' and credentials['api_key_value']: if api_key_header_prefix == 'basic' and credentials['api_key_value']:
@ -82,20 +83,20 @@ class ApiTool(Tool):
credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}'
elif api_key_header_prefix == 'custom': elif api_key_header_prefix == 'custom':
pass pass
headers[api_key_header] = credentials['api_key_value'] headers[api_key_header] = credentials['api_key_value']
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
for parameter in needed_parameters: for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters: if parameter.required and parameter.name not in parameters:
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
if parameter.default is not None and parameter.name not in parameters: if parameter.default is not None and parameter.name not in parameters:
parameters[parameter.name] = parameter.default parameters[parameter.name] = parameter.default
return headers return headers
def validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> str: def validate_and_parse_response(self, response: httpx.Response) -> str:
""" """
validate the response validate the response
""" """
@ -112,23 +113,20 @@ class ApiTool(Tool):
return json.dumps(response) return json.dumps(response)
except Exception as e: except Exception as e:
return response.text return response.text
elif isinstance(response, requests.Response):
if not response.ok:
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
if not response.content:
return 'Empty response from the tool, please check your parameters and try again.'
try:
response = response.json()
try:
return json.dumps(response, ensure_ascii=False)
except Exception as e:
return json.dumps(response)
except Exception as e:
return response.text
else: else:
raise ValueError(f'Invalid response type {type(response)}') raise ValueError(f'Invalid response type {type(response)}')
def do_http_request(self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]) -> httpx.Response: @staticmethod
def get_parameter_value(parameter, parameters):
if parameter['name'] in parameters:
return parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
return (parameter.get('schema', {}) or {}).get('default', '')
def do_http_request(self, url: str, method: str, headers: dict[str, Any],
parameters: dict[str, Any]) -> httpx.Response:
""" """
do http request depending on api bundle do http request depending on api bundle
""" """
@ -141,44 +139,17 @@ class ApiTool(Tool):
# check parameters # check parameters
for parameter in self.api_bundle.openapi.get('parameters', []): for parameter in self.api_bundle.openapi.get('parameters', []):
value = self.get_parameter_value(parameter, parameters)
if parameter['in'] == 'path': if parameter['in'] == 'path':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter['required']:
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
path_params[parameter['name']] = value path_params[parameter['name']] = value
elif parameter['in'] == 'query': elif parameter['in'] == 'query':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
params[parameter['name']] = value params[parameter['name']] = value
elif parameter['in'] == 'cookie': elif parameter['in'] == 'cookie':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
cookies[parameter['name']] = value cookies[parameter['name']] = value
elif parameter['in'] == 'header': elif parameter['in'] == 'header':
value = ''
if parameter['name'] in parameters:
value = parameters[parameter['name']]
elif parameter.get('required', False):
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
else:
value = (parameter.get('schema', {}) or {}).get('default', '')
headers[parameter['name']] = value headers[parameter['name']] = value
# check if there is a request body and handle it # check if there is a request body and handle it
@ -203,7 +174,7 @@ class ApiTool(Tool):
else: else:
body[name] = None body[name] = None
break break
# replace path parameters # replace path parameters
for name, value in path_params.items(): for name, value in path_params.items():
url = url.replace(f'{{{name}}}', f'{value}') url = url.replace(f'{{{name}}}', f'{value}')
@ -211,33 +182,21 @@ class ApiTool(Tool):
# parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
if 'Content-Type' in headers: if 'Content-Type' in headers:
if headers['Content-Type'] == 'application/json': if headers['Content-Type'] == 'application/json':
body = dumps(body) body = json.dumps(body)
elif headers['Content-Type'] == 'application/x-www-form-urlencoded': elif headers['Content-Type'] == 'application/x-www-form-urlencoded':
body = urlencode(body) body = urlencode(body)
else: else:
body = body body = body
# do http request if method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
if method == 'get': response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body,
response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'post': return response
response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'put':
response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'delete':
response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, allow_redirects=True)
elif method == 'patch':
response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'head':
response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
elif method == 'options':
response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True)
else: else:
raise ValueError(f'Invalid http method {method}') raise ValueError(f'Invalid http method {self.method}')
return response def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]],
max_recursive=10) -> Any:
def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10) -> Any:
if max_recursive <= 0: if max_recursive <= 0:
raise Exception("Max recursion depth reached") raise Exception("Max recursion depth reached")
for option in any_of or []: for option in any_of or []:
@ -322,4 +281,3 @@ class ApiTool(Tool):
# assemble invoke message # assemble invoke message
return self.create_text_message(response) return self.create_text_message(response)

View File

@ -6,7 +6,6 @@ from typing import Any, Optional, Union
from urllib.parse import urlencode from urllib.parse import urlencode
import httpx import httpx
import requests
import core.helper.ssrf_proxy as ssrf_proxy import core.helper.ssrf_proxy as ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
@ -22,14 +21,11 @@ READABLE_MAX_TEXT_SIZE = f'{MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
class HttpExecutorResponse: class HttpExecutorResponse:
headers: dict[str, str] headers: dict[str, str]
response: Union[httpx.Response, requests.Response] response: httpx.Response
def __init__(self, response: Union[httpx.Response, requests.Response] = None): def __init__(self, response: httpx.Response = None):
self.headers = {}
if isinstance(response, httpx.Response | requests.Response):
for k, v in response.headers.items():
self.headers[k] = v
self.response = response self.response = response
self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}
@property @property
def is_file(self) -> bool: def is_file(self) -> bool:
@ -42,10 +38,8 @@ class HttpExecutorResponse:
return any(v in content_type for v in file_content_types) return any(v in content_type for v in file_content_types)
def get_content_type(self) -> str: def get_content_type(self) -> str:
if 'content-type' in self.headers: return self.headers.get('content-type', '')
return self.headers.get('content-type')
else:
return self.headers.get('Content-Type') or ""
def extract_file(self) -> tuple[str, bytes]: def extract_file(self) -> tuple[str, bytes]:
""" """
@ -58,46 +52,31 @@ class HttpExecutorResponse:
@property @property
def content(self) -> str: def content(self) -> str:
""" if isinstance(self.response, httpx.Response):
get content
"""
if isinstance(self.response, httpx.Response | requests.Response):
return self.response.text return self.response.text
else: else:
raise ValueError(f'Invalid response type {type(self.response)}') raise ValueError(f'Invalid response type {type(self.response)}')
@property @property
def body(self) -> bytes: def body(self) -> bytes:
""" if isinstance(self.response, httpx.Response):
get body
"""
if isinstance(self.response, httpx.Response | requests.Response):
return self.response.content return self.response.content
else: else:
raise ValueError(f'Invalid response type {type(self.response)}') raise ValueError(f'Invalid response type {type(self.response)}')
@property @property
def status_code(self) -> int: def status_code(self) -> int:
""" if isinstance(self.response, httpx.Response):
get status code
"""
if isinstance(self.response, httpx.Response | requests.Response):
return self.response.status_code return self.response.status_code
else: else:
raise ValueError(f'Invalid response type {type(self.response)}') raise ValueError(f'Invalid response type {type(self.response)}')
@property @property
def size(self) -> int: def size(self) -> int:
"""
get size
"""
return len(self.body) return len(self.body)
@property @property
def readable_size(self) -> str: def readable_size(self) -> str:
"""
get readable size
"""
if self.size < 1024: if self.size < 1024:
return f'{self.size} bytes' return f'{self.size} bytes'
elif self.size < 1024 * 1024: elif self.size < 1024 * 1024:
@ -148,13 +127,9 @@ class HttpExecutor:
return False return False
@staticmethod @staticmethod
def _to_dict(convert_item: str, convert_text: str, maxsplit: int = -1): def _to_dict(convert_text: str):
""" """
Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}`
:param convert_item: A label for what item to be converted, params, headers or body.
:param convert_text: The string containing key-value pairs separated by '\n'.
:param maxsplit: The maximum number of splits allowed for the ':' character in each key-value pair. Default is -1 (no limit).
:return: A dictionary containing the key-value pairs from the input string.
""" """
kv_paris = convert_text.split('\n') kv_paris = convert_text.split('\n')
result = {} result = {}
@ -162,15 +137,11 @@ class HttpExecutor:
if not kv.strip(): if not kv.strip():
continue continue
kv = kv.split(':', maxsplit=maxsplit) kv = kv.split(':', maxsplit=1)
if len(kv) >= 3: if len(kv) == 1:
k, v = kv[0], ":".join(kv[1:])
elif len(kv) == 2:
k, v = kv
elif len(kv) == 1:
k, v = kv[0], '' k, v = kv[0], ''
else: else:
raise ValueError(f'Invalid {convert_item} {kv}') k, v = kv
result[k.strip()] = v result[k.strip()] = v
return result return result
@ -181,11 +152,11 @@ class HttpExecutor:
# extract all template in params # extract all template in params
params, params_variable_selectors = self._format_template(node_data.params, variable_pool) params, params_variable_selectors = self._format_template(node_data.params, variable_pool)
self.params = self._to_dict("params", params) self.params = self._to_dict(params)
# extract all template in headers # extract all template in headers
headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool)
self.headers = self._to_dict("headers", headers) self.headers = self._to_dict(headers)
# extract all template in body # extract all template in body
body_data_variable_selectors = [] body_data_variable_selectors = []
@ -203,7 +174,7 @@ class HttpExecutor:
self.headers['Content-Type'] = 'application/x-www-form-urlencoded' self.headers['Content-Type'] = 'application/x-www-form-urlencoded'
if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: if node_data.body.type in ['form-data', 'x-www-form-urlencoded']:
body = self._to_dict("body", body_data, 1) body = self._to_dict(body_data)
if node_data.body.type == 'form-data': if node_data.body.type == 'form-data':
self.files = { self.files = {
@ -242,11 +213,11 @@ class HttpExecutor:
return headers return headers
def _validate_and_parse_response(self, response: Union[httpx.Response, requests.Response]) -> HttpExecutorResponse: def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse:
""" """
validate the response validate the response
""" """
if isinstance(response, httpx.Response | requests.Response): if isinstance(response, httpx.Response):
executor_response = HttpExecutorResponse(response) executor_response = HttpExecutorResponse(response)
else: else:
raise ValueError(f'Invalid response type {type(response)}') raise ValueError(f'Invalid response type {type(response)}')
@ -274,9 +245,7 @@ class HttpExecutor:
'follow_redirects': True 'follow_redirects': True
} }
if self.method in ('get', 'head', 'options'): if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
response = getattr(ssrf_proxy, self.method)(**kwargs)
elif self.method in ('post', 'put', 'delete', 'patch'):
response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs)
else: else:
raise ValueError(f'Invalid http method {self.method}') raise ValueError(f'Invalid http method {self.method}')

View File

@ -0,0 +1,36 @@
import json
from typing import Literal
import httpx
import pytest
from _pytest.monkeypatch import MonkeyPatch
class MockedHttp:
def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
url: str, **kwargs) -> httpx.Response:
"""
Mocked httpx.request
"""
request = httpx.Request(
method,
url,
params=kwargs.get('params'),
headers=kwargs.get('headers'),
cookies=kwargs.get('cookies')
)
data = kwargs.get('data', None)
resp = json.dumps(data).encode('utf-8') if data else b'OK'
response = httpx.Response(
status_code=200,
request=request,
content=resp,
)
return response
@pytest.fixture
def setup_http_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request)
yield
monkeypatch.undo()

View File

@ -0,0 +1,39 @@
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.tool import Tool
from tests.integration_tests.tools.__mock.http import setup_http_mock
tool_bundle = {
'server_url': 'http://www.example.com/{path_param}',
'method': 'post',
'author': '',
'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'},
{'in': 'query', 'name': 'query_param'},
{'in': 'cookie', 'name': 'cookie_param'},
{'in': 'header', 'name': 'header_param'},
],
'requestBody': {
'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}}
},
'parameters': []
}
parameters = {
'path_param': 'p_param',
'query_param': 'q_param',
'cookie_param': 'c_param',
'header_param': 'h_param',
'body_param': 'b_param',
}
def test_api_tool(setup_http_mock):
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'}))
headers = tool.assembling_request(parameters)
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
assert response.status_code == 200
assert '/p_param' == response.request.url.path
assert b'query_param=q_param' == response.request.url.query
assert 'h_param' == response.request.headers.get('header_param')
assert 'application/json' == response.request.headers.get('content-type')
assert 'cookie_param=c_param' == response.request.headers.get('cookie')
assert 'b_param' in response.content.decode()

View File

@ -2,84 +2,52 @@ import os
from json import dumps from json import dumps
from typing import Literal from typing import Literal
import httpx._api as httpx import httpx
import pytest import pytest
import requests.api as requests
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from httpx import Request as HttpxRequest
from requests import Response as RequestsResponse
from yarl import URL
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
class MockedHttp: class MockedHttp:
def requests_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'], url: str, def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
**kwargs) -> RequestsResponse:
"""
Mocked requests.request
"""
response = RequestsResponse()
response.url = str(URL(url) % kwargs.get('params', {}))
response.headers = kwargs.get('headers', {})
if url == 'http://404.com':
response.status_code = 404
response._content = b'Not Found'
return response
# get data, files
data = kwargs.get('data', None)
files = kwargs.get('files', None)
if data is not None:
resp = dumps(data).encode('utf-8')
if files is not None:
resp = dumps(files).encode('utf-8')
else:
resp = b'OK'
response.status_code = 200
response._content = resp
return response
def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'OPTIONS'],
url: str, **kwargs) -> httpx.Response: url: str, **kwargs) -> httpx.Response:
""" """
Mocked httpx.request Mocked httpx.request
""" """
response = httpx.Response(
status_code=200,
request=HttpxRequest(method, url)
)
response.headers = kwargs.get('headers', {})
if url == 'http://404.com': if url == 'http://404.com':
response.status_code = 404 response = httpx.Response(
response.content = b'Not Found' status_code=404,
request=httpx.Request(method, url),
content=b'Not Found'
)
return response return response
# get data, files # get data, files
data = kwargs.get('data', None) data = kwargs.get('data', None)
files = kwargs.get('files', None) files = kwargs.get('files', None)
if data is not None: if data is not None:
resp = dumps(data).encode('utf-8') resp = dumps(data).encode('utf-8')
if files is not None: elif files is not None:
resp = dumps(files).encode('utf-8') resp = dumps(files).encode('utf-8')
else: else:
resp = b'OK' resp = b'OK'
response.status_code = 200 response = httpx.Response(
response._content = resp status_code=200,
request=httpx.Request(method, url),
headers=kwargs.get('headers', {}),
content=resp
)
return response return response
@pytest.fixture @pytest.fixture
def setup_http_mock(request, monkeypatch: MonkeyPatch): def setup_http_mock(request, monkeypatch: MonkeyPatch):
if not MOCK: if not MOCK:
yield yield
return return
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request) monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request)
yield yield
monkeypatch.undo() monkeypatch.undo()

View File

@ -1,3 +1,5 @@
from urllib.parse import urlencode
import pytest import pytest
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -20,6 +22,7 @@ pool = VariablePool(system_variables={}, user_inputs={})
pool.append_variable(node_id='a', variable_key_list=['b123', 'args1'], value=1) pool.append_variable(node_id='a', variable_key_list=['b123', 'args1'], value=1)
pool.append_variable(node_id='a', variable_key_list=['b123', 'args2'], value=2) pool.append_variable(node_id='a', variable_key_list=['b123', 'args2'], value=2)
@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_get(setup_http_mock): def test_get(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
@ -33,7 +36,7 @@ def test_get(setup_http_mock):
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx', 'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
@ -52,6 +55,7 @@ def test_get(setup_http_mock):
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_no_auth(setup_http_mock): def test_no_auth(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
@ -78,6 +82,7 @@ def test_no_auth(setup_http_mock):
assert '?A=b' in data assert '?A=b' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_custom_authorization_header(setup_http_mock): def test_custom_authorization_header(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
@ -110,6 +115,7 @@ def test_custom_authorization_header(setup_http_mock):
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
assert 'X-Auth: Auth' in data assert 'X-Auth: Auth' in data
@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_template(setup_http_mock): def test_template(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
@ -123,7 +129,7 @@ def test_template(setup_http_mock):
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx', 'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
@ -143,6 +149,7 @@ def test_template(setup_http_mock):
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
assert 'X-Header2: 2' in data assert 'X-Header2: 2' in data
@pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True)
def test_json(setup_http_mock): def test_json(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
@ -156,7 +163,7 @@ def test_json(setup_http_mock):
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx', 'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
@ -177,6 +184,7 @@ def test_json(setup_http_mock):
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
def test_x_www_form_urlencoded(setup_http_mock): def test_x_www_form_urlencoded(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'id': '1', 'id': '1',
@ -189,7 +197,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx', 'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
@ -210,6 +218,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
def test_form_data(setup_http_mock): def test_form_data(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'id': '1', 'id': '1',
@ -222,7 +231,7 @@ def test_form_data(setup_http_mock):
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx', 'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
@ -246,6 +255,7 @@ def test_form_data(setup_http_mock):
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
def test_none_data(setup_http_mock): def test_none_data(setup_http_mock):
node = HttpRequestNode(config={ node = HttpRequestNode(config={
'id': '1', 'id': '1',
@ -258,7 +268,7 @@ def test_none_data(setup_http_mock):
'type': 'api-key', 'type': 'api-key',
'config': { 'config': {
'type': 'basic', 'type': 'basic',
'api_key':'ak-xxx', 'api_key': 'ak-xxx',
'header': 'api-key', 'header': 'api-key',
} }
}, },
@ -278,3 +288,59 @@ def test_none_data(setup_http_mock):
assert 'api-key: Basic ak-xxx' in data assert 'api-key: Basic ak-xxx' in data
assert 'X-Header: 123' in data assert 'X-Header: 123' in data
assert '123123123' not in data assert '123123123' not in data
def test_mock_404(setup_http_mock):
node = HttpRequestNode(config={
'id': '1',
'data': {
'title': 'http',
'desc': '',
'method': 'get',
'url': 'http://404.com',
'authorization': {
'type': 'no-auth',
'config': None,
},
'body': None,
'params': '',
'headers': 'X-Header:123',
'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
result = node.run(pool)
resp = result.outputs
assert 404 == resp.get('status_code')
assert 'Not Found' in resp.get('body')
def test_multi_colons_parse(setup_http_mock):
node = HttpRequestNode(config={
'id': '1',
'data': {
'title': 'http',
'desc': '',
'method': 'get',
'url': 'http://example.com',
'authorization': {
'type': 'no-auth',
'config': None,
},
'params': 'Referer:http://example1.com\nRedirect:http://example2.com',
'headers': 'Referer:http://example3.com\nRedirect:http://example4.com',
'body': {
'type': 'form-data',
'data': 'Referer:http://example5.com\nRedirect:http://example6.com'
},
'mask_authorization_header': False,
}
}, **BASIC_NODE_DATA)
result = node.run(pool)
resp = result.outputs
assert urlencode({'Redirect': 'http://example2.com'}) in result.process_data.get('request')
assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get('request')
assert 'http://example3.com' == resp.get('headers').get('referer')