fix: Fix some type error in http executor. (#5915)

This commit is contained in:
-LAN- 2024-07-04 19:34:37 +08:00 committed by GitHub
parent 421a24c38d
commit 02982df0d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 132 additions and 108 deletions

View File

@ -9,24 +9,20 @@ MAX_CONNECT_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_CONNECT_TIMEOUT', '30
MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600')) MAX_READ_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_READ_TIMEOUT', '600'))
MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600')) MAX_WRITE_TIMEOUT = int(os.environ.get('HTTP_REQUEST_MAX_WRITE_TIMEOUT', '600'))
class HttpRequestNodeData(BaseNodeData):
""" class HttpRequestNodeAuthorizationConfig(BaseModel):
Code Node Data.
"""
class Authorization(BaseModel):
# TODO[pydantic]: The `Config` class inherits from another class, please create the `model_config` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
class Config(BaseModel):
type: Literal[None, 'basic', 'bearer', 'custom'] type: Literal[None, 'basic', 'bearer', 'custom']
api_key: Union[None, str] = None api_key: Union[None, str] = None
header: Union[None, str] = None header: Union[None, str] = None
class HttpRequestNodeAuthorization(BaseModel):
type: Literal['no-auth', 'api-key'] type: Literal['no-auth', 'api-key']
config: Optional[Config] = None config: Optional[HttpRequestNodeAuthorizationConfig] = None
@field_validator('config', mode='before') @field_validator('config', mode='before')
@classmethod @classmethod
def check_config(cls, v: Config, values: ValidationInfo): def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo):
""" """
Check config, if type is no-auth, config should be None, otherwise it should be a dict. Check config, if type is no-auth, config should be None, otherwise it should be a dict.
""" """
@ -38,20 +34,28 @@ class HttpRequestNodeData(BaseNodeData):
return v return v
class Body(BaseModel):
class HttpRequestNodeBody(BaseModel):
type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json'] type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json']
data: Union[None, str] = None data: Union[None, str] = None
class Timeout(BaseModel):
connect: Optional[int] = MAX_CONNECT_TIMEOUT class HttpRequestNodeTimeout(BaseModel):
read: Optional[int] = MAX_READ_TIMEOUT connect: int = MAX_CONNECT_TIMEOUT
write: Optional[int] = MAX_WRITE_TIMEOUT read: int = MAX_READ_TIMEOUT
write: int = MAX_WRITE_TIMEOUT
class HttpRequestNodeData(BaseNodeData):
"""
Code Node Data.
"""
method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] method: Literal['get', 'post', 'put', 'patch', 'delete', 'head']
url: str url: str
authorization: Authorization authorization: HttpRequestNodeAuthorization
headers: str headers: str
params: str params: str
body: Optional[Body] = None body: Optional[HttpRequestNodeBody] = None
timeout: Optional[Timeout] = None timeout: Optional[HttpRequestNodeTimeout] = None
mask_authorization_header: Optional[bool] = True mask_authorization_header: Optional[bool] = True

View File

