Merge branch 'main' into chore01

This commit is contained in:
Manu Sheel Gupta
2025-09-24 22:40:38 +05:30
committed by GitHub
119 changed files with 16551 additions and 520 deletions

View File

@ -1,5 +1,12 @@
"""Libp2p Python implementation."""
import logging
import ssl
from libp2p.transport.quic.utils import is_quic_multiaddr
from typing import Any
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.quic.config import QUICTransportConfig
from collections.abc import (
Mapping,
Sequence,
@ -18,6 +25,7 @@ from libp2p.abc import (
IPeerRouting,
IPeerStore,
ISecureTransport,
ITransport,
)
from libp2p.crypto.keys import (
KeyPair,
@ -38,10 +46,12 @@ from libp2p.host.routed_host import (
RoutedHost,
)
from libp2p.network.swarm import (
ConnectionConfig,
RetryConfig,
Swarm,
)
from libp2p.network.config import (
ConnectionConfig,
RetryConfig
)
from libp2p.peer.id import (
ID,
)
@ -72,6 +82,10 @@ from libp2p.transport.tcp.tcp import (
from libp2p.transport.upgrader import (
TransportUpgrader,
)
from libp2p.transport.transport_registry import (
create_transport_for_multiaddr,
get_supported_transport_protocols,
)
from libp2p.utils.logging import (
setup_logging,
)
@ -87,6 +101,7 @@ MUXER_YAMUX = "YAMUX"
MUXER_MPLEX = "MPLEX"
DEFAULT_NEGOTIATE_TIMEOUT = 5
logger = logging.getLogger(__name__)
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
"""
@ -162,9 +177,13 @@ def new_swarm(
peerstore_opt: IPeerStore | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
enable_quic: bool = False,
retry_config: Optional["RetryConfig"] = None,
connection_config: Optional["ConnectionConfig"] = None,
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
tls_client_config: ssl.SSLContext | None = None,
tls_server_config: ssl.SSLContext | None = None,
) -> INetworkService:
logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}")
"""
Create a swarm instance based on the parameters.
@ -174,6 +193,8 @@ def new_swarm(
:param peerstore_opt: optional peerstore
:param muxer_preference: optional explicit muxer preference
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_quic: enable quic for transport
:param quic_transport_opt: options for transport
:return: return a default swarm instance
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
@ -186,16 +207,48 @@ def new_swarm(
id_opt = generate_peer_id_from(key_pair)
transport: TCP | QUICTransport | ITransport
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
if listen_addrs is None:
transport = TCP()
else:
addr = listen_addrs[0]
if addr.__contains__("tcp"):
transport = TCP()
elif addr.__contains__("quic"):
raise ValueError("QUIC not yet supported")
if enable_quic:
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
else:
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
transport = TCP()
else:
# Use transport registry to select the appropriate transport
from libp2p.transport.transport_registry import create_transport_for_multiaddr
# Create a temporary upgrader for transport selection
# We'll create the real upgrader later with the proper configuration
temp_upgrader = TransportUpgrader(
secure_transports_by_protocol={},
muxer_transports_by_protocol={}
)
addr = listen_addrs[0]
logger.debug(f"new_swarm: Creating transport for address: {addr}")
transport_maybe = create_transport_for_multiaddr(
addr,
temp_upgrader,
private_key=key_pair.private_key,
config=quic_transport_opt,
tls_client_config=tls_client_config,
tls_server_config=tls_server_config
)
if transport_maybe is None:
raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}")
transport = transport_maybe
logger.debug(f"new_swarm: Created transport: {type(transport)}")
# If enable_quic is True but we didn't get a QUIC transport, force QUIC
if enable_quic and not isinstance(transport, QUICTransport):
logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})")
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
logger.debug(f"new_swarm: Final transport type: {type(transport)}")
# Generate X25519 keypair for Noise
noise_key_pair = create_new_x25519_key_pair()
@ -236,6 +289,7 @@ def new_swarm(
muxer_transports_by_protocol=muxer_transports_by_protocol,
)
peerstore = peerstore_opt or PeerStore()
# Store our key pair in peerstore
peerstore.add_key_pair(id_opt, key_pair)
@ -261,6 +315,10 @@ def new_host(
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
enable_quic: bool = False,
quic_transport_opt: QUICTransportConfig | None = None,
tls_client_config: ssl.SSLContext | None = None,
tls_server_config: ssl.SSLContext | None = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.
@ -274,15 +332,27 @@ def new_host(
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_mDNS: whether to enable mDNS discovery
:param bootstrap: optional list of bootstrap peer addresses as strings
:param enable_quic: optinal choice to use QUIC for transport
:param quic_transport_opt: optional configuration for quic transport
:param tls_client_config: optional TLS client configuration for WebSocket transport
:param tls_server_config: optional TLS server configuration for WebSocket transport
:return: return a host instance
"""
if not enable_quic and quic_transport_opt is not None:
logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config")
swarm = new_swarm(
enable_quic=enable_quic,
key_pair=key_pair,
muxer_opt=muxer_opt,
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
muxer_preference=muxer_preference,
listen_addrs=listen_addrs,
connection_config=quic_transport_opt if enable_quic else None,
tls_client_config=tls_client_config,
tls_server_config=tls_server_config
)
if disc_opt is not None:

View File

@ -5,17 +5,17 @@ from collections.abc import (
)
from typing import TYPE_CHECKING, NewType, Union, cast
from libp2p.transport.quic.stream import QUICStream
if TYPE_CHECKING:
from libp2p.abc import (
IMuxedConn,
INetStream,
ISecureTransport,
)
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
from libp2p.transport.quic.connection import QUICConnection
else:
IMuxedConn = cast(type, object)
INetStream = cast(type, object)
ISecureTransport = cast(type, object)
IMuxedStream = cast(type, object)
QUICConnection = cast(type, object)
from libp2p.io.abc import (
ReadWriteCloser,
@ -37,3 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]]
MessageID = NewType("MessageID", str)

View File

@ -213,7 +213,6 @@ class BasicHost(IHost):
self,
peer_id: ID,
protocol_ids: Sequence[TProtocol],
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> INetStream:
"""
:param peer_id: peer_id that host is connecting
@ -227,7 +226,7 @@ class BasicHost(IHost):
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids),
MultiselectCommunicator(net_stream),
negotitate_timeout,
self.negotiate_timeout,
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)

70
libp2p/network/config.py Normal file
View File

@ -0,0 +1,70 @@
from dataclasses import dataclass
@dataclass
class RetryConfig:
"""
Configuration for retry logic with exponential backoff.
This configuration controls how connection attempts are retried when they fail.
The retry mechanism uses exponential backoff with jitter to prevent thundering
herd problems in distributed systems.
Attributes:
max_retries: Maximum number of retry attempts before giving up.
Default: 3 attempts
initial_delay: Initial delay in seconds before the first retry.
Default: 0.1 seconds (100ms)
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
Default: 30.0 seconds
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
the delay by this factor). Default: 2.0 (doubles each time)
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
and prevent synchronized retries. Default: 0.1 (10% jitter)
"""
max_retries: int = 3
initial_delay: float = 0.1
max_delay: float = 30.0
backoff_multiplier: float = 2.0
jitter_factor: float = 0.1
@dataclass
class ConnectionConfig:
"""
Configuration for multi-connection support.
This configuration controls how multiple connections per peer are managed,
including connection limits, timeouts, and load balancing strategies.
Attributes:
max_connections_per_peer: Maximum number of connections allowed to a single
peer. Default: 3 connections
connection_timeout: Timeout in seconds for establishing new connections.
Default: 30.0 seconds
load_balancing_strategy: Strategy for distributing streams across connections.
Options: "round_robin" (default) or "least_loaded"
"""
max_connections_per_peer: int = 3
connection_timeout: float = 30.0
load_balancing_strategy: str = "round_robin" # or "least_loaded"
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not (
self.load_balancing_strategy == "round_robin"
or self.load_balancing_strategy == "least_loaded"
):
raise ValueError(
"Load balancing strategy can only be 'round_robin' or 'least_loaded'"
)
if self.max_connections_per_peer < 1:
raise ValueError("Max connection per peer should be atleast 1")
if self.connection_timeout < 0:
raise ValueError("Connection timeout should be positive")

View File

@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamError,
MuxedStreamReset,
)
from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError
from .exceptions import (
StreamClosed,
@ -170,7 +171,7 @@ class NetStream(INetStream):
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_READ
raise StreamEOF() from error
except MuxedStreamReset as error:
except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
@ -199,7 +200,12 @@ class NetStream(INetStream):
try:
await self.muxed_stream.write(data)
except (MuxedStreamClosed, MuxedStreamError) as error:
except (
MuxedStreamClosed,
MuxedStreamError,
QUICStreamClosedError,
QUICStreamResetError,
) as error:
async with self._state_lock:
if self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE

View File

@ -2,9 +2,9 @@ from collections.abc import (
Awaitable,
Callable,
)
from dataclasses import dataclass
import logging
import random
from typing import cast
from multiaddr import (
Multiaddr,
@ -27,6 +27,7 @@ from libp2p.custom_types import (
from libp2p.io.abc import (
ReadWriteCloser,
)
from libp2p.network.config import ConnectionConfig, RetryConfig
from libp2p.peer.id import (
ID,
)
@ -41,6 +42,9 @@ from libp2p.transport.exceptions import (
OpenConnectionError,
SecurityUpgradeFailure,
)
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.upgrader import (
TransportUpgrader,
)
@ -61,59 +65,6 @@ from .exceptions import (
logger = logging.getLogger("libp2p.network.swarm")
@dataclass
class RetryConfig:
"""
Configuration for retry logic with exponential backoff.
This configuration controls how connection attempts are retried when they fail.
The retry mechanism uses exponential backoff with jitter to prevent thundering
herd problems in distributed systems.
Attributes:
max_retries: Maximum number of retry attempts before giving up.
Default: 3 attempts
initial_delay: Initial delay in seconds before the first retry.
Default: 0.1 seconds (100ms)
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
Default: 30.0 seconds
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
the delay by this factor). Default: 2.0 (doubles each time)
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
and prevent synchronized retries. Default: 0.1 (10% jitter)
"""
max_retries: int = 3
initial_delay: float = 0.1
max_delay: float = 30.0
backoff_multiplier: float = 2.0
jitter_factor: float = 0.1
@dataclass
class ConnectionConfig:
"""
Configuration for multi-connection support.
This configuration controls how multiple connections per peer are managed,
including connection limits, timeouts, and load balancing strategies.
Attributes:
max_connections_per_peer: Maximum number of connections allowed to a single
peer. Default: 3 connections
connection_timeout: Timeout in seconds for establishing new connections.
Default: 30.0 seconds
load_balancing_strategy: Strategy for distributing streams across connections.
Options: "round_robin" (default) or "least_loaded"
"""
max_connections_per_peer: int = 3
connection_timeout: float = 30.0
load_balancing_strategy: str = "round_robin" # or "least_loaded"
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
async def stream_handler(stream: INetStream) -> None:
await network.get_manager().wait_finished()
@ -126,8 +77,7 @@ class Swarm(Service, INetworkService):
peerstore: IPeerStore
upgrader: TransportUpgrader
transport: ITransport
# Enhanced: Support for multiple connections per peer
connections: dict[ID, list[INetConn]] # Multiple connections per peer
connections: dict[ID, list[INetConn]]
listeners: dict[str, IListener]
common_stream_handler: StreamHandlerFn
listener_nursery: trio.Nursery | None
@ -137,7 +87,7 @@ class Swarm(Service, INetworkService):
# Enhanced: New configuration
retry_config: RetryConfig
connection_config: ConnectionConfig
connection_config: ConnectionConfig | QUICTransportConfig
_round_robin_index: dict[ID, int]
def __init__(
@ -147,7 +97,7 @@ class Swarm(Service, INetworkService):
upgrader: TransportUpgrader,
transport: ITransport,
retry_config: RetryConfig | None = None,
connection_config: ConnectionConfig | None = None,
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
):
self.self_id = peer_id
self.peerstore = peerstore
@ -178,6 +128,11 @@ class Swarm(Service, INetworkService):
# Create a nursery for listener tasks.
self.listener_nursery = nursery
self.event_listener_nursery_created.set()
if isinstance(self.transport, QUICTransport):
self.transport.set_background_nursery(nursery)
self.transport.set_swarm(self)
try:
await self.manager.wait_finished()
finally:
@ -370,6 +325,7 @@ class Swarm(Service, INetworkService):
# Dial peer (connection to peer does not yet exist)
# Transport dials peer (gets back a raw conn)
try:
addr = Multiaddr(f"{addr}/p2p/{peer_id}")
raw_conn = await self.transport.dial(addr)
except OpenConnectionError as error:
logger.debug("fail to dial peer %s over base transport", peer_id)
@ -377,6 +333,15 @@ class Swarm(Service, INetworkService):
f"fail to open connection to peer {peer_id}"
) from error
if isinstance(self.transport, QUICTransport) and isinstance(
raw_conn, IMuxedConn
):
logger.info(
"Skipping upgrade for QUIC, QUIC connections are already multiplexed"
)
swarm_conn = await self.add_conn(raw_conn)
return swarm_conn
logger.debug("dialed peer %s over base transport", peer_id)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
@ -402,9 +367,7 @@ class Swarm(Service, INetworkService):
logger.debug("upgraded mux for peer %s", peer_id)
swarm_conn = await self.add_conn(muxed_conn)
logger.debug("successfully dialed peer %s", peer_id)
return swarm_conn
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
@ -427,7 +390,6 @@ class Swarm(Service, INetworkService):
:return: net stream instance
"""
logger.debug("attempting to open a stream to peer %s", peer_id)
# Get existing connections or dial new ones
connections = self.get_connections(peer_id)
if not connections:
@ -436,6 +398,10 @@ class Swarm(Service, INetworkService):
# Load balancing strategy at interface level
connection = self._select_connection(connections, peer_id)
if isinstance(self.transport, QUICTransport) and connection is not None:
conn = cast(SwarmConn, connection)
return await conn.new_stream()
try:
net_stream = await connection.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
@ -515,18 +481,38 @@ class Swarm(Service, INetworkService):
- Call listener listen with the multiaddr
- Map multiaddr to listener
"""
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
# We need to wait until `self.listener_nursery` is created.
logger.debug("Starting to listen")
await self.event_listener_nursery_created.wait()
success_count = 0
for maddr in multiaddrs:
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
if str(maddr) in self.listeners:
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
success_count += 1
continue
async def conn_handler(
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
) -> None:
# No need to upgrade QUIC Connection
if isinstance(self.transport, QUICTransport):
try:
quic_conn = cast(QUICConnection, read_write_closer)
await self.add_conn(quic_conn)
peer_id = quic_conn.peer_id
logger.debug(
f"successfully opened quic connection to peer {peer_id}"
)
# NOTE: This is a intentional barrier to prevent from the
# handler exiting and closing the connection.
await self.manager.wait_finished()
except Exception:
await read_write_closer.close()
return
raw_conn = RawConnection(read_write_closer, False)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
@ -562,13 +548,18 @@ class Swarm(Service, INetworkService):
try:
# Success
logger.debug(f"Swarm.listen: creating listener for {maddr}")
listener = self.transport.create_listener(conn_handler)
logger.debug(f"Swarm.listen: listener created for {maddr}")
self.listeners[str(maddr)] = listener
# TODO: `listener.listen` is not bounded with nursery. If we want to be
# I/O agnostic, we should change the API.
if self.listener_nursery is None:
raise SwarmException("swarm instance hasn't been run")
assert self.listener_nursery is not None # For type checker
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
await listener.listen(maddr, self.listener_nursery)
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
# Call notifiers since event occurred
await self.notify_listen(maddr)
@ -660,9 +651,10 @@ class Swarm(Service, INetworkService):
muxed_conn,
self,
)
logger.debug("Swarm::add_conn | starting muxed connection")
self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
logger.debug("Swarm::add_conn | starting swarm connection")
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()

View File

@ -1,3 +1,5 @@
from builtins import AssertionError
from libp2p.abc import (
IMultiselectCommunicator,
)
@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
msg_bytes = encode_delim(msg_str.encode())
try:
await self.read_writer.write(msg_bytes)
except IOException as error:
# Handle for connection close during ongoing negotiation in QUIC
except (IOException, AssertionError, ValueError) as error:
raise MultiselectCommunicatorError(
"fail to write to multiselect communicator"
) from error

View File

@ -1,6 +1,3 @@
from ast import (
literal_eval,
)
from collections import (
defaultdict,
)
@ -22,6 +19,7 @@ from libp2p.abc import (
IPubsubRouter,
)
from libp2p.custom_types import (
MessageID,
TProtocol,
)
from libp2p.peer.id import (
@ -56,6 +54,10 @@ from .pb import (
from .pubsub import (
Pubsub,
)
from .utils import (
parse_message_id_safe,
safe_parse_message_id,
)
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
@ -306,7 +308,8 @@ class GossipSub(IPubsubRouter, Service):
floodsub_peers: set[ID] = {
peer_id
for peer_id in self.pubsub.peer_topics[topic]
if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
if peer_id in self.peer_protocol
and self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
}
send_to.update(floodsub_peers)
@ -794,8 +797,8 @@ class GossipSub(IPubsubRouter, Service):
# Add all unknown message ids (ids that appear in ihave_msg but not in
# seen_seqnos) to list of messages we want to request
msg_ids_wanted: list[str] = [
msg_id
msg_ids_wanted: list[MessageID] = [
parse_message_id_safe(msg_id)
for msg_id in ihave_msg.messageIDs
if msg_id not in seen_seqnos_and_peers
]
@ -811,9 +814,9 @@ class GossipSub(IPubsubRouter, Service):
Forwards all request messages that are present in mcache to the
requesting peer.
"""
# FIXME: Update type of message ID
# FIXME: Find a better way to parse the msg ids
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
msg_ids: list[tuple[bytes, bytes]] = [
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
]
msgs_to_forward: list[rpc_pb2.Message] = []
for msg_id_iwant in msg_ids:
# Check if the wanted message ID is present in mcache

View File

@ -1,6 +1,10 @@
import ast
import logging
from libp2p.abc import IHost
from libp2p.custom_types import (
MessageID,
)
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import ID
from libp2p.pubsub.pb.rpc_pb2 import RPC
@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
return True
def parse_message_id_safe(msg_id_str: str) -> MessageID:
"""Safely handle message ID as string."""
return MessageID(msg_id_str)
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
"""
Safely parse message ID using ast.literal_eval with validation.
:param msg_id_str: String representation of message ID
:return: Tuple of (seqno, from_id) as bytes
:raises ValueError: If parsing fails
"""
try:
parsed = ast.literal_eval(msg_id_str)
if not isinstance(parsed, tuple) or len(parsed) != 2:
raise ValueError("Invalid message ID format")
seqno, from_id = parsed
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
raise ValueError("Message ID components must be bytes")
return (seqno, from_id)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Invalid message ID format: {e}")

View File

@ -9,6 +9,7 @@ from dataclasses import (
dataclass,
field,
)
from enum import Flag, auto
from libp2p.peer.peerinfo import (
PeerInfo,
@ -18,29 +19,118 @@ from .resources import (
RelayLimits,
)
DEFAULT_MIN_RELAYS = 3
DEFAULT_MAX_RELAYS = 20
DEFAULT_DISCOVERY_INTERVAL = 300 # seconds
DEFAULT_RESERVATION_TTL = 3600 # seconds
DEFAULT_MAX_CIRCUIT_DURATION = 3600 # seconds
DEFAULT_MAX_CIRCUIT_BYTES = 1024 * 1024 * 1024 # 1GB
DEFAULT_MAX_CIRCUIT_CONNS = 8
DEFAULT_MAX_RESERVATIONS = 4
MAX_RESERVATIONS_PER_IP = 8
MAX_CIRCUITS_PER_IP = 16
RESERVATION_RATE_PER_IP = 4 # per minute
CIRCUIT_RATE_PER_IP = 8 # per minute
MAX_CIRCUITS_TOTAL = 64
MAX_RESERVATIONS_TOTAL = 32
MAX_BANDWIDTH_PER_CIRCUIT = 1024 * 1024 # 1MB/s
MAX_BANDWIDTH_TOTAL = 10 * 1024 * 1024 # 10MB/s
MIN_RELAY_SCORE = 0.5
MAX_RELAY_LATENCY = 1.0 # seconds
ENABLE_AUTO_RELAY = True
AUTO_RELAY_TIMEOUT = 30 # seconds
MAX_AUTO_RELAY_ATTEMPTS = 3
RESERVATION_REFRESH_THRESHOLD = 0.8 # Refresh at 80% of TTL
MAX_CONCURRENT_RESERVATIONS = 2
# Timeout constants for different components
DEFAULT_DISCOVERY_STREAM_TIMEOUT = 10 # seconds
DEFAULT_PEER_PROTOCOL_TIMEOUT = 5 # seconds
DEFAULT_PROTOCOL_READ_TIMEOUT = 15 # seconds
DEFAULT_PROTOCOL_WRITE_TIMEOUT = 15 # seconds
DEFAULT_PROTOCOL_CLOSE_TIMEOUT = 10 # seconds
DEFAULT_DCUTR_READ_TIMEOUT = 30 # seconds
DEFAULT_DCUTR_WRITE_TIMEOUT = 30 # seconds
DEFAULT_DIAL_TIMEOUT = 10 # seconds
@dataclass
class TimeoutConfig:
"""Timeout configuration for different Circuit Relay v2 components."""
# Discovery timeouts
discovery_stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT
peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT
# Core protocol timeouts
protocol_read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT
protocol_write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT
protocol_close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT
# DCUtR timeouts
dcutr_read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT
dcutr_write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT
dial_timeout: int = DEFAULT_DIAL_TIMEOUT
# Relay roles enum
class RelayRole(Flag):
"""
Bit-flag enum that captures the three possible relay capabilities.
A node can combine multiple roles using bit-wise OR, for example::
RelayRole.HOP | RelayRole.STOP
"""
HOP = auto() # Act as a relay for others ("hop")
STOP = auto() # Accept relayed connections ("stop")
CLIENT = auto() # Dial through existing relays ("client")
@dataclass
class RelayConfig:
"""Configuration for Circuit Relay v2."""
# Role configuration
enable_hop: bool = False # Whether to act as a relay (hop)
enable_stop: bool = True # Whether to accept relayed connections (stop)
enable_client: bool = True # Whether to use relays for dialing
# Role configuration (bit-flags)
roles: RelayRole = RelayRole.STOP | RelayRole.CLIENT
# Resource limits
limits: RelayLimits | None = None
# Discovery configuration
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
min_relays: int = 3
max_relays: int = 20
discovery_interval: int = 300 # seconds
min_relays: int = DEFAULT_MIN_RELAYS
max_relays: int = DEFAULT_MAX_RELAYS
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL
# Connection configuration
reservation_ttl: int = 3600 # seconds
max_circuit_duration: int = 3600 # seconds
max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB
reservation_ttl: int = DEFAULT_RESERVATION_TTL
max_circuit_duration: int = DEFAULT_MAX_CIRCUIT_DURATION
max_circuit_bytes: int = DEFAULT_MAX_CIRCUIT_BYTES
# Timeout configuration
timeouts: TimeoutConfig = field(default_factory=TimeoutConfig)
# ---------------------------------------------------------------------
# Backwards-compat boolean helpers. Existing code that still accesses
# ``cfg.enable_hop, cfg.enable_stop, cfg.enable_client`` will continue to work.
# ---------------------------------------------------------------------
@property
def enable_hop(self) -> bool: # pragma: no cover helper
return bool(self.roles & RelayRole.HOP)
@property
def enable_stop(self) -> bool: # pragma: no cover helper
return bool(self.roles & RelayRole.STOP)
@property
def enable_client(self) -> bool: # pragma: no cover helper
return bool(self.roles & RelayRole.CLIENT)
def __post_init__(self) -> None:
"""Initialize default values."""
@ -48,8 +138,8 @@ class RelayConfig:
self.limits = RelayLimits(
duration=self.max_circuit_duration,
data=self.max_circuit_bytes,
max_circuit_conns=8,
max_reservations=4,
max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS,
max_reservations=DEFAULT_MAX_RESERVATIONS,
)
@ -58,20 +148,20 @@ class HopConfig:
"""Configuration specific to relay (hop) nodes."""
# Resource limits per IP
max_reservations_per_ip: int = 8
max_circuits_per_ip: int = 16
max_reservations_per_ip: int = MAX_RESERVATIONS_PER_IP
max_circuits_per_ip: int = MAX_CIRCUITS_PER_IP
# Rate limiting
reservation_rate_per_ip: int = 4 # per minute
circuit_rate_per_ip: int = 8 # per minute
reservation_rate_per_ip: int = RESERVATION_RATE_PER_IP
circuit_rate_per_ip: int = CIRCUIT_RATE_PER_IP
# Resource quotas
max_circuits_total: int = 64
max_reservations_total: int = 32
max_circuits_total: int = MAX_CIRCUITS_TOTAL
max_reservations_total: int = MAX_RESERVATIONS_TOTAL
# Bandwidth limits
max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s
max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s
max_bandwidth_per_circuit: int = MAX_BANDWIDTH_PER_CIRCUIT
max_bandwidth_total: int = MAX_BANDWIDTH_TOTAL
@dataclass
@ -79,14 +169,14 @@ class ClientConfig:
"""Configuration specific to relay clients."""
# Relay selection
min_relay_score: float = 0.5
max_relay_latency: float = 1.0 # seconds
min_relay_score: float = MIN_RELAY_SCORE
max_relay_latency: float = MAX_RELAY_LATENCY
# Auto-relay settings
enable_auto_relay: bool = True
auto_relay_timeout: int = 30 # seconds
max_auto_relay_attempts: int = 3
enable_auto_relay: bool = ENABLE_AUTO_RELAY
auto_relay_timeout: int = AUTO_RELAY_TIMEOUT
max_auto_relay_attempts: int = MAX_AUTO_RELAY_ATTEMPTS
# Reservation management
reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL
max_concurrent_reservations: int = 2
reservation_refresh_threshold: float = RESERVATION_REFRESH_THRESHOLD
max_concurrent_reservations: int = MAX_CONCURRENT_RESERVATIONS

View File

@ -29,6 +29,11 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.relay.circuit_v2.config import (
DEFAULT_DCUTR_READ_TIMEOUT,
DEFAULT_DCUTR_WRITE_TIMEOUT,
DEFAULT_DIAL_TIMEOUT,
)
from libp2p.relay.circuit_v2.nat import (
ReachabilityChecker,
)
@ -47,11 +52,7 @@ PROTOCOL_ID = TProtocol("/libp2p/dcutr")
# Maximum message size for DCUtR (4KiB as per spec)
MAX_MESSAGE_SIZE = 4 * 1024
# Timeouts
STREAM_READ_TIMEOUT = 30 # seconds
STREAM_WRITE_TIMEOUT = 30 # seconds
DIAL_TIMEOUT = 10 # seconds
# DCUtR protocol constants
# Maximum number of hole punch attempts per peer
MAX_HOLE_PUNCH_ATTEMPTS = 5
@ -70,7 +71,13 @@ class DCUtRProtocol(Service):
hole punching, after they have established an initial connection through a relay.
"""
def __init__(self, host: IHost):
def __init__(
self,
host: IHost,
read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT,
write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT,
dial_timeout: int = DEFAULT_DIAL_TIMEOUT,
):
"""
Initialize the DCUtR protocol.
@ -78,10 +85,19 @@ class DCUtRProtocol(Service):
----------
host : IHost
The libp2p host this protocol is running on
read_timeout : int
Timeout for stream read operations, in seconds
write_timeout : int
Timeout for stream write operations, in seconds
dial_timeout : int
Timeout for dial operations, in seconds
"""
super().__init__()
self.host = host
self.read_timeout = read_timeout
self.write_timeout = write_timeout
self.dial_timeout = dial_timeout
self.event_started = trio.Event()
self._hole_punch_attempts: dict[ID, int] = {}
self._direct_connections: set[ID] = set()
@ -161,7 +177,7 @@ class DCUtRProtocol(Service):
try:
# Read the CONNECT message
with trio.fail_after(STREAM_READ_TIMEOUT):
with trio.fail_after(self.read_timeout):
msg_bytes = await stream.read(MAX_MESSAGE_SIZE)
# Parse the message
@ -196,7 +212,7 @@ class DCUtRProtocol(Service):
response.type = HolePunch.CONNECT
response.ObsAddrs.extend(our_addrs)
with trio.fail_after(STREAM_WRITE_TIMEOUT):
with trio.fail_after(self.write_timeout):
await stream.write(response.SerializeToString())
logger.debug(
@ -206,7 +222,7 @@ class DCUtRProtocol(Service):
)
# Wait for SYNC message
with trio.fail_after(STREAM_READ_TIMEOUT):
with trio.fail_after(self.read_timeout):
sync_bytes = await stream.read(MAX_MESSAGE_SIZE)
# Parse the SYNC message
@ -300,7 +316,7 @@ class DCUtRProtocol(Service):
connect_msg.ObsAddrs.extend(our_addrs)
start_time = time.time()
with trio.fail_after(STREAM_WRITE_TIMEOUT):
with trio.fail_after(self.write_timeout):
await stream.write(connect_msg.SerializeToString())
logger.debug(
@ -310,7 +326,7 @@ class DCUtRProtocol(Service):
)
# Receive the peer's CONNECT message
with trio.fail_after(STREAM_READ_TIMEOUT):
with trio.fail_after(self.read_timeout):
resp_bytes = await stream.read(MAX_MESSAGE_SIZE)
# Calculate RTT
@ -349,7 +365,7 @@ class DCUtRProtocol(Service):
sync_msg = HolePunch()
sync_msg.type = HolePunch.SYNC
with trio.fail_after(STREAM_WRITE_TIMEOUT):
with trio.fail_after(self.write_timeout):
await stream.write(sync_msg.SerializeToString())
logger.debug("Sent SYNC message to %s", peer_id)
@ -468,7 +484,7 @@ class DCUtRProtocol(Service):
peer_info = PeerInfo(peer_id, [addr])
# Try to connect with timeout
with trio.fail_after(DIAL_TIMEOUT):
with trio.fail_after(self.dial_timeout):
await self.host.connect(peer_info)
logger.info("Successfully connected to %s at %s", peer_id, addr)

View File

@ -31,6 +31,11 @@ from libp2p.tools.async_service import (
Service,
)
from .config import (
DEFAULT_DISCOVERY_INTERVAL,
DEFAULT_DISCOVERY_STREAM_TIMEOUT,
DEFAULT_PEER_PROTOCOL_TIMEOUT,
)
from .pb.circuit_pb2 import (
HopMessage,
)
@ -43,10 +48,8 @@ from .protocol_buffer import (
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
# Constants
# Discovery constants
MAX_RELAYS_TO_TRACK = 10
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
STREAM_TIMEOUT = 10 # seconds
# Extended interfaces for type checking
@ -86,6 +89,8 @@ class RelayDiscovery(Service):
auto_reserve: bool = False,
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
max_relays: int = MAX_RELAYS_TO_TRACK,
stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT,
peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT,
) -> None:
"""
Initialize the discovery service.
@ -100,6 +105,10 @@ class RelayDiscovery(Service):
How often to run discovery, in seconds
max_relays : int
Maximum number of relays to track
stream_timeout : int
Timeout for stream operations during discovery, in seconds
peer_protocol_timeout : int
Timeout for checking peer protocol support, in seconds
"""
super().__init__()
@ -107,6 +116,8 @@ class RelayDiscovery(Service):
self.auto_reserve = auto_reserve
self.discovery_interval = discovery_interval
self.max_relays = max_relays
self.stream_timeout = stream_timeout
self.peer_protocol_timeout = peer_protocol_timeout
self._discovered_relays: dict[ID, RelayInfo] = {}
self._protocol_cache: dict[
ID, set[str]
@ -165,8 +176,8 @@ class RelayDiscovery(Service):
self._discovered_relays[peer_id].last_seen = time.time()
continue
# Check if peer supports the relay protocol
with trio.move_on_after(5): # Don't wait too long for protocol info
# Don't wait too long for protocol info
with trio.move_on_after(self.peer_protocol_timeout):
if await self._supports_relay_protocol(peer_id):
await self._add_relay(peer_id)
@ -264,7 +275,7 @@ class RelayDiscovery(Service):
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
"""Check protocol support via direct connection."""
try:
with trio.fail_after(STREAM_TIMEOUT):
with trio.fail_after(self.stream_timeout):
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
if stream:
await stream.close()
@ -370,7 +381,7 @@ class RelayDiscovery(Service):
# Open a stream to the relay with timeout
try:
with trio.fail_after(STREAM_TIMEOUT):
with trio.fail_after(self.stream_timeout):
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
if not stream:
logger.error("Failed to open stream to relay %s", peer_id)
@ -386,7 +397,7 @@ class RelayDiscovery(Service):
peer=self.host.get_id().to_bytes(),
)
with trio.fail_after(STREAM_TIMEOUT):
with trio.fail_after(self.stream_timeout):
await stream.write(request.SerializeToString())
# Wait for response

View File

@ -5,6 +5,7 @@ This module implements the Circuit Relay v2 protocol as specified in:
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
"""
from enum import Enum, auto
import logging
import time
from typing import (
@ -37,6 +38,15 @@ from libp2p.tools.async_service import (
Service,
)
from .config import (
DEFAULT_MAX_CIRCUIT_BYTES,
DEFAULT_MAX_CIRCUIT_CONNS,
DEFAULT_MAX_CIRCUIT_DURATION,
DEFAULT_MAX_RESERVATIONS,
DEFAULT_PROTOCOL_CLOSE_TIMEOUT,
DEFAULT_PROTOCOL_READ_TIMEOUT,
DEFAULT_PROTOCOL_WRITE_TIMEOUT,
)
from .pb.circuit_pb2 import (
HopMessage,
Limit,
@ -58,18 +68,22 @@ logger = logging.getLogger("libp2p.relay.circuit_v2")
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
# Direction enum for data piping
class Pipe(Enum):
SRC_TO_DST = auto()
DST_TO_SRC = auto()
# Default limits for relay resources
DEFAULT_RELAY_LIMITS = RelayLimits(
duration=60 * 60, # 1 hour
data=1024 * 1024 * 1024, # 1GB
max_circuit_conns=8,
max_reservations=4,
duration=DEFAULT_MAX_CIRCUIT_DURATION,
data=DEFAULT_MAX_CIRCUIT_BYTES,
max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS,
max_reservations=DEFAULT_MAX_RESERVATIONS,
)
# Stream operation timeouts
STREAM_READ_TIMEOUT = 15 # seconds
STREAM_WRITE_TIMEOUT = 15 # seconds
STREAM_CLOSE_TIMEOUT = 10 # seconds
# Stream operation constants
MAX_READ_RETRIES = 5 # Maximum number of read retries
@ -113,6 +127,9 @@ class CircuitV2Protocol(Service):
host: IHost,
limits: RelayLimits | None = None,
allow_hop: bool = False,
read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT,
write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT,
close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT,
) -> None:
"""
Initialize a Circuit Relay v2 protocol instance.
@ -125,11 +142,20 @@ class CircuitV2Protocol(Service):
Resource limits for the relay
allow_hop : bool
Whether to allow this node to act as a relay
read_timeout : int
Timeout for stream read operations, in seconds
write_timeout : int
Timeout for stream write operations, in seconds
close_timeout : int
Timeout for stream close operations, in seconds
"""
self.host = host
self.limits = limits or DEFAULT_RELAY_LIMITS
self.allow_hop = allow_hop
self.read_timeout = read_timeout
self.write_timeout = write_timeout
self.close_timeout = close_timeout
self.resource_manager = RelayResourceManager(self.limits)
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
self.event_started = trio.Event()
@ -174,7 +200,7 @@ class CircuitV2Protocol(Service):
return
try:
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
with trio.fail_after(self.close_timeout):
await stream.close()
except Exception:
try:
@ -216,7 +242,7 @@ class CircuitV2Protocol(Service):
while retries < max_retries:
try:
with trio.fail_after(STREAM_READ_TIMEOUT):
with trio.fail_after(self.read_timeout):
# Try reading with timeout
logger.debug(
"Attempting to read from stream (attempt %d/%d)",
@ -293,7 +319,7 @@ class CircuitV2Protocol(Service):
# First, handle the read timeout gracefully
try:
with trio.fail_after(
STREAM_READ_TIMEOUT * 2
self.read_timeout * 2
): # Double the timeout for reading
msg_bytes = await stream.read()
if not msg_bytes:
@ -414,7 +440,7 @@ class CircuitV2Protocol(Service):
"""
try:
# Read the incoming message with timeout
with trio.fail_after(STREAM_READ_TIMEOUT):
with trio.fail_after(self.read_timeout):
msg_bytes = await stream.read()
stop_msg = StopMessage()
stop_msg.ParseFromString(msg_bytes)
@ -458,8 +484,20 @@ class CircuitV2Protocol(Service):
# Start relaying data
async with trio.open_nursery() as nursery:
nursery.start_soon(self._relay_data, src_stream, stream, peer_id)
nursery.start_soon(self._relay_data, stream, src_stream, peer_id)
nursery.start_soon(
self._relay_data,
src_stream,
stream,
peer_id,
Pipe.SRC_TO_DST,
)
nursery.start_soon(
self._relay_data,
stream,
src_stream,
peer_id,
Pipe.DST_TO_SRC,
)
except trio.TooSlowError:
logger.error("Timeout reading from stop stream")
@ -509,7 +547,7 @@ class CircuitV2Protocol(Service):
ttl = self.resource_manager.reserve(peer_id)
# Send reservation success response
with trio.fail_after(STREAM_WRITE_TIMEOUT):
with trio.fail_after(self.write_timeout):
status = create_status(
code=StatusCode.OK, message="Reservation accepted"
)
@ -560,7 +598,7 @@ class CircuitV2Protocol(Service):
# Always close the stream when done with reservation
if cast(INetStreamWithExtras, stream).is_open():
try:
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
with trio.fail_after(self.close_timeout):
await stream.close()
except Exception as close_err:
logger.error("Error closing stream: %s", str(close_err))
@ -596,7 +634,7 @@ class CircuitV2Protocol(Service):
self._active_relays[peer_id] = (stream, None)
# Try to connect to the destination with timeout
with trio.fail_after(STREAM_READ_TIMEOUT):
with trio.fail_after(self.read_timeout):
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
if not dst_stream:
raise ConnectionError("Could not connect to destination")
@ -648,8 +686,20 @@ class CircuitV2Protocol(Service):
# Start relaying data
async with trio.open_nursery() as nursery:
nursery.start_soon(self._relay_data, stream, dst_stream, peer_id)
nursery.start_soon(self._relay_data, dst_stream, stream, peer_id)
nursery.start_soon(
self._relay_data,
stream,
dst_stream,
peer_id,
Pipe.SRC_TO_DST,
)
nursery.start_soon(
self._relay_data,
dst_stream,
stream,
peer_id,
Pipe.DST_TO_SRC,
)
except (trio.TooSlowError, ConnectionError) as e:
logger.error("Error establishing relay connection: %s", str(e))
@ -685,6 +735,7 @@ class CircuitV2Protocol(Service):
src_stream: INetStream,
dst_stream: INetStream,
peer_id: ID,
direction: Pipe,
) -> None:
"""
Relay data between two streams.
@ -698,24 +749,27 @@ class CircuitV2Protocol(Service):
peer_id : ID
ID of the peer being relayed
direction : Pipe
Direction of data flow (``Pipe.SRC_TO_DST`` or ``Pipe.DST_TO_SRC``)
"""
try:
while True:
# Read data with retries
data = await self._read_stream_with_retry(src_stream)
if not data:
logger.info("Source stream closed/reset")
logger.info("%s closed/reset", direction.name)
break
# Write data with timeout
try:
with trio.fail_after(STREAM_WRITE_TIMEOUT):
with trio.fail_after(self.write_timeout):
await dst_stream.write(data)
except trio.TooSlowError:
logger.error("Timeout writing to destination stream")
logger.error("Timeout writing in %s", direction.name)
break
except Exception as e:
logger.error("Error writing to destination stream: %s", str(e))
logger.error("Error writing in %s: %s", direction.name, str(e))
break
# Update resource usage
@ -744,7 +798,7 @@ class CircuitV2Protocol(Service):
"""Send a status message."""
try:
logger.debug("Sending status message with code %s: %s", code, message)
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
with trio.fail_after(self.write_timeout * 2): # Double the timeout
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(
@ -782,7 +836,7 @@ class CircuitV2Protocol(Service):
"""Send a status message on a STOP stream."""
try:
logger.debug("Sending stop status message with code %s: %s", code, message)
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
with trio.fail_after(self.write_timeout * 2): # Double the timeout
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(

View File

@ -8,6 +8,7 @@ including reservations and connection limits.
from dataclasses import (
dataclass,
)
from enum import Enum, auto
import hashlib
import os
import time
@ -19,6 +20,18 @@ from libp2p.peer.id import (
# Import the protobuf definitions
from .pb.circuit_pb2 import Reservation as PbReservation
RANDOM_BYTES_LENGTH = 16 # 128 bits of randomness
TIMESTAMP_MULTIPLIER = 1000000 # To convert seconds to microseconds
# Reservation status enum
class ReservationStatus(Enum):
"""Lifecycle status of a relay reservation."""
ACTIVE = auto()
EXPIRED = auto()
REJECTED = auto()
@dataclass
class RelayLimits:
@ -68,8 +81,8 @@ class Reservation:
# - Peer ID to bind it to the specific peer
# - Timestamp for uniqueness
# - Hash everything for a fixed size output
random_bytes = os.urandom(16) # 128 bits of randomness
timestamp = str(int(self.created_at * 1000000)).encode()
random_bytes = os.urandom(RANDOM_BYTES_LENGTH)
timestamp = str(int(self.created_at * TIMESTAMP_MULTIPLIER)).encode()
peer_bytes = self.peer_id.to_bytes()
# Combine all elements and hash them
@ -84,6 +97,15 @@ class Reservation:
"""Check if the reservation has expired."""
return time.time() > self.expires_at
# Expose a friendly status enum
@property
def status(self) -> ReservationStatus:
"""Return the current status as a ``ReservationStatus`` enum."""
return (
ReservationStatus.EXPIRED if self.is_expired() else ReservationStatus.ACTIVE
)
def can_accept_connection(self) -> bool:
"""Check if a new connection can be accepted."""
return (

View File

@ -89,6 +89,8 @@ class CircuitV2Transport(ITransport):
auto_reserve=config.enable_client,
discovery_interval=config.discovery_interval,
max_relays=config.max_relays,
stream_timeout=config.timeouts.discovery_stream_timeout,
peer_protocol_timeout=config.timeouts.peer_protocol_timeout,
)
self.relay_counter = 0 # for round robin load balancing

View File

@ -1,3 +1,4 @@
import logging
from typing import (
cast,
)
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
FixedSizeLenMsgReadWriter,
)
logger = logging.getLogger(__name__)
SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
self.noise_state = noise_state
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
data_encrypted = self.encrypt(msg)
if prefix_encoded:
# Manually add the prefix if needed
data_encrypted = self.prefix + data_encrypted
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
await self.read_writer.write_msg(data_encrypted)
logger.debug("Noise write_msg: write completed successfully")
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
logger.debug("Noise read_msg: reading encrypted message")
noise_msg_encrypted = await self.read_writer.read_msg()
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
if prefix_encoded:
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
else:
return self.decrypt(noise_msg_encrypted)
result = self.decrypt(noise_msg_encrypted)
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
return result
async def close(self) -> None:
await self.read_writer.close()

View File

@ -1,6 +1,7 @@
from dataclasses import (
dataclass,
)
import logging
from libp2p.crypto.keys import (
PrivateKey,
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
from .pb import noise_pb2 as noise_pb
logger = logging.getLogger(__name__)
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
) -> bytes:
data = make_data_to_be_signed(noise_static_pubkey)
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
return id_privkey.sign(data)
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
2. signed by the private key corresponding to `id_pubkey`
"""
expected_data = make_data_to_be_signed(noise_static_pubkey)
return payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(
f"verify_handshake_payload_sig: payload.id_pubkey type: "
f"{type(payload.id_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: noise_static_pubkey type: "
f"{type(noise_static_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
)
logger.debug(
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
)
try:
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
return result
except Exception as e:
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
return False

View File

@ -2,6 +2,7 @@ from abc import (
ABC,
abstractmethod,
)
import logging
from cryptography.hazmat.primitives import (
serialization,
@ -46,6 +47,8 @@ from .messages import (
verify_handshake_payload_sig,
)
logger = logging.getLogger(__name__)
class IPattern(ABC):
@abstractmethod
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
self.early_data = early_data
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
noise_state = self.create_noise_state()
noise_state.set_as_responder()
noise_state.start_handshake()
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
logger.debug("Noise XX handshake_inbound: reading msg#1")
await read_writer.read_msg()
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
# Send msg#2, which should include our handshake payload.
logger.debug("Noise XX handshake_inbound: preparing msg#2")
our_payload = self.make_handshake_payload()
msg_2 = our_payload.serialize()
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
await read_writer.write_msg(msg_2)
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
# Receive and consume msg#3.
logger.debug("Noise XX handshake_inbound: reading msg#3")
msg_3 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
if handshake_state.rs is None:
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn:
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
noise_state = self.create_noise_state()
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
logger.debug("Noise XX handshake_outbound: sending msg#1")
msg_1 = b""
await read_writer.write_msg(msg_1)
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
# Read msg#2 from the remote, which contains the public key of the peer.
logger.debug("Noise XX handshake_outbound: reading msg#2")
msg_2 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
if handshake_state.rs is None:
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
logger.debug(
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
)
logger.debug(
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
)
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
logger.debug(
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
f"{id_pubkey_repr}"
)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
logger.error(
f"Noise XX handshake_outbound: signature verification failed for peer "
f"{remote_peer}"
)
raise InvalidSignature
logger.debug(
f"Noise XX handshake_outbound: signature verification successful for peer "
f"{remote_peer}"
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(

View File

@ -1,5 +1,3 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -15,6 +13,7 @@ from libp2p.abc import (
from libp2p.stream_muxer.exceptions import (
MuxedConnUnavailable,
)
from libp2p.stream_muxer.rw_lock import ReadWriteLock
from .constants import (
HeaderTags,
@ -34,72 +33,6 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go

View File

@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
DEFAULT_NEGOTIATE_TIMEOUT,
Multiselect,
)
from libp2p.protocol_muxer.multiselect_client import (
@ -46,11 +47,17 @@ class MuxerMultistream:
transports: "OrderedDict[TProtocol, TMuxerClass]"
multiselect: Multiselect
multiselect_client: MultiselectClient
negotiate_timeout: int
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
def __init__(
self,
muxer_transports_by_protocol: TMuxerOptions,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> None:
self.transports = OrderedDict()
self.multiselect = Multiselect()
self.multistream_client = MultiselectClient()
self.negotiate_timeout = negotiate_timeout
for protocol, transport in muxer_transports_by_protocol.items():
self.add_transport(protocol, transport)
@ -80,10 +87,12 @@ class MuxerMultistream:
communicator = MultiselectCommunicator(conn)
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
tuple(self.transports.keys()), communicator
tuple(self.transports.keys()), communicator, self.negotiate_timeout
)
else:
protocol, _ = await self.multiselect.negotiate(communicator)
protocol, _ = await self.multiselect.negotiate(
communicator, self.negotiate_timeout
)
if protocol is None:
raise MultiselectError(
"Fail to negotiate a stream muxer protocol: no protocol selected"
@ -93,7 +102,7 @@ class MuxerMultistream:
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
communicator = MultiselectCommunicator(conn)
protocol = await self.multistream_client.select_one_of(
tuple(self.transports.keys()), communicator
tuple(self.transports.keys()), communicator, self.negotiate_timeout
)
transport_class = self.transports[protocol]
if protocol == PROTOCOL_ID:

View File

@ -0,0 +1,70 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import trio
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()

View File

@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamError,
MuxedStreamReset,
)
from libp2p.stream_muxer.rw_lock import ReadWriteLock
# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
self.send_window = DEFAULT_WINDOW_SIZE
self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.close_lock = trio.Lock()
async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager."""
@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream):
await self.close()
async def write(self, data: bytes) -> None:
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")
async with self.rw_lock.write_lock():
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logger.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logger.debug(
f"Stream {self.stream_id}: "
"Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release()
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed:
raise MuxedStreamError("Stream is closed")
# Calculate how much we can send now
to_send = min(self.send_window, total_len - sent)
chunk = data[sent : sent + to_send]
self.send_window -= to_send
# Send the data
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
)
# Release lock and wait with timeout
self.window_lock.release()
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed:
raise MuxedStreamError("Stream is closed")
# Calculate how much we can send now
to_send = min(self.send_window, total_len - sent)
chunk = data[sent : sent + to_send]
self.send_window -= to_send
# Send the data
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
)
await self.conn.secured_conn.write(header + chunk)
sent += to_send
await self.conn.secured_conn.write(header + chunk)
sent += to_send
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
"""
@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream):
return data
async def close(self) -> None:
if not self.send_closed:
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.send_closed = True
async with self.close_lock:
if not self.send_closed:
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.send_closed = True
# Only set fully closed if both directions are closed
if self.send_closed and self.recv_closed:
self.closed = True
else:
# Stream is half-closed but not fully closed
self.closed = False
# Only set fully closed if both directions are closed
if self.send_closed and self.recv_closed:
self.closed = True
else:
# Stream is half-closed but not fully closed
self.closed = False
async def reset(self) -> None:
if not self.closed:
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.closed = True
self.reset_received = True # Mark as reset
async with self.close_lock:
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
await self.conn.secured_conn.write(header)
self.closed = True
self.reset_received = True # Mark as reset
def set_deadline(self, ttl: int) -> bool:
"""

View File

@ -0,0 +1,57 @@
from typing import Any
from .tcp.tcp import TCP
from .websocket.transport import WebsocketTransport
from .transport_registry import (
TransportRegistry,
create_transport_for_multiaddr,
get_transport_registry,
register_transport,
get_supported_transport_protocols,
)
from .upgrader import TransportUpgrader
from libp2p.abc import ITransport
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport:
"""
Convenience function to create a transport instance.
:param protocol: The transport protocol ("tcp", "ws", "wss", or custom)
:param upgrader: Optional transport upgrader (required for WebSocket)
:param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config)
:return: Transport instance
"""
# First check if it's a built-in protocol
if protocol in ["ws", "wss"]:
if upgrader is None:
raise ValueError(f"WebSocket transport requires an upgrader")
return WebsocketTransport(
upgrader,
tls_client_config=kwargs.get("tls_client_config"),
tls_server_config=kwargs.get("tls_server_config"),
handshake_timeout=kwargs.get("handshake_timeout", 15.0)
)
elif protocol == "tcp":
return TCP()
else:
# Check if it's a custom registered transport
registry = get_transport_registry()
transport_class = registry.get_transport(protocol)
if transport_class:
transport = registry.create_transport(protocol, upgrader, **kwargs)
if transport is None:
raise ValueError(f"Failed to create transport for protocol: {protocol}")
return transport
else:
raise ValueError(f"Unsupported transport protocol: {protocol}")
__all__ = [
"TCP",
"WebsocketTransport",
"TransportRegistry",
"create_transport_for_multiaddr",
"create_transport",
"get_transport_registry",
"register_transport",
"get_supported_transport_protocols",
]

View File

View File

@ -0,0 +1,345 @@
"""
Configuration classes for QUIC transport.
"""
from dataclasses import (
dataclass,
field,
)
import ssl
from typing import Any, Literal, TypedDict
from libp2p.custom_types import TProtocol
from libp2p.network.config import ConnectionConfig
class QUICTransportKwargs(TypedDict, total=False):
"""Type definition for kwargs accepted by new_transport function."""
# Connection settings
idle_timeout: float
max_datagram_size: int
local_port: int | None
# Protocol version support
enable_draft29: bool
enable_v1: bool
# TLS settings
verify_mode: ssl.VerifyMode
alpn_protocols: list[str]
# Performance settings
max_concurrent_streams: int
connection_window: int
stream_window: int
# Logging and debugging
enable_qlog: bool
qlog_dir: str | None
# Connection management
max_connections: int
connection_timeout: float
# Protocol identifiers
PROTOCOL_QUIC_V1: TProtocol
PROTOCOL_QUIC_DRAFT29: TProtocol
@dataclass
class QUICTransportConfig(ConnectionConfig):
"""Configuration for QUIC transport."""
# Connection settings
idle_timeout: float = 30.0 # Seconds before an idle connection is closed.
max_datagram_size: int = (
1200 # Maximum size of UDP datagrams to avoid IP fragmentation.
)
local_port: int | None = (
None # Local port to bind to. If None, a random port is chosen.
)
# Protocol version support
enable_draft29: bool = True # Enable QUIC draft-29 for compatibility
enable_v1: bool = True # Enable QUIC v1 (RFC 9000)
# TLS settings
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
# Performance settings
max_concurrent_streams: int = 100 # Maximum concurrent streams per connection
connection_window: int = 1024 * 1024 # Connection flow control window
stream_window: int = 64 * 1024 # Stream flow control window
# Logging and debugging
enable_qlog: bool = False # Enable QUIC logging
qlog_dir: str | None = None # Directory for QUIC logs
# Connection management
max_connections: int = 1000 # Maximum number of connections
connection_timeout: float = 10.0 # Connection establishment timeout
MAX_CONCURRENT_STREAMS: int = 1000
"""Maximum number of concurrent streams per connection."""
MAX_INCOMING_STREAMS: int = 1000
"""Maximum number of incoming streams per connection."""
CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0
"""Timeout for connection handshake (seconds)."""
MAX_OUTGOING_STREAMS: int = 1000
"""Maximum number of outgoing streams per connection."""
CONNECTION_CLOSE_TIMEOUT: int = 10
"""Timeout for opening new connection (seconds)."""
# Stream timeouts
STREAM_OPEN_TIMEOUT: float = 5.0
"""Timeout for opening new streams (seconds)."""
STREAM_ACCEPT_TIMEOUT: float = 30.0
"""Timeout for accepting incoming streams (seconds)."""
STREAM_READ_TIMEOUT: float = 30.0
"""Default timeout for stream read operations (seconds)."""
STREAM_WRITE_TIMEOUT: float = 30.0
"""Default timeout for stream write operations (seconds)."""
STREAM_CLOSE_TIMEOUT: float = 10.0
"""Timeout for graceful stream close (seconds)."""
# Flow control configuration
STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB
"""Per-stream flow control window size."""
CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB
"""Connection-wide flow control window size."""
# Buffer management
MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB
"""Maximum receive buffer size per stream."""
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
"""Low watermark for stream receive buffer."""
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
"""High watermark for stream receive buffer."""
# Stream lifecycle configuration
ENABLE_STREAM_RESET_ON_ERROR: bool = True
"""Whether to automatically reset streams on errors."""
STREAM_RESET_ERROR_CODE: int = 1
"""Default error code for stream resets."""
ENABLE_STREAM_KEEP_ALIVE: bool = False
"""Whether to enable stream keep-alive mechanisms."""
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
"""Interval for stream keep-alive pings (seconds)."""
# Resource management
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
"""Whether to track stream resource usage."""
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
"""Memory limit per individual stream."""
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
"""Total memory limit for all streams per connection."""
# Concurrency and performance
ENABLE_STREAM_BATCHING: bool = True
"""Whether to batch multiple stream operations."""
STREAM_BATCH_SIZE: int = 10
"""Number of streams to process in a batch."""
STREAM_PROCESSING_CONCURRENCY: int = 100
"""Maximum concurrent stream processing tasks."""
# Debugging and monitoring
ENABLE_STREAM_METRICS: bool = True
"""Whether to collect stream metrics."""
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
"""Whether to track stream lifecycle timelines."""
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
"""Interval for collecting stream metrics (seconds)."""
# Error handling configuration
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
"""Number of retry attempts for recoverable stream errors."""
STREAM_ERROR_RETRY_DELAY: float = 1.0
"""Initial delay between stream error retries (seconds)."""
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
"""Backoff factor for stream error retries."""
# Protocol identifiers matching go-libp2p
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not (self.enable_draft29 or self.enable_v1):
raise ValueError("At least one QUIC version must be enabled")
if self.idle_timeout <= 0:
raise ValueError("Idle timeout must be positive")
if self.max_datagram_size < 1200:
raise ValueError("Max datagram size must be at least 1200 bytes")
# Validate timeouts
timeout_fields = [
"STREAM_OPEN_TIMEOUT",
"STREAM_ACCEPT_TIMEOUT",
"STREAM_READ_TIMEOUT",
"STREAM_WRITE_TIMEOUT",
"STREAM_CLOSE_TIMEOUT",
]
for timeout_field in timeout_fields:
if getattr(self, timeout_field) <= 0:
raise ValueError(f"{timeout_field} must be positive")
# Validate flow control windows
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
raise ValueError(
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
)
# Validate buffer sizes
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
raise ValueError(
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
"exceed MAX_STREAM_RECEIVE_BUFFER"
)
)
if (
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
):
raise ValueError(
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
)
# Validate memory limits
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
expected_stream_memory = (
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
)
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
# Allow some headroom, but warn if configuration seems inconsistent
import logging
logger = logging.getLogger(__name__)
logger.warning(
"Stream memory configuration may be inconsistent: "
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
"could exceed connection limit of"
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
)
def get_stream_config_dict(self) -> dict[str, Any]:
"""Get stream-specific configuration as dictionary."""
stream_config = {}
for attr_name in dir(self):
if attr_name.startswith(
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
):
stream_config[attr_name.lower()] = getattr(self, attr_name)
return stream_config
# Additional configuration classes for specific stream features
class QUICStreamFlowControlConfig:
"""Configuration for QUIC stream flow control."""
def __init__(
self,
initial_window_size: int = 512 * 1024,
max_window_size: int = 2 * 1024 * 1024,
window_update_threshold: float = 0.5,
enable_auto_tuning: bool = True,
):
self.initial_window_size = initial_window_size
self.max_window_size = max_window_size
self.window_update_threshold = window_update_threshold
self.enable_auto_tuning = enable_auto_tuning
def create_stream_config_for_use_case(
use_case: Literal[
"high_throughput", "low_latency", "many_streams", "memory_constrained"
],
) -> QUICTransportConfig:
"""
Create optimized stream configuration for specific use cases.
Args:
use_case: One of "high_throughput", "low_latency", "many_streams","
"memory_constrained"
Returns:
Optimized QUICTransportConfig
"""
base_config = QUICTransportConfig()
if use_case == "high_throughput":
# Optimize for high throughput
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
base_config.STREAM_PROCESSING_CONCURRENCY = 200
elif use_case == "low_latency":
# Optimize for low latency
base_config.STREAM_OPEN_TIMEOUT = 1.0
base_config.STREAM_READ_TIMEOUT = 5.0
base_config.STREAM_WRITE_TIMEOUT = 5.0
base_config.ENABLE_STREAM_BATCHING = False
base_config.STREAM_BATCH_SIZE = 1
elif use_case == "many_streams":
# Optimize for many concurrent streams
base_config.MAX_CONCURRENT_STREAMS = 5000
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
base_config.STREAM_PROCESSING_CONCURRENCY = 500
elif use_case == "memory_constrained":
# Optimize for low memory usage
base_config.MAX_CONCURRENT_STREAMS = 100
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
base_config.STREAM_PROCESSING_CONCURRENCY = 50
else:
raise ValueError(f"Unknown use case: {use_case}")
return base_config

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,391 @@
"""
QUIC Transport exceptions
"""
from typing import Any, Literal
class QUICError(Exception):
"""Base exception for all QUIC transport errors."""
def __init__(self, message: str, error_code: int | None = None):
super().__init__(message)
self.error_code = error_code
# Transport-level exceptions
class QUICTransportError(QUICError):
"""Base exception for QUIC transport operations."""
pass
class QUICDialError(QUICTransportError):
"""Error occurred during QUIC connection establishment."""
pass
class QUICListenError(QUICTransportError):
"""Error occurred during QUIC listener operations."""
pass
class QUICSecurityError(QUICTransportError):
"""Error related to QUIC security/TLS operations."""
pass
# Connection-level exceptions
class QUICConnectionError(QUICError):
"""Base exception for QUIC connection operations."""
pass
class QUICConnectionClosedError(QUICConnectionError):
"""QUIC connection has been closed."""
pass
class QUICConnectionTimeoutError(QUICConnectionError):
"""QUIC connection operation timed out."""
pass
class QUICHandshakeError(QUICConnectionError):
"""Error during QUIC handshake process."""
pass
class QUICPeerVerificationError(QUICConnectionError):
"""Error verifying peer identity during handshake."""
pass
# Stream-level exceptions
class QUICStreamError(QUICError):
"""Base exception for QUIC stream operations."""
def __init__(
self,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
):
super().__init__(message, error_code)
self.stream_id = stream_id
class QUICStreamClosedError(QUICStreamError):
"""Stream is closed and cannot be used for I/O operations."""
pass
class QUICStreamResetError(QUICStreamError):
"""Stream was reset by local or remote peer."""
def __init__(
self,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
reset_by_peer: bool = False,
):
super().__init__(message, stream_id, error_code)
self.reset_by_peer = reset_by_peer
class QUICStreamTimeoutError(QUICStreamError):
"""Stream operation timed out."""
pass
class QUICStreamBackpressureError(QUICStreamError):
"""Stream write blocked due to flow control."""
pass
class QUICStreamLimitError(QUICStreamError):
"""Stream limit reached (too many concurrent streams)."""
pass
class QUICStreamStateError(QUICStreamError):
"""Invalid operation for current stream state."""
def __init__(
self,
message: str,
stream_id: str | None = None,
current_state: str | None = None,
attempted_operation: str | None = None,
):
super().__init__(message, stream_id)
self.current_state = current_state
self.attempted_operation = attempted_operation
# Flow control exceptions
class QUICFlowControlError(QUICError):
"""Base exception for flow control related errors."""
pass
class QUICFlowControlViolationError(QUICFlowControlError):
"""Flow control limits were violated."""
pass
class QUICFlowControlDeadlockError(QUICFlowControlError):
"""Flow control deadlock detected."""
pass
# Resource management exceptions
class QUICResourceError(QUICError):
"""Base exception for resource management errors."""
pass
class QUICMemoryLimitError(QUICResourceError):
"""Memory limit exceeded."""
pass
class QUICConnectionLimitError(QUICResourceError):
"""Connection limit exceeded."""
pass
# Multiaddr and addressing exceptions
class QUICAddressError(QUICError):
"""Base exception for QUIC addressing errors."""
pass
class QUICInvalidMultiaddrError(QUICAddressError):
"""Invalid multiaddr format for QUIC transport."""
pass
class QUICAddressResolutionError(QUICAddressError):
"""Failed to resolve QUIC address."""
pass
class QUICProtocolError(QUICError):
"""Base exception for QUIC protocol errors."""
pass
class QUICVersionNegotiationError(QUICProtocolError):
"""QUIC version negotiation failed."""
pass
class QUICUnsupportedVersionError(QUICProtocolError):
"""Unsupported QUIC version."""
pass
# Configuration exceptions
class QUICConfigurationError(QUICError):
"""Base exception for QUIC configuration errors."""
pass
class QUICInvalidConfigError(QUICConfigurationError):
"""Invalid QUIC configuration parameters."""
pass
class QUICCertificateError(QUICConfigurationError):
"""Error with TLS certificate configuration."""
pass
def map_quic_error_code(error_code: int) -> str:
"""
Map QUIC error codes to human-readable descriptions.
Based on RFC 9000 Transport Error Codes.
"""
error_codes = {
0x00: "NO_ERROR",
0x01: "INTERNAL_ERROR",
0x02: "CONNECTION_REFUSED",
0x03: "FLOW_CONTROL_ERROR",
0x04: "STREAM_LIMIT_ERROR",
0x05: "STREAM_STATE_ERROR",
0x06: "FINAL_SIZE_ERROR",
0x07: "FRAME_ENCODING_ERROR",
0x08: "TRANSPORT_PARAMETER_ERROR",
0x09: "CONNECTION_ID_LIMIT_ERROR",
0x0A: "PROTOCOL_VIOLATION",
0x0B: "INVALID_TOKEN",
0x0C: "APPLICATION_ERROR",
0x0D: "CRYPTO_BUFFER_EXCEEDED",
0x0E: "KEY_UPDATE_ERROR",
0x0F: "AEAD_LIMIT_REACHED",
0x10: "NO_VIABLE_PATH",
}
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
def create_stream_error(
error_type: str,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
) -> QUICStreamError:
"""
Factory function to create appropriate stream error based on type.
Args:
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
message: Error message
stream_id: Stream identifier
error_code: QUIC error code
Returns:
Appropriate QUICStreamError subclass
"""
error_type = error_type.lower()
if error_type in ("closed", "close"):
return QUICStreamClosedError(message, stream_id, error_code)
elif error_type == "reset":
return QUICStreamResetError(message, stream_id, error_code)
elif error_type == "timeout":
return QUICStreamTimeoutError(message, stream_id, error_code)
elif error_type in ("backpressure", "flow_control"):
return QUICStreamBackpressureError(message, stream_id, error_code)
elif error_type in ("limit", "stream_limit"):
return QUICStreamLimitError(message, stream_id, error_code)
elif error_type == "state":
return QUICStreamStateError(message, stream_id)
else:
return QUICStreamError(message, stream_id, error_code)
def create_connection_error(
error_type: str, message: str, error_code: int | None = None
) -> QUICConnectionError:
"""
Factory function to create appropriate connection error based on type.
Args:
error_type: Type of error ("closed", "timeout", "handshake", etc.)
message: Error message
error_code: QUIC error code
Returns:
Appropriate QUICConnectionError subclass
"""
error_type = error_type.lower()
if error_type in ("closed", "close"):
return QUICConnectionClosedError(message, error_code)
elif error_type == "timeout":
return QUICConnectionTimeoutError(message, error_code)
elif error_type == "handshake":
return QUICHandshakeError(message, error_code)
elif error_type in ("peer_verification", "verification"):
return QUICPeerVerificationError(message, error_code)
else:
return QUICConnectionError(message, error_code)
class QUICErrorContext:
"""
Context manager for handling QUIC errors with automatic error mapping.
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
"""
def __init__(self, operation: str, component: str = "quic") -> None:
self.operation = operation
self.component = component
def __enter__(self) -> "QUICErrorContext":
return self
# TODO: Fix types for exc_type
def __exit__(
self,
exc_type: type[BaseException] | None | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> Literal[False]:
if exc_type is None:
return False
if exc_val is None:
return False
# Map common aioquic exceptions to our exceptions
if "ConnectionClosed" in str(exc_type):
raise QUICConnectionClosedError(
f"Connection closed during {self.operation}: {exc_val}"
) from exc_val
elif "StreamReset" in str(exc_type):
raise QUICStreamResetError(
f"Stream reset during {self.operation}: {exc_val}"
) from exc_val
elif "timeout" in str(exc_val).lower():
if "stream" in self.component.lower():
raise QUICStreamTimeoutError(
f"Timeout during {self.operation}: {exc_val}"
) from exc_val
else:
raise QUICConnectionTimeoutError(
f"Timeout during {self.operation}: {exc_val}"
) from exc_val
elif "flow control" in str(exc_val).lower():
raise QUICStreamBackpressureError(
f"Flow control error during {self.operation}: {exc_val}"
) from exc_val
# Let other exceptions propagate
return False

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,656 @@
"""
QUIC Stream implementation
Provides stream interface over QUIC's native multiplexing.
"""
from enum import Enum
import logging
import time
from types import TracebackType
from typing import TYPE_CHECKING, Any, cast
import trio
from .exceptions import (
QUICStreamBackpressureError,
QUICStreamClosedError,
QUICStreamResetError,
QUICStreamTimeoutError,
)
if TYPE_CHECKING:
from libp2p.abc import IMuxedStream
from libp2p.custom_types import TProtocol
from .connection import QUICConnection
else:
IMuxedStream = cast(type, object)
TProtocol = cast(type, object)
logger = logging.getLogger(__name__)
class StreamState(Enum):
"""Stream lifecycle states following libp2p patterns."""
OPEN = "open"
WRITE_CLOSED = "write_closed"
READ_CLOSED = "read_closed"
CLOSED = "closed"
RESET = "reset"
class StreamDirection(Enum):
"""Stream direction for tracking initiator."""
INBOUND = "inbound"
OUTBOUND = "outbound"
class StreamTimeline:
"""Track stream lifecycle events for debugging and monitoring."""
def __init__(self) -> None:
self.created_at = time.time()
self.opened_at: float | None = None
self.first_data_at: float | None = None
self.closed_at: float | None = None
self.reset_at: float | None = None
self.error_code: int | None = None
def record_open(self) -> None:
self.opened_at = time.time()
def record_first_data(self) -> None:
if self.first_data_at is None:
self.first_data_at = time.time()
def record_close(self) -> None:
self.closed_at = time.time()
def record_reset(self, error_code: int) -> None:
self.reset_at = time.time()
self.error_code = error_code
class QUICStream(IMuxedStream):
"""
QUIC Stream implementation following libp2p IMuxedStream interface.
Based on patterns from go-libp2p and js-libp2p, this implementation:
- Leverages QUIC's native multiplexing and flow control
- Integrates with libp2p resource management
- Provides comprehensive error handling with QUIC-specific codes
- Supports bidirectional communication with independent close semantics
- Implements proper stream lifecycle management
"""
def __init__(
self,
connection: "QUICConnection",
stream_id: int,
direction: StreamDirection,
remote_addr: tuple[str, int],
resource_scope: Any | None = None,
):
"""
Initialize QUIC stream.
Args:
connection: Parent QUIC connection
stream_id: QUIC stream identifier
direction: Stream direction (inbound/outbound)
resource_scope: Resource manager scope for memory accounting
remote_addr: Remote addr stream is connected to
"""
self._connection = connection
self._stream_id = stream_id
self._direction = direction
self._resource_scope = resource_scope
# libp2p interface compliance
self._protocol: TProtocol | None = None
self._metadata: dict[str, Any] = {}
self._remote_addr = remote_addr
# Stream state management
self._state = StreamState.OPEN
self._state_lock = trio.Lock()
# Flow control and buffering
self._receive_buffer = bytearray()
self._receive_buffer_lock = trio.Lock()
self._receive_event = trio.Event()
self._backpressure_event = trio.Event()
self._backpressure_event.set() # Initially no backpressure
# Close/reset state
self._write_closed = False
self._read_closed = False
self._close_event = trio.Event()
self._reset_error_code: int | None = None
# Lifecycle tracking
self._timeline = StreamTimeline()
self._timeline.record_open()
# Resource accounting
self._memory_reserved = 0
# Stream constant configurations
self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT
self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT
self.FLOW_CONTROL_WINDOW_SIZE = (
connection._transport._config.STREAM_FLOW_CONTROL_WINDOW
)
self.MAX_RECEIVE_BUFFER_SIZE = (
connection._transport._config.MAX_STREAM_RECEIVE_BUFFER
)
if self._resource_scope:
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
logger.debug(
f"Created QUIC stream {stream_id} "
f"({direction.value}, connection: {connection.remote_peer_id()})"
)
# Properties for libp2p interface compliance
@property
def protocol(self) -> TProtocol | None:
"""Get the protocol identifier for this stream."""
return self._protocol
@protocol.setter
def protocol(self, protocol_id: TProtocol) -> None:
"""Set the protocol identifier for this stream."""
self._protocol = protocol_id
self._metadata["protocol"] = protocol_id
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
@property
def stream_id(self) -> str:
"""Get stream ID as string for libp2p compatibility."""
return str(self._stream_id)
@property
def muxed_conn(self) -> "QUICConnection": # type: ignore
"""Get the parent muxed connection."""
return self._connection
@property
def state(self) -> StreamState:
"""Get current stream state."""
return self._state
@property
def direction(self) -> StreamDirection:
"""Get stream direction."""
return self._direction
@property
def is_initiator(self) -> bool:
"""Check if this stream was locally initiated."""
return self._direction == StreamDirection.OUTBOUND
# Core stream operations
async def read(self, n: int | None = None) -> bytes:
"""
Read data from the stream with QUIC flow control.
Args:
n: Maximum number of bytes to read. If None or -1, read all available.
Returns:
Data read from stream
Raises:
QUICStreamClosedError: Stream is closed
QUICStreamResetError: Stream was reset
QUICStreamTimeoutError: Read timeout exceeded
"""
if n is None:
n = -1
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
if self._read_closed:
# Return any remaining buffered data, then EOF
async with self._receive_buffer_lock:
if self._receive_buffer:
data = self._extract_data_from_buffer(n)
self._timeline.record_first_data()
return data
return b""
# Wait for data with timeout
timeout = self.READ_TIMEOUT
try:
with trio.move_on_after(timeout) as cancel_scope:
while True:
async with self._receive_buffer_lock:
if self._receive_buffer:
data = self._extract_data_from_buffer(n)
self._timeline.record_first_data()
return data
# Check if stream was closed while waiting
if self._read_closed:
return b""
# Wait for more data
await self._receive_event.wait()
self._receive_event = trio.Event() # Reset for next wait
if cancel_scope.cancelled_caught:
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
return b""
except QUICStreamResetError:
# Stream was reset while reading
raise
except Exception as e:
logger.error(f"Error reading from stream {self.stream_id}: {e}")
await self._handle_stream_error(e)
raise
async def write(self, data: bytes) -> None:
"""
Write data to the stream with QUIC flow control.
Args:
data: Data to write
Raises:
QUICStreamClosedError: Stream is closed for writing
QUICStreamBackpressureError: Flow control window exhausted
QUICStreamResetError: Stream was reset
"""
if not data:
return
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
if self._write_closed:
raise QUICStreamClosedError(
f"Stream {self.stream_id} write side is closed"
)
try:
# Handle flow control backpressure
await self._backpressure_event.wait()
# Send data through QUIC connection
self._connection._quic.send_stream_data(self._stream_id, data)
await self._connection._transmit()
self._timeline.record_first_data()
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
except Exception as e:
logger.error(f"Error writing to stream {self.stream_id}: {e}")
# Convert QUIC-specific errors
if "flow control" in str(e).lower():
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
await self._handle_stream_error(e)
raise
async def close(self) -> None:
"""
Close the stream gracefully (both read and write sides).
This implements proper close semantics where both sides
are closed and resources are cleaned up.
"""
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
return
logger.debug(f"Closing stream {self.stream_id}")
# Close both sides
if not self._write_closed:
await self.close_write()
if not self._read_closed:
await self.close_read()
# Update state and cleanup
async with self._state_lock:
self._state = StreamState.CLOSED
await self._cleanup_resources()
self._timeline.record_close()
self._close_event.set()
logger.debug(f"Stream {self.stream_id} closed")
async def close_write(self) -> None:
"""Close the write side of the stream."""
if self._write_closed:
return
try:
# Send FIN to close write side
self._connection._quic.send_stream_data(
self._stream_id, b"", end_stream=True
)
await self._connection._transmit()
self._write_closed = True
async with self._state_lock:
if self._read_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.WRITE_CLOSED
logger.debug(f"Stream {self.stream_id} write side closed")
except Exception as e:
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
async def close_read(self) -> None:
"""Close the read side of the stream."""
if self._read_closed:
return
try:
self._read_closed = True
async with self._state_lock:
if self._write_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.READ_CLOSED
# Wake up any pending reads
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} read side closed")
except Exception as e:
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
async def reset(self, error_code: int = 0) -> None:
"""
Reset the stream with the given error code.
Args:
error_code: QUIC error code for the reset
"""
async with self._state_lock:
if self._state == StreamState.RESET:
return
logger.debug(
f"Resetting stream {self.stream_id} with error code {error_code}"
)
self._state = StreamState.RESET
self._reset_error_code = error_code
try:
# Send QUIC reset frame
self._connection._quic.reset_stream(self._stream_id, error_code)
await self._connection._transmit()
except Exception as e:
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
finally:
# Always cleanup resources
await self._cleanup_resources()
self._timeline.record_reset(error_code)
self._close_event.set()
def is_closed(self) -> bool:
"""Check if stream is completely closed."""
return self._state in (StreamState.CLOSED, StreamState.RESET)
def is_reset(self) -> bool:
"""Check if stream was reset."""
return self._state == StreamState.RESET
def can_read(self) -> bool:
"""Check if stream can be read from."""
return not self._read_closed and self._state not in (
StreamState.CLOSED,
StreamState.RESET,
)
def can_write(self) -> bool:
"""Check if stream can be written to."""
return not self._write_closed and self._state not in (
StreamState.CLOSED,
StreamState.RESET,
)
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
"""
Handle data received from the QUIC connection.
Args:
data: Received data
end_stream: Whether this is the last data (FIN received)
"""
if self._state == StreamState.RESET:
return
if data:
async with self._receive_buffer_lock:
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
logger.warning(
f"Stream {self.stream_id} receive buffer overflow, "
f"dropping {len(data)} bytes"
)
return
self._receive_buffer.extend(data)
self._timeline.record_first_data()
# Notify waiting readers
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
if end_stream:
self._read_closed = True
async with self._state_lock:
if self._write_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.READ_CLOSED
# Wake up readers to process remaining data and EOF
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} received FIN")
async def handle_stop_sending(self, error_code: int) -> None:
"""
Handle STOP_SENDING frame from remote peer.
When a STOP_SENDING frame is received, the peer is requesting that we
stop sending data on this stream. We respond by resetting the stream.
Args:
error_code: Error code from the STOP_SENDING frame
"""
logger.debug(
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
)
self._write_closed = True
# Wake up any pending write operations
self._backpressure_event.set()
async with self._state_lock:
if self.direction == StreamDirection.OUTBOUND:
self._state = StreamState.CLOSED
elif self._read_closed:
self._state = StreamState.CLOSED
else:
# Only write side closed - add WRITE_CLOSED state if needed
self._state = StreamState.WRITE_CLOSED
# Send RESET_STREAM in response (QUIC protocol requirement)
try:
self._connection._quic.reset_stream(int(self.stream_id), error_code)
await self._connection._transmit()
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
except Exception as e:
logger.warning(
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
)
async def handle_reset(self, error_code: int) -> None:
"""
Handle stream reset from remote peer.
Args:
error_code: QUIC error code from reset frame
"""
logger.debug(
f"Stream {self.stream_id} reset by peer with error code {error_code}"
)
async with self._state_lock:
self._state = StreamState.RESET
self._reset_error_code = error_code
await self._cleanup_resources()
self._timeline.record_reset(error_code)
self._close_event.set()
# Wake up any pending operations
self._receive_event.set()
self._backpressure_event.set()
async def handle_flow_control_update(self, available_window: int) -> None:
"""
Handle flow control window updates.
Args:
available_window: Available flow control window size
"""
if available_window > 0:
self._backpressure_event.set()
logger.debug(
f"Stream {self.stream_id} flow control".__add__(
f"window updated: {available_window}"
)
)
else:
self._backpressure_event = trio.Event() # Reset to blocking state
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
def _extract_data_from_buffer(self, n: int) -> bytes:
"""Extract data from receive buffer with specified limit."""
if n == -1:
# Read all available data
data = bytes(self._receive_buffer)
self._receive_buffer.clear()
else:
# Read up to n bytes
data = bytes(self._receive_buffer[:n])
self._receive_buffer = self._receive_buffer[n:]
return data
async def _handle_stream_error(self, error: Exception) -> None:
"""Handle errors by resetting the stream."""
logger.error(f"Stream {self.stream_id} error: {error}")
await self.reset(error_code=1) # Generic error code
def _reserve_memory(self, size: int) -> None:
"""Reserve memory with resource manager."""
if self._resource_scope:
try:
self._resource_scope.reserve_memory(size)
self._memory_reserved += size
except Exception as e:
logger.warning(
f"Failed to reserve memory for stream {self.stream_id}: {e}"
)
def _release_memory(self, size: int) -> None:
"""Release memory with resource manager."""
if self._resource_scope and size > 0:
try:
self._resource_scope.release_memory(size)
self._memory_reserved = max(0, self._memory_reserved - size)
except Exception as e:
logger.warning(
f"Failed to release memory for stream {self.stream_id}: {e}"
)
async def _cleanup_resources(self) -> None:
"""Clean up stream resources."""
# Release all reserved memory
if self._memory_reserved > 0:
self._release_memory(self._memory_reserved)
# Clear receive buffer
async with self._receive_buffer_lock:
self._receive_buffer.clear()
# Remove from connection's stream registry
self._connection._remove_stream(self._stream_id)
logger.debug(f"Stream {self.stream_id} resources cleaned up")
# Abstact implementations
def get_remote_address(self) -> tuple[str, int]:
return self._remote_addr
async def __aenter__(self) -> "QUICStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager and close the stream."""
logger.debug("Exiting the context and closing the stream")
await self.close()
def set_deadline(self, ttl: int) -> bool:
"""
Set a deadline for the stream. QUIC does not support deadlines natively,
so this method always returns False to indicate the operation is unsupported.
:param ttl: Time-to-live in seconds (ignored).
:return: False, as deadlines are not supported.
"""
raise NotImplementedError("QUIC does not support setting read deadlines")
# String representation for debugging
def __repr__(self) -> str:
return (
f"QUICStream(id={self.stream_id}, "
f"state={self._state.value}, "
f"direction={self._direction.value}, "
f"protocol={self._protocol})"
)
def __str__(self) -> str:
return f"QUICStream({self.stream_id})"

View File

@ -0,0 +1,491 @@
"""
QUIC Transport implementation
"""
import copy
import logging
import ssl
from typing import TYPE_CHECKING, cast
from aioquic.quic.configuration import (
QuicConfiguration,
)
from aioquic.quic.connection import (
QuicConnection as NativeQUICConnection,
)
from aioquic.quic.logger import QuicLogger
import multiaddr
import trio
from libp2p.abc import (
ITransport,
)
from libp2p.crypto.keys import (
PrivateKey,
)
from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
from libp2p.peer.id import (
ID,
)
from libp2p.transport.quic.security import QUICTLSSecurityConfig
from libp2p.transport.quic.utils import (
create_client_config_from_base,
create_server_config_from_base,
get_alpn_protocols,
is_quic_multiaddr,
multiaddr_to_quic_version,
quic_multiaddr_to_endpoint,
quic_version_to_wire_format,
)
if TYPE_CHECKING:
from libp2p.network.swarm import Swarm
else:
Swarm = cast(type, object)
from .config import (
QUICTransportConfig,
)
from .connection import (
QUICConnection,
)
from .exceptions import (
QUICDialError,
QUICListenError,
QUICSecurityError,
)
from .listener import (
QUICListener,
)
from .security import (
QUICTLSConfigManager,
create_quic_security_transport,
)
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
logger = logging.getLogger(__name__)
class QUICTransport(ITransport):
"""
QUIC Stream implementation following libp2p IMuxedStream interface.
"""
def __init__(
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
) -> None:
"""
Initialize QUIC transport with security integration.
Args:
private_key: libp2p private key for identity and TLS cert generation
config: QUIC transport configuration options
"""
self._private_key = private_key
self._peer_id = ID.from_pubkey(private_key.get_public_key())
self._config = config or QUICTransportConfig()
# Connection management
self._connections: dict[str, QUICConnection] = {}
self._listeners: list[QUICListener] = []
# Security manager for TLS integration
self._security_manager = create_quic_security_transport(
self._private_key, self._peer_id
)
# QUIC configurations for different versions
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
self._setup_quic_configurations()
# Resource management
self._closed = False
self._nursery_manager = trio.CapacityLimiter(1)
self._background_nursery: trio.Nursery | None = None
self._swarm: Swarm | None = None
logger.debug(
f"Initialized QUIC transport with security for peer {self._peer_id}"
)
def set_background_nursery(self, nursery: trio.Nursery) -> None:
"""Set the nursery to use for background tasks (called by swarm)."""
self._background_nursery = nursery
logger.debug("Transport background nursery set")
def set_swarm(self, swarm: Swarm) -> None:
"""Set the swarm for adding incoming connections."""
self._swarm = swarm
def _setup_quic_configurations(self) -> None:
"""Setup QUIC configurations."""
try:
# Get TLS configuration from security manager
server_tls_config = self._security_manager.create_server_config()
client_tls_config = self._security_manager.create_client_config()
# Base server configuration
base_server_config = QuicConfiguration(
is_client=False,
alpn_protocols=get_alpn_protocols(),
verify_mode=self._config.verify_mode,
max_datagram_frame_size=self._config.max_datagram_size,
idle_timeout=self._config.idle_timeout,
)
# Base client configuration
base_client_config = QuicConfiguration(
is_client=True,
alpn_protocols=get_alpn_protocols(),
verify_mode=self._config.verify_mode,
max_datagram_frame_size=self._config.max_datagram_size,
idle_timeout=self._config.idle_timeout,
)
# Apply TLS configuration
self._apply_tls_configuration(base_server_config, server_tls_config)
self._apply_tls_configuration(base_client_config, client_tls_config)
# QUIC v1 (RFC 9000) configurations
if self._config.enable_v1:
quic_v1_server_config = create_server_config_from_base(
base_server_config, self._security_manager, self._config
)
quic_v1_server_config.supported_versions = [
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
]
quic_v1_client_config = create_client_config_from_base(
base_client_config, self._security_manager, self._config
)
quic_v1_client_config.supported_versions = [
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
]
# Store both server and client configs for v1
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = (
quic_v1_server_config
)
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = (
quic_v1_client_config
)
# QUIC draft-29 configurations for compatibility
if self._config.enable_draft29:
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
draft29_server_config.supported_versions = [
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
]
draft29_client_config = copy.copy(base_client_config)
draft29_client_config.supported_versions = [
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
]
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = (
draft29_server_config
)
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = (
draft29_client_config
)
logger.debug("QUIC configurations initialized with libp2p TLS security")
except Exception as e:
raise QUICSecurityError(
f"Failed to setup QUIC TLS configurations: {e}"
) from e
def _apply_tls_configuration(
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
) -> None:
"""
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
Args:
config: QuicConfiguration to update
tls_config: TLS configuration dictionary from security manager
"""
try:
config.certificate = tls_config.certificate
config.private_key = tls_config.private_key
config.certificate_chain = tls_config.certificate_chain
config.alpn_protocols = tls_config.alpn_protocols
config.verify_mode = ssl.CERT_NONE
logger.debug("Successfully applied TLS configuration to QUIC config")
except Exception as e:
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
async def dial(
self,
maddr: multiaddr.Multiaddr,
) -> QUICConnection:
"""
Dial a remote peer using QUIC transport with security verification.
Args:
maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1)
peer_id: Expected peer ID for verification
nursery: Nursery to execute the background tasks
Returns:
Raw connection interface to the remote peer
Raises:
QUICDialError: If dialing fails
QUICSecurityError: If security verification fails
"""
if self._closed:
raise QUICDialError("Transport is closed")
if not is_quic_multiaddr(maddr):
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
try:
# Extract connection details from multiaddr
host, port = quic_multiaddr_to_endpoint(maddr)
remote_peer_id = maddr.get_peer_id()
if remote_peer_id is not None:
remote_peer_id = ID.from_base58(remote_peer_id)
if remote_peer_id is None:
logger.error("Unable to derive peer id from multiaddr")
raise QUICDialError("Unable to derive peer id from multiaddr")
quic_version = multiaddr_to_quic_version(maddr)
# Get appropriate QUIC client configuration
config_key = TProtocol(f"{quic_version}_client")
logger.debug("config_key", config_key, self._quic_configs.keys())
config = self._quic_configs.get(config_key)
if not config:
raise QUICDialError(f"Unsupported QUIC version: {quic_version}")
config.is_client = True
config.quic_logger = QuicLogger()
# Ensure client certificate is properly set for mutual authentication
if not config.certificate or not config.private_key:
logger.warning(
"Client config missing certificate - applying TLS config"
)
client_tls_config = self._security_manager.create_client_config()
self._apply_tls_configuration(config, client_tls_config)
# Debug log to verify certificate is present
logger.info(
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
)
logger.debug("Starting QUIC Connection")
# Create QUIC connection using aioquic's sans-IO core
native_quic_connection = NativeQUICConnection(configuration=config)
# Create trio-based QUIC connection wrapper with security
connection = QUICConnection(
quic_connection=native_quic_connection,
remote_addr=(host, port),
remote_peer_id=remote_peer_id,
local_peer_id=self._peer_id,
is_initiator=True,
maddr=maddr,
transport=self,
security_manager=self._security_manager,
)
logger.debug("QUIC Connection Created")
if self._background_nursery is None:
logger.error("No nursery set to execute background tasks")
raise QUICDialError("No nursery found to execute tasks")
await connection.connect(self._background_nursery)
# Store connection for management
conn_id = f"{host}:{port}"
self._connections[conn_id] = connection
return connection
except Exception as e:
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
raise QUICDialError(f"Dial failed: {e}") from e
async def _verify_peer_identity(
self, connection: QUICConnection, expected_peer_id: ID
) -> None:
"""
Verify remote peer identity after TLS handshake.
Args:
connection: The established QUIC connection
expected_peer_id: Expected peer ID
Raises:
QUICSecurityError: If peer verification fails
"""
try:
# Get peer certificate from the connection
peer_certificate = await connection.get_peer_certificate()
if not peer_certificate:
raise QUICSecurityError("No peer certificate available")
# Verify peer identity using security manager
verified_peer_id = self._security_manager.verify_peer_identity(
peer_certificate, expected_peer_id
)
if verified_peer_id != expected_peer_id:
raise QUICSecurityError(
"Peer ID verification failed: expected "
f"{expected_peer_id}, got {verified_peer_id}"
)
logger.debug(f"Peer identity verified: {verified_peer_id}")
logger.debug(f"Peer identity verified: {verified_peer_id}")
except Exception as e:
raise QUICSecurityError(f"Peer identity verification failed: {e}") from e
def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener:
"""
Create a QUIC listener with integrated security.
Args:
handler_function: Function to handle new connections
Returns:
QUIC listener instance
Raises:
QUICListenError: If transport is closed
"""
if self._closed:
raise QUICListenError("Transport is closed")
# Get server configurations for the listener
server_configs = {
version: config
for version, config in self._quic_configs.items()
if version.endswith("_server")
}
listener = QUICListener(
transport=self,
handler_function=handler_function,
quic_configs=server_configs,
config=self._config,
security_manager=self._security_manager,
)
self._listeners.append(listener)
logger.debug("Created QUIC listener with security")
return listener
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
"""
Check if this transport can dial the given multiaddr.
Args:
maddr: Multiaddr to check
Returns:
True if this transport can dial the address
"""
return is_quic_multiaddr(maddr)
def protocols(self) -> list[TProtocol]:
"""
Get supported protocol identifiers.
Returns:
List of supported protocol strings
"""
protocols = [QUIC_V1_PROTOCOL]
if self._config.enable_draft29:
protocols.append(QUIC_DRAFT29_PROTOCOL)
return protocols
def listen_order(self) -> int:
"""
Get the listen order priority for this transport.
Matches go-libp2p's ListenOrder = 1 for QUIC.
Returns:
Priority order for listening (lower = higher priority)
"""
return 1
async def close(self) -> None:
"""Close the transport and cleanup resources."""
if self._closed:
return
self._closed = True
logger.debug("Closing QUIC transport")
# Close all active connections and listeners concurrently using trio nursery
async with trio.open_nursery() as nursery:
# Close all connections
for connection in self._connections.values():
nursery.start_soon(connection.close)
# Close all listeners
for listener in self._listeners:
nursery.start_soon(listener.close)
self._connections.clear()
self._listeners.clear()
logger.debug("QUIC transport closed")
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
"""Clean up a terminated connection from all listeners."""
try:
for listener in self._listeners:
await listener._remove_connection_by_object(connection)
logger.debug(
"✅ TRANSPORT: Cleaned up terminated connection from all listeners"
)
except Exception as e:
logger.error(f"❌ TRANSPORT: Error cleaning up terminated connection: {e}")
def get_stats(self) -> dict[str, int | list[str] | object]:
"""Get transport statistics including security info."""
return {
"active_connections": len(self._connections),
"active_listeners": len(self._listeners),
"supported_protocols": self.protocols(),
"local_peer_id": str(self._peer_id),
"security_enabled": True,
"tls_configured": True,
}
def get_security_manager(self) -> QUICTLSConfigManager:
"""
Get the security manager for this transport.
Returns:
The QUIC TLS configuration manager
"""
return self._security_manager
def get_listener_socket(self) -> trio.socket.SocketType | None:
"""Get the socket from the first active listener."""
for listener in self._listeners:
if listener.is_listening() and listener._socket:
return listener._socket
return None

View File

@ -0,0 +1,466 @@
"""
Multiaddr utilities for QUIC transport - Module 4.
Essential utilities required for QUIC transport implementation.
Based on go-libp2p and js-libp2p QUIC implementations.
"""
import ipaddress
import logging
import ssl
from aioquic.quic.configuration import QuicConfiguration
import multiaddr
from libp2p.custom_types import TProtocol
from libp2p.transport.quic.security import QUICTLSConfigManager
from .config import QUICTransportConfig
from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError
logger = logging.getLogger(__name__)
# Protocol constants
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
UDP_PROTOCOL = "udp"
IP4_PROTOCOL = "ip4"
IP6_PROTOCOL = "ip6"
SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server"
CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client"
SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server"
CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client"
CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = {
SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000
CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000
SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
}
# QUIC version to wire format mappings (required for aioquic)
QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = {
QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000
QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29
}
# ALPN protocols for libp2p over QUIC
LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"]
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
"""
Check if a multiaddr represents a QUIC address.
Valid QUIC multiaddrs:
- /ip4/127.0.0.1/udp/4001/quic-v1
- /ip4/127.0.0.1/udp/4001/quic
- /ip6/::1/udp/4001/quic-v1
- /ip6/::1/udp/4001/quic
Args:
maddr: Multiaddr to check
Returns:
True if the multiaddr represents a QUIC address
"""
try:
addr_str = str(maddr)
# Check for required components
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
has_quic = (
f"/{QUIC_V1_PROTOCOL}" in addr_str
or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str
or "/quic" in addr_str
)
return has_ip and has_udp and has_quic
except Exception:
return False
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
"""
Extract host and port from a QUIC multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
Tuple of (host, port)
Raises:
QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address
"""
if not is_quic_multiaddr(maddr):
raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}")
try:
host = None
port = None
# Try to get IPv4 address
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except Exception:
pass
# Try to get IPv6 address if IPv4 not found
if host is None:
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except Exception:
pass
# Get UDP port
try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
port = int(port_str)
except Exception:
pass
if host is None or port is None:
raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}")
return host, port
except Exception as e:
raise QUICInvalidMultiaddrError(
f"Failed to parse QUIC multiaddr {maddr}: {e}"
) from e
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
"""
Determine QUIC version from multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
QUIC version identifier ("quic-v1" or "quic")
Raises:
QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol
"""
try:
addr_str = str(maddr)
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
return QUIC_V1_PROTOCOL # RFC 9000
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
return QUIC_DRAFT29_PROTOCOL # draft-29
else:
raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}")
except Exception as e:
raise QUICInvalidMultiaddrError(
f"Failed to determine QUIC version from {maddr}: {e}"
) from e
def create_quic_multiaddr(
host: str, port: int, version: str = "quic-v1"
) -> multiaddr.Multiaddr:
"""
Create a QUIC multiaddr from host, port, and version.
Args:
host: IP address (IPv4 or IPv6)
port: UDP port number
version: QUIC version ("quic-v1" or "quic")
Returns:
QUIC multiaddr
Raises:
QUICInvalidMultiaddrError: If invalid parameters provided
"""
try:
# Determine IP version
try:
ip = ipaddress.ip_address(host)
if isinstance(ip, ipaddress.IPv4Address):
ip_proto = IP4_PROTOCOL
else:
ip_proto = IP6_PROTOCOL
except ValueError:
raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}")
# Validate port
if not (0 <= port <= 65535):
raise QUICInvalidMultiaddrError(f"Invalid port: {port}")
# Validate and normalize QUIC version
if version == "quic-v1" or version == "/quic-v1":
quic_proto = QUIC_V1_PROTOCOL
elif version == "quic" or version == "/quic":
quic_proto = QUIC_DRAFT29_PROTOCOL
else:
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
# Construct multiaddr
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
return multiaddr.Multiaddr(addr_str)
except Exception as e:
raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e
def quic_version_to_wire_format(version: TProtocol) -> int:
"""
Convert QUIC version string to wire format integer for aioquic.
Args:
version: QUIC version string ("quic-v1" or "quic")
Returns:
Wire format version number
Raises:
QUICUnsupportedVersionError: If version is not supported
"""
wire_version = QUIC_VERSION_MAPPINGS.get(version)
if wire_version is None:
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
return wire_version
def custom_quic_version_to_wire_format(version: TProtocol) -> int:
"""
Convert QUIC version string to wire format integer for aioquic.
Args:
version: QUIC version string ("quic-v1" or "quic")
Returns:
Wire format version number
Raises:
QUICUnsupportedVersionError: If version is not supported
"""
wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version)
if wire_version is None:
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
return wire_version
def get_alpn_protocols() -> list[str]:
"""
Get ALPN protocols for libp2p over QUIC.
Returns:
List of ALPN protocol identifiers
"""
return LIBP2P_ALPN_PROTOCOLS.copy()
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
"""
Normalize a QUIC multiaddr to canonical form.
Args:
maddr: Input QUIC multiaddr
Returns:
Normalized multiaddr
Raises:
QUICInvalidMultiaddrError: If not a valid QUIC multiaddr
"""
if not is_quic_multiaddr(maddr):
raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}")
host, port = quic_multiaddr_to_endpoint(maddr)
version = multiaddr_to_quic_version(maddr)
return create_quic_multiaddr(host, port, version)
def create_server_config_from_base(
base_config: QuicConfiguration,
security_manager: QUICTLSConfigManager | None = None,
transport_config: QUICTransportConfig | None = None,
) -> QuicConfiguration:
"""
Create a server configuration without using deepcopy.
Manually copies attributes while handling cryptography objects properly.
"""
try:
# Create new server configuration from scratch
server_config = QuicConfiguration(is_client=False)
server_config.verify_mode = ssl.CERT_NONE
# Copy basic configuration attributes (these are safe to copy)
copyable_attrs = [
"alpn_protocols",
"verify_mode",
"max_datagram_frame_size",
"idle_timeout",
"max_concurrent_streams",
"supported_versions",
"max_data",
"max_stream_data",
"stateless_retry",
"quantum_readiness_test",
]
for attr in copyable_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(server_config, attr, value)
# Handle cryptography objects - these need direct reference, not copying
crypto_attrs = [
"certificate",
"private_key",
"certificate_chain",
"ca_certs",
]
for attr in crypto_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(server_config, attr, value)
# Apply security manager configuration if available
if security_manager:
try:
server_tls_config = security_manager.create_server_config()
# Override with security manager's TLS configuration
if server_tls_config.certificate:
server_config.certificate = server_tls_config.certificate
if server_tls_config.private_key:
server_config.private_key = server_tls_config.private_key
if server_tls_config.certificate_chain:
server_config.certificate_chain = (
server_tls_config.certificate_chain
)
if server_tls_config.alpn_protocols:
server_config.alpn_protocols = server_tls_config.alpn_protocols
server_tls_config.request_client_certificate = True
if getattr(server_tls_config, "request_client_certificate", False):
server_config._libp2p_request_client_cert = True # type: ignore
else:
logger.error(
"🔧 Failed to set request_client_certificate in server config"
)
except Exception as e:
logger.warning(f"Failed to apply security manager config: {e}")
# Set transport-specific defaults if provided
if transport_config:
if server_config.idle_timeout == 0:
server_config.idle_timeout = getattr(
transport_config, "idle_timeout", 30.0
)
if server_config.max_datagram_frame_size is None:
server_config.max_datagram_frame_size = getattr(
transport_config, "max_datagram_size", 1200
)
# Ensure we have ALPN protocols
if not server_config.alpn_protocols:
server_config.alpn_protocols = ["libp2p"]
logger.debug("Successfully created server config without deepcopy")
return server_config
except Exception as e:
logger.error(f"Failed to create server config: {e}")
raise
def create_client_config_from_base(
base_config: QuicConfiguration,
security_manager: QUICTLSConfigManager | None = None,
transport_config: QUICTransportConfig | None = None,
) -> QuicConfiguration:
"""
Create a client configuration without using deepcopy.
"""
try:
# Create new client configuration from scratch
client_config = QuicConfiguration(is_client=True)
client_config.verify_mode = ssl.CERT_NONE
# Copy basic configuration attributes
copyable_attrs = [
"alpn_protocols",
"verify_mode",
"max_datagram_frame_size",
"idle_timeout",
"max_concurrent_streams",
"supported_versions",
"max_data",
"max_stream_data",
"quantum_readiness_test",
]
for attr in copyable_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(client_config, attr, value)
# Handle cryptography objects - these need direct reference, not copying
crypto_attrs = [
"certificate",
"private_key",
"certificate_chain",
"ca_certs",
]
for attr in crypto_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(client_config, attr, value)
# Apply security manager configuration if available
if security_manager:
try:
client_tls_config = security_manager.create_client_config()
# Override with security manager's TLS configuration
if client_tls_config.certificate:
client_config.certificate = client_tls_config.certificate
if client_tls_config.private_key:
client_config.private_key = client_tls_config.private_key
if client_tls_config.certificate_chain:
client_config.certificate_chain = (
client_tls_config.certificate_chain
)
if client_tls_config.alpn_protocols:
client_config.alpn_protocols = client_tls_config.alpn_protocols
except Exception as e:
logger.warning(f"Failed to apply security manager config: {e}")
# Ensure we have ALPN protocols
if not client_config.alpn_protocols:
client_config.alpn_protocols = ["libp2p"]
logger.debug("Successfully created client config without deepcopy")
return client_config
except Exception as e:
logger.error(f"Failed to create client config: {e}")
raise

