mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 06:35:53 +08:00
temp
This commit is contained in:
parent
8464ad0b35
commit
1fd4839eca
@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
user = current_user
|
user = current_user
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return jsonable_encoder(
|
return jsonable_encoder(
|
||||||
@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource):
|
|||||||
authed=False,
|
authed=False,
|
||||||
authorization_code=args["authorization_code"],
|
authorization_code=args["authorization_code"],
|
||||||
):
|
):
|
||||||
|
MCPToolManageService.update_mcp_provider_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
credentials={},
|
||||||
|
authed=True,
|
||||||
|
)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
except MCPAuthError as e:
|
|
||||||
|
except MCPAuthError:
|
||||||
auth_provider = OAuthClientProvider(provider_id, tenant_id)
|
auth_provider = OAuthClientProvider(provider_id, tenant_id)
|
||||||
return auth(auth_provider, provider.server_url, args["authorization_code"])
|
return auth(auth_provider, provider.server_url, args["authorization_code"])
|
||||||
|
|
||||||
|
@ -108,3 +108,85 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||||||
|
|
||||||
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ssrf_proxy_mcp_http_client(
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: httpx.Timeout | None = None,
|
||||||
|
) -> httpx.Client:
|
||||||
|
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers: Optional headers to include in the client
|
||||||
|
timeout: Optional timeout configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured httpx.Client with proxy settings
|
||||||
|
"""
|
||||||
|
client_kwargs = {
|
||||||
|
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||||
|
"headers": headers or {},
|
||||||
|
"timeout": timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dify_config.SSRF_PROXY_ALL_URL:
|
||||||
|
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL
|
||||||
|
return httpx.Client(**client_kwargs)
|
||||||
|
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||||
|
proxy_mounts = {
|
||||||
|
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||||
|
"https://": httpx.HTTPTransport(
|
||||||
|
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||||
|
),
|
||||||
|
}
|
||||||
|
client_kwargs["mounts"] = proxy_mounts
|
||||||
|
return httpx.Client(**client_kwargs)
|
||||||
|
else:
|
||||||
|
return httpx.Client(**client_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
|
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||||
|
|
||||||
|
This function creates an SSE connection using the configured proxy settings
|
||||||
|
to prevent SSRF attacks when connecting to external endpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The SSE endpoint URL
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
**kwargs: Additional arguments passed to the SSE connection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EventSource object for SSE streaming
|
||||||
|
"""
|
||||||
|
from httpx_sse import connect_sse
|
||||||
|
|
||||||
|
# Extract client if provided, otherwise create one
|
||||||
|
client = kwargs.pop("client", None)
|
||||||
|
if client is None:
|
||||||
|
# Create client with SSRF proxy configuration
|
||||||
|
timeout = kwargs.pop(
|
||||||
|
"timeout",
|
||||||
|
httpx.Timeout(
|
||||||
|
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||||
|
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||||
|
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||||
|
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
headers = kwargs.pop("headers", {})
|
||||||
|
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||||
|
client_provided = False
|
||||||
|
else:
|
||||||
|
client_provided = True
|
||||||
|
|
||||||
|
# Extract method if provided, default to GET
|
||||||
|
method = kwargs.pop("method", "GET")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return connect_sse(client, method, url, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
# If we created the client, we need to clean it up on error
|
||||||
|
if not client_provided:
|
||||||
|
client.close()
|
||||||
|
raise
|
||||||
|
@ -8,15 +8,21 @@ from typing import Any
|
|||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from httpx_sse import connect_sse
|
|
||||||
from sseclient import SSEClient
|
from sseclient import SSEClient
|
||||||
|
|
||||||
|
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||||
from core.mcp import types
|
from core.mcp import types
|
||||||
|
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||||
from core.mcp.types import SessionMessage
|
from core.mcp.types import SessionMessage
|
||||||
from core.mcp.utils import create_mcp_http_client, remove_request_params
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||||
|
|
||||||
|
|
||||||
|
def remove_request_params(url: str) -> str:
|
||||||
|
return urljoin(url, urlparse(url).path)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def sse_client(
|
def sse_client(
|
||||||
@ -40,9 +46,9 @@ def sse_client(
|
|||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||||
with create_mcp_http_client(headers=headers) as client:
|
with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
|
||||||
with connect_sse(
|
with ssrf_proxy_sse_connect(
|
||||||
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
|
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||||
) as event_source:
|
) as event_source:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("SSE connection established")
|
logger.debug("SSE connection established")
|
||||||
@ -94,7 +100,7 @@ def sse_client(
|
|||||||
try:
|
try:
|
||||||
while not cancel_event.is_set():
|
while not cancel_event.is_set():
|
||||||
try:
|
try:
|
||||||
message = write_queue.get(timeout=5)
|
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||||
if message is None:
|
if message is None:
|
||||||
break
|
break
|
||||||
response = client.post(
|
response = client.post(
|
||||||
@ -130,6 +136,10 @@ def sse_client(
|
|||||||
yield read_queue, write_queue
|
yield read_queue, write_queue
|
||||||
finally:
|
finally:
|
||||||
cancel_event.set()
|
cancel_event.set()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code == 401:
|
||||||
|
raise MCPAuthError()
|
||||||
|
raise MCPConnectionError()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Error connecting to SSE endpoint")
|
logger.exception("Error connecting to SSE endpoint")
|
||||||
raise exc
|
raise exc
|
||||||
|
@ -17,8 +17,9 @@ from datetime import timedelta
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from httpx_sse import EventSource, ServerSentEvent, connect_sse
|
from httpx_sse import EventSource, ServerSentEvent
|
||||||
|
|
||||||
|
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||||
from core.mcp.types import (
|
from core.mcp.types import (
|
||||||
ClientMessageMetadata,
|
ClientMessageMetadata,
|
||||||
ErrorData,
|
ErrorData,
|
||||||
@ -30,7 +31,6 @@ from core.mcp.types import (
|
|||||||
RequestId,
|
RequestId,
|
||||||
SessionMessage,
|
SessionMessage,
|
||||||
)
|
)
|
||||||
from core.mcp.utils import create_mcp_http_client
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -50,6 +50,8 @@ ACCEPT = "Accept"
|
|||||||
JSON = "application/json"
|
JSON = "application/json"
|
||||||
SSE = "text/event-stream"
|
SSE = "text/event-stream"
|
||||||
|
|
||||||
|
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||||
|
|
||||||
|
|
||||||
class StreamableHTTPError(Exception):
|
class StreamableHTTPError(Exception):
|
||||||
"""Base exception for StreamableHTTP transport errors."""
|
"""Base exception for StreamableHTTP transport errors."""
|
||||||
@ -184,12 +186,13 @@ class StreamableHTTPTransport:
|
|||||||
|
|
||||||
headers = self._update_headers_with_session(self.request_headers)
|
headers = self._update_headers_with_session(self.request_headers)
|
||||||
|
|
||||||
with connect_sse(
|
with ssrf_proxy_sse_connect(
|
||||||
client,
|
|
||||||
"GET",
|
|
||||||
self.url,
|
self.url,
|
||||||
|
2,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||||
|
client=client,
|
||||||
|
method="GET",
|
||||||
) as event_source:
|
) as event_source:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("GET SSE connection established")
|
logger.debug("GET SSE connection established")
|
||||||
@ -215,12 +218,13 @@ class StreamableHTTPTransport:
|
|||||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||||
original_request_id = ctx.session_message.message.root.id
|
original_request_id = ctx.session_message.message.root.id
|
||||||
|
|
||||||
with connect_sse(
|
with ssrf_proxy_sse_connect(
|
||||||
ctx.client,
|
|
||||||
"GET",
|
|
||||||
self.url,
|
self.url,
|
||||||
|
2,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||||
|
client=ctx.client,
|
||||||
|
method="GET",
|
||||||
) as event_source:
|
) as event_source:
|
||||||
event_source.response.raise_for_status()
|
event_source.response.raise_for_status()
|
||||||
logger.debug("Resumption GET SSE connection established")
|
logger.debug("Resumption GET SSE connection established")
|
||||||
@ -304,7 +308,6 @@ class StreamableHTTPTransport:
|
|||||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error reading SSE stream:")
|
|
||||||
ctx.server_to_client_queue.put(e)
|
ctx.server_to_client_queue.put(e)
|
||||||
|
|
||||||
def _handle_unexpected_content_type(
|
def _handle_unexpected_content_type(
|
||||||
@ -346,7 +349,7 @@ class StreamableHTTPTransport:
|
|||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
# Read message from client queue with timeout to check stop_event periodically
|
# Read message from client queue with timeout to check stop_event periodically
|
||||||
session_message = client_to_server_queue.get(timeout=5)
|
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||||
if session_message is None:
|
if session_message is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -444,7 +447,7 @@ def streamablehttp_client(
|
|||||||
try:
|
try:
|
||||||
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
|
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
|
||||||
|
|
||||||
with create_mcp_http_client(
|
with create_ssrf_proxy_mcp_http_client(
|
||||||
headers=transport.request_headers,
|
headers=transport.request_headers,
|
||||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||||
) as client:
|
) as client:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from core.mcp.client.sse_client import sse_client
|
from core.mcp.client.sse_client import sse_client
|
||||||
from core.mcp.client.streamable_client import streamablehttp_client
|
from core.mcp.client.streamable_client import streamablehttp_client
|
||||||
@ -59,10 +61,23 @@ class MCPClient:
|
|||||||
first_try: bool = True,
|
first_try: bool = True,
|
||||||
):
|
):
|
||||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||||
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)]
|
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
|
||||||
|
|
||||||
|
parsed_url = urlparse(self.server_url)
|
||||||
|
path = parsed_url.path
|
||||||
|
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||||
|
try:
|
||||||
|
client_factory = connection_methods[method_name]
|
||||||
|
self.connect_server(client_factory, method_name)
|
||||||
|
except KeyError:
|
||||||
|
try:
|
||||||
|
self.connect_server(streamablehttp_client, "sse")
|
||||||
|
except MCPConnectionError:
|
||||||
|
self.connect_server(sse_client, "mcp")
|
||||||
|
|
||||||
|
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
|
||||||
from core.mcp.auth.auth_flow import auth
|
from core.mcp.auth.auth_flow import auth
|
||||||
|
|
||||||
for method_name, client_factory in connection_methods:
|
|
||||||
try:
|
try:
|
||||||
headers = (
|
headers = (
|
||||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||||
@ -70,7 +85,7 @@ class MCPClient:
|
|||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||||
if method_name == "streamablehttp_client":
|
if method_name == "mcp":
|
||||||
read_stream, write_stream, _ = self._streams_context.__enter__()
|
read_stream, write_stream, _ = self._streams_context.__enter__()
|
||||||
streams = (read_stream, write_stream)
|
streams = (read_stream, write_stream)
|
||||||
else: # sse_client
|
else: # sse_client
|
||||||
@ -88,11 +103,9 @@ class MCPClient:
|
|||||||
|
|
||||||
auth(self.provider, self.server_url, self.authorization_code, self.scope)
|
auth(self.provider, self.server_url, self.authorization_code, self.scope)
|
||||||
if first_try:
|
if first_try:
|
||||||
return self._initialize(first_try=False)
|
return self.connect_server(client_factory, method_name, first_try=False)
|
||||||
|
|
||||||
except MCPConnectionError:
|
except MCPConnectionError:
|
||||||
if method_name == "streamablehttp_client":
|
|
||||||
continue
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def list_tools(self) -> list[Tool]:
|
def list_tools(self) -> list[Tool]:
|
||||||
|
@ -40,7 +40,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
|
|||||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||||
DEFAULT_RESPONSE_READ_TIMEOUT = 5
|
DEFAULT_RESPONSE_READ_TIMEOUT = 1
|
||||||
|
|
||||||
|
|
||||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||||
@ -210,7 +210,7 @@ class BaseSession(
|
|||||||
|
|
||||||
future.result(timeout=remaining)
|
future.result(timeout=remaining)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error waiting for task: {e}")
|
logging.exception(f"Error waiting for task: {e}")
|
||||||
|
|
||||||
def send_request(
|
def send_request(
|
||||||
self,
|
self,
|
||||||
@ -247,8 +247,12 @@ class BaseSession(
|
|||||||
timeout = request_read_timeout_seconds.total_seconds()
|
timeout = request_read_timeout_seconds.total_seconds()
|
||||||
elif self._session_read_timeout_seconds is not None:
|
elif self._session_read_timeout_seconds is not None:
|
||||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
try:
|
||||||
response_or_error = response_queue.get(timeout=timeout)
|
response_or_error = response_queue.get(timeout=timeout)
|
||||||
|
break
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
if response_or_error is None:
|
if response_or_error is None:
|
||||||
raise MCPConnectionError(
|
raise MCPConnectionError(
|
||||||
@ -315,7 +319,7 @@ class BaseSession(
|
|||||||
while not self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
# Attempt to receive a message (this would be blocking in a synchronous context)
|
# Attempt to receive a message (this would be blocking in a synchronous context)
|
||||||
message = self._read_stream.get(timeout=5)
|
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||||
if message is None:
|
if message is None:
|
||||||
break
|
break
|
||||||
if isinstance(message, HTTPStatusError):
|
if isinstance(message, HTTPStatusError):
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
from urllib.parse import urljoin, urlparse
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
def create_mcp_http_client(
|
|
||||||
headers: dict[str, str] | None = None,
|
|
||||||
timeout: httpx.Timeout | None = None,
|
|
||||||
) -> httpx.Client:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"follow_redirects": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Handle timeout
|
|
||||||
if timeout is None:
|
|
||||||
kwargs["timeout"] = httpx.Timeout(30.0)
|
|
||||||
else:
|
|
||||||
kwargs["timeout"] = timeout
|
|
||||||
|
|
||||||
# Handle headers
|
|
||||||
if headers is not None:
|
|
||||||
kwargs["headers"] = headers
|
|
||||||
return httpx.Client(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_request_params(url: str) -> str:
|
|
||||||
return urljoin(url, urlparse(url).path)
|
|
Loading…
x
Reference in New Issue
Block a user