@ -10,7 +10,12 @@ import httpx
import core.helper.ssrf_proxy as ssrf_proxy import core.helper.ssrf_proxy as ssrf_proxy
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.http_request.entities import HttpRequestNodeData from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', 1024 * 1024 * 10)) # 10MB MAX_BINARY_SIZE = int(os.environ.get('HTTP_REQUEST_NODE_MAX_BINARY_SIZE', 1024 * 1024 * 10)) # 10MB
@ -23,7 +28,7 @@ class HttpExecutorResponse:
headers: dict[str, str] headers: dict[str, str]
response: httpx.Response response: httpx.Response
def __init__(self, response: httpx.Response = None): def __init__(self, response: httpx.Response):
self.response = response self.response = response
self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {} self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {}
@ -40,7 +45,6 @@ class HttpExecutorResponse:
def get_content_type(self) -> str: def get_content_type(self) -> str:
return self.headers.get('content-type', '') return self.headers.get('content-type', '')
def extract_file(self) -> tuple[str, bytes]: def extract_file(self) -> tuple[str, bytes]:
""" """
extract file from response if content type is file related extract file from response if content type is file related
@ -88,17 +92,21 @@ class HttpExecutorResponse:
class HttpExecutor: class HttpExecutor:
server_url: str server_url: str
method: str method: str
authorization: HttpRequestNodeData.Authorization authorization: HttpRequestNodeAuthorization
params: dict[str, Any] params: dict[str, Any]
headers: dict[str, Any] headers: dict[str, Any]
body: Union[None, str] body: Union[None, str]
files: Union[None, dict[str, Any]] files: Union[None, dict[str, Any]]
boundary: str boundary: str
variable_selectors: list[VariableSelector] variable_selectors: list[VariableSelector]
timeout: HttpRequestNodeData.Timeout timeout: HttpRequestNodeTimeout
def __init__(self, node_data: HttpRequestNodeData, timeout: HttpRequestNodeData.Timeout, def __init__(
variable_pool: Optional[VariablePool] = None): self,
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: Optional[VariablePool] = None,
):
self.server_url = node_data.url self.server_url = node_data.url
self.method = node_data.method self.method = node_data.method
self.authorization = node_data.authorization self.authorization = node_data.authorization
@ -113,11 +121,11 @@ class HttpExecutor:
self._init_template(node_data, variable_pool) self._init_template(node_data, variable_pool)
@staticmethod @staticmethod
def _is_json_body(body: HttpRequestNodeData.Body): def _is_json_body(body: HttpRequestNodeBody):
""" """
check if body is json check if body is json
""" """
if body and body.type == 'json': if body and body.type == 'json' and body.data:
try: try:
json.loads(body.data) json.loads(body.data)
return True return True
@ -146,7 +154,6 @@ class HttpExecutor:
return result return result
def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None):
# extract all template in url # extract all template in url
self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool)
@ -178,9 +185,7 @@ class HttpExecutor:
body = self._to_dict(body_data) body = self._to_dict(body_data)
if node_data.body.type == 'form-data': if node_data.body.type == 'form-data':
self.files = { self.files = {k: ('', v) for k, v in body.items()}
k: ('', v) for k, v in body.items()
}
random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)]) random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)])
self.boundary = f'----WebKitFormBoundary{random_str(16)}' self.boundary = f'----WebKitFormBoundary{random_str(16)}'
@ -192,13 +197,24 @@ class HttpExecutor:
elif node_data.body.type == 'none': elif node_data.body.type == 'none':
self.body = '' self.body = ''
self.variable_selectors = (server_url_variable_selectors + params_variable_selectors self.variable_selectors = (
+ headers_variable_selectors + body_data_variable_selectors) server_url_variable_selectors
+ params_variable_selectors
+ headers_variable_selectors
+ body_data_variable_selectors
)
def _assembling_headers(self) -> dict[str, Any]: def _assembling_headers(self) -> dict[str, Any]:
authorization = deepcopy(self.authorization) authorization = deepcopy(self.authorization)
headers = deepcopy(self.headers) or {} headers = deepcopy(self.headers) or {}
if self.authorization.type == 'api-key': if self.authorization.type == 'api-key':
if self.authorization.config is None:
raise ValueError('self.authorization config is required')
if authorization.config is None:
raise ValueError('authorization config is required')
if authorization.config.header is None:
raise ValueError('authorization config header is required')
if self.authorization.config.api_key is None: if self.authorization.config.api_key is None:
raise ValueError('api_key is required') raise ValueError('api_key is required')
@ -226,11 +242,13 @@ class HttpExecutor:
if executor_response.is_file: if executor_response.is_file:
if executor_response.size > MAX_BINARY_SIZE: if executor_response.size > MAX_BINARY_SIZE:
raise ValueError( raise ValueError(
f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.') f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.'
)
else: else:
if executor_response.size > MAX_TEXT_SIZE: if executor_response.size > MAX_TEXT_SIZE:
raise ValueError( raise ValueError(
f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.') f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.'
)
return executor_response return executor_response
@ -243,7 +261,7 @@ class HttpExecutor:
'headers': headers, 'headers': headers,
'params': self.params, 'params': self.params,
'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write), 'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write),
'follow_redirects': True 'follow_redirects': True,
} }
if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'): if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
@ -306,8 +324,9 @@ class HttpExecutor:
return raw_request return raw_request
def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) \ def _format_template(
-> tuple[str, list[VariableSelector]]: self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False
) -> tuple[str, list[VariableSelector]]:
""" """
format template format template
""" """
@ -318,14 +337,13 @@ class HttpExecutor:
variable_value_mapping = {} variable_value_mapping = {}
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
value = variable_pool.get_variable_value( value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector, variable_selector=variable_selector.value_selector, target_value_type=ValueType.STRING
target_value_type=ValueType.STRING
) )
if value is None: if value is None:
raise ValueError(f'Variable {variable_selector.variable} not found') raise ValueError(f'Variable {variable_selector.variable} not found')
if escape_quotes: if escape_quotes and isinstance(value, str):
value = value.replace('"', '\\"') value = value.replace('"', '\\"')
variable_value_mapping[variable_selector.variable] = value variable_value_mapping[variable_selector.variable] = value

View File