View File

@ -0,0 +1,267 @@
"""
Transport registry for dynamic transport selection based on multiaddr protocols.
"""
from collections.abc import Callable
import logging
from typing import Any
from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
from libp2p.abc import ITransport
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.multiaddr_utils import (
is_valid_websocket_multiaddr,
)
# Import QUIC utilities here to avoid circular imports
def _get_quic_transport() -> Any:
from libp2p.transport.quic.transport import QUICTransport
return QUICTransport
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
from libp2p.transport.quic.utils import is_quic_multiaddr
return is_quic_multiaddr
# Import WebsocketTransport here to avoid circular imports
def _get_websocket_transport() -> Any:
from libp2p.transport.websocket.transport import WebsocketTransport
return WebsocketTransport
logger = logging.getLogger("libp2p.transport.registry")
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid TCP structure.
:param maddr: The multiaddr to validate
:return: True if valid TCP structure, False otherwise
"""
try:
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
# or /ip6/::1/tcp/8080
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 2 protocols: network (ip4/ip6) + tcp
if len(protocols) < 2:
return False
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Should not have any protocols after tcp (unless it's a valid
# continuation like p2p)
# For now, we'll be strict and only allow network + tcp
if len(protocols) > 2:
# Check if the additional protocols are valid continuations
valid_continuations = ["p2p"] # Add more as needed
for i in range(2, len(protocols)):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False
class TransportRegistry:
"""
Registry for mapping multiaddr protocols to transport implementations.
"""
def __init__(self) -> None:
self._transports: dict[str, type[ITransport]] = {}
self._register_default_transports()
def _register_default_transports(self) -> None:
"""Register the default transport implementations."""
# Register TCP transport for /tcp protocol
self.register_transport("tcp", TCP)
# Register WebSocket transport for /ws and /wss protocols
WebsocketTransport = _get_websocket_transport()
self.register_transport("ws", WebsocketTransport)
self.register_transport("wss", WebsocketTransport)
# Register QUIC transport for /quic and /quic-v1 protocols
QUICTransport = _get_quic_transport()
self.register_transport("quic", QUICTransport)
self.register_transport("quic-v1", QUICTransport)
def register_transport(
self, protocol: str, transport_class: type[ITransport]
) -> None:
"""
Register a transport class for a specific protocol.
:param protocol: The protocol identifier (e.g., "tcp", "ws")
:param transport_class: The transport class to register
"""
self._transports[protocol] = transport_class
logger.debug(
f"Registered transport {transport_class.__name__} for protocol {protocol}"
)
def get_transport(self, protocol: str) -> type[ITransport] | None:
"""
Get the transport class for a specific protocol.
:param protocol: The protocol identifier
:return: The transport class or None if not found
"""
return self._transports.get(protocol)
def get_supported_protocols(self) -> list[str]:
"""Get list of supported transport protocols."""
return list(self._transports.keys())
def create_transport(
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
) -> ITransport | None:
"""
Create a transport instance for a specific protocol.
:param protocol: The protocol identifier
:param upgrader: The transport upgrader instance (required for WebSocket)
:param kwargs: Additional arguments for transport construction
:return: Transport instance or None if protocol not supported or creation fails
"""
transport_class = self.get_transport(protocol)
if transport_class is None:
return None
try:
if protocol in ["ws", "wss"]:
# WebSocket transport requires upgrader
if upgrader is None:
logger.warning(
f"WebSocket transport '{protocol}' requires upgrader"
)
return None
# Use explicit WebsocketTransport to avoid type issues
WebsocketTransport = _get_websocket_transport()
return WebsocketTransport(
upgrader,
tls_client_config=kwargs.get("tls_client_config"),
tls_server_config=kwargs.get("tls_server_config"),
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
)
elif protocol in ["quic", "quic-v1"]:
# QUIC transport requires private_key
private_key = kwargs.get("private_key")
if private_key is None:
logger.warning(f"QUIC transport '{protocol}' requires private_key")
return None
# Use explicit QUICTransport to avoid type issues
QUICTransport = _get_quic_transport()
config = kwargs.get("config")
return QUICTransport(private_key, config)
else:
# TCP transport doesn't require upgrader
return transport_class()
except Exception as e:
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
return None
# Global transport registry instance (lazy initialization)
_global_registry: TransportRegistry | None = None
def get_transport_registry() -> TransportRegistry:
"""Get the global transport registry instance."""
global _global_registry
if _global_registry is None:
_global_registry = TransportRegistry()
return _global_registry
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
"""Register a transport class in the global registry."""
registry = get_transport_registry()
registry.register_transport(protocol, transport_class)
def create_transport_for_multiaddr(
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
) -> ITransport | None:
"""
Create the appropriate transport for a given multiaddr.
:param maddr: The multiaddr to create transport for
:param upgrader: The transport upgrader instance
:param kwargs: Additional arguments for transport construction
(e.g., private_key for QUIC)
:return: Transport instance or None if no suitable transport found
"""
try:
# Get all protocols in the multiaddr
protocols = [proto.name for proto in maddr.protocols()]
# Check for supported transport protocols in order of preference
# We need to validate that the multiaddr structure is valid for our transports
if "quic" in protocols or "quic-v1" in protocols:
# For QUIC, we need a valid structure like:
# /ip4/127.0.0.1/udp/4001/quic
# /ip4/127.0.0.1/udp/4001/quic-v1
is_quic_multiaddr = _get_quic_validation()
if is_quic_multiaddr(maddr):
# Determine QUIC version
registry = get_transport_registry()
if "quic-v1" in protocols:
return registry.create_transport("quic-v1", upgrader, **kwargs)
else:
return registry.create_transport("quic", upgrader, **kwargs)
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
# For WebSocket, we need a valid structure like:
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
# /ip4/127.0.0.1/tcp/8080/wss (secure)
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
if is_valid_websocket_multiaddr(maddr):
# Determine if this is a secure WebSocket connection
registry = get_transport_registry()
if "wss" in protocols or "tls" in protocols:
return registry.create_transport("wss", upgrader, **kwargs)
else:
return registry.create_transport("ws", upgrader, **kwargs)
elif "tcp" in protocols:
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
# Check if the multiaddr has proper TCP structure
if _is_valid_tcp_multiaddr(maddr):
registry = get_transport_registry()
return registry.create_transport("tcp", upgrader)
# If no supported transport protocol found or structure is invalid, return None
logger.warning(
f"No supported transport protocol found or invalid structure in "
f"multiaddr: {maddr}"
)
return None
except Exception as e:
# Handle any errors gracefully (e.g., invalid multiaddr)
logger.warning(f"Error processing multiaddr {maddr}: {e}")
return None
def get_supported_transport_protocols() -> list[str]:
"""Get list of supported transport protocols from the global registry."""
registry = get_transport_registry()
return registry.get_supported_protocols()

View File

@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
DEFAULT_NEGOTIATE_TIMEOUT,
)
from libp2p.security.exceptions import (
HandshakeFailure,
)
@ -37,9 +40,12 @@ class TransportUpgrader:
self,
secure_transports_by_protocol: TSecurityOptions,
muxer_transports_by_protocol: TMuxerOptions,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
):
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(
muxer_transports_by_protocol, negotiate_timeout
)
async def upgrade_security(
self,

View File

@ -0,0 +1,198 @@
import logging
import time
from typing import Any
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger(__name__)
class P2PWebSocketConnection(ReadWriteCloser):
"""
Wraps a WebSocketConnection to provide the raw stream interface
that libp2p protocols expect.
Implements production-ready buffer management and flow control
as recommended in the libp2p WebSocket specification.
"""
def __init__(
self,
ws_connection: Any,
ws_context: Any = None,
is_secure: bool = False,
max_buffered_amount: int = 4 * 1024 * 1024,
) -> None:
self._ws_connection = ws_connection
self._ws_context = ws_context
self._is_secure = is_secure
self._read_buffer = b""
self._read_lock = trio.Lock()
self._connection_start_time = time.time()
self._bytes_read = 0
self._bytes_written = 0
self._closed = False
self._close_lock = trio.Lock()
self._max_buffered_amount = max_buffered_amount
self._write_lock = trio.Lock()
async def write(self, data: bytes) -> None:
"""Write data with flow control and buffer management"""
if self._closed:
raise IOException("Connection is closed")
async with self._write_lock:
try:
logger.debug(f"WebSocket writing {len(data)} bytes")
# Check buffer amount for flow control
if hasattr(self._ws_connection, "bufferedAmount"):
buffered = self._ws_connection.bufferedAmount
if buffered > self._max_buffered_amount:
logger.warning(f"WebSocket buffer full: {buffered} bytes")
# In production, you might want to
# wait or implement backpressure
# For now, we'll continue but log the warning
# Send as a binary WebSocket message
await self._ws_connection.send_message(data)
self._bytes_written += len(data)
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
except Exception as e:
logger.error(f"WebSocket write failed: {e}")
self._closed = True
raise IOException from e
async def read(self, n: int | None = None) -> bytes:
"""
Read up to n bytes (if n is given), else read up to 64KiB.
This implementation provides byte-level access to WebSocket messages,
which is required for libp2p protocol compatibility.
For WebSocket compatibility with libp2p protocols, this method:
1. Buffers incoming WebSocket messages
2. Returns exactly the requested number of bytes when n is specified
3. Accumulates multiple WebSocket messages if needed to satisfy the request
4. Returns empty bytes (not raises) when connection is closed and no data
available
"""
if self._closed:
raise IOException("Connection is closed")
async with self._read_lock:
try:
# If n is None, read at least one message and return all buffered data
if n is None:
if not self._read_buffer:
try:
# Use a short timeout to avoid blocking indefinitely
with trio.fail_after(1.0): # 1 second timeout
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode("utf-8")
self._read_buffer = message
except trio.TooSlowError:
# No message available within timeout
return b""
except Exception:
# Return empty bytes if no data available
# (connection closed)
return b""
result = self._read_buffer
self._read_buffer = b""
self._bytes_read += len(result)
return result
# For specific byte count requests, return UP TO n bytes (not exactly n)
# This matches TCP semantics where read(1024) returns available data
# up to 1024 bytes
# If we don't have any data buffered, try to get at least one message
if not self._read_buffer:
try:
# Use a short timeout to avoid blocking indefinitely
with trio.fail_after(1.0): # 1 second timeout
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode("utf-8")
self._read_buffer = message
except trio.TooSlowError:
return b"" # No data available
except Exception:
return b""
# Now return up to n bytes from the buffer (TCP-like semantics)
if len(self._read_buffer) == 0:
return b""
# Return up to n bytes (like TCP read())
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[len(result) :]
self._bytes_read += len(result)
return result
except Exception as e:
logger.error(f"WebSocket read failed: {e}")
raise IOException from e
async def close(self) -> None:
"""Close the WebSocket connection. This method is idempotent."""
async with self._close_lock:
if self._closed:
return # Already closed
logger.debug("WebSocket connection closing")
self._closed = True
try:
# Always close the connection directly, avoid context manager issues
# The context manager may be causing cancel scope corruption
logger.debug("WebSocket closing connection directly")
await self._ws_connection.aclose()
# Exit the context manager if we have one
if self._ws_context is not None:
await self._ws_context.__aexit__(None, None, None)
except Exception as e:
logger.error(f"WebSocket close error: {e}")
# Don't raise here, as close() should be idempotent
finally:
logger.debug("WebSocket connection closed")
def is_closed(self) -> bool:
"""Check if the connection is closed"""
return self._closed
def conn_state(self) -> dict[str, Any]:
"""
Return connection state information similar to Go's ConnState() method.
:return: Dictionary containing connection state information
"""
current_time = time.time()
return {
"transport": "websocket",
"secure": self._is_secure,
"connection_duration": current_time - self._connection_start_time,
"bytes_read": self._bytes_read,
"bytes_written": self._bytes_written,
"total_bytes": self._bytes_read + self._bytes_written,
}
def get_remote_address(self) -> tuple[str, int] | None:
# Try to get remote address from the WebSocket connection
try:
remote = self._ws_connection.remote
if hasattr(remote, "address") and hasattr(remote, "port"):
return str(remote.address), int(remote.port)
elif isinstance(remote, str):
# Parse address:port format
if ":" in remote:
host, port = remote.rsplit(":", 1)
return host, int(port)
except Exception:
pass
return None

View File

@ -0,0 +1,225 @@
from collections.abc import Awaitable, Callable
import logging
import ssl
from typing import Any
from multiaddr import Multiaddr
import trio
from trio_typing import TaskStatus
from trio_websocket import serve_websocket
from libp2p.abc import IListener
from libp2p.custom_types import THandler
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
from .connection import P2PWebSocketConnection
logger = logging.getLogger("libp2p.transport.websocket.listener")
class WebsocketListener(IListener):
"""
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
"""
def __init__(
self,
handler: THandler,
upgrader: TransportUpgrader,
tls_config: ssl.SSLContext | None = None,
handshake_timeout: float = 15.0,
) -> None:
self._handler = handler
self._upgrader = upgrader
self._tls_config = tls_config
self._handshake_timeout = handshake_timeout
self._server = None
self._shutdown_event = trio.Event()
self._nursery: trio.Nursery | None = None
self._listeners: Any = None
self._is_wss = False # Track whether this is a WSS listener
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
logger.debug(f"WebsocketListener.listen called with {maddr}")
# Parse the WebSocket multiaddr to determine if it's secure
try:
parsed = parse_websocket_multiaddr(maddr)
except ValueError as e:
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
# Check if WSS is requested but no TLS config provided
if parsed.is_wss and self._tls_config is None:
raise ValueError(
f"Cannot listen on WSS address {maddr} without TLS configuration"
)
# Store whether this is a WSS listener
self._is_wss = parsed.is_wss
# Extract host and port from the base multiaddr
host = (
parsed.rest_multiaddr.value_for_protocol("ip4")
or parsed.rest_multiaddr.value_for_protocol("ip6")
or parsed.rest_multiaddr.value_for_protocol("dns")
or parsed.rest_multiaddr.value_for_protocol("dns4")
or parsed.rest_multiaddr.value_for_protocol("dns6")
or "0.0.0.0"
)
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
logger.debug(
f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}"
)
async def serve_websocket_tcp(
handler: Callable[[Any], Awaitable[None]],
port: int,
host: str,
task_status: TaskStatus[Any],
) -> None:
"""Start TCP server and handle WebSocket connections manually"""
logger.debug(
"serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss
)
async def websocket_handler(request: Any) -> None:
"""Handle WebSocket requests"""
logger.debug("WebSocket request received")
try:
# Apply handshake timeout
with trio.fail_after(self._handshake_timeout):
# Accept the WebSocket connection
ws_connection = await request.accept()
logger.debug("WebSocket handshake successful")
# Create the WebSocket connection wrapper
conn = P2PWebSocketConnection(
ws_connection, is_secure=parsed.is_wss
) # type: ignore[no-untyped-call]
# Call the handler function that was passed to create_listener
# This handler will handle the security and muxing upgrades
logger.debug("Calling connection handler")
await self._handler(conn)
# Don't keep the connection alive indefinitely
# Let the handler manage the connection lifecycle
logger.debug(
"Handler completed, connection will be managed by handler"
)
except trio.TooSlowError:
logger.debug(
f"WebSocket handshake timeout after {self._handshake_timeout}s"
)
try:
await request.reject(408) # Request Timeout
except Exception:
pass
except Exception as e:
logger.debug(f"WebSocket connection error: {e}")
logger.debug(f"Error type: {type(e)}")
import traceback
logger.debug(f"Traceback: {traceback.format_exc()}")
# Reject the connection
try:
await request.reject(400)
except Exception:
pass
# Use trio_websocket.serve_websocket for proper WebSocket handling
ssl_context = self._tls_config if parsed.is_wss else None
await serve_websocket(
websocket_handler, host, port, ssl_context, task_status=task_status
)
# Store the nursery for shutdown
self._nursery = nursery
# Start the server using nursery.start() like TCP does
logger.debug("Calling nursery.start()...")
started_listeners = await nursery.start(
serve_websocket_tcp,
None, # No handler needed since it's defined inside serve_websocket_tcp
port,
host,
)
logger.debug(f"nursery.start() returned: {started_listeners}")
if started_listeners is None:
logger.error(f"Failed to start WebSocket listener for {maddr}")
return False
# Store the listeners for get_addrs() and close() - these are real
# SocketListener objects
self._listeners = started_listeners
logger.debug(
"WebsocketListener.listen returning True with WebSocketServer object"
)
return True
def get_addrs(self) -> tuple[Multiaddr, ...]:
if not hasattr(self, "_listeners") or not self._listeners:
logger.debug("No listeners available for get_addrs()")
return ()
# Handle WebSocketServer objects
if hasattr(self._listeners, "port"):
# This is a WebSocketServer object
port = self._listeners.port
# Create a multiaddr from the port with correct WSS/WS protocol
protocol = "wss" if self._is_wss else "ws"
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),)
else:
# This is a list of listeners (like TCP)
listeners = self._listeners
# Get addresses from listeners like TCP does
return tuple(
_multiaddr_from_socket(listener.socket, self._is_wss)
for listener in listeners
)
async def close(self) -> None:
"""Close the WebSocket listener and stop accepting new connections"""
logger.debug("WebsocketListener.close called")
if hasattr(self, "_listeners") and self._listeners:
# Signal shutdown
self._shutdown_event.set()
# Close the WebSocket server
if hasattr(self._listeners, "aclose"):
# This is a WebSocketServer object
logger.debug("Closing WebSocket server")
await self._listeners.aclose()
logger.debug("WebSocket server closed")
elif isinstance(self._listeners, (list, tuple)):
# This is a list of listeners (like TCP)
logger.debug("Closing TCP listeners")
for listener in self._listeners:
await listener.aclose()
logger.debug("TCP listeners closed")
else:
# Unknown type, try to close it directly
logger.debug("Closing unknown listener type")
if hasattr(self._listeners, "close"):
self._listeners.close()
logger.debug("Unknown listener closed")
# Clear the listeners reference
self._listeners = None
logger.debug("WebsocketListener.close completed")
def _multiaddr_from_socket(
socket: trio.socket.SocketType, is_wss: bool = False
) -> Multiaddr:
"""Convert socket to multiaddr"""
ip, port = socket.getsockname()
protocol = "wss" if is_wss else "ws"
return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}")

View File

@ -0,0 +1,202 @@
"""
WebSocket multiaddr parsing utilities.
"""
from typing import NamedTuple
from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
class ParsedWebSocketMultiaddr(NamedTuple):
"""Parsed WebSocket multiaddr information."""
is_wss: bool
sni: str | None
rest_multiaddr: Multiaddr
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
"""
Parse a WebSocket multiaddr and extract security information.
:param maddr: The multiaddr to parse
:return: Parsed WebSocket multiaddr information
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
"""
# First validate that this is a valid WebSocket multiaddr
if not is_valid_websocket_multiaddr(maddr):
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
protocols = list(maddr.protocols())
# Find the WebSocket protocol and check for security
is_wss = False
sni = None
ws_index = -1
tls_index = -1
sni_index = -1
# Find protocol indices
for i, protocol in enumerate(protocols):
if protocol.name == "ws":
ws_index = i
elif protocol.name == "wss":
ws_index = i
is_wss = True
elif protocol.name == "tls":
tls_index = i
elif protocol.name == "sni":
sni_index = i
sni = protocol.value
if ws_index == -1:
raise ValueError("Not a WebSocket multiaddr")
# Handle /wss protocol (convert to /tls/ws internally)
if is_wss and tls_index == -1:
# Convert /wss to /tls/ws format
# Remove /wss to get the base multiaddr
without_wss = maddr.decapsulate(Multiaddr("/wss"))
return ParsedWebSocketMultiaddr(
is_wss=True, sni=None, rest_multiaddr=without_wss
)
# Handle /tls/ws and /tls/sni/.../ws formats
if tls_index != -1:
is_wss = True
# Extract the base multiaddr (everything before /tls)
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
# Use multiaddr methods to properly extract the base
rest_multiaddr = maddr
# Remove /tls/ws or /tls/sni/.../ws from the end
if sni_index != -1:
# /tls/sni/example.com/ws format
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
else:
# /tls/ws format
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
return ParsedWebSocketMultiaddr(
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
)
# Regular /ws multiaddr - remove /ws and any additional protocols
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
return ParsedWebSocketMultiaddr(
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
)
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid WebSocket structure.
:param maddr: The multiaddr to validate
:return: True if valid WebSocket structure, False otherwise
"""
try:
# WebSocket multiaddr should have structure like:
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
# /ip4/127.0.0.1/tcp/8080/wss (secure)
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
if len(protocols) < 3:
return False
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Check for valid WebSocket protocols
ws_protocols = ["ws", "wss"]
tls_protocols = ["tls"]
sni_protocols = ["sni"]
# Find the WebSocket protocol
ws_protocol_found = False
tls_found = False
# sni_found = False # Not used currently
for i, protocol in enumerate(protocols[2:], start=2):
if protocol.name in ws_protocols:
ws_protocol_found = True
break
elif protocol.name in tls_protocols:
tls_found = True
elif protocol.name in sni_protocols:
pass # sni_found = True # Not used in current implementation
if not ws_protocol_found:
return False
# Validate protocol sequence
# For /ws: network + tcp + ws
# For /wss: network + tcp + wss
# For /tls/ws: network + tcp + tls + ws
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
# Check if it's a simple /ws or /wss
if len(protocols) == 3:
return protocols[2].name in ["ws", "wss"]
# Check for /tls/ws or /tls/sni/.../ws patterns
if tls_found:
# Must end with /ws (not /wss when using /tls)
if protocols[-1].name != "ws":
return False
# Check for valid TLS sequence
tls_index = None
for i, protocol in enumerate(protocols[2:], start=2):
if protocol.name == "tls":
tls_index = i
break
if tls_index is None:
return False
# After tls, we can have sni, then ws
remaining_protocols = protocols[tls_index + 1 :]
if len(remaining_protocols) == 1:
# /tls/ws
return remaining_protocols[0].name == "ws"
elif len(remaining_protocols) == 2:
# /tls/sni/example.com/ws
return (
remaining_protocols[0].name == "sni"
and remaining_protocols[1].name == "ws"
)
else:
return False
# If we have more than 3 protocols but no TLS, check for valid continuations
# Allow additional protocols after the WebSocket protocol (like /p2p)
valid_continuations = ["p2p"]
# Find the WebSocket protocol index
ws_index = None
for i, protocol in enumerate(protocols):
if protocol.name in ["ws", "wss"]:
ws_index = i
break
if ws_index is not None:
# Check protocols after the WebSocket protocol
for i in range(ws_index + 1, len(protocols)):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False

