mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 02:19:05 +08:00
feat: support max_retries in jina requests (#6585)
This commit is contained in:
parent
55c2b61921
commit
ebcc07e3e9
@ -216,6 +216,7 @@ UNSTRUCTURED_API_KEY=
|
||||
|
||||
SSRF_PROXY_HTTP_URL=
|
||||
SSRF_PROXY_HTTPS_URL=
|
||||
SSRF_DEFAULT_MAX_RETRIES=3
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
|
@ -1,13 +1,16 @@
|
||||
"""
|
||||
Proxy requests to avoid SSRF
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
|
||||
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
|
||||
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')
|
||||
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3'))
|
||||
|
||||
proxies = {
|
||||
'http://': SSRF_PROXY_HTTP_URL,
|
||||
@ -15,34 +18,55 @@ proxies = {
|
||||
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
|
||||
|
||||
|
||||
def make_request(method, url, **kwargs):
|
||||
if SSRF_PROXY_ALL_URL:
|
||||
return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
||||
elif proxies:
|
||||
return httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
||||
else:
|
||||
return httpx.request(method=method, url=url, **kwargs)
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
def get(url, **kwargs):
|
||||
return make_request('GET', url, **kwargs)
|
||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
retries = 0
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if SSRF_PROXY_ALL_URL:
|
||||
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
||||
elif proxies:
|
||||
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
||||
else:
|
||||
response = httpx.request(method=method, url=url, **kwargs)
|
||||
|
||||
if response.status_code not in STATUS_FORCELIST:
|
||||
return response
|
||||
else:
|
||||
logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list")
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}")
|
||||
|
||||
retries += 1
|
||||
if retries <= max_retries:
|
||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||
|
||||
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||
|
||||
|
||||
def post(url, **kwargs):
|
||||
return make_request('POST', url, **kwargs)
|
||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request('GET', url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def put(url, **kwargs):
|
||||
return make_request('PUT', url, **kwargs)
|
||||
def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request('POST', url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def patch(url, **kwargs):
|
||||
return make_request('PATCH', url, **kwargs)
|
||||
def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request('PUT', url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def delete(url, **kwargs):
|
||||
return make_request('DELETE', url, **kwargs)
|
||||
def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request('PATCH', url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def head(url, **kwargs):
|
||||
return make_request('HEAD', url, **kwargs)
|
||||
def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request('DELETE', url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request('HEAD', url, max_retries=max_retries, **kwargs)
|
||||
|
@ -60,11 +60,13 @@ class JinaReaderTool(BuiltinTool):
|
||||
if tool_parameters.get('no_cache', False):
|
||||
headers['X-No-Cache'] = 'true'
|
||||
|
||||
max_retries = tool_parameters.get('max_retries', 3)
|
||||
response = ssrf_proxy.get(
|
||||
str(URL(self._jina_reader_endpoint + url)),
|
||||
headers=headers,
|
||||
params=request_params,
|
||||
timeout=(10, 60),
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
if tool_parameters.get('summary', False):
|
||||
|
@ -150,3 +150,17 @@ parameters:
|
||||
pt_BR: Habilitar resumo para a saída
|
||||
llm_description: enable summary
|
||||
form: form
|
||||
- name: max_retries
|
||||
type: number
|
||||
required: false
|
||||
default: 3
|
||||
label:
|
||||
en_US: Retry
|
||||
zh_Hans: 重试
|
||||
pt_BR: Repetir
|
||||
human_description:
|
||||
en_US: Number of times to retry the request if it fails
|
||||
zh_Hans: 请求失败时重试的次数
|
||||
pt_BR: Número de vezes para repetir a solicitação se falhar
|
||||
llm_description: Number of times to retry the request if it fails
|
||||
form: form
|
||||
|
@ -40,10 +40,12 @@ class JinaSearchTool(BuiltinTool):
|
||||
if tool_parameters.get('no_cache', False):
|
||||
headers['X-No-Cache'] = 'true'
|
||||
|
||||
max_retries = tool_parameters.get('max_retries', 3)
|
||||
response = ssrf_proxy.get(
|
||||
str(URL(self._jina_search_endpoint + query)),
|
||||
headers=headers,
|
||||
timeout=(10, 60)
|
||||
timeout=(10, 60),
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
return self.create_text_message(response.text)
|
||||
|
@ -91,3 +91,17 @@ parameters:
|
||||
pt_BR: Ignorar o cache
|
||||
llm_description: bypass the cache
|
||||
form: form
|
||||
- name: max_retries
|
||||
type: number
|
||||
required: false
|
||||
default: 3
|
||||
label:
|
||||
en_US: Retry
|
||||
zh_Hans: 重试
|
||||
pt_BR: Repetir
|
||||
human_description:
|
||||
en_US: Number of times to retry the request if it fails
|
||||
zh_Hans: 请求失败时重试的次数
|
||||
pt_BR: Número de vezes para repetir a solicitação se falhar
|
||||
llm_description: Number of times to retry the request if it fails
|
||||
form: form
|
||||
|
0
api/tests/unit_tests/core/helper/__init__.py
Normal file
0
api/tests/unit_tests/core/helper/__init__.py
Normal file
52
api/tests/unit_tests/core/helper/test_ssrf_proxy.py
Normal file
52
api/tests/unit_tests/core/helper/test_ssrf_proxy.py
Normal file
@ -0,0 +1,52 @@
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
|
||||
|
||||
@patch('httpx.request')
|
||||
def test_successful_request(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
response = make_request('GET', 'http://example.com')
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch('httpx.request')
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
side_effects = [mock_response] * SSRF_DEFAULT_MAX_RETRIES
|
||||
mock_request.side_effect = side_effects
|
||||
|
||||
try:
|
||||
make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||
raise AssertionError("Expected Exception not raised")
|
||||
except Exception as e:
|
||||
assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch('httpx.request')
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
|
||||
for _ in range(SSRF_DEFAULT_MAX_RETRIES):
|
||||
status_code = random.choice(STATUS_FORCELIST)
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = status_code
|
||||
side_effects.append(mock_response)
|
||||
|
||||
mock_response_200 = MagicMock()
|
||||
mock_response_200.status_code = 200
|
||||
side_effects.append(mock_response_200)
|
||||
|
||||
mock_request.side_effect = side_effects
|
||||
|
||||
response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||
assert mock_request.call_args_list[0][1].get('method') == 'GET'
|
Loading…
x
Reference in New Issue
Block a user