@ -5,6 +5,7 @@ from typing import cast
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
@ -13,49 +14,50 @@ from core.workflow.nodes.http_request.entities import (
MAX_READ_TIMEOUT, MAX_READ_TIMEOUT,
MAX_WRITE_TIMEOUT, MAX_WRITE_TIMEOUT,
HttpRequestNodeData, HttpRequestNodeData,
HttpRequestNodeTimeout,
) )
from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeData.Timeout(connect=min(10, MAX_CONNECT_TIMEOUT), HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=min(10, MAX_CONNECT_TIMEOUT),
read=min(60, MAX_READ_TIMEOUT), read=min(60, MAX_READ_TIMEOUT),
write=min(20, MAX_WRITE_TIMEOUT)) write=min(20, MAX_WRITE_TIMEOUT),
)
class HttpRequestNode(BaseNode): class HttpRequestNode(BaseNode):
_node_data_cls = HttpRequestNodeData _node_data_cls = HttpRequestNodeData
node_type = NodeType.HTTP_REQUEST _node_type = NodeType.HTTP_REQUEST
@classmethod @classmethod
def get_default_config(cls) -> dict: def get_default_config(cls, filters: dict | None = None) -> dict:
return { return {
"type": "http-request", 'type': 'http-request',
"config": { 'config': {
"method": "get", 'method': 'get',
"authorization": { 'authorization': {
"type": "no-auth", 'type': 'no-auth',
}, },
"body": { 'body': {'type': 'none'},
"type": "none" 'timeout': {
},
"timeout": {
**HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(),
"max_connect_timeout": MAX_CONNECT_TIMEOUT, 'max_connect_timeout': MAX_CONNECT_TIMEOUT,
"max_read_timeout": MAX_READ_TIMEOUT, 'max_read_timeout': MAX_READ_TIMEOUT,
"max_write_timeout": MAX_WRITE_TIMEOUT, 'max_write_timeout': MAX_WRITE_TIMEOUT,
} },
}, },
} }
def _run(self, variable_pool: VariablePool) -> NodeRunResult: def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data) node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# init http executor # init http executor
http_executor = None http_executor = None
try: try:
http_executor = HttpExecutor(node_data=node_data, http_executor = HttpExecutor(
timeout=self._get_request_timeout(node_data), node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
variable_pool=variable_pool) )
# invoke http executor # invoke http executor
response = http_executor.invoke() response = http_executor.invoke()
@ -70,7 +72,7 @@ class HttpRequestNode(BaseNode):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
error=str(e), error=str(e),
process_data=process_data process_data=process_data,
) )
files = self.extract_files(http_executor.server_url, response) files = self.extract_files(http_executor.server_url, response)
@ -85,34 +87,32 @@ class HttpRequestNode(BaseNode):
}, },
process_data={ process_data={
'request': http_executor.to_raw_request( 'request': http_executor.to_raw_request(
mask_authorization_header=node_data.mask_authorization_header mask_authorization_header=node_data.mask_authorization_header,
), ),
} },
) )
def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeData.Timeout: def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout:
timeout = node_data.timeout timeout = node_data.timeout
if timeout is None: if timeout is None:
return HTTP_REQUEST_DEFAULT_TIMEOUT return HTTP_REQUEST_DEFAULT_TIMEOUT
if timeout.connect is None: timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect
timeout.connect = HTTP_REQUEST_DEFAULT_TIMEOUT.connect
timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT) timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT)
if timeout.read is None: timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read
timeout.read = HTTP_REQUEST_DEFAULT_TIMEOUT.read
timeout.read = min(timeout.read, MAX_READ_TIMEOUT) timeout.read = min(timeout.read, MAX_READ_TIMEOUT)
if timeout.write is None: timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write
timeout.write = HTTP_REQUEST_DEFAULT_TIMEOUT.write
timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT) timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT)
return timeout return timeout
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
""" """
Extract variable selector to variable mapping Extract variable selector to variable mapping
:param node_data: node data :param node_data: node data
:return: :return:
""" """
node_data = cast(HttpRequestNodeData, node_data)
try: try:
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
@ -124,7 +124,7 @@ class HttpRequestNode(BaseNode):
return variable_mapping return variable_mapping
except Exception as e: except Exception as e:
logging.exception(f"Failed to extract variable selector to variable mapping: {e}") logging.exception(f'Failed to extract variable selector to variable mapping: {e}')
return {} return {}
def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]:
@ -151,7 +151,8 @@ class HttpRequestNode(BaseNode):
mimetype=mimetype, mimetype=mimetype,
) )
files.append(FileVar( files.append(
FileVar(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.IMAGE, type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE, transfer_method=FileTransferMethod.TOOL_FILE,
@ -159,6 +160,7 @@ class HttpRequestNode(BaseNode):
filename=filename, filename=filename,
extension=extension, extension=extension,
mime_type=mimetype, mime_type=mimetype,
)) )
)
return files return files