View File

@ -0,0 +1,229 @@
import logging
import ssl
from multiaddr import Multiaddr
from libp2p.abc import IListener, ITransport
from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
from .connection import P2PWebSocketConnection
from .listener import WebsocketListener
logger = logging.getLogger(__name__)
class WebsocketTransport(ITransport):
"""
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
Implements production-ready WebSocket transport with:
- Flow control and buffer management
- Connection limits and rate limiting
- Proper error handling and cleanup
- Support for both WS and WSS protocols
- TLS configuration and handshake timeout
"""
def __init__(
self,
upgrader: TransportUpgrader,
tls_client_config: ssl.SSLContext | None = None,
tls_server_config: ssl.SSLContext | None = None,
handshake_timeout: float = 15.0,
max_buffered_amount: int = 4 * 1024 * 1024,
):
self._upgrader = upgrader
self._tls_client_config = tls_client_config
self._tls_server_config = tls_server_config
self._handshake_timeout = handshake_timeout
self._max_buffered_amount = max_buffered_amount
self._connection_count = 0
self._max_connections = 1000 # Production limit
async def dial(self, maddr: Multiaddr) -> RawConnection:
"""Dial a WebSocket connection to the given multiaddr."""
logger.debug(f"WebsocketTransport.dial called with {maddr}")
# Parse the WebSocket multiaddr to determine if it's secure
try:
parsed = parse_websocket_multiaddr(maddr)
except ValueError as e:
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
# Extract host and port from the base multiaddr
host = (
parsed.rest_multiaddr.value_for_protocol("ip4")
or parsed.rest_multiaddr.value_for_protocol("ip6")
or parsed.rest_multiaddr.value_for_protocol("dns")
or parsed.rest_multiaddr.value_for_protocol("dns4")
or parsed.rest_multiaddr.value_for_protocol("dns6")
)
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
# Build WebSocket URL based on security
if parsed.is_wss:
ws_url = f"wss://{host}:{port}/"
else:
ws_url = f"ws://{host}:{port}/"
logger.debug(
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
)
try:
# Check connection limits
if self._connection_count >= self._max_connections:
raise OpenConnectionError(
f"Maximum connections reached: {self._max_connections}"
)
# Prepare SSL context for WSS connections
ssl_context = None
if parsed.is_wss:
if self._tls_client_config:
ssl_context = self._tls_client_config
else:
# Create default SSL context for client
ssl_context = ssl.create_default_context()
# Set SNI if available
if parsed.sni:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
# Use a different approach: start background nursery that will persist
logger.debug("WebsocketTransport.dial establishing connection")
# Import trio-websocket functions
from trio_websocket import connect_websocket
from trio_websocket._impl import _url_to_host
# Parse the WebSocket URL to get host, port, resource
# like trio-websocket does
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
ws_url, ssl_context
)
logger.debug(
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
f"port={ws_port}, resource={ws_resource}"
)
# Create a background task manager for this connection
import trio
nursery_manager = trio.lowlevel.current_task().parent_nursery
if nursery_manager is None:
raise OpenConnectionError(
f"No parent nursery available for WebSocket connection to {maddr}"
)
# Apply timeout to the connection process
with trio.fail_after(self._handshake_timeout):
logger.debug("WebsocketTransport.dial connecting WebSocket")
ws = await connect_websocket(
nursery_manager, # Use the existing nursery from libp2p
ws_host,
ws_port,
ws_resource,
use_ssl=ws_ssl_context,
message_queue_size=1024, # Reasonable defaults
max_message_size=16 * 1024 * 1024, # 16MB max message
)
logger.debug("WebsocketTransport.dial WebSocket connection established")
# Create our connection wrapper with both WSS support and flow control
conn = P2PWebSocketConnection(
ws,
None,
is_secure=parsed.is_wss,
max_buffered_amount=self._max_buffered_amount,
)
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
self._connection_count += 1
logger.debug(f"Total connections: {self._connection_count}")
return RawConnection(conn, initiator=True)
except trio.TooSlowError as e:
raise OpenConnectionError(
f"WebSocket handshake timeout after {self._handshake_timeout}s "
f"for {maddr}"
) from e
except Exception as e:
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
"""
The type checker is incorrectly reporting this as an inconsistent override.
"""
logger.debug("WebsocketTransport.create_listener called")
return WebsocketListener(
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
)
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
"""
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
Similar to Go's Resolve() method.
:param maddr: The multiaddr to resolve
:return: List of resolved multiaddrs
"""
try:
parsed = parse_websocket_multiaddr(maddr)
except ValueError as e:
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
return [maddr] # Return original if not a valid WebSocket multiaddr
logger.debug(
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
)
if not parsed.is_wss:
# No /tls/ws component, this isn't a secure websocket multiaddr
return [maddr]
if parsed.sni is not None:
# Already has SNI, return as-is
return [maddr]
# Try to extract DNS name from the base multiaddr
dns_name = None
for protocol_name in ["dns", "dns4", "dns6"]:
try:
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
break
except Exception:
continue
if dns_name is None:
# No DNS name found, return original
return [maddr]
# Create new multiaddr with SNI
# For /dns/example.com/tcp/8080/wss ->
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
try:
# Remove /wss and add /tls/sni/example.com/ws
without_wss = maddr.decapsulate(Multiaddr("/wss"))
sni_component = Multiaddr(f"/sni/{dns_name}")
resolved = (
without_wss.encapsulate(Multiaddr("/tls"))
.encapsulate(sni_component)
.encapsulate(Multiaddr("/ws"))
)
logger.debug(f"Resolved {maddr} to {resolved}")
return [resolved]
except Exception as e:
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
return [maddr]

