Merge branch 'main' into refactor/replace-magic-numbers-with-named-constants

This commit is contained in:
Manu Sheel Gupta
2025-09-23 21:26:19 +05:30
committed by GitHub
66 changed files with 6592 additions and 357 deletions

View File

@ -1,6 +1,7 @@
"""Libp2p Python implementation."""
import logging
import ssl
from libp2p.transport.quic.utils import is_quic_multiaddr
from typing import Any
@ -24,6 +25,7 @@ from libp2p.abc import (
IPeerRouting,
IPeerStore,
ISecureTransport,
ITransport,
)
from libp2p.crypto.keys import (
KeyPair,
@ -80,6 +82,10 @@ from libp2p.transport.tcp.tcp import (
from libp2p.transport.upgrader import (
TransportUpgrader,
)
from libp2p.transport.transport_registry import (
create_transport_for_multiaddr,
get_supported_transport_protocols,
)
from libp2p.utils.logging import (
setup_logging,
)
@ -174,7 +180,10 @@ def new_swarm(
enable_quic: bool = False,
retry_config: Optional["RetryConfig"] = None,
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
tls_client_config: ssl.SSLContext | None = None,
tls_server_config: ssl.SSLContext | None = None,
) -> INetworkService:
logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}")
"""
Create a swarm instance based on the parameters.
@ -198,7 +207,7 @@ def new_swarm(
id_opt = generate_peer_id_from(key_pair)
transport: TCP | QUICTransport
transport: TCP | QUICTransport | ITransport
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
if listen_addrs is None:
@ -207,14 +216,39 @@ def new_swarm(
else:
transport = TCP()
else:
# Use transport registry to select the appropriate transport
from libp2p.transport.transport_registry import create_transport_for_multiaddr
# Create a temporary upgrader for transport selection
# We'll create the real upgrader later with the proper configuration
temp_upgrader = TransportUpgrader(
secure_transports_by_protocol={},
muxer_transports_by_protocol={}
)
addr = listen_addrs[0]
is_quic = is_quic_multiaddr(addr)
if addr.__contains__("tcp"):
transport = TCP()
elif is_quic:
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
else:
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
logger.debug(f"new_swarm: Creating transport for address: {addr}")
transport_maybe = create_transport_for_multiaddr(
addr,
temp_upgrader,
private_key=key_pair.private_key,
config=quic_transport_opt,
tls_client_config=tls_client_config,
tls_server_config=tls_server_config
)
if transport_maybe is None:
raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}")
transport = transport_maybe
logger.debug(f"new_swarm: Created transport: {type(transport)}")
# If enable_quic is True but we didn't get a QUIC transport, force QUIC
if enable_quic and not isinstance(transport, QUICTransport):
logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})")
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
logger.debug(f"new_swarm: Final transport type: {type(transport)}")
# Generate X25519 keypair for Noise
noise_key_pair = create_new_x25519_key_pair()
@ -255,6 +289,7 @@ def new_swarm(
muxer_transports_by_protocol=muxer_transports_by_protocol,
)
peerstore = peerstore_opt or PeerStore()
# Store our key pair in peerstore
peerstore.add_key_pair(id_opt, key_pair)
@ -282,6 +317,8 @@ def new_host(
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
enable_quic: bool = False,
quic_transport_opt: QUICTransportConfig | None = None,
tls_client_config: ssl.SSLContext | None = None,
tls_server_config: ssl.SSLContext | None = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.
@ -296,7 +333,9 @@ def new_host(
:param enable_mDNS: whether to enable mDNS discovery
:param bootstrap: optional list of bootstrap peer addresses as strings
:param enable_quic: optinal choice to use QUIC for transport
:param transport_opt: optional configuration for quic transport
:param quic_transport_opt: optional configuration for quic transport
:param tls_client_config: optional TLS client configuration for WebSocket transport
:param tls_server_config: optional TLS server configuration for WebSocket transport
:return: return a host instance
"""
@ -311,7 +350,9 @@ def new_host(
peerstore_opt=peerstore_opt,
muxer_preference=muxer_preference,
listen_addrs=listen_addrs,
connection_config=quic_transport_opt if enable_quic else None
connection_config=quic_transport_opt if enable_quic else None,
tls_client_config=tls_client_config,
tls_server_config=tls_server_config
)
if disc_opt is not None:

View File

@ -481,13 +481,16 @@ class Swarm(Service, INetworkService):
- Call listener listen with the multiaddr
- Map multiaddr to listener
"""
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
# We need to wait until `self.listener_nursery` is created.
logger.debug("Starting to listen")
await self.event_listener_nursery_created.wait()
success_count = 0
for maddr in multiaddrs:
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
if str(maddr) in self.listeners:
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
success_count += 1
continue
@ -545,13 +548,18 @@ class Swarm(Service, INetworkService):
try:
# Success
logger.debug(f"Swarm.listen: creating listener for {maddr}")
listener = self.transport.create_listener(conn_handler)
logger.debug(f"Swarm.listen: listener created for {maddr}")
self.listeners[str(maddr)] = listener
# TODO: `listener.listen` is not bounded with nursery. If we want to be
# I/O agnostic, we should change the API.
if self.listener_nursery is None:
raise SwarmException("swarm instance hasn't been run")
assert self.listener_nursery is not None # For type checker
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
await listener.listen(maddr, self.listener_nursery)
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
# Call notifiers since event occurred
await self.notify_listen(maddr)

View File

@ -1,3 +1,4 @@
import logging
from typing import (
cast,
)
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
FixedSizeLenMsgReadWriter,
)
logger = logging.getLogger(__name__)
SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
self.noise_state = noise_state
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
data_encrypted = self.encrypt(msg)
if prefix_encoded:
# Manually add the prefix if needed
data_encrypted = self.prefix + data_encrypted
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
await self.read_writer.write_msg(data_encrypted)
logger.debug("Noise write_msg: write completed successfully")
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
logger.debug("Noise read_msg: reading encrypted message")
noise_msg_encrypted = await self.read_writer.read_msg()
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
if prefix_encoded:
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
else:
return self.decrypt(noise_msg_encrypted)
result = self.decrypt(noise_msg_encrypted)
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
return result
async def close(self) -> None:
await self.read_writer.close()

View File

@ -1,6 +1,7 @@
from dataclasses import (
dataclass,
)
import logging
from libp2p.crypto.keys import (
PrivateKey,
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
from .pb import noise_pb2 as noise_pb
logger = logging.getLogger(__name__)
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
) -> bytes:
data = make_data_to_be_signed(noise_static_pubkey)
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
return id_privkey.sign(data)
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
2. signed by the private key corresponding to `id_pubkey`
"""
expected_data = make_data_to_be_signed(noise_static_pubkey)
return payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(
f"verify_handshake_payload_sig: payload.id_pubkey type: "
f"{type(payload.id_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: noise_static_pubkey type: "
f"{type(noise_static_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
)
logger.debug(
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
)
try:
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
return result
except Exception as e:
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
return False

View File

@ -2,6 +2,7 @@ from abc import (
ABC,
abstractmethod,
)
import logging
from cryptography.hazmat.primitives import (
serialization,
@ -46,6 +47,8 @@ from .messages import (
verify_handshake_payload_sig,
)
logger = logging.getLogger(__name__)
class IPattern(ABC):
@abstractmethod
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
self.early_data = early_data
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
noise_state = self.create_noise_state()
noise_state.set_as_responder()
noise_state.start_handshake()
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
logger.debug("Noise XX handshake_inbound: reading msg#1")
await read_writer.read_msg()
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
# Send msg#2, which should include our handshake payload.
logger.debug("Noise XX handshake_inbound: preparing msg#2")
our_payload = self.make_handshake_payload()
msg_2 = our_payload.serialize()
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
await read_writer.write_msg(msg_2)
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
# Receive and consume msg#3.
logger.debug("Noise XX handshake_inbound: reading msg#3")
msg_3 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
if handshake_state.rs is None:
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn:
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
noise_state = self.create_noise_state()
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
logger.debug("Noise XX handshake_outbound: sending msg#1")
msg_1 = b""
await read_writer.write_msg(msg_1)
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
# Read msg#2 from the remote, which contains the public key of the peer.
logger.debug("Noise XX handshake_outbound: reading msg#2")
msg_2 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
if handshake_state.rs is None:
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
logger.debug(
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
)
logger.debug(
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
)
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
logger.debug(
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
f"{id_pubkey_repr}"
)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
logger.error(
f"Noise XX handshake_outbound: signature verification failed for peer "
f"{remote_peer}"
)
raise InvalidSignature
logger.debug(
f"Noise XX handshake_outbound: signature verification successful for peer "
f"{remote_peer}"
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(

View File

@ -1,5 +1,3 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -15,6 +13,7 @@ from libp2p.abc import (
from libp2p.stream_muxer.exceptions import (
MuxedConnUnavailable,
)
from libp2p.stream_muxer.rw_lock import ReadWriteLock
from .constants import (
HeaderTags,
@ -34,72 +33,6 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go

View File

@ -0,0 +1,70 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import trio
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()

View File

@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamError,
MuxedStreamReset,
)
from libp2p.stream_muxer.rw_lock import ReadWriteLock
# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
self.send_window = DEFAULT_WINDOW_SIZE
self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.close_lock = trio.Lock()
async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager."""
@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream):
await self.close()
async def write(self, data: bytes) -> None:
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")
async with self.rw_lock.write_lock():
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logger.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logger.debug(
f"Stream {self.stream_id}: "
"Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release()
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed:
raise MuxedStreamError("Stream is closed")
# Calculate how much we can send now
to_send = min(self.send_window, total_len - sent)
chunk = data[sent : sent + to_send]
self.send_window -= to_send
# Send the data
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
)
# Release lock and wait with timeout
self.window_lock.release()
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed:
raise MuxedStreamError("Stream is closed")
# Calculate how much we can send now
to_send = min(self.send_window, total_len - sent)
chunk = data[sent : sent + to_send]
self.send_window -= to_send
# Send the data
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
)
await self.conn.secured_conn.write(header + chunk)
sent += to_send
await self.conn.secured_conn.write(header + chunk)
sent += to_send
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
"""
@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream):
return data
async def close(self) -> None:
if not self.send_closed:
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.send_closed = True
async with self.close_lock:
if not self.send_closed:
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.send_closed = True
# Only set fully closed if both directions are closed
if self.send_closed and self.recv_closed:
self.closed = True
else:
# Stream is half-closed but not fully closed
self.closed = False
# Only set fully closed if both directions are closed
if self.send_closed and self.recv_closed:
self.closed = True
else:
# Stream is half-closed but not fully closed
self.closed = False
async def reset(self) -> None:
if not self.closed:
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.closed = True
self.reset_received = True # Mark as reset
async with self.close_lock:
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.closed = True
self.reset_received = True # Mark as reset
def set_deadline(self, ttl: int) -> bool:
"""

View File

@ -0,0 +1,57 @@
from typing import Any
from .tcp.tcp import TCP
from .websocket.transport import WebsocketTransport
from .transport_registry import (
TransportRegistry,
create_transport_for_multiaddr,
get_transport_registry,
register_transport,
get_supported_transport_protocols,
)
from .upgrader import TransportUpgrader
from libp2p.abc import ITransport
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport:
"""
Convenience function to create a transport instance.
:param protocol: The transport protocol ("tcp", "ws", "wss", or custom)
:param upgrader: Optional transport upgrader (required for WebSocket)
:param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config)
:return: Transport instance
"""
# First check if it's a built-in protocol
if protocol in ["ws", "wss"]:
if upgrader is None:
raise ValueError(f"WebSocket transport requires an upgrader")
return WebsocketTransport(
upgrader,
tls_client_config=kwargs.get("tls_client_config"),
tls_server_config=kwargs.get("tls_server_config"),
handshake_timeout=kwargs.get("handshake_timeout", 15.0)
)
elif protocol == "tcp":
return TCP()
else:
# Check if it's a custom registered transport
registry = get_transport_registry()
transport_class = registry.get_transport(protocol)
if transport_class:
transport = registry.create_transport(protocol, upgrader, **kwargs)
if transport is None:
raise ValueError(f"Failed to create transport for protocol: {protocol}")
return transport
else:
raise ValueError(f"Unsupported transport protocol: {protocol}")
__all__ = [
"TCP",
"WebsocketTransport",
"TransportRegistry",
"create_transport_for_multiaddr",
"create_transport",
"get_transport_registry",
"register_transport",
"get_supported_transport_protocols",
]

View File

@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
import logging
import socket
import time
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional
from aioquic.quic import events
from aioquic.quic.connection import QuicConnection
@ -871,9 +871,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Process events by type
for event_type, event_list in events_by_type.items():
if event_type == type(events.StreamDataReceived).__name__:
await self._handle_stream_data_batch(
cast(list[events.StreamDataReceived], event_list)
)
# Filter to only StreamDataReceived events
stream_data_events = [
e for e in event_list if isinstance(e, events.StreamDataReceived)
]
await self._handle_stream_data_batch(stream_data_events)
else:
# Process other events individually
for event in event_list:

View File

@ -901,7 +901,7 @@ class QUICListener(IListener):
# Set socket options
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, "SO_REUSEPORT"):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # type: ignore[attr-defined]
# Bind to address
await sock.bind((host, port))

View File

@ -0,0 +1,267 @@
"""
Transport registry for dynamic transport selection based on multiaddr protocols.
"""
from collections.abc import Callable
import logging
from typing import Any
from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
from libp2p.abc import ITransport
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.multiaddr_utils import (
is_valid_websocket_multiaddr,
)
# Import QUIC utilities here to avoid circular imports
def _get_quic_transport() -> Any:
from libp2p.transport.quic.transport import QUICTransport
return QUICTransport
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
from libp2p.transport.quic.utils import is_quic_multiaddr
return is_quic_multiaddr
# Import WebsocketTransport here to avoid circular imports
def _get_websocket_transport() -> Any:
from libp2p.transport.websocket.transport import WebsocketTransport
return WebsocketTransport
logger = logging.getLogger("libp2p.transport.registry")
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid TCP structure.
:param maddr: The multiaddr to validate
:return: True if valid TCP structure, False otherwise
"""
try:
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
# or /ip6/::1/tcp/8080
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 2 protocols: network (ip4/ip6) + tcp
if len(protocols) < 2:
return False
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Should not have any protocols after tcp (unless it's a valid
# continuation like p2p)
# For now, we'll be strict and only allow network + tcp
if len(protocols) > 2:
# Check if the additional protocols are valid continuations
valid_continuations = ["p2p"] # Add more as needed
for i in range(2, len(protocols)):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False
class TransportRegistry:
"""
Registry for mapping multiaddr protocols to transport implementations.
"""
def __init__(self) -> None:
self._transports: dict[str, type[ITransport]] = {}
self._register_default_transports()
def _register_default_transports(self) -> None:
"""Register the default transport implementations."""
# Register TCP transport for /tcp protocol
self.register_transport("tcp", TCP)
# Register WebSocket transport for /ws and /wss protocols
WebsocketTransport = _get_websocket_transport()
self.register_transport("ws", WebsocketTransport)
self.register_transport("wss", WebsocketTransport)
# Register QUIC transport for /quic and /quic-v1 protocols
QUICTransport = _get_quic_transport()
self.register_transport("quic", QUICTransport)
self.register_transport("quic-v1", QUICTransport)
def register_transport(
self, protocol: str, transport_class: type[ITransport]
) -> None:
"""
Register a transport class for a specific protocol.
:param protocol: The protocol identifier (e.g., "tcp", "ws")
:param transport_class: The transport class to register
"""
self._transports[protocol] = transport_class
logger.debug(
f"Registered transport {transport_class.__name__} for protocol {protocol}"
)
def get_transport(self, protocol: str) -> type[ITransport] | None:
"""
Get the transport class for a specific protocol.
:param protocol: The protocol identifier
:return: The transport class or None if not found
"""
return self._transports.get(protocol)
def get_supported_protocols(self) -> list[str]:
"""Get list of supported transport protocols."""
return list(self._transports.keys())
def create_transport(
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
) -> ITransport | None:
"""
Create a transport instance for a specific protocol.
:param protocol: The protocol identifier
:param upgrader: The transport upgrader instance (required for WebSocket)
:param kwargs: Additional arguments for transport construction
:return: Transport instance or None if protocol not supported or creation fails
"""
transport_class = self.get_transport(protocol)
if transport_class is None:
return None
try:
if protocol in ["ws", "wss"]:
# WebSocket transport requires upgrader
if upgrader is None:
logger.warning(
f"WebSocket transport '{protocol}' requires upgrader"
)
return None
# Use explicit WebsocketTransport to avoid type issues
WebsocketTransport = _get_websocket_transport()
return WebsocketTransport(
upgrader,
tls_client_config=kwargs.get("tls_client_config"),
tls_server_config=kwargs.get("tls_server_config"),
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
)
elif protocol in ["quic", "quic-v1"]:
# QUIC transport requires private_key
private_key = kwargs.get("private_key")
if private_key is None:
logger.warning(f"QUIC transport '{protocol}' requires private_key")
return None
# Use explicit QUICTransport to avoid type issues
QUICTransport = _get_quic_transport()
config = kwargs.get("config")
return QUICTransport(private_key, config)
else:
# TCP transport doesn't require upgrader
return transport_class()
except Exception as e:
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
return None
# Global transport registry instance (lazy initialization)
_global_registry: TransportRegistry | None = None
def get_transport_registry() -> TransportRegistry:
"""Get the global transport registry instance."""
global _global_registry
if _global_registry is None:
_global_registry = TransportRegistry()
return _global_registry
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
"""Register a transport class in the global registry."""
registry = get_transport_registry()
registry.register_transport(protocol, transport_class)
def create_transport_for_multiaddr(
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
) -> ITransport | None:
"""
Create the appropriate transport for a given multiaddr.
:param maddr: The multiaddr to create transport for
:param upgrader: The transport upgrader instance
:param kwargs: Additional arguments for transport construction
(e.g., private_key for QUIC)
:return: Transport instance or None if no suitable transport found
"""
try:
# Get all protocols in the multiaddr
protocols = [proto.name for proto in maddr.protocols()]
# Check for supported transport protocols in order of preference
# We need to validate that the multiaddr structure is valid for our transports
if "quic" in protocols or "quic-v1" in protocols:
# For QUIC, we need a valid structure like:
# /ip4/127.0.0.1/udp/4001/quic
# /ip4/127.0.0.1/udp/4001/quic-v1
is_quic_multiaddr = _get_quic_validation()
if is_quic_multiaddr(maddr):
# Determine QUIC version
registry = get_transport_registry()
if "quic-v1" in protocols:
return registry.create_transport("quic-v1", upgrader, **kwargs)
else:
return registry.create_transport("quic", upgrader, **kwargs)
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
# For WebSocket, we need a valid structure like:
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
# /ip4/127.0.0.1/tcp/8080/wss (secure)
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
if is_valid_websocket_multiaddr(maddr):
# Determine if this is a secure WebSocket connection
registry = get_transport_registry()
if "wss" in protocols or "tls" in protocols:
return registry.create_transport("wss", upgrader, **kwargs)
else:
return registry.create_transport("ws", upgrader, **kwargs)
elif "tcp" in protocols:
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
# Check if the multiaddr has proper TCP structure
if _is_valid_tcp_multiaddr(maddr):
registry = get_transport_registry()
return registry.create_transport("tcp", upgrader)
# If no supported transport protocol found or structure is invalid, return None
logger.warning(
f"No supported transport protocol found or invalid structure in "
f"multiaddr: {maddr}"
)
return None
except Exception as e:
# Handle any errors gracefully (e.g., invalid multiaddr)
logger.warning(f"Error processing multiaddr {maddr}: {e}")
return None
def get_supported_transport_protocols() -> list[str]:
"""Get list of supported transport protocols from the global registry."""
registry = get_transport_registry()
return registry.get_supported_protocols()

View File

@ -0,0 +1,198 @@
import logging
import time
from typing import Any
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger(__name__)
class P2PWebSocketConnection(ReadWriteCloser):
"""
Wraps a WebSocketConnection to provide the raw stream interface
that libp2p protocols expect.
Implements production-ready buffer management and flow control
as recommended in the libp2p WebSocket specification.
"""
def __init__(
self,
ws_connection: Any,
ws_context: Any = None,
is_secure: bool = False,
max_buffered_amount: int = 4 * 1024 * 1024,
) -> None:
self._ws_connection = ws_connection
self._ws_context = ws_context
self._is_secure = is_secure
self._read_buffer = b""
self._read_lock = trio.Lock()
self._connection_start_time = time.time()
self._bytes_read = 0
self._bytes_written = 0
self._closed = False
self._close_lock = trio.Lock()
self._max_buffered_amount = max_buffered_amount
self._write_lock = trio.Lock()
async def write(self, data: bytes) -> None:
"""Write data with flow control and buffer management"""
if self._closed:
raise IOException("Connection is closed")
async with self._write_lock:
try:
logger.debug(f"WebSocket writing {len(data)} bytes")
# Check buffer amount for flow control
if hasattr(self._ws_connection, "bufferedAmount"):
buffered = self._ws_connection.bufferedAmount
if buffered > self._max_buffered_amount:
logger.warning(f"WebSocket buffer full: {buffered} bytes")
# In production, you might want to
# wait or implement backpressure
# For now, we'll continue but log the warning
# Send as a binary WebSocket message
await self._ws_connection.send_message(data)
self._bytes_written += len(data)
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
except Exception as e:
logger.error(f"WebSocket write failed: {e}")
self._closed = True
raise IOException from e
async def read(self, n: int | None = None) -> bytes:
"""
Read up to n bytes (if n is given), else read up to 64KiB.
This implementation provides byte-level access to WebSocket messages,
which is required for libp2p protocol compatibility.
For WebSocket compatibility with libp2p protocols, this method:
1. Buffers incoming WebSocket messages
2. Returns exactly the requested number of bytes when n is specified
3. Accumulates multiple WebSocket messages if needed to satisfy the request
4. Returns empty bytes (not raises) when connection is closed and no data
available
"""
if self._closed:
raise IOException("Connection is closed")
async with self._read_lock:
try:
# If n is None, read at least one message and return all buffered data
if n is None:
if not self._read_buffer:
try:
# Use a short timeout to avoid blocking indefinitely
with trio.fail_after(1.0): # 1 second timeout
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode("utf-8")
self._read_buffer = message
except trio.TooSlowError:
# No message available within timeout
return b""
except Exception:
# Return empty bytes if no data available
# (connection closed)
return b""
result = self._read_buffer
self._read_buffer = b""
self._bytes_read += len(result)
return result
# For specific byte count requests, return UP TO n bytes (not exactly n)
# This matches TCP semantics where read(1024) returns available data
# up to 1024 bytes
# If we don't have any data buffered, try to get at least one message
if not self._read_buffer:
try:
# Use a short timeout to avoid blocking indefinitely
with trio.fail_after(1.0): # 1 second timeout
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode("utf-8")
self._read_buffer = message
except trio.TooSlowError:
return b"" # No data available
except Exception:
return b""
# Now return up to n bytes from the buffer (TCP-like semantics)
if len(self._read_buffer) == 0:
return b""
# Return up to n bytes (like TCP read())
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[len(result) :]
self._bytes_read += len(result)
return result
except Exception as e:
logger.error(f"WebSocket read failed: {e}")
raise IOException from e
async def close(self) -> None:
"""Close the WebSocket connection. This method is idempotent."""
async with self._close_lock:
if self._closed:
return # Already closed
logger.debug("WebSocket connection closing")
self._closed = True
try:
# Always close the connection directly, avoid context manager issues
# The context manager may be causing cancel scope corruption
logger.debug("WebSocket closing connection directly")
await self._ws_connection.aclose()
# Exit the context manager if we have one
if self._ws_context is not None:
await self._ws_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"WebSocket close error: {e}")
# Don't raise here, as close() should be idempotent
finally:
logger.debug("WebSocket connection closed")
def is_closed(self) -> bool:
"""Check if the connection is closed"""
return self._closed
def conn_state(self) -> dict[str, Any]:
"""
Return connection state information similar to Go's ConnState() method.
:return: Dictionary containing connection state information
"""
current_time = time.time()
return {
"transport": "websocket",
"secure": self._is_secure,
"connection_duration": current_time - self._connection_start_time,
"bytes_read": self._bytes_read,
"bytes_written": self._bytes_written,
"total_bytes": self._bytes_read + self._bytes_written,
}
def get_remote_address(self) -> tuple[str, int] | None:
# Try to get remote address from the WebSocket connection
try:
remote = self._ws_connection.remote
if hasattr(remote, "address") and hasattr(remote, "port"):
return str(remote.address), int(remote.port)
elif isinstance(remote, str):
# Parse address:port format
if ":" in remote:
host, port = remote.rsplit(":", 1)
return host, int(port)
except Exception:
pass
return None

View File

@ -0,0 +1,225 @@
from collections.abc import Awaitable, Callable
import logging
import ssl
from typing import Any
from multiaddr import Multiaddr
import trio
from trio_typing import TaskStatus
from trio_websocket import serve_websocket
from libp2p.abc import IListener
from libp2p.custom_types import THandler
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
from .connection import P2PWebSocketConnection
logger = logging.getLogger("libp2p.transport.websocket.listener")
class WebsocketListener(IListener):
"""
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
"""
def __init__(
self,
handler: THandler,
upgrader: TransportUpgrader,
tls_config: ssl.SSLContext | None = None,
handshake_timeout: float = 15.0,
) -> None:
self._handler = handler
self._upgrader = upgrader
self._tls_config = tls_config
self._handshake_timeout = handshake_timeout
self._server = None
self._shutdown_event = trio.Event()
self._nursery: trio.Nursery | None = None
self._listeners: Any = None
self._is_wss = False # Track whether this is a WSS listener
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
logger.debug(f"WebsocketListener.listen called with {maddr}")
# Parse the WebSocket multiaddr to determine if it's secure
try:
parsed = parse_websocket_multiaddr(maddr)
except ValueError as e:
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
# Check if WSS is requested but no TLS config provided
if parsed.is_wss and self._tls_config is None:
raise ValueError(
f"Cannot listen on WSS address {maddr} without TLS configuration"
)
# Store whether this is a WSS listener
self._is_wss = parsed.is_wss
# Extract host and port from the base multiaddr
host = (
parsed.rest_multiaddr.value_for_protocol("ip4")
or parsed.rest_multiaddr.value_for_protocol("ip6")
or parsed.rest_multiaddr.value_for_protocol("dns")
or parsed.rest_multiaddr.value_for_protocol("dns4")
or parsed.rest_multiaddr.value_for_protocol("dns6")
or "0.0.0.0"
)
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
logger.debug(
f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}"
)
async def serve_websocket_tcp(
handler: Callable[[Any], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Any],
) -> None:
"""Start TCP server and handle WebSocket connections manually"""
logger.debug(
"serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss
)
async def websocket_handler(request: Any) -> None:
"""Handle WebSocket requests"""
logger.debug("WebSocket request received")
try:
# Apply handshake timeout
with trio.fail_after(self._handshake_timeout):
# Accept the WebSocket connection
ws_connection = await request.accept()
logger.debug("WebSocket handshake successful")
# Create the WebSocket connection wrapper
conn = P2PWebSocketConnection(
ws_connection, is_secure=parsed.is_wss
) # type: ignore[no-untyped-call]
# Call the handler function that was passed to create_listener
# This handler will handle the security and muxing upgrades
logger.debug("Calling connection handler")
await self._handler(conn)
# Don't keep the connection alive indefinitely
# Let the handler manage the connection lifecycle
logger.debug(
"Handler completed, connection will be managed by handler"
)
except trio.TooSlowError:
logger.debug(
f"WebSocket handshake timeout after {self._handshake_timeout}s"
)
try:
await request.reject(408) # Request Timeout
except Exception:
pass
except Exception as e:
logger.debug(f"WebSocket connection error: {e}")
logger.debug(f"Error type: {type(e)}")
import traceback
logger.debug(f"Traceback: {traceback.format_exc()}")
# Reject the connection
try:
await request.reject(400)
except Exception:
pass
# Use trio_websocket.serve_websocket for proper WebSocket handling
ssl_context = self._tls_config if parsed.is_wss else None
await serve_websocket(
websocket_handler, host, port, ssl_context, task_status=task_status
)
# Store the nursery for shutdown
self._nursery = nursery
# Start the server using nursery.start() like TCP does
logger.debug("Calling nursery.start()...")
started_listeners = await nursery.start(
serve_websocket_tcp,
None, # No handler needed since it's defined inside serve_websocket_tcp
port,
host,
)
logger.debug(f"nursery.start() returned: {started_listeners}")
if started_listeners is None:
logger.error(f"Failed to start WebSocket listener for {maddr}")
return False
# Store the listeners for get_addrs() and close() - these are real
# SocketListener objects
self._listeners = started_listeners
logger.debug(
"WebsocketListener.listen returning True with WebSocketServer object"
)
return True
def get_addrs(self) -> tuple[Multiaddr, ...]:
if not hasattr(self, "_listeners") or not self._listeners:
logger.debug("No listeners available for get_addrs()")
return ()
# Handle WebSocketServer objects
if hasattr(self._listeners, "port"):
# This is a WebSocketServer object
port = self._listeners.port
# Create a multiaddr from the port with correct WSS/WS protocol
protocol = "wss" if self._is_wss else "ws"
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),)
else:
# This is a list of listeners (like TCP)
listeners = self._listeners
# Get addresses from listeners like TCP does
return tuple(
_multiaddr_from_socket(listener.socket, self._is_wss)
for listener in listeners
)
async def close(self) -> None:
"""Close the WebSocket listener and stop accepting new connections"""
logger.debug("WebsocketListener.close called")
if hasattr(self, "_listeners") and self._listeners:
# Signal shutdown
self._shutdown_event.set()
# Close the WebSocket server
if hasattr(self._listeners, "aclose"):
# This is a WebSocketServer object
logger.debug("Closing WebSocket server")
await self._listeners.aclose()
logger.debug("WebSocket server closed")
elif isinstance(self._listeners, (list, tuple)):
# This is a list of listeners (like TCP)
logger.debug("Closing TCP listeners")
for listener in self._listeners:
await listener.aclose()
logger.debug("TCP listeners closed")
else:
# Unknown type, try to close it directly
logger.debug("Closing unknown listener type")
if hasattr(self._listeners, "close"):
self._listeners.close()
logger.debug("Unknown listener closed")
# Clear the listeners reference
self._listeners = None
logger.debug("WebsocketListener.close completed")
def _multiaddr_from_socket(
socket: trio.socket.SocketType, is_wss: bool = False
) -> Multiaddr:
"""Convert socket to multiaddr"""
ip, port = socket.getsockname()
protocol = "wss" if is_wss else "ws"
return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}")

