mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
- Add TLS configuration support to new_host and new_swarm functions - Fix WebSocket transport tests (test_wss_host_pair_data_exchange, test_wss_listen_without_tls_config) - Integrate TLS configuration with transport registry for proper WebSocket WSS support - Move debug files to downloads directory for future reference - All 47 WebSocket tests now passing including WSS functionality - Maintain backward compatibility with existing code - Resolve all type checking and linting issues
230 lines
8.8 KiB
Python
230 lines
8.8 KiB
Python
import logging
|
|
import ssl
|
|
|
|
from multiaddr import Multiaddr
|
|
|
|
from libp2p.abc import IListener, ITransport
|
|
from libp2p.custom_types import THandler
|
|
from libp2p.network.connection.raw_connection import RawConnection
|
|
from libp2p.transport.exceptions import OpenConnectionError
|
|
from libp2p.transport.upgrader import TransportUpgrader
|
|
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
|
|
|
from .connection import P2PWebSocketConnection
|
|
from .listener import WebsocketListener
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WebsocketTransport(ITransport):
|
|
"""
|
|
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
|
|
|
Implements production-ready WebSocket transport with:
|
|
- Flow control and buffer management
|
|
- Connection limits and rate limiting
|
|
- Proper error handling and cleanup
|
|
- Support for both WS and WSS protocols
|
|
- TLS configuration and handshake timeout
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
upgrader: TransportUpgrader,
|
|
tls_client_config: ssl.SSLContext | None = None,
|
|
tls_server_config: ssl.SSLContext | None = None,
|
|
handshake_timeout: float = 15.0,
|
|
max_buffered_amount: int = 4 * 1024 * 1024,
|
|
):
|
|
self._upgrader = upgrader
|
|
self._tls_client_config = tls_client_config
|
|
self._tls_server_config = tls_server_config
|
|
self._handshake_timeout = handshake_timeout
|
|
self._max_buffered_amount = max_buffered_amount
|
|
self._connection_count = 0
|
|
self._max_connections = 1000 # Production limit
|
|
|
|
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
|
"""Dial a WebSocket connection to the given multiaddr."""
|
|
logger.debug(f"WebsocketTransport.dial called with {maddr}")
|
|
|
|
# Parse the WebSocket multiaddr to determine if it's secure
|
|
try:
|
|
parsed = parse_websocket_multiaddr(maddr)
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
|
|
|
# Extract host and port from the base multiaddr
|
|
host = (
|
|
parsed.rest_multiaddr.value_for_protocol("ip4")
|
|
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
|
or parsed.rest_multiaddr.value_for_protocol("dns")
|
|
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
|
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
|
)
|
|
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
|
if port_str is None:
|
|
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
|
port = int(port_str)
|
|
|
|
# Build WebSocket URL based on security
|
|
if parsed.is_wss:
|
|
ws_url = f"wss://{host}:{port}/"
|
|
else:
|
|
ws_url = f"ws://{host}:{port}/"
|
|
|
|
logger.debug(
|
|
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
|
|
)
|
|
|
|
try:
|
|
# Check connection limits
|
|
if self._connection_count >= self._max_connections:
|
|
raise OpenConnectionError(
|
|
f"Maximum connections reached: {self._max_connections}"
|
|
)
|
|
|
|
# Prepare SSL context for WSS connections
|
|
ssl_context = None
|
|
if parsed.is_wss:
|
|
if self._tls_client_config:
|
|
ssl_context = self._tls_client_config
|
|
else:
|
|
# Create default SSL context for client
|
|
ssl_context = ssl.create_default_context()
|
|
# Set SNI if available
|
|
if parsed.sni:
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
|
|
|
|
# Use a different approach: start background nursery that will persist
|
|
logger.debug("WebsocketTransport.dial establishing connection")
|
|
|
|
# Import trio-websocket functions
|
|
from trio_websocket import connect_websocket
|
|
from trio_websocket._impl import _url_to_host
|
|
|
|
# Parse the WebSocket URL to get host, port, resource
|
|
# like trio-websocket does
|
|
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
|
|
ws_url, ssl_context
|
|
)
|
|
|
|
logger.debug(
|
|
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
|
|
f"port={ws_port}, resource={ws_resource}"
|
|
)
|
|
|
|
# Create a background task manager for this connection
|
|
import trio
|
|
|
|
nursery_manager = trio.lowlevel.current_task().parent_nursery
|
|
if nursery_manager is None:
|
|
raise OpenConnectionError(
|
|
f"No parent nursery available for WebSocket connection to {maddr}"
|
|
)
|
|
|
|
# Apply timeout to the connection process
|
|
with trio.fail_after(self._handshake_timeout):
|
|
logger.debug("WebsocketTransport.dial connecting WebSocket")
|
|
ws = await connect_websocket(
|
|
nursery_manager, # Use the existing nursery from libp2p
|
|
ws_host,
|
|
ws_port,
|
|
ws_resource,
|
|
use_ssl=ws_ssl_context,
|
|
message_queue_size=1024, # Reasonable defaults
|
|
max_message_size=16 * 1024 * 1024, # 16MB max message
|
|
)
|
|
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
|
|
|
# Create our connection wrapper with both WSS support and flow control
|
|
conn = P2PWebSocketConnection(
|
|
ws,
|
|
None,
|
|
is_secure=parsed.is_wss,
|
|
max_buffered_amount=self._max_buffered_amount,
|
|
)
|
|
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
|
|
|
self._connection_count += 1
|
|
logger.debug(f"Total connections: {self._connection_count}")
|
|
|
|
return RawConnection(conn, initiator=True)
|
|
except trio.TooSlowError as e:
|
|
raise OpenConnectionError(
|
|
f"WebSocket handshake timeout after {self._handshake_timeout}s "
|
|
f"for {maddr}"
|
|
) from e
|
|
except Exception as e:
|
|
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
|
|
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
|
|
|
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
|
"""
|
|
The type checker is incorrectly reporting this as an inconsistent override.
|
|
"""
|
|
logger.debug("WebsocketTransport.create_listener called")
|
|
return WebsocketListener(
|
|
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
|
|
)
|
|
|
|
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
|
|
"""
|
|
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
|
|
Similar to Go's Resolve() method.
|
|
|
|
:param maddr: The multiaddr to resolve
|
|
:return: List of resolved multiaddrs
|
|
"""
|
|
try:
|
|
parsed = parse_websocket_multiaddr(maddr)
|
|
except ValueError as e:
|
|
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
|
|
return [maddr] # Return original if not a valid WebSocket multiaddr
|
|
|
|
logger.debug(
|
|
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
|
|
)
|
|
|
|
if not parsed.is_wss:
|
|
# No /tls/ws component, this isn't a secure websocket multiaddr
|
|
return [maddr]
|
|
|
|
if parsed.sni is not None:
|
|
# Already has SNI, return as-is
|
|
return [maddr]
|
|
|
|
# Try to extract DNS name from the base multiaddr
|
|
dns_name = None
|
|
for protocol_name in ["dns", "dns4", "dns6"]:
|
|
try:
|
|
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
|
|
break
|
|
except Exception:
|
|
continue
|
|
|
|
if dns_name is None:
|
|
# No DNS name found, return original
|
|
return [maddr]
|
|
|
|
# Create new multiaddr with SNI
|
|
# For /dns/example.com/tcp/8080/wss ->
|
|
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
|
|
try:
|
|
# Remove /wss and add /tls/sni/example.com/ws
|
|
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
|
sni_component = Multiaddr(f"/sni/{dns_name}")
|
|
resolved = (
|
|
without_wss.encapsulate(Multiaddr("/tls"))
|
|
.encapsulate(sni_component)
|
|
.encapsulate(Multiaddr("/ws"))
|
|
)
|
|
logger.debug(f"Resolved {maddr} to {resolved}")
|
|
return [resolved]
|
|
except Exception as e:
|
|
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
|
|
return [maddr]
|