mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 06:19:03 +08:00
feat: support max_retries in jina requests (#6585)
This commit is contained in:
parent
55c2b61921
commit
ebcc07e3e9
@ -216,6 +216,7 @@ UNSTRUCTURED_API_KEY=
|
|||||||
|
|
||||||
SSRF_PROXY_HTTP_URL=
|
SSRF_PROXY_HTTP_URL=
|
||||||
SSRF_PROXY_HTTPS_URL=
|
SSRF_PROXY_HTTPS_URL=
|
||||||
|
SSRF_DEFAULT_MAX_RETRIES=3
|
||||||
|
|
||||||
BATCH_UPLOAD_LIMIT=10
|
BATCH_UPLOAD_LIMIT=10
|
||||||
KEYWORD_DATA_SOURCE_TYPE=database
|
KEYWORD_DATA_SOURCE_TYPE=database
|
||||||
|
@ -1,13 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
Proxy requests to avoid SSRF
|
Proxy requests to avoid SSRF
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
|
SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
|
||||||
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
|
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
|
||||||
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')
|
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')
|
||||||
|
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3'))
|
||||||
|
|
||||||
proxies = {
|
proxies = {
|
||||||
'http://': SSRF_PROXY_HTTP_URL,
|
'http://': SSRF_PROXY_HTTP_URL,
|
||||||
@ -15,34 +18,55 @@ proxies = {
|
|||||||
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
|
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
|
||||||
|
|
||||||
|
|
||||||
def make_request(method, url, **kwargs):
|
BACKOFF_FACTOR = 0.5
|
||||||
if SSRF_PROXY_ALL_URL:
|
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||||
return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
|
||||||
elif proxies:
|
|
||||||
return httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
|
||||||
else:
|
|
||||||
return httpx.request(method=method, url=url, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get(url, **kwargs):
|
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
return make_request('GET', url, **kwargs)
|
retries = 0
|
||||||
|
while retries <= max_retries:
|
||||||
|
try:
|
||||||
|
if SSRF_PROXY_ALL_URL:
|
||||||
|
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
||||||
|
elif proxies:
|
||||||
|
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
||||||
|
else:
|
||||||
|
response = httpx.request(method=method, url=url, **kwargs)
|
||||||
|
|
||||||
|
if response.status_code not in STATUS_FORCELIST:
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
|
||||||
|
|
||||||
|
retries += 1
|
||||||
|
if retries <= max_retries:
|
||||||
|
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||||
|
|
||||||
|
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||||
|
|
||||||
|
|
||||||
def post(url, **kwargs):
|
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
return make_request('POST', url, **kwargs)
|
return make_request('GET', url, max_retries=max_retries, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def put(url, **kwargs):
|
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
return make_request('PUT', url, **kwargs)
|
return make_request('POST', url, max_retries=max_retries, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def patch(url, **kwargs):
|
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
return make_request('PATCH', url, **kwargs)
|
return make_request('PUT', url, max_retries=max_retries, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def delete(url, **kwargs):
|
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
return make_request('DELETE', url, **kwargs)
|
return make_request('PATCH', url, max_retries=max_retries, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def head(url, **kwargs):
|
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
return make_request('HEAD', url, **kwargs)
|
return make_request('DELETE', url, max_retries=max_retries, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
|
return make_request('HEAD', url, max_retries=max_retries, **kwargs)
|
||||||
|
@ -60,11 +60,13 @@ class JinaReaderTool(BuiltinTool):
|
|||||||
if tool_parameters.get('no_cache', False):
|
if tool_parameters.get('no_cache', False):
|
||||||
headers['X-No-Cache'] = 'true'
|
headers['X-No-Cache'] = 'true'
|
||||||
|
|
||||||
|
max_retries = tool_parameters.get('max_retries', 3)
|
||||||
response = ssrf_proxy.get(
|
response = ssrf_proxy.get(
|
||||||
str(URL(self._jina_reader_endpoint + url)),
|
str(URL(self._jina_reader_endpoint + url)),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=request_params,
|
params=request_params,
|
||||||
timeout=(10, 60),
|
timeout=(10, 60),
|
||||||
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
|
|
||||||
if tool_parameters.get('summary', False):
|
if tool_parameters.get('summary', False):
|
||||||
|
@ -150,3 +150,17 @@ parameters:
|
|||||||
pt_BR: Habilitar resumo para a saída
|
pt_BR: Habilitar resumo para a saída
|
||||||
llm_description: enable summary
|
llm_description: enable summary
|
||||||
form: form
|
form: form
|
||||||
|
- name: max_retries
|
||||||
|
type: number
|
||||||
|
required: false
|
||||||
|
default: 3
|
||||||
|
label:
|
||||||
|
en_US: Retry
|
||||||
|
zh_Hans: 重试
|
||||||
|
pt_BR: Repetir
|
||||||
|
human_description:
|
||||||
|
en_US: Number of times to retry the request if it fails
|
||||||
|
zh_Hans: 请求失败时重试的次数
|
||||||
|
pt_BR: Número de vezes para repetir a solicitação se falhar
|
||||||
|
llm_description: Number of times to retry the request if it fails
|
||||||
|
form: form
|
||||||
|
@ -40,10 +40,12 @@ class JinaSearchTool(BuiltinTool):
|
|||||||
if tool_parameters.get('no_cache', False):
|
if tool_parameters.get('no_cache', False):
|
||||||
headers['X-No-Cache'] = 'true'
|
headers['X-No-Cache'] = 'true'
|
||||||
|
|
||||||
|
max_retries = tool_parameters.get('max_retries', 3)
|
||||||
response = ssrf_proxy.get(
|
response = ssrf_proxy.get(
|
||||||
str(URL(self._jina_search_endpoint + query)),
|
str(URL(self._jina_search_endpoint + query)),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=(10, 60)
|
timeout=(10, 60),
|
||||||
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.create_text_message(response.text)
|
return self.create_text_message(response.text)
|
||||||
|
@ -91,3 +91,17 @@ parameters:
|
|||||||
pt_BR: Ignorar o cache
|
pt_BR: Ignorar o cache
|
||||||
llm_description: bypass the cache
|
llm_description: bypass the cache
|
||||||
form: form
|
form: form
|
||||||
|
- name: max_retries
|
||||||
|
type: number
|
||||||
|
required: false
|
||||||
|
default: 3
|
||||||
|
label:
|
||||||
|
en_US: Retry
|
||||||
|
zh_Hans: 重试
|
||||||
|
pt_BR: Repetir
|
||||||
|
human_description:
|
||||||
|
en_US: Number of times to retry the request if it fails
|
||||||
|
zh_Hans: 请求失败时重试的次数
|
||||||
|
pt_BR: Número de vezes para repetir a solicitação se falhar
|
||||||
|
llm_description: Number of times to retry the request if it fails
|
||||||
|
form: form
|
||||||
|
0
api/tests/unit_tests/core/helper/__init__.py
Normal file
0
api/tests/unit_tests/core/helper/__init__.py
Normal file
52
api/tests/unit_tests/core/helper/test_ssrf_proxy.py
Normal file
52
api/tests/unit_tests/core/helper/test_ssrf_proxy.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import random
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||||
|
|
||||||
|
|
||||||
|
@patch('httpx.request')
|
||||||
|
def test_successful_request(mock_request):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_request.return_value = mock_response
|
||||||
|
|
||||||
|
response = make_request('GET', 'http://example.com')
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@patch('httpx.request')
|
||||||
|
def test_retry_exceed_max_retries(mock_request):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
|
||||||
|
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
|
||||||
|
mock_request.side_effect = side_effects
|
||||||
|
|
||||||
|
try:
|
||||||
|
make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||||
|
raise AssertionError("Expected Exception not raised")
|
||||||
|
except Exception as e:
|
||||||
|
assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||||
|
|
||||||
|
|
||||||
|
@patch('httpx.request')
|
||||||
|
def test_retry_logic_success(mock_request):
|
||||||
|
side_effects = []
|
||||||
|
|
||||||
|
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
||||||
|
status_code = random.choice(STATUS_FORCELIST)
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = status_code
|
||||||
|
side_effects.append(mock_response)
|
||||||
|
|
||||||
|
mock_response_200 = MagicMock()
|
||||||
|
mock_response_200.status_code = 200
|
||||||
|
side_effects.append(mock_response_200)
|
||||||
|
|
||||||
|
mock_request.side_effect = side_effects
|
||||||
|
|
||||||
|
response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||||
|
assert mock_request.call_args_list[0][1].get('method') == 'GET'
|
Loading…
x
Reference in New Issue
Block a user