diff --git a/api/.env.example b/api/.env.example index 474798cef7..80ef185e51 100644 --- a/api/.env.example +++ b/api/.env.example @@ -216,6 +216,7 @@ UNSTRUCTURED_API_KEY= SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= +SSRF_DEFAULT_MAX_RETRIES=3 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 019b27f28a..63cf548ae4 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,13 +1,16 @@ """ Proxy requests to avoid SSRF """ +import logging import os +import time import httpx SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') +SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) proxies = { 'http://': SSRF_PROXY_HTTP_URL, @@ -15,34 +18,55 @@ proxies = { } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None -def make_request(method, url, **kwargs): - if SSRF_PROXY_ALL_URL: - return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) - elif proxies: - return httpx.request(method=method, url=url, proxies=proxies, **kwargs) - else: - return httpx.request(method=method, url=url, **kwargs) +BACKOFF_FACTOR = 0.5 +STATUS_FORCELIST = [429, 500, 502, 503, 504] -def get(url, **kwargs): - return make_request('GET', url, **kwargs) +def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + retries = 0 + while retries <= max_retries: + try: + if SSRF_PROXY_ALL_URL: + response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) + elif proxies: + response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) + else: + response = httpx.request(method=method, url=url, **kwargs) + + if response.status_code not in STATUS_FORCELIST: + return response + else: + logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") + + except httpx.RequestError as e: + logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") + + retries += 1 + if retries <= max_retries: + time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) + + raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") -def post(url, **kwargs): - return make_request('POST', url, **kwargs) +def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('GET', url, max_retries=max_retries, **kwargs) -def put(url, **kwargs): - return make_request('PUT', url, **kwargs) +def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('POST', url, max_retries=max_retries, **kwargs) -def patch(url, **kwargs): - return make_request('PATCH', url, **kwargs) +def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('PUT', url, max_retries=max_retries, **kwargs) -def delete(url, **kwargs): - return make_request('DELETE', url, **kwargs) +def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('PATCH', url, max_retries=max_retries, **kwargs) -def head(url, **kwargs): - return make_request('HEAD', url, **kwargs) +def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('DELETE', url, max_retries=max_retries, **kwargs) + + +def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + return make_request('HEAD', url, max_retries=max_retries, **kwargs) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index 8409129833..cee46cee23 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -60,11 +60,13 @@ class JinaReaderTool(BuiltinTool): if tool_parameters.get('no_cache', False): headers['X-No-Cache'] = 'true' + max_retries = tool_parameters.get('max_retries', 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), + max_retries=max_retries ) if tool_parameters.get('summary', False): diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml index 072e7f0528..58ad6d8694 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml @@ -150,3 +150,17 @@ parameters: pt_BR: Habilitar resumo para a saída llm_description: enable summary form: form + - name: max_retries + type: number + required: false + default: 3 + label: + en_US: Retry + zh_Hans: 重试 + pt_BR: Repetir + human_description: + en_US: Number of times to retry the request if it fails + zh_Hans: 请求失败时重试的次数 + pt_BR: Número de vezes para repetir a solicitação se falhar + llm_description: Number of times to retry the request if it fails + form: form diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index e6bc08147f..d4a81cd096 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -40,10 +40,12 @@ class JinaSearchTool(BuiltinTool): if tool_parameters.get('no_cache', False): headers['X-No-Cache'] = 'true' + max_retries = tool_parameters.get('max_retries', 3) response = ssrf_proxy.get( str(URL(self._jina_search_endpoint + query)), headers=headers, - timeout=(10, 60) + timeout=(10, 60), + max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml index da0a300c6c..2bc70e1be1 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml @@ -91,3 +91,17 @@ parameters: pt_BR: Ignorar o cache llm_description: bypass the cache form: form + - name: max_retries + type: number + required: false + default: 3 + label: + en_US: Retry + zh_Hans: 重试 + pt_BR: Repetir + human_description: + en_US: Number of times to retry the request if it fails + zh_Hans: 请求失败时重试的次数 + pt_BR: Número de vezes para repetir a solicitação se falhar + llm_description: Number of times to retry the request if it fails + form: form diff --git a/api/tests/unit_tests/core/helper/__init__.py b/api/tests/unit_tests/core/helper/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py new file mode 100644 index 0000000000..d917bb1003 --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -0,0 +1,52 @@ +import random +from unittest.mock import MagicMock, patch + +from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request + + +@patch('httpx.request') +def test_successful_request(mock_request): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + response = make_request('GET', 'http://example.com') + assert response.status_code == 200 + + +@patch('httpx.request') +def test_retry_exceed_max_retries(mock_request): + mock_response = MagicMock() + mock_response.status_code = 500 + + side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES + mock_request.side_effect = side_effects + + try: + make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1) + raise AssertionError("Expected Exception not raised") + except Exception as e: + assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" + + +@patch('httpx.request') +def test_retry_logic_success(mock_request): + side_effects = [] + + for _ in range(SSRF_DEFAULT_MAX_RETRIES): + status_code = random.choice(STATUS_FORCELIST) + mock_response = MagicMock() + mock_response.status_code = status_code + side_effects.append(mock_response) + + mock_response_200 = MagicMock() + mock_response_200.status_code = 200 + side_effects.append(mock_response_200) + + mock_request.side_effect = side_effects + + response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES) + + assert response.status_code == 200 + assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1 + assert mock_request.call_args_list[0][1].get('method') == 'GET'