mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Merge branch 'main' into fix/885-Update-default-Bind-address
This commit is contained in:
@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
|
||||
from .tcp.tcp import TCP
|
||||
from .websocket.transport import WebsocketTransport
|
||||
from .transport_registry import (
|
||||
TransportRegistry,
|
||||
create_transport_for_multiaddr,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from .upgrader import TransportUpgrader
|
||||
from libp2p.abc import ITransport
|
||||
|
||||
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport:
|
||||
"""
|
||||
Convenience function to create a transport instance.
|
||||
|
||||
:param protocol: The transport protocol ("tcp", "ws", "wss", or custom)
|
||||
:param upgrader: Optional transport upgrader (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config)
|
||||
:return: Transport instance
|
||||
"""
|
||||
# First check if it's a built-in protocol
|
||||
if protocol in ["ws", "wss"]:
|
||||
if upgrader is None:
|
||||
raise ValueError(f"WebSocket transport requires an upgrader")
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0)
|
||||
)
|
||||
elif protocol == "tcp":
|
||||
return TCP()
|
||||
else:
|
||||
# Check if it's a custom registered transport
|
||||
registry = get_transport_registry()
|
||||
transport_class = registry.get_transport(protocol)
|
||||
if transport_class:
|
||||
transport = registry.create_transport(protocol, upgrader, **kwargs)
|
||||
if transport is None:
|
||||
raise ValueError(f"Failed to create transport for protocol: {protocol}")
|
||||
return transport
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
||||
|
||||
__all__ = [
|
||||
"TCP",
|
||||
"WebsocketTransport",
|
||||
"TransportRegistry",
|
||||
"create_transport_for_multiaddr",
|
||||
"create_transport",
|
||||
"get_transport_registry",
|
||||
"register_transport",
|
||||
"get_supported_transport_protocols",
|
||||
]
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from aioquic.quic import events
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
@ -871,9 +871,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Process events by type
|
||||
for event_type, event_list in events_by_type.items():
|
||||
if event_type == type(events.StreamDataReceived).__name__:
|
||||
await self._handle_stream_data_batch(
|
||||
cast(list[events.StreamDataReceived], event_list)
|
||||
)
|
||||
# Filter to only StreamDataReceived events
|
||||
stream_data_events = [
|
||||
e for e in event_list if isinstance(e, events.StreamDataReceived)
|
||||
]
|
||||
await self._handle_stream_data_batch(stream_data_events)
|
||||
else:
|
||||
# Process other events individually
|
||||
for event in event_list:
|
||||
|
||||
267
libp2p/transport/transport_registry.py
Normal file
267
libp2p/transport/transport_registry.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Transport registry for dynamic transport selection based on multiaddr protocols.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
from libp2p.abc import ITransport
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import (
|
||||
is_valid_websocket_multiaddr,
|
||||
)
|
||||
|
||||
|
||||
# Import QUIC utilities here to avoid circular imports
|
||||
def _get_quic_transport() -> Any:
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
return QUICTransport
|
||||
|
||||
|
||||
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
|
||||
return is_quic_multiaddr
|
||||
|
||||
|
||||
# Import WebsocketTransport here to avoid circular imports
|
||||
def _get_websocket_transport() -> Any:
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
return WebsocketTransport
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.registry")
|
||||
|
||||
|
||||
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid TCP structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid TCP structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
||||
# or /ip6/::1/tcp/8080
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
||||
if len(protocols) < 2:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Should not have any protocols after tcp (unless it's a valid
|
||||
# continuation like p2p)
|
||||
# For now, we'll be strict and only allow network + tcp
|
||||
if len(protocols) > 2:
|
||||
# Check if the additional protocols are valid continuations
|
||||
valid_continuations = ["p2p"] # Add more as needed
|
||||
for i in range(2, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class TransportRegistry:
|
||||
"""
|
||||
Registry for mapping multiaddr protocols to transport implementations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._transports: dict[str, type[ITransport]] = {}
|
||||
self._register_default_transports()
|
||||
|
||||
def _register_default_transports(self) -> None:
|
||||
"""Register the default transport implementations."""
|
||||
# Register TCP transport for /tcp protocol
|
||||
self.register_transport("tcp", TCP)
|
||||
|
||||
# Register WebSocket transport for /ws and /wss protocols
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
self.register_transport("ws", WebsocketTransport)
|
||||
self.register_transport("wss", WebsocketTransport)
|
||||
|
||||
# Register QUIC transport for /quic and /quic-v1 protocols
|
||||
QUICTransport = _get_quic_transport()
|
||||
self.register_transport("quic", QUICTransport)
|
||||
self.register_transport("quic-v1", QUICTransport)
|
||||
|
||||
def register_transport(
|
||||
self, protocol: str, transport_class: type[ITransport]
|
||||
) -> None:
|
||||
"""
|
||||
Register a transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier (e.g., "tcp", "ws")
|
||||
:param transport_class: The transport class to register
|
||||
"""
|
||||
self._transports[protocol] = transport_class
|
||||
logger.debug(
|
||||
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
||||
)
|
||||
|
||||
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
||||
"""
|
||||
Get the transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:return: The transport class or None if not found
|
||||
"""
|
||||
return self._transports.get(protocol)
|
||||
|
||||
def get_supported_protocols(self) -> list[str]:
|
||||
"""Get list of supported transport protocols."""
|
||||
return list(self._transports.keys())
|
||||
|
||||
def create_transport(
|
||||
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create a transport instance for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:param upgrader: The transport upgrader instance (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
:return: Transport instance or None if protocol not supported or creation fails
|
||||
"""
|
||||
transport_class = self.get_transport(protocol)
|
||||
if transport_class is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if protocol in ["ws", "wss"]:
|
||||
# WebSocket transport requires upgrader
|
||||
if upgrader is None:
|
||||
logger.warning(
|
||||
f"WebSocket transport '{protocol}' requires upgrader"
|
||||
)
|
||||
return None
|
||||
# Use explicit WebsocketTransport to avoid type issues
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
|
||||
)
|
||||
elif protocol in ["quic", "quic-v1"]:
|
||||
# QUIC transport requires private_key
|
||||
private_key = kwargs.get("private_key")
|
||||
if private_key is None:
|
||||
logger.warning(f"QUIC transport '{protocol}' requires private_key")
|
||||
return None
|
||||
# Use explicit QUICTransport to avoid type issues
|
||||
QUICTransport = _get_quic_transport()
|
||||
config = kwargs.get("config")
|
||||
return QUICTransport(private_key, config)
|
||||
else:
|
||||
# TCP transport doesn't require upgrader
|
||||
return transport_class()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global transport registry instance (lazy initialization)
|
||||
_global_registry: TransportRegistry | None = None
|
||||
|
||||
|
||||
def get_transport_registry() -> TransportRegistry:
|
||||
"""Get the global transport registry instance."""
|
||||
global _global_registry
|
||||
if _global_registry is None:
|
||||
_global_registry = TransportRegistry()
|
||||
return _global_registry
|
||||
|
||||
|
||||
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
||||
"""Register a transport class in the global registry."""
|
||||
registry = get_transport_registry()
|
||||
registry.register_transport(protocol, transport_class)
|
||||
|
||||
|
||||
def create_transport_for_multiaddr(
|
||||
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create the appropriate transport for a given multiaddr.
|
||||
|
||||
:param maddr: The multiaddr to create transport for
|
||||
:param upgrader: The transport upgrader instance
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
(e.g., private_key for QUIC)
|
||||
:return: Transport instance or None if no suitable transport found
|
||||
"""
|
||||
try:
|
||||
# Get all protocols in the multiaddr
|
||||
protocols = [proto.name for proto in maddr.protocols()]
|
||||
|
||||
# Check for supported transport protocols in order of preference
|
||||
# We need to validate that the multiaddr structure is valid for our transports
|
||||
if "quic" in protocols or "quic-v1" in protocols:
|
||||
# For QUIC, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/udp/4001/quic
|
||||
# /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
is_quic_multiaddr = _get_quic_validation()
|
||||
if is_quic_multiaddr(maddr):
|
||||
# Determine QUIC version
|
||||
registry = get_transport_registry()
|
||||
if "quic-v1" in protocols:
|
||||
return registry.create_transport("quic-v1", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("quic", upgrader, **kwargs)
|
||||
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
|
||||
# For WebSocket, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
if is_valid_websocket_multiaddr(maddr):
|
||||
# Determine if this is a secure WebSocket connection
|
||||
registry = get_transport_registry()
|
||||
if "wss" in protocols or "tls" in protocols:
|
||||
return registry.create_transport("wss", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("ws", upgrader, **kwargs)
|
||||
elif "tcp" in protocols:
|
||||
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
|
||||
# Check if the multiaddr has proper TCP structure
|
||||
if _is_valid_tcp_multiaddr(maddr):
|
||||
registry = get_transport_registry()
|
||||
return registry.create_transport("tcp", upgrader)
|
||||
|
||||
# If no supported transport protocol found or structure is invalid, return None
|
||||
logger.warning(
|
||||
f"No supported transport protocol found or invalid structure in "
|
||||
f"multiaddr: {maddr}"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors gracefully (e.g., invalid multiaddr)
|
||||
logger.warning(f"Error processing multiaddr {maddr}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_transport_protocols() -> list[str]:
|
||||
"""Get list of supported transport protocols from the global registry."""
|
||||
registry = get_transport_registry()
|
||||
return registry.get_supported_protocols()
|
||||
198
libp2p/transport/websocket/connection.py
Normal file
198
libp2p/transport/websocket/connection.py
Normal file
@ -0,0 +1,198 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class P2PWebSocketConnection(ReadWriteCloser):
|
||||
"""
|
||||
Wraps a WebSocketConnection to provide the raw stream interface
|
||||
that libp2p protocols expect.
|
||||
|
||||
Implements production-ready buffer management and flow control
|
||||
as recommended in the libp2p WebSocket specification.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ws_connection: Any,
|
||||
ws_context: Any = None,
|
||||
is_secure: bool = False,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
) -> None:
|
||||
self._ws_connection = ws_connection
|
||||
self._ws_context = ws_context
|
||||
self._is_secure = is_secure
|
||||
self._read_buffer = b""
|
||||
self._read_lock = trio.Lock()
|
||||
self._connection_start_time = time.time()
|
||||
self._bytes_read = 0
|
||||
self._bytes_written = 0
|
||||
self._closed = False
|
||||
self._close_lock = trio.Lock()
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._write_lock = trio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write data with flow control and buffer management"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._write_lock:
|
||||
try:
|
||||
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||
|
||||
# Check buffer amount for flow control
|
||||
if hasattr(self._ws_connection, "bufferedAmount"):
|
||||
buffered = self._ws_connection.bufferedAmount
|
||||
if buffered > self._max_buffered_amount:
|
||||
logger.warning(f"WebSocket buffer full: {buffered} bytes")
|
||||
# In production, you might want to
|
||||
# wait or implement backpressure
|
||||
# For now, we'll continue but log the warning
|
||||
|
||||
# Send as a binary WebSocket message
|
||||
await self._ws_connection.send_message(data)
|
||||
self._bytes_written += len(data)
|
||||
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket write failed: {e}")
|
||||
self._closed = True
|
||||
raise IOException from e
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||
This implementation provides byte-level access to WebSocket messages,
|
||||
which is required for libp2p protocol compatibility.
|
||||
|
||||
For WebSocket compatibility with libp2p protocols, this method:
|
||||
1. Buffers incoming WebSocket messages
|
||||
2. Returns exactly the requested number of bytes when n is specified
|
||||
3. Accumulates multiple WebSocket messages if needed to satisfy the request
|
||||
4. Returns empty bytes (not raises) when connection is closed and no data
|
||||
available
|
||||
"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._read_lock:
|
||||
try:
|
||||
# If n is None, read at least one message and return all buffered data
|
||||
if n is None:
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
# No message available within timeout
|
||||
return b""
|
||||
except Exception:
|
||||
# Return empty bytes if no data available
|
||||
# (connection closed)
|
||||
return b""
|
||||
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
# For specific byte count requests, return UP TO n bytes (not exactly n)
|
||||
# This matches TCP semantics where read(1024) returns available data
|
||||
# up to 1024 bytes
|
||||
|
||||
# If we don't have any data buffered, try to get at least one message
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
return b"" # No data available
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
# Now return up to n bytes from the buffer (TCP-like semantics)
|
||||
if len(self._read_buffer) == 0:
|
||||
return b""
|
||||
|
||||
# Return up to n bytes (like TCP read())
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[len(result) :]
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket read failed: {e}")
|
||||
raise IOException from e
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection. This method is idempotent."""
|
||||
async with self._close_lock:
|
||||
if self._closed:
|
||||
return # Already closed
|
||||
|
||||
logger.debug("WebSocket connection closing")
|
||||
self._closed = True
|
||||
try:
|
||||
# Always close the connection directly, avoid context manager issues
|
||||
# The context manager may be causing cancel scope corruption
|
||||
logger.debug("WebSocket closing connection directly")
|
||||
await self._ws_connection.aclose()
|
||||
# Exit the context manager if we have one
|
||||
if self._ws_context is not None:
|
||||
await self._ws_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket close error: {e}")
|
||||
# Don't raise here, as close() should be idempotent
|
||||
finally:
|
||||
logger.debug("WebSocket connection closed")
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the connection is closed"""
|
||||
return self._closed
|
||||
|
||||
def conn_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
Return connection state information similar to Go's ConnState() method.
|
||||
|
||||
:return: Dictionary containing connection state information
|
||||
"""
|
||||
current_time = time.time()
|
||||
return {
|
||||
"transport": "websocket",
|
||||
"secure": self._is_secure,
|
||||
"connection_duration": current_time - self._connection_start_time,
|
||||
"bytes_read": self._bytes_read,
|
||||
"bytes_written": self._bytes_written,
|
||||
"total_bytes": self._bytes_read + self._bytes_written,
|
||||
}
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
# Try to get remote address from the WebSocket connection
|
||||
try:
|
||||
remote = self._ws_connection.remote
|
||||
if hasattr(remote, "address") and hasattr(remote, "port"):
|
||||
return str(remote.address), int(remote.port)
|
||||
elif isinstance(remote, str):
|
||||
# Parse address:port format
|
||||
if ":" in remote:
|
||||
host, port = remote.rsplit(":", 1)
|
||||
return host, int(port)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
225
libp2p/transport/websocket/listener.py
Normal file
225
libp2p/transport/websocket/listener.py
Normal file
@ -0,0 +1,225 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
from trio_typing import TaskStatus
|
||||
from trio_websocket import serve_websocket
|
||||
|
||||
from libp2p.abc import IListener
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.websocket.listener")
|
||||
|
||||
|
||||
class WebsocketListener(IListener):
|
||||
"""
|
||||
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: THandler,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
) -> None:
|
||||
self._handler = handler
|
||||
self._upgrader = upgrader
|
||||
self._tls_config = tls_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._server = None
|
||||
self._shutdown_event = trio.Event()
|
||||
self._nursery: trio.Nursery | None = None
|
||||
self._listeners: Any = None
|
||||
self._is_wss = False # Track whether this is a WSS listener
|
||||
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Check if WSS is requested but no TLS config provided
|
||||
if parsed.is_wss and self._tls_config is None:
|
||||
raise ValueError(
|
||||
f"Cannot listen on WSS address {maddr} without TLS configuration"
|
||||
)
|
||||
|
||||
# Store whether this is a WSS listener
|
||||
self._is_wss = parsed.is_wss
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
or "0.0.0.0"
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}"
|
||||
)
|
||||
|
||||
async def serve_websocket_tcp(
|
||||
handler: Callable[[Any], Awaitable[None]],
|
||||
port: int,
|
||||
host: str,
|
||||
task_status: TaskStatus[Any],
|
||||
) -> None:
|
||||
"""Start TCP server and handle WebSocket connections manually"""
|
||||
logger.debug(
|
||||
"serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss
|
||||
)
|
||||
|
||||
async def websocket_handler(request: Any) -> None:
|
||||
"""Handle WebSocket requests"""
|
||||
logger.debug("WebSocket request received")
|
||||
try:
|
||||
# Apply handshake timeout
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
# Accept the WebSocket connection
|
||||
ws_connection = await request.accept()
|
||||
logger.debug("WebSocket handshake successful")
|
||||
|
||||
# Create the WebSocket connection wrapper
|
||||
conn = P2PWebSocketConnection(
|
||||
ws_connection, is_secure=parsed.is_wss
|
||||
) # type: ignore[no-untyped-call]
|
||||
|
||||
# Call the handler function that was passed to create_listener
|
||||
# This handler will handle the security and muxing upgrades
|
||||
logger.debug("Calling connection handler")
|
||||
await self._handler(conn)
|
||||
|
||||
# Don't keep the connection alive indefinitely
|
||||
# Let the handler manage the connection lifecycle
|
||||
logger.debug(
|
||||
"Handler completed, connection will be managed by handler"
|
||||
)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s"
|
||||
)
|
||||
try:
|
||||
await request.reject(408) # Request Timeout
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket connection error: {e}")
|
||||
logger.debug(f"Error type: {type(e)}")
|
||||
import traceback
|
||||
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
# Reject the connection
|
||||
try:
|
||||
await request.reject(400)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
||||
ssl_context = self._tls_config if parsed.is_wss else None
|
||||
await serve_websocket(
|
||||
websocket_handler, host, port, ssl_context, task_status=task_status
|
||||
)
|
||||
|
||||
# Store the nursery for shutdown
|
||||
self._nursery = nursery
|
||||
|
||||
# Start the server using nursery.start() like TCP does
|
||||
logger.debug("Calling nursery.start()...")
|
||||
started_listeners = await nursery.start(
|
||||
serve_websocket_tcp,
|
||||
None, # No handler needed since it's defined inside serve_websocket_tcp
|
||||
port,
|
||||
host,
|
||||
)
|
||||
logger.debug(f"nursery.start() returned: {started_listeners}")
|
||||
|
||||
if started_listeners is None:
|
||||
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
||||
return False
|
||||
|
||||
# Store the listeners for get_addrs() and close() - these are real
|
||||
# SocketListener objects
|
||||
self._listeners = started_listeners
|
||||
logger.debug(
|
||||
"WebsocketListener.listen returning True with WebSocketServer object"
|
||||
)
|
||||
return True
|
||||
|
||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||
if not hasattr(self, "_listeners") or not self._listeners:
|
||||
logger.debug("No listeners available for get_addrs()")
|
||||
return ()
|
||||
|
||||
# Handle WebSocketServer objects
|
||||
if hasattr(self._listeners, "port"):
|
||||
# This is a WebSocketServer object
|
||||
port = self._listeners.port
|
||||
# Create a multiaddr from the port with correct WSS/WS protocol
|
||||
protocol = "wss" if self._is_wss else "ws"
|
||||
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),)
|
||||
else:
|
||||
# This is a list of listeners (like TCP)
|
||||
listeners = self._listeners
|
||||
# Get addresses from listeners like TCP does
|
||||
return tuple(
|
||||
_multiaddr_from_socket(listener.socket, self._is_wss)
|
||||
for listener in listeners
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket listener and stop accepting new connections"""
|
||||
logger.debug("WebsocketListener.close called")
|
||||
if hasattr(self, "_listeners") and self._listeners:
|
||||
# Signal shutdown
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Close the WebSocket server
|
||||
if hasattr(self._listeners, "aclose"):
|
||||
# This is a WebSocketServer object
|
||||
logger.debug("Closing WebSocket server")
|
||||
await self._listeners.aclose()
|
||||
logger.debug("WebSocket server closed")
|
||||
elif isinstance(self._listeners, (list, tuple)):
|
||||
# This is a list of listeners (like TCP)
|
||||
logger.debug("Closing TCP listeners")
|
||||
for listener in self._listeners:
|
||||
await listener.aclose()
|
||||
logger.debug("TCP listeners closed")
|
||||
else:
|
||||
# Unknown type, try to close it directly
|
||||
logger.debug("Closing unknown listener type")
|
||||
if hasattr(self._listeners, "close"):
|
||||
self._listeners.close()
|
||||
logger.debug("Unknown listener closed")
|
||||
|
||||
# Clear the listeners reference
|
||||
self._listeners = None
|
||||
logger.debug("WebsocketListener.close completed")
|
||||
|
||||
|
||||
def _multiaddr_from_socket(
|
||||
socket: trio.socket.SocketType, is_wss: bool = False
|
||||
) -> Multiaddr:
|
||||
"""Convert socket to multiaddr"""
|
||||
ip, port = socket.getsockname()
|
||||
protocol = "wss" if is_wss else "ws"
|
||||
return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}")
|
||||
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
@ -0,0 +1,202 @@
|
||||
"""
|
||||
WebSocket multiaddr parsing utilities.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
|
||||
class ParsedWebSocketMultiaddr(NamedTuple):
|
||||
"""Parsed WebSocket multiaddr information."""
|
||||
|
||||
is_wss: bool
|
||||
sni: str | None
|
||||
rest_multiaddr: Multiaddr
|
||||
|
||||
|
||||
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
|
||||
"""
|
||||
Parse a WebSocket multiaddr and extract security information.
|
||||
|
||||
:param maddr: The multiaddr to parse
|
||||
:return: Parsed WebSocket multiaddr information
|
||||
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
|
||||
"""
|
||||
# First validate that this is a valid WebSocket multiaddr
|
||||
if not is_valid_websocket_multiaddr(maddr):
|
||||
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
|
||||
|
||||
protocols = list(maddr.protocols())
|
||||
|
||||
# Find the WebSocket protocol and check for security
|
||||
is_wss = False
|
||||
sni = None
|
||||
ws_index = -1
|
||||
tls_index = -1
|
||||
sni_index = -1
|
||||
|
||||
# Find protocol indices
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name == "ws":
|
||||
ws_index = i
|
||||
elif protocol.name == "wss":
|
||||
ws_index = i
|
||||
is_wss = True
|
||||
elif protocol.name == "tls":
|
||||
tls_index = i
|
||||
elif protocol.name == "sni":
|
||||
sni_index = i
|
||||
sni = protocol.value
|
||||
|
||||
if ws_index == -1:
|
||||
raise ValueError("Not a WebSocket multiaddr")
|
||||
|
||||
# Handle /wss protocol (convert to /tls/ws internally)
|
||||
if is_wss and tls_index == -1:
|
||||
# Convert /wss to /tls/ws format
|
||||
# Remove /wss to get the base multiaddr
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=True, sni=None, rest_multiaddr=without_wss
|
||||
)
|
||||
|
||||
# Handle /tls/ws and /tls/sni/.../ws formats
|
||||
if tls_index != -1:
|
||||
is_wss = True
|
||||
# Extract the base multiaddr (everything before /tls)
|
||||
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
|
||||
# Use multiaddr methods to properly extract the base
|
||||
rest_multiaddr = maddr
|
||||
# Remove /tls/ws or /tls/sni/.../ws from the end
|
||||
if sni_index != -1:
|
||||
# /tls/sni/example.com/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
else:
|
||||
# /tls/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
# Regular /ws multiaddr - remove /ws and any additional protocols
|
||||
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
|
||||
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid WebSocket structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid WebSocket structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# WebSocket multiaddr should have structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
|
||||
if len(protocols) < 3:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Check for valid WebSocket protocols
|
||||
ws_protocols = ["ws", "wss"]
|
||||
tls_protocols = ["tls"]
|
||||
sni_protocols = ["sni"]
|
||||
|
||||
# Find the WebSocket protocol
|
||||
ws_protocol_found = False
|
||||
tls_found = False
|
||||
# sni_found = False # Not used currently
|
||||
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name in ws_protocols:
|
||||
ws_protocol_found = True
|
||||
break
|
||||
elif protocol.name in tls_protocols:
|
||||
tls_found = True
|
||||
elif protocol.name in sni_protocols:
|
||||
pass # sni_found = True # Not used in current implementation
|
||||
|
||||
if not ws_protocol_found:
|
||||
return False
|
||||
|
||||
# Validate protocol sequence
|
||||
# For /ws: network + tcp + ws
|
||||
# For /wss: network + tcp + wss
|
||||
# For /tls/ws: network + tcp + tls + ws
|
||||
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
|
||||
|
||||
# Check if it's a simple /ws or /wss
|
||||
if len(protocols) == 3:
|
||||
return protocols[2].name in ["ws", "wss"]
|
||||
|
||||
# Check for /tls/ws or /tls/sni/.../ws patterns
|
||||
if tls_found:
|
||||
# Must end with /ws (not /wss when using /tls)
|
||||
if protocols[-1].name != "ws":
|
||||
return False
|
||||
|
||||
# Check for valid TLS sequence
|
||||
tls_index = None
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name == "tls":
|
||||
tls_index = i
|
||||
break
|
||||
|
||||
if tls_index is None:
|
||||
return False
|
||||
|
||||
# After tls, we can have sni, then ws
|
||||
remaining_protocols = protocols[tls_index + 1 :]
|
||||
if len(remaining_protocols) == 1:
|
||||
# /tls/ws
|
||||
return remaining_protocols[0].name == "ws"
|
||||
elif len(remaining_protocols) == 2:
|
||||
# /tls/sni/example.com/ws
|
||||
return (
|
||||
remaining_protocols[0].name == "sni"
|
||||
and remaining_protocols[1].name == "ws"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
# If we have more than 3 protocols but no TLS, check for valid continuations
|
||||
# Allow additional protocols after the WebSocket protocol (like /p2p)
|
||||
valid_continuations = ["p2p"]
|
||||
|
||||
# Find the WebSocket protocol index
|
||||
ws_index = None
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name in ["ws", "wss"]:
|
||||
ws_index = i
|
||||
break
|
||||
|
||||
if ws_index is not None:
|
||||
# Check protocols after the WebSocket protocol
|
||||
for i in range(ws_index + 1, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
229
libp2p/transport/websocket/transport.py
Normal file
229
libp2p/transport/websocket/transport.py
Normal file
@ -0,0 +1,229 @@
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.abc import IListener, ITransport
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
from .listener import WebsocketListener
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketTransport(ITransport):
|
||||
"""
|
||||
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
||||
|
||||
Implements production-ready WebSocket transport with:
|
||||
- Flow control and buffer management
|
||||
- Connection limits and rate limiting
|
||||
- Proper error handling and cleanup
|
||||
- Support for both WS and WSS protocols
|
||||
- TLS configuration and handshake timeout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
):
|
||||
self._upgrader = upgrader
|
||||
self._tls_client_config = tls_client_config
|
||||
self._tls_server_config = tls_server_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._connection_count = 0
|
||||
self._max_connections = 1000 # Production limit
|
||||
|
||||
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
||||
"""Dial a WebSocket connection to the given multiaddr."""
|
||||
logger.debug(f"WebsocketTransport.dial called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
# Build WebSocket URL based on security
|
||||
if parsed.is_wss:
|
||||
ws_url = f"wss://{host}:{port}/"
|
||||
else:
|
||||
ws_url = f"ws://{host}:{port}/"
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check connection limits
|
||||
if self._connection_count >= self._max_connections:
|
||||
raise OpenConnectionError(
|
||||
f"Maximum connections reached: {self._max_connections}"
|
||||
)
|
||||
|
||||
# Prepare SSL context for WSS connections
|
||||
ssl_context = None
|
||||
if parsed.is_wss:
|
||||
if self._tls_client_config:
|
||||
ssl_context = self._tls_client_config
|
||||
else:
|
||||
# Create default SSL context for client
|
||||
ssl_context = ssl.create_default_context()
|
||||
# Set SNI if available
|
||||
if parsed.sni:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
|
||||
|
||||
# Use a different approach: start background nursery that will persist
|
||||
logger.debug("WebsocketTransport.dial establishing connection")
|
||||
|
||||
# Import trio-websocket functions
|
||||
from trio_websocket import connect_websocket
|
||||
from trio_websocket._impl import _url_to_host
|
||||
|
||||
# Parse the WebSocket URL to get host, port, resource
|
||||
# like trio-websocket does
|
||||
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
|
||||
ws_url, ssl_context
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
|
||||
f"port={ws_port}, resource={ws_resource}"
|
||||
)
|
||||
|
||||
# Create a background task manager for this connection
|
||||
import trio
|
||||
|
||||
nursery_manager = trio.lowlevel.current_task().parent_nursery
|
||||
if nursery_manager is None:
|
||||
raise OpenConnectionError(
|
||||
f"No parent nursery available for WebSocket connection to {maddr}"
|
||||
)
|
||||
|
||||
# Apply timeout to the connection process
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
logger.debug("WebsocketTransport.dial connecting WebSocket")
|
||||
ws = await connect_websocket(
|
||||
nursery_manager, # Use the existing nursery from libp2p
|
||||
ws_host,
|
||||
ws_port,
|
||||
ws_resource,
|
||||
use_ssl=ws_ssl_context,
|
||||
message_queue_size=1024, # Reasonable defaults
|
||||
max_message_size=16 * 1024 * 1024, # 16MB max message
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
||||
|
||||
# Create our connection wrapper with both WSS support and flow control
|
||||
conn = P2PWebSocketConnection(
|
||||
ws,
|
||||
None,
|
||||
is_secure=parsed.is_wss,
|
||||
max_buffered_amount=self._max_buffered_amount,
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
||||
|
||||
self._connection_count += 1
|
||||
logger.debug(f"Total connections: {self._connection_count}")
|
||||
|
||||
return RawConnection(conn, initiator=True)
|
||||
except trio.TooSlowError as e:
|
||||
raise OpenConnectionError(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s "
|
||||
f"for {maddr}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
|
||||
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
||||
|
||||
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
||||
"""
|
||||
The type checker is incorrectly reporting this as an inconsistent override.
|
||||
"""
|
||||
logger.debug("WebsocketTransport.create_listener called")
|
||||
return WebsocketListener(
|
||||
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
|
||||
)
|
||||
|
||||
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
|
||||
"""
|
||||
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
|
||||
Similar to Go's Resolve() method.
|
||||
|
||||
:param maddr: The multiaddr to resolve
|
||||
:return: List of resolved multiaddrs
|
||||
"""
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
|
||||
return [maddr] # Return original if not a valid WebSocket multiaddr
|
||||
|
||||
logger.debug(
|
||||
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
|
||||
)
|
||||
|
||||
if not parsed.is_wss:
|
||||
# No /tls/ws component, this isn't a secure websocket multiaddr
|
||||
return [maddr]
|
||||
|
||||
if parsed.sni is not None:
|
||||
# Already has SNI, return as-is
|
||||
return [maddr]
|
||||
|
||||
# Try to extract DNS name from the base multiaddr
|
||||
dns_name = None
|
||||
for protocol_name in ["dns", "dns4", "dns6"]:
|
||||
try:
|
||||
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if dns_name is None:
|
||||
# No DNS name found, return original
|
||||
return [maddr]
|
||||
|
||||
# Create new multiaddr with SNI
|
||||
# For /dns/example.com/tcp/8080/wss ->
|
||||
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
|
||||
try:
|
||||
# Remove /wss and add /tls/sni/example.com/ws
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
sni_component = Multiaddr(f"/sni/{dns_name}")
|
||||
resolved = (
|
||||
without_wss.encapsulate(Multiaddr("/tls"))
|
||||
.encapsulate(sni_component)
|
||||
.encapsulate(Multiaddr("/ws"))
|
||||
)
|
||||
logger.debug(f"Resolved {maddr} to {resolved}")
|
||||
return [resolved]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
|
||||
return [maddr]
|
||||
Reference in New Issue
Block a user