View File

@ -3,38 +3,24 @@ from __future__ import annotations
import socket
from multiaddr import Multiaddr
try:
from multiaddr.utils import ( # type: ignore
get_network_addrs,
get_thin_waist_addresses,
)
_HAS_THIN_WAIST = True
except ImportError: # pragma: no cover - only executed in older environments
_HAS_THIN_WAIST = False
get_thin_waist_addresses = None # type: ignore
get_network_addrs = None # type: ignore
from multiaddr.utils import get_network_addrs, get_thin_waist_addresses
def _safe_get_network_addrs(ip_version: int) -> list[str]:
"""
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
Falls back to minimal defaults when Thin Waist helpers are missing.
:param ip_version: 4 or 6
"""
if _HAS_THIN_WAIST and get_network_addrs:
try:
return get_network_addrs(ip_version) or []
except Exception: # pragma: no cover - defensive
return []
# Fallback behavior (very conservative)
if ip_version == 4:
return ["127.0.0.1"]
if ip_version == 6:
return ["::1"]
return []
try:
return get_network_addrs(ip_version) or []
except Exception: # pragma: no cover - defensive
# Fallback behavior (very conservative)
if ip_version == 4:
return ["127.0.0.1"]
if ip_version == 6:
return ["::1"]
return []
def find_free_port() -> int:
@ -47,16 +33,13 @@ def find_free_port() -> int:
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
"""
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
If Thin Waist isn't available, returns [addr] (identity).
"""
if _HAS_THIN_WAIST and get_thin_waist_addresses:
try:
if port is not None:
return get_thin_waist_addresses(addr, port=port) or []
return get_thin_waist_addresses(addr) or []
except Exception: # pragma: no cover - defensive
return [addr]
return [addr]
try:
if port is not None:
return get_thin_waist_addresses(addr, port=port) or []
return get_thin_waist_addresses(addr) or []
except Exception: # pragma: no cover - defensive
return [addr]
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
@ -73,8 +56,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
seen_v4: set[str] = set()
for ip in _safe_get_network_addrs(4):
seen_v4.add(ip)
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
if ip not in seen_v4: # Avoid duplicates
seen_v4.add(ip)
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
if seen_v4 and "127.0.0.1" not in seen_v4:
@ -89,8 +73,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
#
# seen_v6: set[str] = set()
# for ip in _safe_get_network_addrs(6):
# seen_v6.add(ip)
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
# if ip not in seen_v6: # Avoid duplicates
# seen_v6.add(ip)
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
#
# # Always include IPv6 loopback for testing purposes when IPv6 is available
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
@ -99,7 +84,7 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
# Fallback if nothing discovered
if not addrs:
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
return addrs
@ -120,6 +105,20 @@ def expand_wildcard_address(
return expanded
def get_wildcard_address(port: int, protocol: str = "tcp") -> Multiaddr:
"""
Get wildcard address (0.0.0.0) when explicitly needed.
This function provides access to wildcard binding as a feature when
explicitly required, preserving the ability to bind to all interfaces.
:param port: Port number.
:param protocol: Transport protocol.
:return: A Multiaddr with wildcard binding (0.0.0.0).
"""
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
"""
Choose an optimal address for an example to bind to:
@ -148,13 +147,14 @@ def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
return c
# As a final fallback, produce a wildcard
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
# As a final fallback, produce a loopback address
return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")
__all__ = [
"get_available_interfaces",
"get_optimal_binding_address",
"get_wildcard_address",
"expand_wildcard_address",
"find_free_port",
]