View File

@ -0,0 +1,202 @@
"""
WebSocket multiaddr parsing utilities.
"""
from typing import NamedTuple
from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
class ParsedWebSocketMultiaddr(NamedTuple):
"""Parsed WebSocket multiaddr information."""
is_wss: bool
sni: str | None
rest_multiaddr: Multiaddr
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
"""
Parse a WebSocket multiaddr and extract security information.
:param maddr: The multiaddr to parse
:return: Parsed WebSocket multiaddr information
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
"""
# First validate that this is a valid WebSocket multiaddr
if not is_valid_websocket_multiaddr(maddr):
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
protocols = list(maddr.protocols())
# Find the WebSocket protocol and check for security
is_wss = False
sni = None
ws_index = -1
tls_index = -1
sni_index = -1
# Find protocol indices
for i, protocol in enumerate(protocols):
if protocol.name == "ws":
ws_index = i
elif protocol.name == "wss":
ws_index = i
is_wss = True
elif protocol.name == "tls":
tls_index = i
elif protocol.name == "sni":
sni_index = i
sni = protocol.value
if ws_index == -1:
raise ValueError("Not a WebSocket multiaddr")
# Handle /wss protocol (convert to /tls/ws internally)
if is_wss and tls_index == -1:
# Convert /wss to /tls/ws format
# Remove /wss to get the base multiaddr
without_wss = maddr.decapsulate(Multiaddr("/wss"))
return ParsedWebSocketMultiaddr(
is_wss=True, sni=None, rest_multiaddr=without_wss
)
# Handle /tls/ws and /tls/sni/.../ws formats
if tls_index != -1:
is_wss = True
# Extract the base multiaddr (everything before /tls)
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
# Use multiaddr methods to properly extract the base
rest_multiaddr = maddr
# Remove /tls/ws or /tls/sni/.../ws from the end
if sni_index != -1:
# /tls/sni/example.com/ws format
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
else:
# /tls/ws format
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
return ParsedWebSocketMultiaddr(
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
)
# Regular /ws multiaddr - remove /ws and any additional protocols
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
return ParsedWebSocketMultiaddr(
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
)
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid WebSocket structure.
:param maddr: The multiaddr to validate
:return: True if valid WebSocket structure, False otherwise
"""
try:
# WebSocket multiaddr should have structure like:
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
# /ip4/127.0.0.1/tcp/8080/wss (secure)
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
if len(protocols) < 3:
return False
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Check for valid WebSocket protocols
ws_protocols = ["ws", "wss"]
tls_protocols = ["tls"]
sni_protocols = ["sni"]
# Find the WebSocket protocol
ws_protocol_found = False
tls_found = False
# sni_found = False # Not used currently
for i, protocol in enumerate(protocols[2:], start=2):
if protocol.name in ws_protocols:
ws_protocol_found = True
break
elif protocol.name in tls_protocols:
tls_found = True
elif protocol.name in sni_protocols:
pass # sni_found = True # Not used in current implementation
if not ws_protocol_found:
return False
# Validate protocol sequence
# For /ws: network + tcp + ws
# For /wss: network + tcp + wss
# For /tls/ws: network + tcp + tls + ws
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
# Check if it's a simple /ws or /wss
if len(protocols) == 3:
return protocols[2].name in ["ws", "wss"]
# Check for /tls/ws or /tls/sni/.../ws patterns
if tls_found:
# Must end with /ws (not /wss when using /tls)
if protocols[-1].name != "ws":
return False
# Check for valid TLS sequence
tls_index = None
for i, protocol in enumerate(protocols[2:], start=2):
if protocol.name == "tls":
tls_index = i
break
if tls_index is None:
return False
# After tls, we can have sni, then ws
remaining_protocols = protocols[tls_index + 1 :]
if len(remaining_protocols) == 1:
# /tls/ws
return remaining_protocols[0].name == "ws"
elif len(remaining_protocols) == 2:
# /tls/sni/example.com/ws
return (
remaining_protocols[0].name == "sni"
and remaining_protocols[1].name == "ws"
)
else:
return False
# If we have more than 3 protocols but no TLS, check for valid continuations
# Allow additional protocols after the WebSocket protocol (like /p2p)
valid_continuations = ["p2p"]
# Find the WebSocket protocol index
ws_index = None
for i, protocol in enumerate(protocols):
if protocol.name in ["ws", "wss"]:
ws_index = i
break
if ws_index is not None:
# Check protocols after the WebSocket protocol
for i in range(ws_index + 1, len(protocols)):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False

View File

@ -0,0 +1,229 @@
import logging
import ssl
from multiaddr import Multiaddr
from libp2p.abc import IListener, ITransport
from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
from .connection import P2PWebSocketConnection
from .listener import WebsocketListener
logger = logging.getLogger(__name__)
class WebsocketTransport(ITransport):
"""
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
Implements production-ready WebSocket transport with:
- Flow control and buffer management
- Connection limits and rate limiting
- Proper error handling and cleanup
- Support for both WS and WSS protocols
- TLS configuration and handshake timeout
"""
def __init__(
self,
upgrader: TransportUpgrader,
tls_client_config: ssl.SSLContext | None = None,
tls_server_config: ssl.SSLContext | None = None,
handshake_timeout: float = 15.0,
max_buffered_amount: int = 4 * 1024 * 1024,
):
self._upgrader = upgrader
self._tls_client_config = tls_client_config
self._tls_server_config = tls_server_config
self._handshake_timeout = handshake_timeout
self._max_buffered_amount = max_buffered_amount
self._connection_count = 0
self._max_connections = 1000 # Production limit
async def dial(self, maddr: Multiaddr) -> RawConnection:
"""Dial a WebSocket connection to the given multiaddr."""
logger.debug(f"WebsocketTransport.dial called with {maddr}")
# Parse the WebSocket multiaddr to determine if it's secure
try:
parsed = parse_websocket_multiaddr(maddr)
except ValueError as e:
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
# Extract host and port from the base multiaddr
host = (
parsed.rest_multiaddr.value_for_protocol("ip4")
or parsed.rest_multiaddr.value_for_protocol("ip6")
or parsed.rest_multiaddr.value_for_protocol("dns")
or parsed.rest_multiaddr.value_for_protocol("dns4")
or parsed.rest_multiaddr.value_for_protocol("dns6")
)
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
# Build WebSocket URL based on security
if parsed.is_wss:
ws_url = f"wss://{host}:{port}/"
else:
ws_url = f"ws://{host}:{port}/"
logger.debug(
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
)
try:
# Check connection limits
if self._connection_count >= self._max_connections:
raise OpenConnectionError(
f"Maximum connections reached: {self._max_connections}"
)
# Prepare SSL context for WSS connections
ssl_context = None
if parsed.is_wss:
if self._tls_client_config:
ssl_context = self._tls_client_config
else:
# Create default SSL context for client
ssl_context = ssl.create_default_context()
# Set SNI if available
if parsed.sni:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
# Use a different approach: start background nursery that will persist
logger.debug("WebsocketTransport.dial establishing connection")
# Import trio-websocket functions
from trio_websocket import connect_websocket
from trio_websocket._impl import _url_to_host
# Parse the WebSocket URL to get host, port, resource
# like trio-websocket does
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
ws_url, ssl_context
)
logger.debug(
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
f"port={ws_port}, resource={ws_resource}"
)
# Create a background task manager for this connection
import trio
nursery_manager = trio.lowlevel.current_task().parent_nursery
if nursery_manager is None:
raise OpenConnectionError(
f"No parent nursery available for WebSocket connection to {maddr}"
)
# Apply timeout to the connection process
with trio.fail_after(self._handshake_timeout):
logger.debug("WebsocketTransport.dial connecting WebSocket")
ws = await connect_websocket(
nursery_manager, # Use the existing nursery from libp2p
ws_host,
ws_port,
ws_resource,
use_ssl=ws_ssl_context,
message_queue_size=1024, # Reasonable defaults
max_message_size=16 * 1024 * 1024, # 16MB max message
)
logger.debug("WebsocketTransport.dial WebSocket connection established")
# Create our connection wrapper with both WSS support and flow control
conn = P2PWebSocketConnection(
ws,
None,
is_secure=parsed.is_wss,
max_buffered_amount=self._max_buffered_amount,
)
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
self._connection_count += 1
logger.debug(f"Total connections: {self._connection_count}")
return RawConnection(conn, initiator=True)
except trio.TooSlowError as e:
raise OpenConnectionError(
f"WebSocket handshake timeout after {self._handshake_timeout}s "
f"for {maddr}"
) from e
except Exception as e:
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
"""
The type checker is incorrectly reporting this as an inconsistent override.
"""
logger.debug("WebsocketTransport.create_listener called")
return WebsocketListener(
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
)
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
"""
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
Similar to Go's Resolve() method.
:param maddr: The multiaddr to resolve
:return: List of resolved multiaddrs
"""
try:
parsed = parse_websocket_multiaddr(maddr)
except ValueError as e:
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
return [maddr] # Return original if not a valid WebSocket multiaddr
logger.debug(
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
)
if not parsed.is_wss:
# No /tls/ws component, this isn't a secure websocket multiaddr
return [maddr]
if parsed.sni is not None:
# Already has SNI, return as-is
return [maddr]
# Try to extract DNS name from the base multiaddr
dns_name = None
for protocol_name in ["dns", "dns4", "dns6"]:
try:
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
break
except Exception:
continue
if dns_name is None:
# No DNS name found, return original
return [maddr]
# Create new multiaddr with SNI
# For /dns/example.com/tcp/8080/wss ->
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
try:
# Remove /wss and add /tls/sni/example.com/ws
without_wss = maddr.decapsulate(Multiaddr("/wss"))
sni_component = Multiaddr(f"/sni/{dns_name}")
resolved = (
without_wss.encapsulate(Multiaddr("/tls"))
.encapsulate(sni_component)
.encapsulate(Multiaddr("/ws"))
)
logger.debug(f"Resolved {maddr} to {resolved}")
return [resolved]
except Exception as e:
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
return [maddr]

View File

@ -3,38 +3,24 @@ from __future__ import annotations
import socket
from multiaddr import Multiaddr
try:
from multiaddr.utils import ( # type: ignore
get_network_addrs,
get_thin_waist_addresses,
)
_HAS_THIN_WAIST = True
except ImportError: # pragma: no cover - only executed in older environments
_HAS_THIN_WAIST = False
get_thin_waist_addresses = None # type: ignore
get_network_addrs = None # type: ignore
from multiaddr.utils import get_network_addrs, get_thin_waist_addresses
def _safe_get_network_addrs(ip_version: int) -> list[str]:
"""
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
Falls back to minimal defaults when Thin Waist helpers are missing.
:param ip_version: 4 or 6
"""
if _HAS_THIN_WAIST and get_network_addrs:
try:
return get_network_addrs(ip_version) or []
except Exception: # pragma: no cover - defensive
return []
# Fallback behavior (very conservative)
if ip_version == 4:
return ["127.0.0.1"]
if ip_version == 6:
return ["::1"]
return []
try:
return get_network_addrs(ip_version) or []
except Exception: # pragma: no cover - defensive
# Fallback behavior (very conservative)
if ip_version == 4:
return ["127.0.0.1"]
if ip_version == 6:
return ["::1"]
return []
def find_free_port() -> int:
@ -47,16 +33,13 @@ def find_free_port() -> int:
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
"""
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
If Thin Waist isn't available, returns [addr] (identity).
"""
if _HAS_THIN_WAIST and get_thin_waist_addresses:
try:
if port is not None:
return get_thin_waist_addresses(addr, port=port) or []
return get_thin_waist_addresses(addr) or []
except Exception: # pragma: no cover - defensive
return [addr]
return [addr]
try:
if port is not None:
return get_thin_waist_addresses(addr, port=port) or []
return get_thin_waist_addresses(addr) or []
except Exception: # pragma: no cover - defensive
return [addr]
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
@ -73,8 +56,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
seen_v4: set[str] = set()
for ip in _safe_get_network_addrs(4):
seen_v4.add(ip)
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
if ip not in seen_v4: # Avoid duplicates
seen_v4.add(ip)
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
if seen_v4 and "127.0.0.1" not in seen_v4:
@ -89,8 +73,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
#
# seen_v6: set[str] = set()
# for ip in _safe_get_network_addrs(6):
# seen_v6.add(ip)
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
# if ip not in seen_v6: # Avoid duplicates
# seen_v6.add(ip)
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
#
# # Always include IPv6 loopback for testing purposes when IPv6 is available
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
@ -99,7 +84,7 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
# Fallback if nothing discovered
if not addrs:
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
return addrs
@ -120,6 +105,20 @@ def expand_wildcard_address(
return expanded
def get_wildcard_address(port: int, protocol: str = "tcp") -> Multiaddr:
"""
Get wildcard address (0.0.0.0) when explicitly needed.
This function provides access to wildcard binding as a feature when
explicitly required, preserving the ability to bind to all interfaces.
:param port: Port number.
:param protocol: Transport protocol.
:return: A Multiaddr with wildcard binding (0.0.0.0).
"""
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
"""
Choose an optimal address for an example to bind to:
@ -148,13 +147,14 @@ def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
return c
# As a final fallback, produce a wildcard
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
# As a final fallback, produce a loopback address
return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")
__all__ = [
"get_available_interfaces",
"get_optimal_binding_address",
"get_wildcard_address",
"expand_wildcard_address",
"find_free_port",
]