mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-09 22:50:54 +00:00
Merge branch 'main' into chore01
This commit is contained in:
@ -1,5 +1,12 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
from typing import Any
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
@ -18,6 +25,7 @@ from libp2p.abc import (
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
ISecureTransport,
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
KeyPair,
|
||||
@ -38,10 +46,12 @@ from libp2p.host.routed_host import (
|
||||
RoutedHost,
|
||||
)
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.network.config import (
|
||||
ConnectionConfig,
|
||||
RetryConfig
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -72,6 +82,10 @@ from libp2p.transport.tcp.tcp import (
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
from libp2p.transport.transport_registry import (
|
||||
create_transport_for_multiaddr,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from libp2p.utils.logging import (
|
||||
setup_logging,
|
||||
)
|
||||
@ -87,6 +101,7 @@ MUXER_YAMUX = "YAMUX"
|
||||
MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
"""
|
||||
@ -162,9 +177,13 @@ def new_swarm(
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
enable_quic: bool = False,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: Optional["ConnectionConfig"] = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
) -> INetworkService:
|
||||
logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}")
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
|
||||
@ -174,6 +193,8 @@ def new_swarm(
|
||||
:param peerstore_opt: optional peerstore
|
||||
:param muxer_preference: optional explicit muxer preference
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_quic: enable quic for transport
|
||||
:param quic_transport_opt: options for transport
|
||||
:return: return a default swarm instance
|
||||
|
||||
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
||||
@ -186,16 +207,48 @@ def new_swarm(
|
||||
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
transport: TCP | QUICTransport | ITransport
|
||||
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
|
||||
|
||||
if listen_addrs is None:
|
||||
transport = TCP()
|
||||
else:
|
||||
addr = listen_addrs[0]
|
||||
if addr.__contains__("tcp"):
|
||||
transport = TCP()
|
||||
elif addr.__contains__("quic"):
|
||||
raise ValueError("QUIC not yet supported")
|
||||
if enable_quic:
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
else:
|
||||
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
|
||||
transport = TCP()
|
||||
else:
|
||||
# Use transport registry to select the appropriate transport
|
||||
from libp2p.transport.transport_registry import create_transport_for_multiaddr
|
||||
|
||||
# Create a temporary upgrader for transport selection
|
||||
# We'll create the real upgrader later with the proper configuration
|
||||
temp_upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={},
|
||||
muxer_transports_by_protocol={}
|
||||
)
|
||||
|
||||
addr = listen_addrs[0]
|
||||
logger.debug(f"new_swarm: Creating transport for address: {addr}")
|
||||
transport_maybe = create_transport_for_multiaddr(
|
||||
addr,
|
||||
temp_upgrader,
|
||||
private_key=key_pair.private_key,
|
||||
config=quic_transport_opt,
|
||||
tls_client_config=tls_client_config,
|
||||
tls_server_config=tls_server_config
|
||||
)
|
||||
|
||||
if transport_maybe is None:
|
||||
raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}")
|
||||
|
||||
transport = transport_maybe
|
||||
logger.debug(f"new_swarm: Created transport: {type(transport)}")
|
||||
|
||||
# If enable_quic is True but we didn't get a QUIC transport, force QUIC
|
||||
if enable_quic and not isinstance(transport, QUICTransport):
|
||||
logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})")
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
|
||||
logger.debug(f"new_swarm: Final transport type: {type(transport)}")
|
||||
|
||||
# Generate X25519 keypair for Noise
|
||||
noise_key_pair = create_new_x25519_key_pair()
|
||||
@ -236,6 +289,7 @@ def new_swarm(
|
||||
muxer_transports_by_protocol=muxer_transports_by_protocol,
|
||||
)
|
||||
|
||||
|
||||
peerstore = peerstore_opt or PeerStore()
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
@ -261,6 +315,10 @@ def new_host(
|
||||
enable_mDNS: bool = False,
|
||||
bootstrap: list[str] | None = None,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
enable_quic: bool = False,
|
||||
quic_transport_opt: QUICTransportConfig | None = None,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
) -> IHost:
|
||||
"""
|
||||
Create a new libp2p host based on the given parameters.
|
||||
@ -274,15 +332,27 @@ def new_host(
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||
:param enable_quic: optinal choice to use QUIC for transport
|
||||
:param quic_transport_opt: optional configuration for quic transport
|
||||
:param tls_client_config: optional TLS client configuration for WebSocket transport
|
||||
:param tls_server_config: optional TLS server configuration for WebSocket transport
|
||||
:return: return a host instance
|
||||
"""
|
||||
|
||||
if not enable_quic and quic_transport_opt is not None:
|
||||
logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config")
|
||||
|
||||
swarm = new_swarm(
|
||||
enable_quic=enable_quic,
|
||||
key_pair=key_pair,
|
||||
muxer_opt=muxer_opt,
|
||||
sec_opt=sec_opt,
|
||||
peerstore_opt=peerstore_opt,
|
||||
muxer_preference=muxer_preference,
|
||||
listen_addrs=listen_addrs,
|
||||
connection_config=quic_transport_opt if enable_quic else None,
|
||||
tls_client_config=tls_client_config,
|
||||
tls_server_config=tls_server_config
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
|
||||
@ -5,17 +5,17 @@ from collections.abc import (
|
||||
)
|
||||
from typing import TYPE_CHECKING, NewType, Union, cast
|
||||
|
||||
from libp2p.transport.quic.stream import QUICStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import (
|
||||
IMuxedConn,
|
||||
INetStream,
|
||||
ISecureTransport,
|
||||
)
|
||||
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
else:
|
||||
IMuxedConn = cast(type, object)
|
||||
INetStream = cast(type, object)
|
||||
ISecureTransport = cast(type, object)
|
||||
|
||||
IMuxedStream = cast(type, object)
|
||||
QUICConnection = cast(type, object)
|
||||
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
@ -37,3 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
|
||||
TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]]
|
||||
MessageID = NewType("MessageID", str)
|
||||
|
||||
@ -213,7 +213,6 @@ class BasicHost(IHost):
|
||||
self,
|
||||
peer_id: ID,
|
||||
protocol_ids: Sequence[TProtocol],
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> INetStream:
|
||||
"""
|
||||
:param peer_id: peer_id that host is connecting
|
||||
@ -227,7 +226,7 @@ class BasicHost(IHost):
|
||||
selected_protocol = await self.multiselect_client.select_one_of(
|
||||
list(protocol_ids),
|
||||
MultiselectCommunicator(net_stream),
|
||||
negotitate_timeout,
|
||||
self.negotiate_timeout,
|
||||
)
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
|
||||
70
libp2p/network/config.py
Normal file
70
libp2p/network/config.py
Normal file
@ -0,0 +1,70 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""
|
||||
Configuration for retry logic with exponential backoff.
|
||||
|
||||
This configuration controls how connection attempts are retried when they fail.
|
||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||
herd problems in distributed systems.
|
||||
|
||||
Attributes:
|
||||
max_retries: Maximum number of retry attempts before giving up.
|
||||
Default: 3 attempts
|
||||
initial_delay: Initial delay in seconds before the first retry.
|
||||
Default: 0.1 seconds (100ms)
|
||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||
Default: 30.0 seconds
|
||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||
the delay by this factor). Default: 2.0 (doubles each time)
|
||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""
|
||||
Configuration for multi-connection support.
|
||||
|
||||
This configuration controls how multiple connections per peer are managed,
|
||||
including connection limits, timeouts, and load balancing strategies.
|
||||
|
||||
Attributes:
|
||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||
peer. Default: 3 connections
|
||||
connection_timeout: Timeout in seconds for establishing new connections.
|
||||
Default: 30.0 seconds
|
||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||
Options: "round_robin" (default) or "least_loaded"
|
||||
|
||||
"""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration after initialization."""
|
||||
if not (
|
||||
self.load_balancing_strategy == "round_robin"
|
||||
or self.load_balancing_strategy == "least_loaded"
|
||||
):
|
||||
raise ValueError(
|
||||
"Load balancing strategy can only be 'round_robin' or 'least_loaded'"
|
||||
)
|
||||
|
||||
if self.max_connections_per_peer < 1:
|
||||
raise ValueError("Max connection per peer should be atleast 1")
|
||||
|
||||
if self.connection_timeout < 0:
|
||||
raise ValueError("Connection timeout should be positive")
|
||||
@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamError,
|
||||
MuxedStreamReset,
|
||||
)
|
||||
from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError
|
||||
|
||||
from .exceptions import (
|
||||
StreamClosed,
|
||||
@ -170,7 +171,7 @@ class NetStream(INetStream):
|
||||
elif self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_READ
|
||||
raise StreamEOF() from error
|
||||
except MuxedStreamReset as error:
|
||||
except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.OPEN,
|
||||
@ -199,7 +200,12 @@ class NetStream(INetStream):
|
||||
|
||||
try:
|
||||
await self.muxed_stream.write(data)
|
||||
except (MuxedStreamClosed, MuxedStreamError) as error:
|
||||
except (
|
||||
MuxedStreamClosed,
|
||||
MuxedStreamError,
|
||||
QUICStreamClosedError,
|
||||
QUICStreamResetError,
|
||||
) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_WRITE
|
||||
|
||||
@ -2,9 +2,9 @@ from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import random
|
||||
from typing import cast
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -27,6 +27,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
)
|
||||
from libp2p.network.config import ConnectionConfig, RetryConfig
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -41,6 +42,9 @@ from libp2p.transport.exceptions import (
|
||||
OpenConnectionError,
|
||||
SecurityUpgradeFailure,
|
||||
)
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
@ -61,59 +65,6 @@ from .exceptions import (
|
||||
logger = logging.getLogger("libp2p.network.swarm")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""
|
||||
Configuration for retry logic with exponential backoff.
|
||||
|
||||
This configuration controls how connection attempts are retried when they fail.
|
||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||
herd problems in distributed systems.
|
||||
|
||||
Attributes:
|
||||
max_retries: Maximum number of retry attempts before giving up.
|
||||
Default: 3 attempts
|
||||
initial_delay: Initial delay in seconds before the first retry.
|
||||
Default: 0.1 seconds (100ms)
|
||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||
Default: 30.0 seconds
|
||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||
the delay by this factor). Default: 2.0 (doubles each time)
|
||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""
|
||||
Configuration for multi-connection support.
|
||||
|
||||
This configuration controls how multiple connections per peer are managed,
|
||||
including connection limits, timeouts, and load balancing strategies.
|
||||
|
||||
Attributes:
|
||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||
peer. Default: 3 connections
|
||||
connection_timeout: Timeout in seconds for establishing new connections.
|
||||
Default: 30.0 seconds
|
||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||
Options: "round_robin" (default) or "least_loaded"
|
||||
|
||||
"""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
await network.get_manager().wait_finished()
|
||||
@ -126,8 +77,7 @@ class Swarm(Service, INetworkService):
|
||||
peerstore: IPeerStore
|
||||
upgrader: TransportUpgrader
|
||||
transport: ITransport
|
||||
# Enhanced: Support for multiple connections per peer
|
||||
connections: dict[ID, list[INetConn]] # Multiple connections per peer
|
||||
connections: dict[ID, list[INetConn]]
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
@ -137,7 +87,7 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
# Enhanced: New configuration
|
||||
retry_config: RetryConfig
|
||||
connection_config: ConnectionConfig
|
||||
connection_config: ConnectionConfig | QUICTransportConfig
|
||||
_round_robin_index: dict[ID, int]
|
||||
|
||||
def __init__(
|
||||
@ -147,7 +97,7 @@ class Swarm(Service, INetworkService):
|
||||
upgrader: TransportUpgrader,
|
||||
transport: ITransport,
|
||||
retry_config: RetryConfig | None = None,
|
||||
connection_config: ConnectionConfig | None = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
):
|
||||
self.self_id = peer_id
|
||||
self.peerstore = peerstore
|
||||
@ -178,6 +128,11 @@ class Swarm(Service, INetworkService):
|
||||
# Create a nursery for listener tasks.
|
||||
self.listener_nursery = nursery
|
||||
self.event_listener_nursery_created.set()
|
||||
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
self.transport.set_background_nursery(nursery)
|
||||
self.transport.set_swarm(self)
|
||||
|
||||
try:
|
||||
await self.manager.wait_finished()
|
||||
finally:
|
||||
@ -370,6 +325,7 @@ class Swarm(Service, INetworkService):
|
||||
# Dial peer (connection to peer does not yet exist)
|
||||
# Transport dials peer (gets back a raw conn)
|
||||
try:
|
||||
addr = Multiaddr(f"{addr}/p2p/{peer_id}")
|
||||
raw_conn = await self.transport.dial(addr)
|
||||
except OpenConnectionError as error:
|
||||
logger.debug("fail to dial peer %s over base transport", peer_id)
|
||||
@ -377,6 +333,15 @@ class Swarm(Service, INetworkService):
|
||||
f"fail to open connection to peer {peer_id}"
|
||||
) from error
|
||||
|
||||
if isinstance(self.transport, QUICTransport) and isinstance(
|
||||
raw_conn, IMuxedConn
|
||||
):
|
||||
logger.info(
|
||||
"Skipping upgrade for QUIC, QUIC connections are already multiplexed"
|
||||
)
|
||||
swarm_conn = await self.add_conn(raw_conn)
|
||||
return swarm_conn
|
||||
|
||||
logger.debug("dialed peer %s over base transport", peer_id)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
@ -402,9 +367,7 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("upgraded mux for peer %s", peer_id)
|
||||
|
||||
swarm_conn = await self.add_conn(muxed_conn)
|
||||
|
||||
logger.debug("successfully dialed peer %s", peer_id)
|
||||
|
||||
return swarm_conn
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
@ -427,7 +390,6 @@ class Swarm(Service, INetworkService):
|
||||
:return: net stream instance
|
||||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
# Get existing connections or dial new ones
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
@ -436,6 +398,10 @@ class Swarm(Service, INetworkService):
|
||||
# Load balancing strategy at interface level
|
||||
connection = self._select_connection(connections, peer_id)
|
||||
|
||||
if isinstance(self.transport, QUICTransport) and connection is not None:
|
||||
conn = cast(SwarmConn, connection)
|
||||
return await conn.new_stream()
|
||||
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
@ -515,18 +481,38 @@ class Swarm(Service, INetworkService):
|
||||
- Call listener listen with the multiaddr
|
||||
- Map multiaddr to listener
|
||||
"""
|
||||
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
logger.debug("Starting to listen")
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
for maddr in multiaddrs:
|
||||
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
|
||||
if str(maddr) in self.listeners:
|
||||
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
|
||||
success_count += 1
|
||||
continue
|
||||
|
||||
async def conn_handler(
|
||||
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
||||
) -> None:
|
||||
# No need to upgrade QUIC Connection
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
try:
|
||||
quic_conn = cast(QUICConnection, read_write_closer)
|
||||
await self.add_conn(quic_conn)
|
||||
peer_id = quic_conn.peer_id
|
||||
logger.debug(
|
||||
f"successfully opened quic connection to peer {peer_id}"
|
||||
)
|
||||
# NOTE: This is a intentional barrier to prevent from the
|
||||
# handler exiting and closing the connection.
|
||||
await self.manager.wait_finished()
|
||||
except Exception:
|
||||
await read_write_closer.close()
|
||||
return
|
||||
|
||||
raw_conn = RawConnection(read_write_closer, False)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
|
||||
@ -562,13 +548,18 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
try:
|
||||
# Success
|
||||
logger.debug(f"Swarm.listen: creating listener for {maddr}")
|
||||
listener = self.transport.create_listener(conn_handler)
|
||||
logger.debug(f"Swarm.listen: listener created for {maddr}")
|
||||
self.listeners[str(maddr)] = listener
|
||||
# TODO: `listener.listen` is not bounded with nursery. If we want to be
|
||||
# I/O agnostic, we should change the API.
|
||||
if self.listener_nursery is None:
|
||||
raise SwarmException("swarm instance hasn't been run")
|
||||
assert self.listener_nursery is not None # For type checker
|
||||
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
|
||||
await listener.listen(maddr, self.listener_nursery)
|
||||
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
|
||||
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_listen(maddr)
|
||||
@ -660,9 +651,10 @@ class Swarm(Service, INetworkService):
|
||||
muxed_conn,
|
||||
self,
|
||||
)
|
||||
|
||||
logger.debug("Swarm::add_conn | starting muxed connection")
|
||||
self.manager.run_task(muxed_conn.start)
|
||||
await muxed_conn.event_started.wait()
|
||||
logger.debug("Swarm::add_conn | starting swarm connection")
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from builtins import AssertionError
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
# Handle for connection close during ongoing negotiation in QUIC
|
||||
except (IOException, AssertionError, ValueError) as error:
|
||||
raise MultiselectCommunicatorError(
|
||||
"fail to write to multiselect communicator"
|
||||
) from error
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from ast import (
|
||||
literal_eval,
|
||||
)
|
||||
from collections import (
|
||||
defaultdict,
|
||||
)
|
||||
@ -22,6 +19,7 @@ from libp2p.abc import (
|
||||
IPubsubRouter,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
MessageID,
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
@ -56,6 +54,10 @@ from .pb import (
|
||||
from .pubsub import (
|
||||
Pubsub,
|
||||
)
|
||||
from .utils import (
|
||||
parse_message_id_safe,
|
||||
safe_parse_message_id,
|
||||
)
|
||||
|
||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
||||
@ -306,7 +308,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
floodsub_peers: set[ID] = {
|
||||
peer_id
|
||||
for peer_id in self.pubsub.peer_topics[topic]
|
||||
if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
|
||||
if peer_id in self.peer_protocol
|
||||
and self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
|
||||
}
|
||||
send_to.update(floodsub_peers)
|
||||
|
||||
@ -794,8 +797,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||
# seen_seqnos) to list of messages we want to request
|
||||
msg_ids_wanted: list[str] = [
|
||||
msg_id
|
||||
msg_ids_wanted: list[MessageID] = [
|
||||
parse_message_id_safe(msg_id)
|
||||
for msg_id in ihave_msg.messageIDs
|
||||
if msg_id not in seen_seqnos_and_peers
|
||||
]
|
||||
@ -811,9 +814,9 @@ class GossipSub(IPubsubRouter, Service):
|
||||
Forwards all request messages that are present in mcache to the
|
||||
requesting peer.
|
||||
"""
|
||||
# FIXME: Update type of message ID
|
||||
# FIXME: Find a better way to parse the msg ids
|
||||
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
|
||||
msg_ids: list[tuple[bytes, bytes]] = [
|
||||
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
|
||||
]
|
||||
msgs_to_forward: list[rpc_pb2.Message] = []
|
||||
for msg_id_iwant in msg_ids:
|
||||
# Check if the wanted message ID is present in mcache
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
import ast
|
||||
import logging
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.custom_types import (
|
||||
MessageID,
|
||||
)
|
||||
from libp2p.peer.envelope import consume_envelope
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
||||
@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
|
||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_message_id_safe(msg_id_str: str) -> MessageID:
|
||||
"""Safely handle message ID as string."""
|
||||
return MessageID(msg_id_str)
|
||||
|
||||
|
||||
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
|
||||
"""
|
||||
Safely parse message ID using ast.literal_eval with validation.
|
||||
:param msg_id_str: String representation of message ID
|
||||
:return: Tuple of (seqno, from_id) as bytes
|
||||
:raises ValueError: If parsing fails
|
||||
"""
|
||||
try:
|
||||
parsed = ast.literal_eval(msg_id_str)
|
||||
if not isinstance(parsed, tuple) or len(parsed) != 2:
|
||||
raise ValueError("Invalid message ID format")
|
||||
|
||||
seqno, from_id = parsed
|
||||
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
|
||||
raise ValueError("Message ID components must be bytes")
|
||||
|
||||
return (seqno, from_id)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Invalid message ID format: {e}")
|
||||
|
||||
@ -9,6 +9,7 @@ from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from enum import Flag, auto
|
||||
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
@ -18,29 +19,118 @@ from .resources import (
|
||||
RelayLimits,
|
||||
)
|
||||
|
||||
DEFAULT_MIN_RELAYS = 3
|
||||
DEFAULT_MAX_RELAYS = 20
|
||||
DEFAULT_DISCOVERY_INTERVAL = 300 # seconds
|
||||
DEFAULT_RESERVATION_TTL = 3600 # seconds
|
||||
DEFAULT_MAX_CIRCUIT_DURATION = 3600 # seconds
|
||||
DEFAULT_MAX_CIRCUIT_BYTES = 1024 * 1024 * 1024 # 1GB
|
||||
|
||||
DEFAULT_MAX_CIRCUIT_CONNS = 8
|
||||
DEFAULT_MAX_RESERVATIONS = 4
|
||||
|
||||
MAX_RESERVATIONS_PER_IP = 8
|
||||
MAX_CIRCUITS_PER_IP = 16
|
||||
RESERVATION_RATE_PER_IP = 4 # per minute
|
||||
CIRCUIT_RATE_PER_IP = 8 # per minute
|
||||
MAX_CIRCUITS_TOTAL = 64
|
||||
MAX_RESERVATIONS_TOTAL = 32
|
||||
MAX_BANDWIDTH_PER_CIRCUIT = 1024 * 1024 # 1MB/s
|
||||
MAX_BANDWIDTH_TOTAL = 10 * 1024 * 1024 # 10MB/s
|
||||
|
||||
MIN_RELAY_SCORE = 0.5
|
||||
MAX_RELAY_LATENCY = 1.0 # seconds
|
||||
ENABLE_AUTO_RELAY = True
|
||||
AUTO_RELAY_TIMEOUT = 30 # seconds
|
||||
MAX_AUTO_RELAY_ATTEMPTS = 3
|
||||
RESERVATION_REFRESH_THRESHOLD = 0.8 # Refresh at 80% of TTL
|
||||
MAX_CONCURRENT_RESERVATIONS = 2
|
||||
|
||||
# Timeout constants for different components
|
||||
DEFAULT_DISCOVERY_STREAM_TIMEOUT = 10 # seconds
|
||||
DEFAULT_PEER_PROTOCOL_TIMEOUT = 5 # seconds
|
||||
DEFAULT_PROTOCOL_READ_TIMEOUT = 15 # seconds
|
||||
DEFAULT_PROTOCOL_WRITE_TIMEOUT = 15 # seconds
|
||||
DEFAULT_PROTOCOL_CLOSE_TIMEOUT = 10 # seconds
|
||||
DEFAULT_DCUTR_READ_TIMEOUT = 30 # seconds
|
||||
DEFAULT_DCUTR_WRITE_TIMEOUT = 30 # seconds
|
||||
DEFAULT_DIAL_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeoutConfig:
|
||||
"""Timeout configuration for different Circuit Relay v2 components."""
|
||||
|
||||
# Discovery timeouts
|
||||
discovery_stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT
|
||||
peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT
|
||||
|
||||
# Core protocol timeouts
|
||||
protocol_read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT
|
||||
protocol_write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT
|
||||
protocol_close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT
|
||||
|
||||
# DCUtR timeouts
|
||||
dcutr_read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT
|
||||
dcutr_write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT
|
||||
dial_timeout: int = DEFAULT_DIAL_TIMEOUT
|
||||
|
||||
|
||||
# Relay roles enum
|
||||
class RelayRole(Flag):
|
||||
"""
|
||||
Bit-flag enum that captures the three possible relay capabilities.
|
||||
|
||||
A node can combine multiple roles using bit-wise OR, for example::
|
||||
|
||||
RelayRole.HOP | RelayRole.STOP
|
||||
"""
|
||||
|
||||
HOP = auto() # Act as a relay for others ("hop")
|
||||
STOP = auto() # Accept relayed connections ("stop")
|
||||
CLIENT = auto() # Dial through existing relays ("client")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayConfig:
|
||||
"""Configuration for Circuit Relay v2."""
|
||||
|
||||
# Role configuration
|
||||
enable_hop: bool = False # Whether to act as a relay (hop)
|
||||
enable_stop: bool = True # Whether to accept relayed connections (stop)
|
||||
enable_client: bool = True # Whether to use relays for dialing
|
||||
# Role configuration (bit-flags)
|
||||
roles: RelayRole = RelayRole.STOP | RelayRole.CLIENT
|
||||
|
||||
# Resource limits
|
||||
limits: RelayLimits | None = None
|
||||
|
||||
# Discovery configuration
|
||||
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
|
||||
min_relays: int = 3
|
||||
max_relays: int = 20
|
||||
discovery_interval: int = 300 # seconds
|
||||
min_relays: int = DEFAULT_MIN_RELAYS
|
||||
max_relays: int = DEFAULT_MAX_RELAYS
|
||||
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL
|
||||
|
||||
# Connection configuration
|
||||
reservation_ttl: int = 3600 # seconds
|
||||
max_circuit_duration: int = 3600 # seconds
|
||||
max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB
|
||||
reservation_ttl: int = DEFAULT_RESERVATION_TTL
|
||||
max_circuit_duration: int = DEFAULT_MAX_CIRCUIT_DURATION
|
||||
max_circuit_bytes: int = DEFAULT_MAX_CIRCUIT_BYTES
|
||||
|
||||
# Timeout configuration
|
||||
timeouts: TimeoutConfig = field(default_factory=TimeoutConfig)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Backwards-compat boolean helpers. Existing code that still accesses
|
||||
# ``cfg.enable_hop, cfg.enable_stop, cfg.enable_client`` will continue to work.
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def enable_hop(self) -> bool: # pragma: no cover – helper
|
||||
return bool(self.roles & RelayRole.HOP)
|
||||
|
||||
@property
|
||||
def enable_stop(self) -> bool: # pragma: no cover – helper
|
||||
return bool(self.roles & RelayRole.STOP)
|
||||
|
||||
@property
|
||||
def enable_client(self) -> bool: # pragma: no cover – helper
|
||||
return bool(self.roles & RelayRole.CLIENT)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize default values."""
|
||||
@ -48,8 +138,8 @@ class RelayConfig:
|
||||
self.limits = RelayLimits(
|
||||
duration=self.max_circuit_duration,
|
||||
data=self.max_circuit_bytes,
|
||||
max_circuit_conns=8,
|
||||
max_reservations=4,
|
||||
max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS,
|
||||
max_reservations=DEFAULT_MAX_RESERVATIONS,
|
||||
)
|
||||
|
||||
|
||||
@ -58,20 +148,20 @@ class HopConfig:
|
||||
"""Configuration specific to relay (hop) nodes."""
|
||||
|
||||
# Resource limits per IP
|
||||
max_reservations_per_ip: int = 8
|
||||
max_circuits_per_ip: int = 16
|
||||
max_reservations_per_ip: int = MAX_RESERVATIONS_PER_IP
|
||||
max_circuits_per_ip: int = MAX_CIRCUITS_PER_IP
|
||||
|
||||
# Rate limiting
|
||||
reservation_rate_per_ip: int = 4 # per minute
|
||||
circuit_rate_per_ip: int = 8 # per minute
|
||||
reservation_rate_per_ip: int = RESERVATION_RATE_PER_IP
|
||||
circuit_rate_per_ip: int = CIRCUIT_RATE_PER_IP
|
||||
|
||||
# Resource quotas
|
||||
max_circuits_total: int = 64
|
||||
max_reservations_total: int = 32
|
||||
max_circuits_total: int = MAX_CIRCUITS_TOTAL
|
||||
max_reservations_total: int = MAX_RESERVATIONS_TOTAL
|
||||
|
||||
# Bandwidth limits
|
||||
max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s
|
||||
max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s
|
||||
max_bandwidth_per_circuit: int = MAX_BANDWIDTH_PER_CIRCUIT
|
||||
max_bandwidth_total: int = MAX_BANDWIDTH_TOTAL
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -79,14 +169,14 @@ class ClientConfig:
|
||||
"""Configuration specific to relay clients."""
|
||||
|
||||
# Relay selection
|
||||
min_relay_score: float = 0.5
|
||||
max_relay_latency: float = 1.0 # seconds
|
||||
min_relay_score: float = MIN_RELAY_SCORE
|
||||
max_relay_latency: float = MAX_RELAY_LATENCY
|
||||
|
||||
# Auto-relay settings
|
||||
enable_auto_relay: bool = True
|
||||
auto_relay_timeout: int = 30 # seconds
|
||||
max_auto_relay_attempts: int = 3
|
||||
enable_auto_relay: bool = ENABLE_AUTO_RELAY
|
||||
auto_relay_timeout: int = AUTO_RELAY_TIMEOUT
|
||||
max_auto_relay_attempts: int = MAX_AUTO_RELAY_ATTEMPTS
|
||||
|
||||
# Reservation management
|
||||
reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL
|
||||
max_concurrent_reservations: int = 2
|
||||
reservation_refresh_threshold: float = RESERVATION_REFRESH_THRESHOLD
|
||||
max_concurrent_reservations: int = MAX_CONCURRENT_RESERVATIONS
|
||||
|
||||
@ -29,6 +29,11 @@ from libp2p.peer.id import (
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.config import (
|
||||
DEFAULT_DCUTR_READ_TIMEOUT,
|
||||
DEFAULT_DCUTR_WRITE_TIMEOUT,
|
||||
DEFAULT_DIAL_TIMEOUT,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.nat import (
|
||||
ReachabilityChecker,
|
||||
)
|
||||
@ -47,11 +52,7 @@ PROTOCOL_ID = TProtocol("/libp2p/dcutr")
|
||||
# Maximum message size for DCUtR (4KiB as per spec)
|
||||
MAX_MESSAGE_SIZE = 4 * 1024
|
||||
|
||||
# Timeouts
|
||||
STREAM_READ_TIMEOUT = 30 # seconds
|
||||
STREAM_WRITE_TIMEOUT = 30 # seconds
|
||||
DIAL_TIMEOUT = 10 # seconds
|
||||
|
||||
# DCUtR protocol constants
|
||||
# Maximum number of hole punch attempts per peer
|
||||
MAX_HOLE_PUNCH_ATTEMPTS = 5
|
||||
|
||||
@ -70,7 +71,13 @@ class DCUtRProtocol(Service):
|
||||
hole punching, after they have established an initial connection through a relay.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost):
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT,
|
||||
write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT,
|
||||
dial_timeout: int = DEFAULT_DIAL_TIMEOUT,
|
||||
):
|
||||
"""
|
||||
Initialize the DCUtR protocol.
|
||||
|
||||
@ -78,10 +85,19 @@ class DCUtRProtocol(Service):
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this protocol is running on
|
||||
read_timeout : int
|
||||
Timeout for stream read operations, in seconds
|
||||
write_timeout : int
|
||||
Timeout for stream write operations, in seconds
|
||||
dial_timeout : int
|
||||
Timeout for dial operations, in seconds
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.read_timeout = read_timeout
|
||||
self.write_timeout = write_timeout
|
||||
self.dial_timeout = dial_timeout
|
||||
self.event_started = trio.Event()
|
||||
self._hole_punch_attempts: dict[ID, int] = {}
|
||||
self._direct_connections: set[ID] = set()
|
||||
@ -161,7 +177,7 @@ class DCUtRProtocol(Service):
|
||||
|
||||
try:
|
||||
# Read the CONNECT message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
with trio.fail_after(self.read_timeout):
|
||||
msg_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Parse the message
|
||||
@ -196,7 +212,7 @@ class DCUtRProtocol(Service):
|
||||
response.type = HolePunch.CONNECT
|
||||
response.ObsAddrs.extend(our_addrs)
|
||||
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
with trio.fail_after(self.write_timeout):
|
||||
await stream.write(response.SerializeToString())
|
||||
|
||||
logger.debug(
|
||||
@ -206,7 +222,7 @@ class DCUtRProtocol(Service):
|
||||
)
|
||||
|
||||
# Wait for SYNC message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
with trio.fail_after(self.read_timeout):
|
||||
sync_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Parse the SYNC message
|
||||
@ -300,7 +316,7 @@ class DCUtRProtocol(Service):
|
||||
connect_msg.ObsAddrs.extend(our_addrs)
|
||||
|
||||
start_time = time.time()
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
with trio.fail_after(self.write_timeout):
|
||||
await stream.write(connect_msg.SerializeToString())
|
||||
|
||||
logger.debug(
|
||||
@ -310,7 +326,7 @@ class DCUtRProtocol(Service):
|
||||
)
|
||||
|
||||
# Receive the peer's CONNECT message
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
with trio.fail_after(self.read_timeout):
|
||||
resp_bytes = await stream.read(MAX_MESSAGE_SIZE)
|
||||
|
||||
# Calculate RTT
|
||||
@ -349,7 +365,7 @@ class DCUtRProtocol(Service):
|
||||
sync_msg = HolePunch()
|
||||
sync_msg.type = HolePunch.SYNC
|
||||
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
with trio.fail_after(self.write_timeout):
|
||||
await stream.write(sync_msg.SerializeToString())
|
||||
|
||||
logger.debug("Sent SYNC message to %s", peer_id)
|
||||
@ -468,7 +484,7 @@ class DCUtRProtocol(Service):
|
||||
peer_info = PeerInfo(peer_id, [addr])
|
||||
|
||||
# Try to connect with timeout
|
||||
with trio.fail_after(DIAL_TIMEOUT):
|
||||
with trio.fail_after(self.dial_timeout):
|
||||
await self.host.connect(peer_info)
|
||||
|
||||
logger.info("Successfully connected to %s at %s", peer_id, addr)
|
||||
|
||||
@ -31,6 +31,11 @@ from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
DEFAULT_DISCOVERY_INTERVAL,
|
||||
DEFAULT_DISCOVERY_STREAM_TIMEOUT,
|
||||
DEFAULT_PEER_PROTOCOL_TIMEOUT,
|
||||
)
|
||||
from .pb.circuit_pb2 import (
|
||||
HopMessage,
|
||||
)
|
||||
@ -43,10 +48,8 @@ from .protocol_buffer import (
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
|
||||
|
||||
# Constants
|
||||
# Discovery constants
|
||||
MAX_RELAYS_TO_TRACK = 10
|
||||
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
|
||||
STREAM_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
# Extended interfaces for type checking
|
||||
@ -86,6 +89,8 @@ class RelayDiscovery(Service):
|
||||
auto_reserve: bool = False,
|
||||
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
|
||||
max_relays: int = MAX_RELAYS_TO_TRACK,
|
||||
stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT,
|
||||
peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the discovery service.
|
||||
@ -100,6 +105,10 @@ class RelayDiscovery(Service):
|
||||
How often to run discovery, in seconds
|
||||
max_relays : int
|
||||
Maximum number of relays to track
|
||||
stream_timeout : int
|
||||
Timeout for stream operations during discovery, in seconds
|
||||
peer_protocol_timeout : int
|
||||
Timeout for checking peer protocol support, in seconds
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
@ -107,6 +116,8 @@ class RelayDiscovery(Service):
|
||||
self.auto_reserve = auto_reserve
|
||||
self.discovery_interval = discovery_interval
|
||||
self.max_relays = max_relays
|
||||
self.stream_timeout = stream_timeout
|
||||
self.peer_protocol_timeout = peer_protocol_timeout
|
||||
self._discovered_relays: dict[ID, RelayInfo] = {}
|
||||
self._protocol_cache: dict[
|
||||
ID, set[str]
|
||||
@ -165,8 +176,8 @@ class RelayDiscovery(Service):
|
||||
self._discovered_relays[peer_id].last_seen = time.time()
|
||||
continue
|
||||
|
||||
# Check if peer supports the relay protocol
|
||||
with trio.move_on_after(5): # Don't wait too long for protocol info
|
||||
# Don't wait too long for protocol info
|
||||
with trio.move_on_after(self.peer_protocol_timeout):
|
||||
if await self._supports_relay_protocol(peer_id):
|
||||
await self._add_relay(peer_id)
|
||||
|
||||
@ -264,7 +275,7 @@ class RelayDiscovery(Service):
|
||||
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via direct connection."""
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
with trio.fail_after(self.stream_timeout):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if stream:
|
||||
await stream.close()
|
||||
@ -370,7 +381,7 @@ class RelayDiscovery(Service):
|
||||
|
||||
# Open a stream to the relay with timeout
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
with trio.fail_after(self.stream_timeout):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if not stream:
|
||||
logger.error("Failed to open stream to relay %s", peer_id)
|
||||
@ -386,7 +397,7 @@ class RelayDiscovery(Service):
|
||||
peer=self.host.get_id().to_bytes(),
|
||||
)
|
||||
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
with trio.fail_after(self.stream_timeout):
|
||||
await stream.write(request.SerializeToString())
|
||||
|
||||
# Wait for response
|
||||
|
||||
@ -5,6 +5,7 @@ This module implements the Circuit Relay v2 protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
@ -37,6 +38,15 @@ from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
DEFAULT_MAX_CIRCUIT_BYTES,
|
||||
DEFAULT_MAX_CIRCUIT_CONNS,
|
||||
DEFAULT_MAX_CIRCUIT_DURATION,
|
||||
DEFAULT_MAX_RESERVATIONS,
|
||||
DEFAULT_PROTOCOL_CLOSE_TIMEOUT,
|
||||
DEFAULT_PROTOCOL_READ_TIMEOUT,
|
||||
DEFAULT_PROTOCOL_WRITE_TIMEOUT,
|
||||
)
|
||||
from .pb.circuit_pb2 import (
|
||||
HopMessage,
|
||||
Limit,
|
||||
@ -58,18 +68,22 @@ logger = logging.getLogger("libp2p.relay.circuit_v2")
|
||||
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
|
||||
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
|
||||
|
||||
|
||||
# Direction enum for data piping
|
||||
class Pipe(Enum):
|
||||
SRC_TO_DST = auto()
|
||||
DST_TO_SRC = auto()
|
||||
|
||||
|
||||
# Default limits for relay resources
|
||||
DEFAULT_RELAY_LIMITS = RelayLimits(
|
||||
duration=60 * 60, # 1 hour
|
||||
data=1024 * 1024 * 1024, # 1GB
|
||||
max_circuit_conns=8,
|
||||
max_reservations=4,
|
||||
duration=DEFAULT_MAX_CIRCUIT_DURATION,
|
||||
data=DEFAULT_MAX_CIRCUIT_BYTES,
|
||||
max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS,
|
||||
max_reservations=DEFAULT_MAX_RESERVATIONS,
|
||||
)
|
||||
|
||||
# Stream operation timeouts
|
||||
STREAM_READ_TIMEOUT = 15 # seconds
|
||||
STREAM_WRITE_TIMEOUT = 15 # seconds
|
||||
STREAM_CLOSE_TIMEOUT = 10 # seconds
|
||||
# Stream operation constants
|
||||
MAX_READ_RETRIES = 5 # Maximum number of read retries
|
||||
|
||||
|
||||
@ -113,6 +127,9 @@ class CircuitV2Protocol(Service):
|
||||
host: IHost,
|
||||
limits: RelayLimits | None = None,
|
||||
allow_hop: bool = False,
|
||||
read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT,
|
||||
write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT,
|
||||
close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a Circuit Relay v2 protocol instance.
|
||||
@ -125,11 +142,20 @@ class CircuitV2Protocol(Service):
|
||||
Resource limits for the relay
|
||||
allow_hop : bool
|
||||
Whether to allow this node to act as a relay
|
||||
read_timeout : int
|
||||
Timeout for stream read operations, in seconds
|
||||
write_timeout : int
|
||||
Timeout for stream write operations, in seconds
|
||||
close_timeout : int
|
||||
Timeout for stream close operations, in seconds
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.limits = limits or DEFAULT_RELAY_LIMITS
|
||||
self.allow_hop = allow_hop
|
||||
self.read_timeout = read_timeout
|
||||
self.write_timeout = write_timeout
|
||||
self.close_timeout = close_timeout
|
||||
self.resource_manager = RelayResourceManager(self.limits)
|
||||
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
|
||||
self.event_started = trio.Event()
|
||||
@ -174,7 +200,7 @@ class CircuitV2Protocol(Service):
|
||||
return
|
||||
|
||||
try:
|
||||
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
||||
with trio.fail_after(self.close_timeout):
|
||||
await stream.close()
|
||||
except Exception:
|
||||
try:
|
||||
@ -216,7 +242,7 @@ class CircuitV2Protocol(Service):
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
with trio.fail_after(self.read_timeout):
|
||||
# Try reading with timeout
|
||||
logger.debug(
|
||||
"Attempting to read from stream (attempt %d/%d)",
|
||||
@ -293,7 +319,7 @@ class CircuitV2Protocol(Service):
|
||||
# First, handle the read timeout gracefully
|
||||
try:
|
||||
with trio.fail_after(
|
||||
STREAM_READ_TIMEOUT * 2
|
||||
self.read_timeout * 2
|
||||
): # Double the timeout for reading
|
||||
msg_bytes = await stream.read()
|
||||
if not msg_bytes:
|
||||
@ -414,7 +440,7 @@ class CircuitV2Protocol(Service):
|
||||
"""
|
||||
try:
|
||||
# Read the incoming message with timeout
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
with trio.fail_after(self.read_timeout):
|
||||
msg_bytes = await stream.read()
|
||||
stop_msg = StopMessage()
|
||||
stop_msg.ParseFromString(msg_bytes)
|
||||
@ -458,8 +484,20 @@ class CircuitV2Protocol(Service):
|
||||
|
||||
# Start relaying data
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._relay_data, src_stream, stream, peer_id)
|
||||
nursery.start_soon(self._relay_data, stream, src_stream, peer_id)
|
||||
nursery.start_soon(
|
||||
self._relay_data,
|
||||
src_stream,
|
||||
stream,
|
||||
peer_id,
|
||||
Pipe.SRC_TO_DST,
|
||||
)
|
||||
nursery.start_soon(
|
||||
self._relay_data,
|
||||
stream,
|
||||
src_stream,
|
||||
peer_id,
|
||||
Pipe.DST_TO_SRC,
|
||||
)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.error("Timeout reading from stop stream")
|
||||
@ -509,7 +547,7 @@ class CircuitV2Protocol(Service):
|
||||
ttl = self.resource_manager.reserve(peer_id)
|
||||
|
||||
# Send reservation success response
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
with trio.fail_after(self.write_timeout):
|
||||
status = create_status(
|
||||
code=StatusCode.OK, message="Reservation accepted"
|
||||
)
|
||||
@ -560,7 +598,7 @@ class CircuitV2Protocol(Service):
|
||||
# Always close the stream when done with reservation
|
||||
if cast(INetStreamWithExtras, stream).is_open():
|
||||
try:
|
||||
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
||||
with trio.fail_after(self.close_timeout):
|
||||
await stream.close()
|
||||
except Exception as close_err:
|
||||
logger.error("Error closing stream: %s", str(close_err))
|
||||
@ -596,7 +634,7 @@ class CircuitV2Protocol(Service):
|
||||
self._active_relays[peer_id] = (stream, None)
|
||||
|
||||
# Try to connect to the destination with timeout
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
with trio.fail_after(self.read_timeout):
|
||||
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
|
||||
if not dst_stream:
|
||||
raise ConnectionError("Could not connect to destination")
|
||||
@ -648,8 +686,20 @@ class CircuitV2Protocol(Service):
|
||||
|
||||
# Start relaying data
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._relay_data, stream, dst_stream, peer_id)
|
||||
nursery.start_soon(self._relay_data, dst_stream, stream, peer_id)
|
||||
nursery.start_soon(
|
||||
self._relay_data,
|
||||
stream,
|
||||
dst_stream,
|
||||
peer_id,
|
||||
Pipe.SRC_TO_DST,
|
||||
)
|
||||
nursery.start_soon(
|
||||
self._relay_data,
|
||||
dst_stream,
|
||||
stream,
|
||||
peer_id,
|
||||
Pipe.DST_TO_SRC,
|
||||
)
|
||||
|
||||
except (trio.TooSlowError, ConnectionError) as e:
|
||||
logger.error("Error establishing relay connection: %s", str(e))
|
||||
@ -685,6 +735,7 @@ class CircuitV2Protocol(Service):
|
||||
src_stream: INetStream,
|
||||
dst_stream: INetStream,
|
||||
peer_id: ID,
|
||||
direction: Pipe,
|
||||
) -> None:
|
||||
"""
|
||||
Relay data between two streams.
|
||||
@ -698,24 +749,27 @@ class CircuitV2Protocol(Service):
|
||||
peer_id : ID
|
||||
ID of the peer being relayed
|
||||
|
||||
direction : Pipe
|
||||
Direction of data flow (``Pipe.SRC_TO_DST`` or ``Pipe.DST_TO_SRC``)
|
||||
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Read data with retries
|
||||
data = await self._read_stream_with_retry(src_stream)
|
||||
if not data:
|
||||
logger.info("Source stream closed/reset")
|
||||
logger.info("%s closed/reset", direction.name)
|
||||
break
|
||||
|
||||
# Write data with timeout
|
||||
try:
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
with trio.fail_after(self.write_timeout):
|
||||
await dst_stream.write(data)
|
||||
except trio.TooSlowError:
|
||||
logger.error("Timeout writing to destination stream")
|
||||
logger.error("Timeout writing in %s", direction.name)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error writing to destination stream: %s", str(e))
|
||||
logger.error("Error writing in %s: %s", direction.name, str(e))
|
||||
break
|
||||
|
||||
# Update resource usage
|
||||
@ -744,7 +798,7 @@ class CircuitV2Protocol(Service):
|
||||
"""Send a status message."""
|
||||
try:
|
||||
logger.debug("Sending status message with code %s: %s", code, message)
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
||||
with trio.fail_after(self.write_timeout * 2): # Double the timeout
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(
|
||||
@ -782,7 +836,7 @@ class CircuitV2Protocol(Service):
|
||||
"""Send a status message on a STOP stream."""
|
||||
try:
|
||||
logger.debug("Sending stop status message with code %s: %s", code, message)
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
||||
with trio.fail_after(self.write_timeout * 2): # Double the timeout
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(
|
||||
|
||||
@ -8,6 +8,7 @@ including reservations and connection limits.
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
from enum import Enum, auto
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
@ -19,6 +20,18 @@ from libp2p.peer.id import (
|
||||
# Import the protobuf definitions
|
||||
from .pb.circuit_pb2 import Reservation as PbReservation
|
||||
|
||||
RANDOM_BYTES_LENGTH = 16 # 128 bits of randomness
|
||||
TIMESTAMP_MULTIPLIER = 1000000 # To convert seconds to microseconds
|
||||
|
||||
|
||||
# Reservation status enum
|
||||
class ReservationStatus(Enum):
|
||||
"""Lifecycle status of a relay reservation."""
|
||||
|
||||
ACTIVE = auto()
|
||||
EXPIRED = auto()
|
||||
REJECTED = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayLimits:
|
||||
@ -68,8 +81,8 @@ class Reservation:
|
||||
# - Peer ID to bind it to the specific peer
|
||||
# - Timestamp for uniqueness
|
||||
# - Hash everything for a fixed size output
|
||||
random_bytes = os.urandom(16) # 128 bits of randomness
|
||||
timestamp = str(int(self.created_at * 1000000)).encode()
|
||||
random_bytes = os.urandom(RANDOM_BYTES_LENGTH)
|
||||
timestamp = str(int(self.created_at * TIMESTAMP_MULTIPLIER)).encode()
|
||||
peer_bytes = self.peer_id.to_bytes()
|
||||
|
||||
# Combine all elements and hash them
|
||||
@ -84,6 +97,15 @@ class Reservation:
|
||||
"""Check if the reservation has expired."""
|
||||
return time.time() > self.expires_at
|
||||
|
||||
# Expose a friendly status enum
|
||||
|
||||
@property
|
||||
def status(self) -> ReservationStatus:
|
||||
"""Return the current status as a ``ReservationStatus`` enum."""
|
||||
return (
|
||||
ReservationStatus.EXPIRED if self.is_expired() else ReservationStatus.ACTIVE
|
||||
)
|
||||
|
||||
def can_accept_connection(self) -> bool:
|
||||
"""Check if a new connection can be accepted."""
|
||||
return (
|
||||
|
||||
@ -89,6 +89,8 @@ class CircuitV2Transport(ITransport):
|
||||
auto_reserve=config.enable_client,
|
||||
discovery_interval=config.discovery_interval,
|
||||
max_relays=config.max_relays,
|
||||
stream_timeout=config.timeouts.discovery_stream_timeout,
|
||||
peer_protocol_timeout=config.timeouts.peer_protocol_timeout,
|
||||
)
|
||||
self.relay_counter = 0 # for round robin load balancing
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import (
|
||||
cast,
|
||||
)
|
||||
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
|
||||
FixedSizeLenMsgReadWriter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIZE_NOISE_MESSAGE_LEN = 2
|
||||
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
||||
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
||||
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
||||
self.noise_state = noise_state
|
||||
|
||||
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
||||
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
|
||||
data_encrypted = self.encrypt(msg)
|
||||
if prefix_encoded:
|
||||
# Manually add the prefix if needed
|
||||
data_encrypted = self.prefix + data_encrypted
|
||||
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
|
||||
await self.read_writer.write_msg(data_encrypted)
|
||||
logger.debug("Noise write_msg: write completed successfully")
|
||||
|
||||
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
||||
logger.debug("Noise read_msg: reading encrypted message")
|
||||
noise_msg_encrypted = await self.read_writer.read_msg()
|
||||
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
|
||||
if prefix_encoded:
|
||||
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
else:
|
||||
return self.decrypt(noise_msg_encrypted)
|
||||
result = self.decrypt(noise_msg_encrypted)
|
||||
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
|
||||
return result
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.read_writer.close()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import logging
|
||||
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
|
||||
|
||||
from .pb import noise_pb2 as noise_pb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
||||
|
||||
|
||||
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
|
||||
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
||||
) -> bytes:
|
||||
data = make_data_to_be_signed(noise_static_pubkey)
|
||||
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
|
||||
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
|
||||
return id_privkey.sign(data)
|
||||
|
||||
|
||||
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
|
||||
2. signed by the private key corresponding to `id_pubkey`
|
||||
"""
|
||||
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
||||
return payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_pubkey type: "
|
||||
f"{type(payload.id_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: noise_static_pubkey type: "
|
||||
f"{type(noise_static_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
|
||||
)
|
||||
try:
|
||||
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
|
||||
return False
|
||||
|
||||
@ -2,6 +2,7 @@ from abc import (
|
||||
ABC,
|
||||
abstractmethod,
|
||||
)
|
||||
import logging
|
||||
|
||||
from cryptography.hazmat.primitives import (
|
||||
serialization,
|
||||
@ -46,6 +47,8 @@ from .messages import (
|
||||
verify_handshake_payload_sig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IPattern(ABC):
|
||||
@abstractmethod
|
||||
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
|
||||
self.early_data = early_data
|
||||
|
||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
noise_state.set_as_responder()
|
||||
noise_state.start_handshake()
|
||||
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
|
||||
# Consume msg#1.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#1")
|
||||
await read_writer.read_msg()
|
||||
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
|
||||
|
||||
# Send msg#2, which should include our handshake payload.
|
||||
logger.debug("Noise XX handshake_inbound: preparing msg#2")
|
||||
our_payload = self.make_handshake_payload()
|
||||
msg_2 = our_payload.serialize()
|
||||
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
|
||||
await read_writer.write_msg(msg_2)
|
||||
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
|
||||
|
||||
# Receive and consume msg#3.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#3")
|
||||
msg_3 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
|
||||
async def handshake_outbound(
|
||||
self, conn: IRawConnection, remote_peer: ID
|
||||
) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
|
||||
raise NoiseStateError("Handshake state is not initialized")
|
||||
|
||||
# Send msg#1, which is *not* encrypted.
|
||||
logger.debug("Noise XX handshake_outbound: sending msg#1")
|
||||
msg_1 = b""
|
||||
await read_writer.write_msg(msg_1)
|
||||
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
|
||||
|
||||
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||
logger.debug("Noise XX handshake_outbound: reading msg#2")
|
||||
msg_2 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
|
||||
)
|
||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
|
||||
)
|
||||
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
|
||||
f"{id_pubkey_repr}"
|
||||
)
|
||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||
logger.error(
|
||||
f"Noise XX handshake_outbound: signature verification failed for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
raise InvalidSignature
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: signature verification successful for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||
if remote_peer_id_from_pubkey != remote_peer:
|
||||
raise PeerIDMismatchesPubkey(
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
@ -15,6 +13,7 @@ from libp2p.abc import (
|
||||
from libp2p.stream_muxer.exceptions import (
|
||||
MuxedConnUnavailable,
|
||||
)
|
||||
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||
|
||||
from .constants import (
|
||||
HeaderTags,
|
||||
@ -34,72 +33,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
"""
|
||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||
|
||||
@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
Multiselect,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect_client import (
|
||||
@ -46,11 +47,17 @@ class MuxerMultistream:
|
||||
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
||||
multiselect: Multiselect
|
||||
multiselect_client: MultiselectClient
|
||||
negotiate_timeout: int
|
||||
|
||||
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
muxer_transports_by_protocol: TMuxerOptions,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> None:
|
||||
self.transports = OrderedDict()
|
||||
self.multiselect = Multiselect()
|
||||
self.multistream_client = MultiselectClient()
|
||||
self.negotiate_timeout = negotiate_timeout
|
||||
for protocol, transport in muxer_transports_by_protocol.items():
|
||||
self.add_transport(protocol, transport)
|
||||
|
||||
@ -80,10 +87,12 @@ class MuxerMultistream:
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
tuple(self.transports.keys()), communicator
|
||||
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
protocol, _ = await self.multiselect.negotiate(
|
||||
communicator, self.negotiate_timeout
|
||||
)
|
||||
if protocol is None:
|
||||
raise MultiselectError(
|
||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||
@ -93,7 +102,7 @@ class MuxerMultistream:
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
protocol = await self.multistream_client.select_one_of(
|
||||
tuple(self.transports.keys()), communicator
|
||||
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||
)
|
||||
transport_class = self.transports[protocol]
|
||||
if protocol == PROTOCOL_ID:
|
||||
|
||||
70
libp2p/stream_muxer/rw_lock.py
Normal file
70
libp2p/stream_muxer/rw_lock.py
Normal file
@ -0,0 +1,70 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamError,
|
||||
MuxedStreamReset,
|
||||
)
|
||||
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||
|
||||
# Configure logger for this module
|
||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
|
||||
self.send_window = DEFAULT_WINDOW_SIZE
|
||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||
self.window_lock = trio.Lock()
|
||||
self.rw_lock = ReadWriteLock()
|
||||
self.close_lock = trio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "YamuxStream":
|
||||
"""Enter the async context manager."""
|
||||
@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream):
|
||||
await self.close()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
if self.send_closed:
|
||||
raise MuxedStreamError("Stream is closed for sending")
|
||||
async with self.rw_lock.write_lock():
|
||||
if self.send_closed:
|
||||
raise MuxedStreamError("Stream is closed for sending")
|
||||
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: "
|
||||
"Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
# Calculate how much we can send now
|
||||
to_send = min(self.send_window, total_len - sent)
|
||||
chunk = data[sent : sent + to_send]
|
||||
self.send_window -= to_send
|
||||
|
||||
# Send the data
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
# Calculate how much we can send now
|
||||
to_send = min(self.send_window, total_len - sent)
|
||||
chunk = data[sent : sent + to_send]
|
||||
self.send_window -= to_send
|
||||
|
||||
# Send the data
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||
)
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
|
||||
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
||||
"""
|
||||
@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream):
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.send_closed = True
|
||||
async with self.close_lock:
|
||||
if not self.send_closed:
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.send_closed = True
|
||||
|
||||
# Only set fully closed if both directions are closed
|
||||
if self.send_closed and self.recv_closed:
|
||||
self.closed = True
|
||||
else:
|
||||
# Stream is half-closed but not fully closed
|
||||
self.closed = False
|
||||
# Only set fully closed if both directions are closed
|
||||
if self.send_closed and self.recv_closed:
|
||||
self.closed = True
|
||||
else:
|
||||
# Stream is half-closed but not fully closed
|
||||
self.closed = False
|
||||
|
||||
async def reset(self) -> None:
|
||||
if not self.closed:
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.closed = True
|
||||
self.reset_received = True # Mark as reset
|
||||
async with self.close_lock:
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.closed = True
|
||||
self.reset_received = True # Mark as reset
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
|
||||
from .tcp.tcp import TCP
|
||||
from .websocket.transport import WebsocketTransport
|
||||
from .transport_registry import (
|
||||
TransportRegistry,
|
||||
create_transport_for_multiaddr,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from .upgrader import TransportUpgrader
|
||||
from libp2p.abc import ITransport
|
||||
|
||||
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport:
|
||||
"""
|
||||
Convenience function to create a transport instance.
|
||||
|
||||
:param protocol: The transport protocol ("tcp", "ws", "wss", or custom)
|
||||
:param upgrader: Optional transport upgrader (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config)
|
||||
:return: Transport instance
|
||||
"""
|
||||
# First check if it's a built-in protocol
|
||||
if protocol in ["ws", "wss"]:
|
||||
if upgrader is None:
|
||||
raise ValueError(f"WebSocket transport requires an upgrader")
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0)
|
||||
)
|
||||
elif protocol == "tcp":
|
||||
return TCP()
|
||||
else:
|
||||
# Check if it's a custom registered transport
|
||||
registry = get_transport_registry()
|
||||
transport_class = registry.get_transport(protocol)
|
||||
if transport_class:
|
||||
transport = registry.create_transport(protocol, upgrader, **kwargs)
|
||||
if transport is None:
|
||||
raise ValueError(f"Failed to create transport for protocol: {protocol}")
|
||||
return transport
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
||||
|
||||
__all__ = [
|
||||
"TCP",
|
||||
"WebsocketTransport",
|
||||
"TransportRegistry",
|
||||
"create_transport_for_multiaddr",
|
||||
"create_transport",
|
||||
"get_transport_registry",
|
||||
"register_transport",
|
||||
"get_supported_transport_protocols",
|
||||
]
|
||||
|
||||
0
libp2p/transport/quic/__init__.py
Normal file
0
libp2p/transport/quic/__init__.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""
|
||||
Configuration classes for QUIC transport.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
import ssl
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.network.config import ConnectionConfig
|
||||
|
||||
|
||||
class QUICTransportKwargs(TypedDict, total=False):
|
||||
"""Type definition for kwargs accepted by new_transport function."""
|
||||
|
||||
# Connection settings
|
||||
idle_timeout: float
|
||||
max_datagram_size: int
|
||||
local_port: int | None
|
||||
|
||||
# Protocol version support
|
||||
enable_draft29: bool
|
||||
enable_v1: bool
|
||||
|
||||
# TLS settings
|
||||
verify_mode: ssl.VerifyMode
|
||||
alpn_protocols: list[str]
|
||||
|
||||
# Performance settings
|
||||
max_concurrent_streams: int
|
||||
connection_window: int
|
||||
stream_window: int
|
||||
|
||||
# Logging and debugging
|
||||
enable_qlog: bool
|
||||
qlog_dir: str | None
|
||||
|
||||
# Connection management
|
||||
max_connections: int
|
||||
connection_timeout: float
|
||||
|
||||
# Protocol identifiers
|
||||
PROTOCOL_QUIC_V1: TProtocol
|
||||
PROTOCOL_QUIC_DRAFT29: TProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class QUICTransportConfig(ConnectionConfig):
|
||||
"""Configuration for QUIC transport."""
|
||||
|
||||
# Connection settings
|
||||
idle_timeout: float = 30.0 # Seconds before an idle connection is closed.
|
||||
max_datagram_size: int = (
|
||||
1200 # Maximum size of UDP datagrams to avoid IP fragmentation.
|
||||
)
|
||||
local_port: int | None = (
|
||||
None # Local port to bind to. If None, a random port is chosen.
|
||||
)
|
||||
|
||||
# Protocol version support
|
||||
enable_draft29: bool = True # Enable QUIC draft-29 for compatibility
|
||||
enable_v1: bool = True # Enable QUIC v1 (RFC 9000)
|
||||
|
||||
# TLS settings
|
||||
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
|
||||
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
|
||||
|
||||
# Performance settings
|
||||
max_concurrent_streams: int = 100 # Maximum concurrent streams per connection
|
||||
connection_window: int = 1024 * 1024 # Connection flow control window
|
||||
stream_window: int = 64 * 1024 # Stream flow control window
|
||||
|
||||
# Logging and debugging
|
||||
enable_qlog: bool = False # Enable QUIC logging
|
||||
qlog_dir: str | None = None # Directory for QUIC logs
|
||||
|
||||
# Connection management
|
||||
max_connections: int = 1000 # Maximum number of connections
|
||||
connection_timeout: float = 10.0 # Connection establishment timeout
|
||||
|
||||
MAX_CONCURRENT_STREAMS: int = 1000
|
||||
"""Maximum number of concurrent streams per connection."""
|
||||
|
||||
MAX_INCOMING_STREAMS: int = 1000
|
||||
"""Maximum number of incoming streams per connection."""
|
||||
|
||||
CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0
|
||||
"""Timeout for connection handshake (seconds)."""
|
||||
|
||||
MAX_OUTGOING_STREAMS: int = 1000
|
||||
"""Maximum number of outgoing streams per connection."""
|
||||
|
||||
CONNECTION_CLOSE_TIMEOUT: int = 10
|
||||
"""Timeout for opening new connection (seconds)."""
|
||||
|
||||
# Stream timeouts
|
||||
STREAM_OPEN_TIMEOUT: float = 5.0
|
||||
"""Timeout for opening new streams (seconds)."""
|
||||
|
||||
STREAM_ACCEPT_TIMEOUT: float = 30.0
|
||||
"""Timeout for accepting incoming streams (seconds)."""
|
||||
|
||||
STREAM_READ_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream read operations (seconds)."""
|
||||
|
||||
STREAM_WRITE_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream write operations (seconds)."""
|
||||
|
||||
STREAM_CLOSE_TIMEOUT: float = 10.0
|
||||
"""Timeout for graceful stream close (seconds)."""
|
||||
|
||||
# Flow control configuration
|
||||
STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB
|
||||
"""Per-stream flow control window size."""
|
||||
|
||||
CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB
|
||||
"""Connection-wide flow control window size."""
|
||||
|
||||
# Buffer management
|
||||
MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB
|
||||
"""Maximum receive buffer size per stream."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
|
||||
"""Low watermark for stream receive buffer."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
|
||||
"""High watermark for stream receive buffer."""
|
||||
|
||||
# Stream lifecycle configuration
|
||||
ENABLE_STREAM_RESET_ON_ERROR: bool = True
|
||||
"""Whether to automatically reset streams on errors."""
|
||||
|
||||
STREAM_RESET_ERROR_CODE: int = 1
|
||||
"""Default error code for stream resets."""
|
||||
|
||||
ENABLE_STREAM_KEEP_ALIVE: bool = False
|
||||
"""Whether to enable stream keep-alive mechanisms."""
|
||||
|
||||
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
|
||||
"""Interval for stream keep-alive pings (seconds)."""
|
||||
|
||||
# Resource management
|
||||
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
|
||||
"""Whether to track stream resource usage."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
|
||||
"""Memory limit per individual stream."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
|
||||
"""Total memory limit for all streams per connection."""
|
||||
|
||||
# Concurrency and performance
|
||||
ENABLE_STREAM_BATCHING: bool = True
|
||||
"""Whether to batch multiple stream operations."""
|
||||
|
||||
STREAM_BATCH_SIZE: int = 10
|
||||
"""Number of streams to process in a batch."""
|
||||
|
||||
STREAM_PROCESSING_CONCURRENCY: int = 100
|
||||
"""Maximum concurrent stream processing tasks."""
|
||||
|
||||
# Debugging and monitoring
|
||||
ENABLE_STREAM_METRICS: bool = True
|
||||
"""Whether to collect stream metrics."""
|
||||
|
||||
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
|
||||
"""Whether to track stream lifecycle timelines."""
|
||||
|
||||
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
|
||||
"""Interval for collecting stream metrics (seconds)."""
|
||||
|
||||
# Error handling configuration
|
||||
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
|
||||
"""Number of retry attempts for recoverable stream errors."""
|
||||
|
||||
STREAM_ERROR_RETRY_DELAY: float = 1.0
|
||||
"""Initial delay between stream error retries (seconds)."""
|
||||
|
||||
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
|
||||
"""Backoff factor for stream error retries."""
|
||||
|
||||
# Protocol identifiers matching go-libp2p
|
||||
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000
|
||||
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration after initialization."""
|
||||
if not (self.enable_draft29 or self.enable_v1):
|
||||
raise ValueError("At least one QUIC version must be enabled")
|
||||
|
||||
if self.idle_timeout <= 0:
|
||||
raise ValueError("Idle timeout must be positive")
|
||||
|
||||
if self.max_datagram_size < 1200:
|
||||
raise ValueError("Max datagram size must be at least 1200 bytes")
|
||||
|
||||
# Validate timeouts
|
||||
timeout_fields = [
|
||||
"STREAM_OPEN_TIMEOUT",
|
||||
"STREAM_ACCEPT_TIMEOUT",
|
||||
"STREAM_READ_TIMEOUT",
|
||||
"STREAM_WRITE_TIMEOUT",
|
||||
"STREAM_CLOSE_TIMEOUT",
|
||||
]
|
||||
for timeout_field in timeout_fields:
|
||||
if getattr(self, timeout_field) <= 0:
|
||||
raise ValueError(f"{timeout_field} must be positive")
|
||||
|
||||
# Validate flow control windows
|
||||
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
|
||||
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
|
||||
|
||||
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
|
||||
raise ValueError(
|
||||
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
|
||||
)
|
||||
|
||||
# Validate buffer sizes
|
||||
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
|
||||
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
|
||||
|
||||
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
|
||||
"exceed MAX_STREAM_RECEIVE_BUFFER"
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
|
||||
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
|
||||
):
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
|
||||
)
|
||||
|
||||
# Validate memory limits
|
||||
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
|
||||
|
||||
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
|
||||
|
||||
expected_stream_memory = (
|
||||
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
|
||||
)
|
||||
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
|
||||
# Allow some headroom, but warn if configuration seems inconsistent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Stream memory configuration may be inconsistent: "
|
||||
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
|
||||
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
|
||||
"could exceed connection limit of"
|
||||
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
|
||||
)
|
||||
|
||||
def get_stream_config_dict(self) -> dict[str, Any]:
|
||||
"""Get stream-specific configuration as dictionary."""
|
||||
stream_config = {}
|
||||
for attr_name in dir(self):
|
||||
if attr_name.startswith(
|
||||
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
|
||||
):
|
||||
stream_config[attr_name.lower()] = getattr(self, attr_name)
|
||||
return stream_config
|
||||
|
||||
|
||||
# Additional configuration classes for specific stream features
|
||||
|
||||
|
||||
class QUICStreamFlowControlConfig:
|
||||
"""Configuration for QUIC stream flow control."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_window_size: int = 512 * 1024,
|
||||
max_window_size: int = 2 * 1024 * 1024,
|
||||
window_update_threshold: float = 0.5,
|
||||
enable_auto_tuning: bool = True,
|
||||
):
|
||||
self.initial_window_size = initial_window_size
|
||||
self.max_window_size = max_window_size
|
||||
self.window_update_threshold = window_update_threshold
|
||||
self.enable_auto_tuning = enable_auto_tuning
|
||||
|
||||
|
||||
def create_stream_config_for_use_case(
|
||||
use_case: Literal[
|
||||
"high_throughput", "low_latency", "many_streams", "memory_constrained"
|
||||
],
|
||||
) -> QUICTransportConfig:
|
||||
"""
|
||||
Create optimized stream configuration for specific use cases.
|
||||
|
||||
Args:
|
||||
use_case: One of "high_throughput", "low_latency", "many_streams","
|
||||
"memory_constrained"
|
||||
|
||||
Returns:
|
||||
Optimized QUICTransportConfig
|
||||
|
||||
"""
|
||||
base_config = QUICTransportConfig()
|
||||
|
||||
if use_case == "high_throughput":
|
||||
# Optimize for high throughput
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 200
|
||||
|
||||
elif use_case == "low_latency":
|
||||
# Optimize for low latency
|
||||
base_config.STREAM_OPEN_TIMEOUT = 1.0
|
||||
base_config.STREAM_READ_TIMEOUT = 5.0
|
||||
base_config.STREAM_WRITE_TIMEOUT = 5.0
|
||||
base_config.ENABLE_STREAM_BATCHING = False
|
||||
base_config.STREAM_BATCH_SIZE = 1
|
||||
|
||||
elif use_case == "many_streams":
|
||||
# Optimize for many concurrent streams
|
||||
base_config.MAX_CONCURRENT_STREAMS = 5000
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 500
|
||||
|
||||
elif use_case == "memory_constrained":
|
||||
# Optimize for low memory usage
|
||||
base_config.MAX_CONCURRENT_STREAMS = 100
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
|
||||
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 50
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown use case: {use_case}")
|
||||
|
||||
return base_config
|
||||
1489
libp2p/transport/quic/connection.py
Normal file
1489
libp2p/transport/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
391
libp2p/transport/quic/exceptions.py
Normal file
391
libp2p/transport/quic/exceptions.py
Normal file
@ -0,0 +1,391 @@
|
||||
"""
|
||||
QUIC Transport exceptions
|
||||
"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
class QUICError(Exception):
|
||||
"""Base exception for all QUIC transport errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
# Transport-level exceptions
|
||||
|
||||
|
||||
class QUICTransportError(QUICError):
|
||||
"""Base exception for QUIC transport operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICDialError(QUICTransportError):
|
||||
"""Error occurred during QUIC connection establishment."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICListenError(QUICTransportError):
|
||||
"""Error occurred during QUIC listener operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICSecurityError(QUICTransportError):
|
||||
"""Error related to QUIC security/TLS operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Connection-level exceptions
|
||||
|
||||
|
||||
class QUICConnectionError(QUICError):
|
||||
"""Base exception for QUIC connection operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionClosedError(QUICConnectionError):
|
||||
"""QUIC connection has been closed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionTimeoutError(QUICConnectionError):
|
||||
"""QUIC connection operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICHandshakeError(QUICConnectionError):
|
||||
"""Error during QUIC handshake process."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICPeerVerificationError(QUICConnectionError):
|
||||
"""Error verifying peer identity during handshake."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Stream-level exceptions
|
||||
|
||||
|
||||
class QUICStreamError(QUICError):
|
||||
"""Base exception for QUIC stream operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
):
|
||||
super().__init__(message, error_code)
|
||||
self.stream_id = stream_id
|
||||
|
||||
|
||||
class QUICStreamClosedError(QUICStreamError):
|
||||
"""Stream is closed and cannot be used for I/O operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamResetError(QUICStreamError):
|
||||
"""Stream was reset by local or remote peer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
reset_by_peer: bool = False,
|
||||
):
|
||||
super().__init__(message, stream_id, error_code)
|
||||
self.reset_by_peer = reset_by_peer
|
||||
|
||||
|
||||
class QUICStreamTimeoutError(QUICStreamError):
|
||||
"""Stream operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamBackpressureError(QUICStreamError):
|
||||
"""Stream write blocked due to flow control."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamLimitError(QUICStreamError):
|
||||
"""Stream limit reached (too many concurrent streams)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamStateError(QUICStreamError):
|
||||
"""Invalid operation for current stream state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
current_state: str | None = None,
|
||||
attempted_operation: str | None = None,
|
||||
):
|
||||
super().__init__(message, stream_id)
|
||||
self.current_state = current_state
|
||||
self.attempted_operation = attempted_operation
|
||||
|
||||
|
||||
# Flow control exceptions
|
||||
|
||||
|
||||
class QUICFlowControlError(QUICError):
|
||||
"""Base exception for flow control related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlViolationError(QUICFlowControlError):
|
||||
"""Flow control limits were violated."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlDeadlockError(QUICFlowControlError):
|
||||
"""Flow control deadlock detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Resource management exceptions
|
||||
|
||||
|
||||
class QUICResourceError(QUICError):
|
||||
"""Base exception for resource management errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICMemoryLimitError(QUICResourceError):
|
||||
"""Memory limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionLimitError(QUICResourceError):
|
||||
"""Connection limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Multiaddr and addressing exceptions
|
||||
|
||||
|
||||
class QUICAddressError(QUICError):
|
||||
"""Base exception for QUIC addressing errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICInvalidMultiaddrError(QUICAddressError):
|
||||
"""Invalid multiaddr format for QUIC transport."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICAddressResolutionError(QUICAddressError):
|
||||
"""Failed to resolve QUIC address."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICProtocolError(QUICError):
|
||||
"""Base exception for QUIC protocol errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICVersionNegotiationError(QUICProtocolError):
|
||||
"""QUIC version negotiation failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICUnsupportedVersionError(QUICProtocolError):
|
||||
"""Unsupported QUIC version."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Configuration exceptions
|
||||
|
||||
|
||||
class QUICConfigurationError(QUICError):
|
||||
"""Base exception for QUIC configuration errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICInvalidConfigError(QUICConfigurationError):
|
||||
"""Invalid QUIC configuration parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICCertificateError(QUICConfigurationError):
|
||||
"""Error with TLS certificate configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def map_quic_error_code(error_code: int) -> str:
|
||||
"""
|
||||
Map QUIC error codes to human-readable descriptions.
|
||||
Based on RFC 9000 Transport Error Codes.
|
||||
"""
|
||||
error_codes = {
|
||||
0x00: "NO_ERROR",
|
||||
0x01: "INTERNAL_ERROR",
|
||||
0x02: "CONNECTION_REFUSED",
|
||||
0x03: "FLOW_CONTROL_ERROR",
|
||||
0x04: "STREAM_LIMIT_ERROR",
|
||||
0x05: "STREAM_STATE_ERROR",
|
||||
0x06: "FINAL_SIZE_ERROR",
|
||||
0x07: "FRAME_ENCODING_ERROR",
|
||||
0x08: "TRANSPORT_PARAMETER_ERROR",
|
||||
0x09: "CONNECTION_ID_LIMIT_ERROR",
|
||||
0x0A: "PROTOCOL_VIOLATION",
|
||||
0x0B: "INVALID_TOKEN",
|
||||
0x0C: "APPLICATION_ERROR",
|
||||
0x0D: "CRYPTO_BUFFER_EXCEEDED",
|
||||
0x0E: "KEY_UPDATE_ERROR",
|
||||
0x0F: "AEAD_LIMIT_REACHED",
|
||||
0x10: "NO_VIABLE_PATH",
|
||||
}
|
||||
|
||||
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
|
||||
|
||||
|
||||
def create_stream_error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
) -> QUICStreamError:
|
||||
"""
|
||||
Factory function to create appropriate stream error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
|
||||
message: Error message
|
||||
stream_id: Stream identifier
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICStreamError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICStreamClosedError(message, stream_id, error_code)
|
||||
elif error_type == "reset":
|
||||
return QUICStreamResetError(message, stream_id, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICStreamTimeoutError(message, stream_id, error_code)
|
||||
elif error_type in ("backpressure", "flow_control"):
|
||||
return QUICStreamBackpressureError(message, stream_id, error_code)
|
||||
elif error_type in ("limit", "stream_limit"):
|
||||
return QUICStreamLimitError(message, stream_id, error_code)
|
||||
elif error_type == "state":
|
||||
return QUICStreamStateError(message, stream_id)
|
||||
else:
|
||||
return QUICStreamError(message, stream_id, error_code)
|
||||
|
||||
|
||||
def create_connection_error(
|
||||
error_type: str, message: str, error_code: int | None = None
|
||||
) -> QUICConnectionError:
|
||||
"""
|
||||
Factory function to create appropriate connection error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "timeout", "handshake", etc.)
|
||||
message: Error message
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICConnectionError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICConnectionClosedError(message, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICConnectionTimeoutError(message, error_code)
|
||||
elif error_type == "handshake":
|
||||
return QUICHandshakeError(message, error_code)
|
||||
elif error_type in ("peer_verification", "verification"):
|
||||
return QUICPeerVerificationError(message, error_code)
|
||||
else:
|
||||
return QUICConnectionError(message, error_code)
|
||||
|
||||
|
||||
class QUICErrorContext:
|
||||
"""
|
||||
Context manager for handling QUIC errors with automatic error mapping.
|
||||
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
|
||||
"""
|
||||
|
||||
def __init__(self, operation: str, component: str = "quic") -> None:
|
||||
self.operation = operation
|
||||
self.component = component
|
||||
|
||||
def __enter__(self) -> "QUICErrorContext":
|
||||
return self
|
||||
|
||||
# TODO: Fix types for exc_type
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> Literal[False]:
|
||||
if exc_type is None:
|
||||
return False
|
||||
|
||||
if exc_val is None:
|
||||
return False
|
||||
|
||||
# Map common aioquic exceptions to our exceptions
|
||||
if "ConnectionClosed" in str(exc_type):
|
||||
raise QUICConnectionClosedError(
|
||||
f"Connection closed during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "StreamReset" in str(exc_type):
|
||||
raise QUICStreamResetError(
|
||||
f"Stream reset during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "timeout" in str(exc_val).lower():
|
||||
if "stream" in self.component.lower():
|
||||
raise QUICStreamTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
else:
|
||||
raise QUICConnectionTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "flow control" in str(exc_val).lower():
|
||||
raise QUICStreamBackpressureError(
|
||||
f"Flow control error during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
|
||||
# Let other exceptions propagate
|
||||
return False
|
||||
1041
libp2p/transport/quic/listener.py
Normal file
1041
libp2p/transport/quic/listener.py
Normal file
File diff suppressed because it is too large
Load Diff
1165
libp2p/transport/quic/security.py
Normal file
1165
libp2p/transport/quic/security.py
Normal file
File diff suppressed because it is too large
Load Diff
656
libp2p/transport/quic/stream.py
Normal file
656
libp2p/transport/quic/stream.py
Normal file
@ -0,0 +1,656 @@
|
||||
"""
|
||||
QUIC Stream implementation
|
||||
Provides stream interface over QUIC's native multiplexing.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import trio
|
||||
|
||||
from .exceptions import (
|
||||
QUICStreamBackpressureError,
|
||||
QUICStreamClosedError,
|
||||
QUICStreamResetError,
|
||||
QUICStreamTimeoutError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import IMuxedStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
|
||||
from .connection import QUICConnection
|
||||
else:
|
||||
IMuxedStream = cast(type, object)
|
||||
TProtocol = cast(type, object)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamState(Enum):
|
||||
"""Stream lifecycle states following libp2p patterns."""
|
||||
|
||||
OPEN = "open"
|
||||
WRITE_CLOSED = "write_closed"
|
||||
READ_CLOSED = "read_closed"
|
||||
CLOSED = "closed"
|
||||
RESET = "reset"
|
||||
|
||||
|
||||
class StreamDirection(Enum):
|
||||
"""Stream direction for tracking initiator."""
|
||||
|
||||
INBOUND = "inbound"
|
||||
OUTBOUND = "outbound"
|
||||
|
||||
|
||||
class StreamTimeline:
|
||||
"""Track stream lifecycle events for debugging and monitoring."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.created_at = time.time()
|
||||
self.opened_at: float | None = None
|
||||
self.first_data_at: float | None = None
|
||||
self.closed_at: float | None = None
|
||||
self.reset_at: float | None = None
|
||||
self.error_code: int | None = None
|
||||
|
||||
def record_open(self) -> None:
|
||||
self.opened_at = time.time()
|
||||
|
||||
def record_first_data(self) -> None:
|
||||
if self.first_data_at is None:
|
||||
self.first_data_at = time.time()
|
||||
|
||||
def record_close(self) -> None:
|
||||
self.closed_at = time.time()
|
||||
|
||||
def record_reset(self, error_code: int) -> None:
|
||||
self.reset_at = time.time()
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class QUICStream(IMuxedStream):
|
||||
"""
|
||||
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||
|
||||
Based on patterns from go-libp2p and js-libp2p, this implementation:
|
||||
- Leverages QUIC's native multiplexing and flow control
|
||||
- Integrates with libp2p resource management
|
||||
- Provides comprehensive error handling with QUIC-specific codes
|
||||
- Supports bidirectional communication with independent close semantics
|
||||
- Implements proper stream lifecycle management
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "QUICConnection",
|
||||
stream_id: int,
|
||||
direction: StreamDirection,
|
||||
remote_addr: tuple[str, int],
|
||||
resource_scope: Any | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize QUIC stream.
|
||||
|
||||
Args:
|
||||
connection: Parent QUIC connection
|
||||
stream_id: QUIC stream identifier
|
||||
direction: Stream direction (inbound/outbound)
|
||||
resource_scope: Resource manager scope for memory accounting
|
||||
remote_addr: Remote addr stream is connected to
|
||||
|
||||
"""
|
||||
self._connection = connection
|
||||
self._stream_id = stream_id
|
||||
self._direction = direction
|
||||
self._resource_scope = resource_scope
|
||||
|
||||
# libp2p interface compliance
|
||||
self._protocol: TProtocol | None = None
|
||||
self._metadata: dict[str, Any] = {}
|
||||
self._remote_addr = remote_addr
|
||||
|
||||
# Stream state management
|
||||
self._state = StreamState.OPEN
|
||||
self._state_lock = trio.Lock()
|
||||
|
||||
# Flow control and buffering
|
||||
self._receive_buffer = bytearray()
|
||||
self._receive_buffer_lock = trio.Lock()
|
||||
self._receive_event = trio.Event()
|
||||
self._backpressure_event = trio.Event()
|
||||
self._backpressure_event.set() # Initially no backpressure
|
||||
|
||||
# Close/reset state
|
||||
self._write_closed = False
|
||||
self._read_closed = False
|
||||
self._close_event = trio.Event()
|
||||
self._reset_error_code: int | None = None
|
||||
|
||||
# Lifecycle tracking
|
||||
self._timeline = StreamTimeline()
|
||||
self._timeline.record_open()
|
||||
|
||||
# Resource accounting
|
||||
self._memory_reserved = 0
|
||||
|
||||
# Stream constant configurations
|
||||
self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT
|
||||
self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT
|
||||
self.FLOW_CONTROL_WINDOW_SIZE = (
|
||||
connection._transport._config.STREAM_FLOW_CONTROL_WINDOW
|
||||
)
|
||||
self.MAX_RECEIVE_BUFFER_SIZE = (
|
||||
connection._transport._config.MAX_STREAM_RECEIVE_BUFFER
|
||||
)
|
||||
|
||||
if self._resource_scope:
|
||||
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
|
||||
|
||||
logger.debug(
|
||||
f"Created QUIC stream {stream_id} "
|
||||
f"({direction.value}, connection: {connection.remote_peer_id()})"
|
||||
)
|
||||
|
||||
# Properties for libp2p interface compliance
|
||||
|
||||
@property
|
||||
def protocol(self) -> TProtocol | None:
|
||||
"""Get the protocol identifier for this stream."""
|
||||
return self._protocol
|
||||
|
||||
@protocol.setter
|
||||
def protocol(self, protocol_id: TProtocol) -> None:
|
||||
"""Set the protocol identifier for this stream."""
|
||||
self._protocol = protocol_id
|
||||
self._metadata["protocol"] = protocol_id
|
||||
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str:
|
||||
"""Get stream ID as string for libp2p compatibility."""
|
||||
return str(self._stream_id)
|
||||
|
||||
@property
|
||||
def muxed_conn(self) -> "QUICConnection": # type: ignore
|
||||
"""Get the parent muxed connection."""
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def state(self) -> StreamState:
|
||||
"""Get current stream state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def direction(self) -> StreamDirection:
|
||||
"""Get stream direction."""
|
||||
return self._direction
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
"""Check if this stream was locally initiated."""
|
||||
return self._direction == StreamDirection.OUTBOUND
|
||||
|
||||
# Core stream operations
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read data from the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
n: Maximum number of bytes to read. If None or -1, read all available.
|
||||
|
||||
Returns:
|
||||
Data read from stream
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed
|
||||
QUICStreamResetError: Stream was reset
|
||||
QUICStreamTimeoutError: Read timeout exceeded
|
||||
|
||||
"""
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._read_closed:
|
||||
# Return any remaining buffered data, then EOF
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
return b""
|
||||
|
||||
# Wait for data with timeout
|
||||
timeout = self.READ_TIMEOUT
|
||||
try:
|
||||
with trio.move_on_after(timeout) as cancel_scope:
|
||||
while True:
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
|
||||
# Check if stream was closed while waiting
|
||||
if self._read_closed:
|
||||
return b""
|
||||
|
||||
# Wait for more data
|
||||
await self._receive_event.wait()
|
||||
self._receive_event = trio.Event() # Reset for next wait
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
|
||||
|
||||
return b""
|
||||
except QUICStreamResetError:
|
||||
# Stream was reset while reading
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from stream {self.stream_id}: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
data: Data to write
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed for writing
|
||||
QUICStreamBackpressureError: Flow control window exhausted
|
||||
QUICStreamResetError: Stream was reset
|
||||
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._write_closed:
|
||||
raise QUICStreamClosedError(
|
||||
f"Stream {self.stream_id} write side is closed"
|
||||
)
|
||||
|
||||
try:
|
||||
# Handle flow control backpressure
|
||||
await self._backpressure_event.wait()
|
||||
|
||||
# Send data through QUIC connection
|
||||
self._connection._quic.send_stream_data(self._stream_id, data)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._timeline.record_first_data()
|
||||
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to stream {self.stream_id}: {e}")
|
||||
# Convert QUIC-specific errors
|
||||
if "flow control" in str(e).lower():
|
||||
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the stream gracefully (both read and write sides).
|
||||
|
||||
This implements proper close semantics where both sides
|
||||
are closed and resources are cleaned up.
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
return
|
||||
|
||||
logger.debug(f"Closing stream {self.stream_id}")
|
||||
|
||||
# Close both sides
|
||||
if not self._write_closed:
|
||||
await self.close_write()
|
||||
if not self._read_closed:
|
||||
await self.close_read()
|
||||
|
||||
# Update state and cleanup
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.CLOSED
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_close()
|
||||
self._close_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} closed")
|
||||
|
||||
async def close_write(self) -> None:
|
||||
"""Close the write side of the stream."""
|
||||
if self._write_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
# Send FIN to close write side
|
||||
self._connection._quic.send_stream_data(
|
||||
self._stream_id, b"", end_stream=True
|
||||
)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} write side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def close_read(self) -> None:
|
||||
"""Close the read side of the stream."""
|
||||
if self._read_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
self._read_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up any pending reads
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} read side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def reset(self, error_code: int = 0) -> None:
|
||||
"""
|
||||
Reset the stream with the given error code.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code for the reset
|
||||
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Resetting stream {self.stream_id} with error code {error_code}"
|
||||
)
|
||||
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
try:
|
||||
# Send QUIC reset frame
|
||||
self._connection._quic.reset_stream(self._stream_id, error_code)
|
||||
await self._connection._transmit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
|
||||
finally:
|
||||
# Always cleanup resources
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if stream is completely closed."""
|
||||
return self._state in (StreamState.CLOSED, StreamState.RESET)
|
||||
|
||||
def is_reset(self) -> bool:
|
||||
"""Check if stream was reset."""
|
||||
return self._state == StreamState.RESET
|
||||
|
||||
def can_read(self) -> bool:
|
||||
"""Check if stream can be read from."""
|
||||
return not self._read_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
def can_write(self) -> bool:
|
||||
"""Check if stream can be written to."""
|
||||
return not self._write_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
||||
"""
|
||||
Handle data received from the QUIC connection.
|
||||
|
||||
Args:
|
||||
data: Received data
|
||||
end_stream: Whether this is the last data (FIN received)
|
||||
|
||||
"""
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
if data:
|
||||
async with self._receive_buffer_lock:
|
||||
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
|
||||
logger.warning(
|
||||
f"Stream {self.stream_id} receive buffer overflow, "
|
||||
f"dropping {len(data)} bytes"
|
||||
)
|
||||
return
|
||||
|
||||
self._receive_buffer.extend(data)
|
||||
self._timeline.record_first_data()
|
||||
|
||||
# Notify waiting readers
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
|
||||
|
||||
if end_stream:
|
||||
self._read_closed = True
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up readers to process remaining data and EOF
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received FIN")
|
||||
|
||||
async def handle_stop_sending(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle STOP_SENDING frame from remote peer.
|
||||
|
||||
When a STOP_SENDING frame is received, the peer is requesting that we
|
||||
stop sending data on this stream. We respond by resetting the stream.
|
||||
|
||||
Args:
|
||||
error_code: Error code from the STOP_SENDING frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
|
||||
)
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
# Wake up any pending write operations
|
||||
self._backpressure_event.set()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.direction == StreamDirection.OUTBOUND:
|
||||
self._state = StreamState.CLOSED
|
||||
elif self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
# Only write side closed - add WRITE_CLOSED state if needed
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
# Send RESET_STREAM in response (QUIC protocol requirement)
|
||||
try:
|
||||
self._connection._quic.reset_stream(int(self.stream_id), error_code)
|
||||
await self._connection._transmit()
|
||||
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def handle_reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle stream reset from remote peer.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code from reset frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} reset by peer with error code {error_code}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
# Wake up any pending operations
|
||||
self._receive_event.set()
|
||||
self._backpressure_event.set()
|
||||
|
||||
async def handle_flow_control_update(self, available_window: int) -> None:
|
||||
"""
|
||||
Handle flow control window updates.
|
||||
|
||||
Args:
|
||||
available_window: Available flow control window size
|
||||
|
||||
"""
|
||||
if available_window > 0:
|
||||
self._backpressure_event.set()
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} flow control".__add__(
|
||||
f"window updated: {available_window}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._backpressure_event = trio.Event() # Reset to blocking state
|
||||
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
|
||||
|
||||
def _extract_data_from_buffer(self, n: int) -> bytes:
|
||||
"""Extract data from receive buffer with specified limit."""
|
||||
if n == -1:
|
||||
# Read all available data
|
||||
data = bytes(self._receive_buffer)
|
||||
self._receive_buffer.clear()
|
||||
else:
|
||||
# Read up to n bytes
|
||||
data = bytes(self._receive_buffer[:n])
|
||||
self._receive_buffer = self._receive_buffer[n:]
|
||||
|
||||
return data
|
||||
|
||||
async def _handle_stream_error(self, error: Exception) -> None:
|
||||
"""Handle errors by resetting the stream."""
|
||||
logger.error(f"Stream {self.stream_id} error: {error}")
|
||||
await self.reset(error_code=1) # Generic error code
|
||||
|
||||
def _reserve_memory(self, size: int) -> None:
|
||||
"""Reserve memory with resource manager."""
|
||||
if self._resource_scope:
|
||||
try:
|
||||
self._resource_scope.reserve_memory(size)
|
||||
self._memory_reserved += size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to reserve memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
def _release_memory(self, size: int) -> None:
|
||||
"""Release memory with resource manager."""
|
||||
if self._resource_scope and size > 0:
|
||||
try:
|
||||
self._resource_scope.release_memory(size)
|
||||
self._memory_reserved = max(0, self._memory_reserved - size)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to release memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def _cleanup_resources(self) -> None:
|
||||
"""Clean up stream resources."""
|
||||
# Release all reserved memory
|
||||
if self._memory_reserved > 0:
|
||||
self._release_memory(self._memory_reserved)
|
||||
|
||||
# Clear receive buffer
|
||||
async with self._receive_buffer_lock:
|
||||
self._receive_buffer.clear()
|
||||
|
||||
# Remove from connection's stream registry
|
||||
self._connection._remove_stream(self._stream_id)
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} resources cleaned up")
|
||||
|
||||
# Abstact implementations
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int]:
|
||||
return self._remote_addr
|
||||
|
||||
async def __aenter__(self) -> "QUICStream":
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the stream."""
|
||||
logger.debug("Exiting the context and closing the stream")
|
||||
await self.close()
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
Set a deadline for the stream. QUIC does not support deadlines natively,
|
||||
so this method always returns False to indicate the operation is unsupported.
|
||||
|
||||
:param ttl: Time-to-live in seconds (ignored).
|
||||
:return: False, as deadlines are not supported.
|
||||
"""
|
||||
raise NotImplementedError("QUIC does not support setting read deadlines")
|
||||
|
||||
# String representation for debugging
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"QUICStream(id={self.stream_id}, "
|
||||
f"state={self._state.value}, "
|
||||
f"direction={self._direction.value}, "
|
||||
f"protocol={self._protocol})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"QUICStream({self.stream_id})"
|
||||
491
libp2p/transport/quic/transport.py
Normal file
491
libp2p/transport/quic/transport.py
Normal file
@ -0,0 +1,491 @@
|
||||
"""
|
||||
QUIC Transport implementation
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import ssl
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from aioquic.quic.configuration import (
|
||||
QuicConfiguration,
|
||||
)
|
||||
from aioquic.quic.connection import (
|
||||
QuicConnection as NativeQUICConnection,
|
||||
)
|
||||
from aioquic.quic.logger import QuicLogger
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
)
|
||||
from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.transport.quic.security import QUICTLSSecurityConfig
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_client_config_from_base,
|
||||
create_server_config_from_base,
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
quic_multiaddr_to_endpoint,
|
||||
quic_version_to_wire_format,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.network.swarm import Swarm
|
||||
else:
|
||||
Swarm = cast(type, object)
|
||||
|
||||
from .config import (
|
||||
QUICTransportConfig,
|
||||
)
|
||||
from .connection import (
|
||||
QUICConnection,
|
||||
)
|
||||
from .exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
QUICSecurityError,
|
||||
)
|
||||
from .listener import (
|
||||
QUICListener,
|
||||
)
|
||||
from .security import (
|
||||
QUICTLSConfigManager,
|
||||
create_quic_security_transport,
|
||||
)
|
||||
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QUICTransport(ITransport):
|
||||
"""
|
||||
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Initialize QUIC transport with security integration.
|
||||
|
||||
Args:
|
||||
private_key: libp2p private key for identity and TLS cert generation
|
||||
config: QUIC transport configuration options
|
||||
|
||||
"""
|
||||
self._private_key = private_key
|
||||
self._peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
self._config = config or QUICTransportConfig()
|
||||
|
||||
# Connection management
|
||||
self._connections: dict[str, QUICConnection] = {}
|
||||
self._listeners: list[QUICListener] = []
|
||||
|
||||
# Security manager for TLS integration
|
||||
self._security_manager = create_quic_security_transport(
|
||||
self._private_key, self._peer_id
|
||||
)
|
||||
|
||||
# QUIC configurations for different versions
|
||||
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
|
||||
self._setup_quic_configurations()
|
||||
|
||||
# Resource management
|
||||
self._closed = False
|
||||
self._nursery_manager = trio.CapacityLimiter(1)
|
||||
self._background_nursery: trio.Nursery | None = None
|
||||
|
||||
self._swarm: Swarm | None = None
|
||||
|
||||
logger.debug(
|
||||
f"Initialized QUIC transport with security for peer {self._peer_id}"
|
||||
)
|
||||
|
||||
def set_background_nursery(self, nursery: trio.Nursery) -> None:
|
||||
"""Set the nursery to use for background tasks (called by swarm)."""
|
||||
self._background_nursery = nursery
|
||||
logger.debug("Transport background nursery set")
|
||||
|
||||
def set_swarm(self, swarm: Swarm) -> None:
|
||||
"""Set the swarm for adding incoming connections."""
|
||||
self._swarm = swarm
|
||||
|
||||
def _setup_quic_configurations(self) -> None:
|
||||
"""Setup QUIC configurations."""
|
||||
try:
|
||||
# Get TLS configuration from security manager
|
||||
server_tls_config = self._security_manager.create_server_config()
|
||||
client_tls_config = self._security_manager.create_client_config()
|
||||
|
||||
# Base server configuration
|
||||
base_server_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=get_alpn_protocols(),
|
||||
verify_mode=self._config.verify_mode,
|
||||
max_datagram_frame_size=self._config.max_datagram_size,
|
||||
idle_timeout=self._config.idle_timeout,
|
||||
)
|
||||
|
||||
# Base client configuration
|
||||
base_client_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=get_alpn_protocols(),
|
||||
verify_mode=self._config.verify_mode,
|
||||
max_datagram_frame_size=self._config.max_datagram_size,
|
||||
idle_timeout=self._config.idle_timeout,
|
||||
)
|
||||
|
||||
# Apply TLS configuration
|
||||
self._apply_tls_configuration(base_server_config, server_tls_config)
|
||||
self._apply_tls_configuration(base_client_config, client_tls_config)
|
||||
|
||||
# QUIC v1 (RFC 9000) configurations
|
||||
if self._config.enable_v1:
|
||||
quic_v1_server_config = create_server_config_from_base(
|
||||
base_server_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
quic_v1_client_config = create_client_config_from_base(
|
||||
base_client_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
# Store both server and client configs for v1
|
||||
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = (
|
||||
quic_v1_server_config
|
||||
)
|
||||
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = (
|
||||
quic_v1_client_config
|
||||
)
|
||||
|
||||
# QUIC draft-29 configurations for compatibility
|
||||
if self._config.enable_draft29:
|
||||
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
|
||||
draft29_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
|
||||
draft29_client_config = copy.copy(base_client_config)
|
||||
draft29_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
|
||||
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = (
|
||||
draft29_server_config
|
||||
)
|
||||
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = (
|
||||
draft29_client_config
|
||||
)
|
||||
|
||||
logger.debug("QUIC configurations initialized with libp2p TLS security")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(
|
||||
f"Failed to setup QUIC TLS configurations: {e}"
|
||||
) from e
|
||||
|
||||
def _apply_tls_configuration(
|
||||
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
|
||||
) -> None:
|
||||
"""
|
||||
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
||||
|
||||
Args:
|
||||
config: QuicConfiguration to update
|
||||
tls_config: TLS configuration dictionary from security manager
|
||||
|
||||
"""
|
||||
try:
|
||||
config.certificate = tls_config.certificate
|
||||
config.private_key = tls_config.private_key
|
||||
config.certificate_chain = tls_config.certificate_chain
|
||||
config.alpn_protocols = tls_config.alpn_protocols
|
||||
config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.debug("Successfully applied TLS configuration to QUIC config")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
|
||||
|
||||
async def dial(
|
||||
self,
|
||||
maddr: multiaddr.Multiaddr,
|
||||
) -> QUICConnection:
|
||||
"""
|
||||
Dial a remote peer using QUIC transport with security verification.
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1)
|
||||
peer_id: Expected peer ID for verification
|
||||
nursery: Nursery to execute the background tasks
|
||||
|
||||
Returns:
|
||||
Raw connection interface to the remote peer
|
||||
|
||||
Raises:
|
||||
QUICDialError: If dialing fails
|
||||
QUICSecurityError: If security verification fails
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise QUICDialError("Transport is closed")
|
||||
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
|
||||
|
||||
try:
|
||||
# Extract connection details from multiaddr
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
remote_peer_id = maddr.get_peer_id()
|
||||
if remote_peer_id is not None:
|
||||
remote_peer_id = ID.from_base58(remote_peer_id)
|
||||
|
||||
if remote_peer_id is None:
|
||||
logger.error("Unable to derive peer id from multiaddr")
|
||||
raise QUICDialError("Unable to derive peer id from multiaddr")
|
||||
quic_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
# Get appropriate QUIC client configuration
|
||||
config_key = TProtocol(f"{quic_version}_client")
|
||||
logger.debug("config_key", config_key, self._quic_configs.keys())
|
||||
config = self._quic_configs.get(config_key)
|
||||
if not config:
|
||||
raise QUICDialError(f"Unsupported QUIC version: {quic_version}")
|
||||
|
||||
config.is_client = True
|
||||
config.quic_logger = QuicLogger()
|
||||
|
||||
# Ensure client certificate is properly set for mutual authentication
|
||||
if not config.certificate or not config.private_key:
|
||||
logger.warning(
|
||||
"Client config missing certificate - applying TLS config"
|
||||
)
|
||||
client_tls_config = self._security_manager.create_client_config()
|
||||
self._apply_tls_configuration(config, client_tls_config)
|
||||
|
||||
# Debug log to verify certificate is present
|
||||
logger.info(
|
||||
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
|
||||
)
|
||||
|
||||
logger.debug("Starting QUIC Connection")
|
||||
# Create QUIC connection using aioquic's sans-IO core
|
||||
native_quic_connection = NativeQUICConnection(configuration=config)
|
||||
|
||||
# Create trio-based QUIC connection wrapper with security
|
||||
connection = QUICConnection(
|
||||
quic_connection=native_quic_connection,
|
||||
remote_addr=(host, port),
|
||||
remote_peer_id=remote_peer_id,
|
||||
local_peer_id=self._peer_id,
|
||||
is_initiator=True,
|
||||
maddr=maddr,
|
||||
transport=self,
|
||||
security_manager=self._security_manager,
|
||||
)
|
||||
logger.debug("QUIC Connection Created")
|
||||
|
||||
if self._background_nursery is None:
|
||||
logger.error("No nursery set to execute background tasks")
|
||||
raise QUICDialError("No nursery found to execute tasks")
|
||||
|
||||
await connection.connect(self._background_nursery)
|
||||
|
||||
# Store connection for management
|
||||
conn_id = f"{host}:{port}"
|
||||
self._connections[conn_id] = connection
|
||||
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
|
||||
raise QUICDialError(f"Dial failed: {e}") from e
|
||||
|
||||
async def _verify_peer_identity(
|
||||
self, connection: QUICConnection, expected_peer_id: ID
|
||||
) -> None:
|
||||
"""
|
||||
Verify remote peer identity after TLS handshake.
|
||||
|
||||
Args:
|
||||
connection: The established QUIC connection
|
||||
expected_peer_id: Expected peer ID
|
||||
|
||||
Raises:
|
||||
QUICSecurityError: If peer verification fails
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get peer certificate from the connection
|
||||
peer_certificate = await connection.get_peer_certificate()
|
||||
|
||||
if not peer_certificate:
|
||||
raise QUICSecurityError("No peer certificate available")
|
||||
|
||||
# Verify peer identity using security manager
|
||||
verified_peer_id = self._security_manager.verify_peer_identity(
|
||||
peer_certificate, expected_peer_id
|
||||
)
|
||||
|
||||
if verified_peer_id != expected_peer_id:
|
||||
raise QUICSecurityError(
|
||||
"Peer ID verification failed: expected "
|
||||
f"{expected_peer_id}, got {verified_peer_id}"
|
||||
)
|
||||
|
||||
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Peer identity verification failed: {e}") from e
|
||||
|
||||
def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener:
|
||||
"""
|
||||
Create a QUIC listener with integrated security.
|
||||
|
||||
Args:
|
||||
handler_function: Function to handle new connections
|
||||
|
||||
Returns:
|
||||
QUIC listener instance
|
||||
|
||||
Raises:
|
||||
QUICListenError: If transport is closed
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise QUICListenError("Transport is closed")
|
||||
|
||||
# Get server configurations for the listener
|
||||
server_configs = {
|
||||
version: config
|
||||
for version, config in self._quic_configs.items()
|
||||
if version.endswith("_server")
|
||||
}
|
||||
|
||||
listener = QUICListener(
|
||||
transport=self,
|
||||
handler_function=handler_function,
|
||||
quic_configs=server_configs,
|
||||
config=self._config,
|
||||
security_manager=self._security_manager,
|
||||
)
|
||||
|
||||
self._listeners.append(listener)
|
||||
logger.debug("Created QUIC listener with security")
|
||||
return listener
|
||||
|
||||
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""
|
||||
Check if this transport can dial the given multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to check
|
||||
|
||||
Returns:
|
||||
True if this transport can dial the address
|
||||
|
||||
"""
|
||||
return is_quic_multiaddr(maddr)
|
||||
|
||||
def protocols(self) -> list[TProtocol]:
|
||||
"""
|
||||
Get supported protocol identifiers.
|
||||
|
||||
Returns:
|
||||
List of supported protocol strings
|
||||
|
||||
"""
|
||||
protocols = [QUIC_V1_PROTOCOL]
|
||||
if self._config.enable_draft29:
|
||||
protocols.append(QUIC_DRAFT29_PROTOCOL)
|
||||
return protocols
|
||||
|
||||
def listen_order(self) -> int:
|
||||
"""
|
||||
Get the listen order priority for this transport.
|
||||
Matches go-libp2p's ListenOrder = 1 for QUIC.
|
||||
|
||||
Returns:
|
||||
Priority order for listening (lower = higher priority)
|
||||
|
||||
"""
|
||||
return 1
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
logger.debug("Closing QUIC transport")
|
||||
|
||||
# Close all active connections and listeners concurrently using trio nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Close all connections
|
||||
for connection in self._connections.values():
|
||||
nursery.start_soon(connection.close)
|
||||
|
||||
# Close all listeners
|
||||
for listener in self._listeners:
|
||||
nursery.start_soon(listener.close)
|
||||
|
||||
self._connections.clear()
|
||||
self._listeners.clear()
|
||||
|
||||
logger.debug("QUIC transport closed")
|
||||
|
||||
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
|
||||
"""Clean up a terminated connection from all listeners."""
|
||||
try:
|
||||
for listener in self._listeners:
|
||||
await listener._remove_connection_by_object(connection)
|
||||
logger.debug(
|
||||
"✅ TRANSPORT: Cleaned up terminated connection from all listeners"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ TRANSPORT: Error cleaning up terminated connection: {e}")
|
||||
|
||||
def get_stats(self) -> dict[str, int | list[str] | object]:
|
||||
"""Get transport statistics including security info."""
|
||||
return {
|
||||
"active_connections": len(self._connections),
|
||||
"active_listeners": len(self._listeners),
|
||||
"supported_protocols": self.protocols(),
|
||||
"local_peer_id": str(self._peer_id),
|
||||
"security_enabled": True,
|
||||
"tls_configured": True,
|
||||
}
|
||||
|
||||
def get_security_manager(self) -> QUICTLSConfigManager:
|
||||
"""
|
||||
Get the security manager for this transport.
|
||||
|
||||
Returns:
|
||||
The QUIC TLS configuration manager
|
||||
|
||||
"""
|
||||
return self._security_manager
|
||||
|
||||
def get_listener_socket(self) -> trio.socket.SocketType | None:
|
||||
"""Get the socket from the first active listener."""
|
||||
for listener in self._listeners:
|
||||
if listener.is_listening() and listener._socket:
|
||||
return listener._socket
|
||||
return None
|
||||
466
libp2p/transport/quic/utils.py
Normal file
466
libp2p/transport/quic/utils.py
Normal file
@ -0,0 +1,466 @@
|
||||
"""
|
||||
Multiaddr utilities for QUIC transport - Module 4.
|
||||
Essential utilities required for QUIC transport implementation.
|
||||
Based on go-libp2p and js-libp2p QUIC implementations.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
import multiaddr
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||
|
||||
from .config import QUICTransportConfig
|
||||
from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Protocol constants
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
UDP_PROTOCOL = "udp"
|
||||
IP4_PROTOCOL = "ip4"
|
||||
IP6_PROTOCOL = "ip6"
|
||||
|
||||
SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server"
|
||||
CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client"
|
||||
|
||||
SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server"
|
||||
CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client"
|
||||
|
||||
CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = {
|
||||
SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000
|
||||
CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000
|
||||
SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||
CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||
}
|
||||
|
||||
# QUIC version to wire format mappings (required for aioquic)
|
||||
QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = {
|
||||
QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000
|
||||
QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29
|
||||
}
|
||||
|
||||
# ALPN protocols for libp2p over QUIC
|
||||
LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"]
|
||||
|
||||
|
||||
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""
|
||||
Check if a multiaddr represents a QUIC address.
|
||||
|
||||
Valid QUIC multiaddrs:
|
||||
- /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
- /ip4/127.0.0.1/udp/4001/quic
|
||||
- /ip6/::1/udp/4001/quic-v1
|
||||
- /ip6/::1/udp/4001/quic
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to check
|
||||
|
||||
Returns:
|
||||
True if the multiaddr represents a QUIC address
|
||||
|
||||
"""
|
||||
try:
|
||||
addr_str = str(maddr)
|
||||
|
||||
# Check for required components
|
||||
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
|
||||
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
|
||||
has_quic = (
|
||||
f"/{QUIC_V1_PROTOCOL}" in addr_str
|
||||
or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str
|
||||
or "/quic" in addr_str
|
||||
)
|
||||
|
||||
return has_ip and has_udp and has_quic
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
|
||||
"""
|
||||
Extract host and port from a QUIC multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Tuple of (host, port)
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}")
|
||||
|
||||
try:
|
||||
host = None
|
||||
port = None
|
||||
|
||||
# Try to get IPv4 address
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to get IPv6 address if IPv4 not found
|
||||
if host is None:
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get UDP port
|
||||
try:
|
||||
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
|
||||
port = int(port_str)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if host is None or port is None:
|
||||
raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}")
|
||||
|
||||
return host, port
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(
|
||||
f"Failed to parse QUIC multiaddr {maddr}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
|
||||
"""
|
||||
Determine QUIC version from multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
QUIC version identifier ("quic-v1" or "quic")
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol
|
||||
|
||||
"""
|
||||
try:
|
||||
addr_str = str(maddr)
|
||||
|
||||
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
|
||||
return QUIC_V1_PROTOCOL # RFC 9000
|
||||
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
|
||||
return QUIC_DRAFT29_PROTOCOL # draft-29
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(
|
||||
f"Failed to determine QUIC version from {maddr}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def create_quic_multiaddr(
|
||||
host: str, port: int, version: str = "quic-v1"
|
||||
) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Create a QUIC multiaddr from host, port, and version.
|
||||
|
||||
Args:
|
||||
host: IP address (IPv4 or IPv6)
|
||||
port: UDP port number
|
||||
version: QUIC version ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
QUIC multiaddr
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If invalid parameters provided
|
||||
|
||||
"""
|
||||
try:
|
||||
# Determine IP version
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
ip_proto = IP4_PROTOCOL
|
||||
else:
|
||||
ip_proto = IP6_PROTOCOL
|
||||
except ValueError:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}")
|
||||
|
||||
# Validate port
|
||||
if not (0 <= port <= 65535):
|
||||
raise QUICInvalidMultiaddrError(f"Invalid port: {port}")
|
||||
|
||||
# Validate and normalize QUIC version
|
||||
if version == "quic-v1" or version == "/quic-v1":
|
||||
quic_proto = QUIC_V1_PROTOCOL
|
||||
elif version == "quic" or version == "/quic":
|
||||
quic_proto = QUIC_DRAFT29_PROTOCOL
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
|
||||
|
||||
# Construct multiaddr
|
||||
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
|
||||
return multiaddr.Multiaddr(addr_str)
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e
|
||||
|
||||
|
||||
def quic_version_to_wire_format(version: TProtocol) -> int:
|
||||
"""
|
||||
Convert QUIC version string to wire format integer for aioquic.
|
||||
|
||||
Args:
|
||||
version: QUIC version string ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
Wire format version number
|
||||
|
||||
Raises:
|
||||
QUICUnsupportedVersionError: If version is not supported
|
||||
|
||||
"""
|
||||
wire_version = QUIC_VERSION_MAPPINGS.get(version)
|
||||
if wire_version is None:
|
||||
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||
|
||||
return wire_version
|
||||
|
||||
|
||||
def custom_quic_version_to_wire_format(version: TProtocol) -> int:
|
||||
"""
|
||||
Convert QUIC version string to wire format integer for aioquic.
|
||||
|
||||
Args:
|
||||
version: QUIC version string ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
Wire format version number
|
||||
|
||||
Raises:
|
||||
QUICUnsupportedVersionError: If version is not supported
|
||||
|
||||
"""
|
||||
wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version)
|
||||
if wire_version is None:
|
||||
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||
|
||||
return wire_version
|
||||
|
||||
|
||||
def get_alpn_protocols() -> list[str]:
|
||||
"""
|
||||
Get ALPN protocols for libp2p over QUIC.
|
||||
|
||||
Returns:
|
||||
List of ALPN protocol identifiers
|
||||
|
||||
"""
|
||||
return LIBP2P_ALPN_PROTOCOLS.copy()
|
||||
|
||||
|
||||
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Normalize a QUIC multiaddr to canonical form.
|
||||
|
||||
Args:
|
||||
maddr: Input QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Normalized multiaddr
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If not a valid QUIC multiaddr
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}")
|
||||
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
return create_quic_multiaddr(host, port, version)
|
||||
|
||||
|
||||
def create_server_config_from_base(
|
||||
base_config: QuicConfiguration,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
transport_config: QUICTransportConfig | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""
|
||||
Create a server configuration without using deepcopy.
|
||||
Manually copies attributes while handling cryptography objects properly.
|
||||
"""
|
||||
try:
|
||||
# Create new server configuration from scratch
|
||||
server_config = QuicConfiguration(is_client=False)
|
||||
server_config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Copy basic configuration attributes (these are safe to copy)
|
||||
copyable_attrs = [
|
||||
"alpn_protocols",
|
||||
"verify_mode",
|
||||
"max_datagram_frame_size",
|
||||
"idle_timeout",
|
||||
"max_concurrent_streams",
|
||||
"supported_versions",
|
||||
"max_data",
|
||||
"max_stream_data",
|
||||
"stateless_retry",
|
||||
"quantum_readiness_test",
|
||||
]
|
||||
|
||||
for attr in copyable_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(server_config, attr, value)
|
||||
|
||||
# Handle cryptography objects - these need direct reference, not copying
|
||||
crypto_attrs = [
|
||||
"certificate",
|
||||
"private_key",
|
||||
"certificate_chain",
|
||||
"ca_certs",
|
||||
]
|
||||
|
||||
for attr in crypto_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(server_config, attr, value)
|
||||
|
||||
# Apply security manager configuration if available
|
||||
if security_manager:
|
||||
try:
|
||||
server_tls_config = security_manager.create_server_config()
|
||||
|
||||
# Override with security manager's TLS configuration
|
||||
if server_tls_config.certificate:
|
||||
server_config.certificate = server_tls_config.certificate
|
||||
if server_tls_config.private_key:
|
||||
server_config.private_key = server_tls_config.private_key
|
||||
if server_tls_config.certificate_chain:
|
||||
server_config.certificate_chain = (
|
||||
server_tls_config.certificate_chain
|
||||
)
|
||||
if server_tls_config.alpn_protocols:
|
||||
server_config.alpn_protocols = server_tls_config.alpn_protocols
|
||||
server_tls_config.request_client_certificate = True
|
||||
if getattr(server_tls_config, "request_client_certificate", False):
|
||||
server_config._libp2p_request_client_cert = True # type: ignore
|
||||
else:
|
||||
logger.error(
|
||||
"🔧 Failed to set request_client_certificate in server config"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
|
||||
# Set transport-specific defaults if provided
|
||||
if transport_config:
|
||||
if server_config.idle_timeout == 0:
|
||||
server_config.idle_timeout = getattr(
|
||||
transport_config, "idle_timeout", 30.0
|
||||
)
|
||||
if server_config.max_datagram_frame_size is None:
|
||||
server_config.max_datagram_frame_size = getattr(
|
||||
transport_config, "max_datagram_size", 1200
|
||||
)
|
||||
# Ensure we have ALPN protocols
|
||||
if not server_config.alpn_protocols:
|
||||
server_config.alpn_protocols = ["libp2p"]
|
||||
|
||||
logger.debug("Successfully created server config without deepcopy")
|
||||
return server_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create server config: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_client_config_from_base(
|
||||
base_config: QuicConfiguration,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
transport_config: QUICTransportConfig | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""
|
||||
Create a client configuration without using deepcopy.
|
||||
"""
|
||||
try:
|
||||
# Create new client configuration from scratch
|
||||
client_config = QuicConfiguration(is_client=True)
|
||||
client_config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Copy basic configuration attributes
|
||||
copyable_attrs = [
|
||||
"alpn_protocols",
|
||||
"verify_mode",
|
||||
"max_datagram_frame_size",
|
||||
"idle_timeout",
|
||||
"max_concurrent_streams",
|
||||
"supported_versions",
|
||||
"max_data",
|
||||
"max_stream_data",
|
||||
"quantum_readiness_test",
|
||||
]
|
||||
|
||||
for attr in copyable_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Handle cryptography objects - these need direct reference, not copying
|
||||
crypto_attrs = [
|
||||
"certificate",
|
||||
"private_key",
|
||||
"certificate_chain",
|
||||
"ca_certs",
|
||||
]
|
||||
|
||||
for attr in crypto_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Apply security manager configuration if available
|
||||
if security_manager:
|
||||
try:
|
||||
client_tls_config = security_manager.create_client_config()
|
||||
|
||||
# Override with security manager's TLS configuration
|
||||
if client_tls_config.certificate:
|
||||
client_config.certificate = client_tls_config.certificate
|
||||
if client_tls_config.private_key:
|
||||
client_config.private_key = client_tls_config.private_key
|
||||
if client_tls_config.certificate_chain:
|
||||
client_config.certificate_chain = (
|
||||
client_tls_config.certificate_chain
|
||||
)
|
||||
if client_tls_config.alpn_protocols:
|
||||
client_config.alpn_protocols = client_tls_config.alpn_protocols
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
|
||||
# Ensure we have ALPN protocols
|
||||
if not client_config.alpn_protocols:
|
||||
client_config.alpn_protocols = ["libp2p"]
|
||||
|
||||
logger.debug("Successfully created client config without deepcopy")
|
||||
return client_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create client config: {e}")
|
||||
raise
|
||||
267
libp2p/transport/transport_registry.py
Normal file
267
libp2p/transport/transport_registry.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Transport registry for dynamic transport selection based on multiaddr protocols.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
from libp2p.abc import ITransport
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import (
|
||||
is_valid_websocket_multiaddr,
|
||||
)
|
||||
|
||||
|
||||
# Import QUIC utilities here to avoid circular imports
|
||||
def _get_quic_transport() -> Any:
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
return QUICTransport
|
||||
|
||||
|
||||
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
|
||||
return is_quic_multiaddr
|
||||
|
||||
|
||||
# Import WebsocketTransport here to avoid circular imports
|
||||
def _get_websocket_transport() -> Any:
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
return WebsocketTransport
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.registry")
|
||||
|
||||
|
||||
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid TCP structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid TCP structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
||||
# or /ip6/::1/tcp/8080
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
||||
if len(protocols) < 2:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Should not have any protocols after tcp (unless it's a valid
|
||||
# continuation like p2p)
|
||||
# For now, we'll be strict and only allow network + tcp
|
||||
if len(protocols) > 2:
|
||||
# Check if the additional protocols are valid continuations
|
||||
valid_continuations = ["p2p"] # Add more as needed
|
||||
for i in range(2, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class TransportRegistry:
|
||||
"""
|
||||
Registry for mapping multiaddr protocols to transport implementations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._transports: dict[str, type[ITransport]] = {}
|
||||
self._register_default_transports()
|
||||
|
||||
def _register_default_transports(self) -> None:
|
||||
"""Register the default transport implementations."""
|
||||
# Register TCP transport for /tcp protocol
|
||||
self.register_transport("tcp", TCP)
|
||||
|
||||
# Register WebSocket transport for /ws and /wss protocols
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
self.register_transport("ws", WebsocketTransport)
|
||||
self.register_transport("wss", WebsocketTransport)
|
||||
|
||||
# Register QUIC transport for /quic and /quic-v1 protocols
|
||||
QUICTransport = _get_quic_transport()
|
||||
self.register_transport("quic", QUICTransport)
|
||||
self.register_transport("quic-v1", QUICTransport)
|
||||
|
||||
def register_transport(
|
||||
self, protocol: str, transport_class: type[ITransport]
|
||||
) -> None:
|
||||
"""
|
||||
Register a transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier (e.g., "tcp", "ws")
|
||||
:param transport_class: The transport class to register
|
||||
"""
|
||||
self._transports[protocol] = transport_class
|
||||
logger.debug(
|
||||
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
||||
)
|
||||
|
||||
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
||||
"""
|
||||
Get the transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:return: The transport class or None if not found
|
||||
"""
|
||||
return self._transports.get(protocol)
|
||||
|
||||
def get_supported_protocols(self) -> list[str]:
|
||||
"""Get list of supported transport protocols."""
|
||||
return list(self._transports.keys())
|
||||
|
||||
def create_transport(
|
||||
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create a transport instance for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:param upgrader: The transport upgrader instance (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
:return: Transport instance or None if protocol not supported or creation fails
|
||||
"""
|
||||
transport_class = self.get_transport(protocol)
|
||||
if transport_class is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if protocol in ["ws", "wss"]:
|
||||
# WebSocket transport requires upgrader
|
||||
if upgrader is None:
|
||||
logger.warning(
|
||||
f"WebSocket transport '{protocol}' requires upgrader"
|
||||
)
|
||||
return None
|
||||
# Use explicit WebsocketTransport to avoid type issues
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
|
||||
)
|
||||
elif protocol in ["quic", "quic-v1"]:
|
||||
# QUIC transport requires private_key
|
||||
private_key = kwargs.get("private_key")
|
||||
if private_key is None:
|
||||
logger.warning(f"QUIC transport '{protocol}' requires private_key")
|
||||
return None
|
||||
# Use explicit QUICTransport to avoid type issues
|
||||
QUICTransport = _get_quic_transport()
|
||||
config = kwargs.get("config")
|
||||
return QUICTransport(private_key, config)
|
||||
else:
|
||||
# TCP transport doesn't require upgrader
|
||||
return transport_class()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global transport registry instance (lazy initialization)
|
||||
_global_registry: TransportRegistry | None = None
|
||||
|
||||
|
||||
def get_transport_registry() -> TransportRegistry:
|
||||
"""Get the global transport registry instance."""
|
||||
global _global_registry
|
||||
if _global_registry is None:
|
||||
_global_registry = TransportRegistry()
|
||||
return _global_registry
|
||||
|
||||
|
||||
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
||||
"""Register a transport class in the global registry."""
|
||||
registry = get_transport_registry()
|
||||
registry.register_transport(protocol, transport_class)
|
||||
|
||||
|
||||
def create_transport_for_multiaddr(
|
||||
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create the appropriate transport for a given multiaddr.
|
||||
|
||||
:param maddr: The multiaddr to create transport for
|
||||
:param upgrader: The transport upgrader instance
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
(e.g., private_key for QUIC)
|
||||
:return: Transport instance or None if no suitable transport found
|
||||
"""
|
||||
try:
|
||||
# Get all protocols in the multiaddr
|
||||
protocols = [proto.name for proto in maddr.protocols()]
|
||||
|
||||
# Check for supported transport protocols in order of preference
|
||||
# We need to validate that the multiaddr structure is valid for our transports
|
||||
if "quic" in protocols or "quic-v1" in protocols:
|
||||
# For QUIC, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/udp/4001/quic
|
||||
# /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
is_quic_multiaddr = _get_quic_validation()
|
||||
if is_quic_multiaddr(maddr):
|
||||
# Determine QUIC version
|
||||
registry = get_transport_registry()
|
||||
if "quic-v1" in protocols:
|
||||
return registry.create_transport("quic-v1", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("quic", upgrader, **kwargs)
|
||||
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
|
||||
# For WebSocket, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
if is_valid_websocket_multiaddr(maddr):
|
||||
# Determine if this is a secure WebSocket connection
|
||||
registry = get_transport_registry()
|
||||
if "wss" in protocols or "tls" in protocols:
|
||||
return registry.create_transport("wss", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("ws", upgrader, **kwargs)
|
||||
elif "tcp" in protocols:
|
||||
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
|
||||
# Check if the multiaddr has proper TCP structure
|
||||
if _is_valid_tcp_multiaddr(maddr):
|
||||
registry = get_transport_registry()
|
||||
return registry.create_transport("tcp", upgrader)
|
||||
|
||||
# If no supported transport protocol found or structure is invalid, return None
|
||||
logger.warning(
|
||||
f"No supported transport protocol found or invalid structure in "
|
||||
f"multiaddr: {maddr}"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors gracefully (e.g., invalid multiaddr)
|
||||
logger.warning(f"Error processing multiaddr {maddr}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_transport_protocols() -> list[str]:
|
||||
"""Get list of supported transport protocols from the global registry."""
|
||||
registry = get_transport_registry()
|
||||
return registry.get_supported_protocols()
|
||||
@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
)
|
||||
from libp2p.security.exceptions import (
|
||||
HandshakeFailure,
|
||||
)
|
||||
@ -37,9 +40,12 @@ class TransportUpgrader:
|
||||
self,
|
||||
secure_transports_by_protocol: TSecurityOptions,
|
||||
muxer_transports_by_protocol: TMuxerOptions,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
):
|
||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(
|
||||
muxer_transports_by_protocol, negotiate_timeout
|
||||
)
|
||||
|
||||
async def upgrade_security(
|
||||
self,
|
||||
|
||||
198
libp2p/transport/websocket/connection.py
Normal file
198
libp2p/transport/websocket/connection.py
Normal file
@ -0,0 +1,198 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class P2PWebSocketConnection(ReadWriteCloser):
|
||||
"""
|
||||
Wraps a WebSocketConnection to provide the raw stream interface
|
||||
that libp2p protocols expect.
|
||||
|
||||
Implements production-ready buffer management and flow control
|
||||
as recommended in the libp2p WebSocket specification.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ws_connection: Any,
|
||||
ws_context: Any = None,
|
||||
is_secure: bool = False,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
) -> None:
|
||||
self._ws_connection = ws_connection
|
||||
self._ws_context = ws_context
|
||||
self._is_secure = is_secure
|
||||
self._read_buffer = b""
|
||||
self._read_lock = trio.Lock()
|
||||
self._connection_start_time = time.time()
|
||||
self._bytes_read = 0
|
||||
self._bytes_written = 0
|
||||
self._closed = False
|
||||
self._close_lock = trio.Lock()
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._write_lock = trio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write data with flow control and buffer management"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._write_lock:
|
||||
try:
|
||||
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||
|
||||
# Check buffer amount for flow control
|
||||
if hasattr(self._ws_connection, "bufferedAmount"):
|
||||
buffered = self._ws_connection.bufferedAmount
|
||||
if buffered > self._max_buffered_amount:
|
||||
logger.warning(f"WebSocket buffer full: {buffered} bytes")
|
||||
# In production, you might want to
|
||||
# wait or implement backpressure
|
||||
# For now, we'll continue but log the warning
|
||||
|
||||
# Send as a binary WebSocket message
|
||||
await self._ws_connection.send_message(data)
|
||||
self._bytes_written += len(data)
|
||||
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket write failed: {e}")
|
||||
self._closed = True
|
||||
raise IOException from e
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||
This implementation provides byte-level access to WebSocket messages,
|
||||
which is required for libp2p protocol compatibility.
|
||||
|
||||
For WebSocket compatibility with libp2p protocols, this method:
|
||||
1. Buffers incoming WebSocket messages
|
||||
2. Returns exactly the requested number of bytes when n is specified
|
||||
3. Accumulates multiple WebSocket messages if needed to satisfy the request
|
||||
4. Returns empty bytes (not raises) when connection is closed and no data
|
||||
available
|
||||
"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._read_lock:
|
||||
try:
|
||||
# If n is None, read at least one message and return all buffered data
|
||||
if n is None:
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
# No message available within timeout
|
||||
return b""
|
||||
except Exception:
|
||||
# Return empty bytes if no data available
|
||||
# (connection closed)
|
||||
return b""
|
||||
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
# For specific byte count requests, return UP TO n bytes (not exactly n)
|
||||
# This matches TCP semantics where read(1024) returns available data
|
||||
# up to 1024 bytes
|
||||
|
||||
# If we don't have any data buffered, try to get at least one message
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
return b"" # No data available
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
# Now return up to n bytes from the buffer (TCP-like semantics)
|
||||
if len(self._read_buffer) == 0:
|
||||
return b""
|
||||
|
||||
# Return up to n bytes (like TCP read())
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[len(result) :]
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket read failed: {e}")
|
||||
raise IOException from e
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection. This method is idempotent."""
|
||||
async with self._close_lock:
|
||||
if self._closed:
|
||||
return # Already closed
|
||||
|
||||
logger.debug("WebSocket connection closing")
|
||||
self._closed = True
|
||||
try:
|
||||
# Always close the connection directly, avoid context manager issues
|
||||
# The context manager may be causing cancel scope corruption
|
||||
logger.debug("WebSocket closing connection directly")
|
||||
await self._ws_connection.aclose()
|
||||
# Exit the context manager if we have one
|
||||
if self._ws_context is not None:
|
||||
await self._ws_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket close error: {e}")
|
||||
# Don't raise here, as close() should be idempotent
|
||||
finally:
|
||||
logger.debug("WebSocket connection closed")
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the connection is closed"""
|
||||
return self._closed
|
||||
|
||||
def conn_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
Return connection state information similar to Go's ConnState() method.
|
||||
|
||||
:return: Dictionary containing connection state information
|
||||
"""
|
||||
current_time = time.time()
|
||||
return {
|
||||
"transport": "websocket",
|
||||
"secure": self._is_secure,
|
||||
"connection_duration": current_time - self._connection_start_time,
|
||||
"bytes_read": self._bytes_read,
|
||||
"bytes_written": self._bytes_written,
|
||||
"total_bytes": self._bytes_read + self._bytes_written,
|
||||
}
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
# Try to get remote address from the WebSocket connection
|
||||
try:
|
||||
remote = self._ws_connection.remote
|
||||
if hasattr(remote, "address") and hasattr(remote, "port"):
|
||||
return str(remote.address), int(remote.port)
|
||||
elif isinstance(remote, str):
|
||||
# Parse address:port format
|
||||
if ":" in remote:
|
||||
host, port = remote.rsplit(":", 1)
|
||||
return host, int(port)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
225
libp2p/transport/websocket/listener.py
Normal file
225
libp2p/transport/websocket/listener.py
Normal file
@ -0,0 +1,225 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
from trio_typing import TaskStatus
|
||||
from trio_websocket import serve_websocket
|
||||
|
||||
from libp2p.abc import IListener
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.websocket.listener")
|
||||
|
||||
|
||||
class WebsocketListener(IListener):
|
||||
"""
|
||||
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: THandler,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
) -> None:
|
||||
self._handler = handler
|
||||
self._upgrader = upgrader
|
||||
self._tls_config = tls_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._server = None
|
||||
self._shutdown_event = trio.Event()
|
||||
self._nursery: trio.Nursery | None = None
|
||||
self._listeners: Any = None
|
||||
self._is_wss = False # Track whether this is a WSS listener
|
||||
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Check if WSS is requested but no TLS config provided
|
||||
if parsed.is_wss and self._tls_config is None:
|
||||
raise ValueError(
|
||||
f"Cannot listen on WSS address {maddr} without TLS configuration"
|
||||
)
|
||||
|
||||
# Store whether this is a WSS listener
|
||||
self._is_wss = parsed.is_wss
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
or "0.0.0.0"
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}"
|
||||
)
|
||||
|
||||
async def serve_websocket_tcp(
|
||||
handler: Callable[[Any], Awaitable[None]],
|
||||
port: int,
|
||||
host: str,
|
||||
task_status: TaskStatus[Any],
|
||||
) -> None:
|
||||
"""Start TCP server and handle WebSocket connections manually"""
|
||||
logger.debug(
|
||||
"serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss
|
||||
)
|
||||
|
||||
async def websocket_handler(request: Any) -> None:
|
||||
"""Handle WebSocket requests"""
|
||||
logger.debug("WebSocket request received")
|
||||
try:
|
||||
# Apply handshake timeout
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
# Accept the WebSocket connection
|
||||
ws_connection = await request.accept()
|
||||
logger.debug("WebSocket handshake successful")
|
||||
|
||||
# Create the WebSocket connection wrapper
|
||||
conn = P2PWebSocketConnection(
|
||||
ws_connection, is_secure=parsed.is_wss
|
||||
) # type: ignore[no-untyped-call]
|
||||
|
||||
# Call the handler function that was passed to create_listener
|
||||
# This handler will handle the security and muxing upgrades
|
||||
logger.debug("Calling connection handler")
|
||||
await self._handler(conn)
|
||||
|
||||
# Don't keep the connection alive indefinitely
|
||||
# Let the handler manage the connection lifecycle
|
||||
logger.debug(
|
||||
"Handler completed, connection will be managed by handler"
|
||||
)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s"
|
||||
)
|
||||
try:
|
||||
await request.reject(408) # Request Timeout
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket connection error: {e}")
|
||||
logger.debug(f"Error type: {type(e)}")
|
||||
import traceback
|
||||
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
# Reject the connection
|
||||
try:
|
||||
await request.reject(400)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
||||
ssl_context = self._tls_config if parsed.is_wss else None
|
||||
await serve_websocket(
|
||||
websocket_handler, host, port, ssl_context, task_status=task_status
|
||||
)
|
||||
|
||||
# Store the nursery for shutdown
|
||||
self._nursery = nursery
|
||||
|
||||
# Start the server using nursery.start() like TCP does
|
||||
logger.debug("Calling nursery.start()...")
|
||||
started_listeners = await nursery.start(
|
||||
serve_websocket_tcp,
|
||||
None, # No handler needed since it's defined inside serve_websocket_tcp
|
||||
port,
|
||||
host,
|
||||
)
|
||||
logger.debug(f"nursery.start() returned: {started_listeners}")
|
||||
|
||||
if started_listeners is None:
|
||||
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
||||
return False
|
||||
|
||||
# Store the listeners for get_addrs() and close() - these are real
|
||||
# SocketListener objects
|
||||
self._listeners = started_listeners
|
||||
logger.debug(
|
||||
"WebsocketListener.listen returning True with WebSocketServer object"
|
||||
)
|
||||
return True
|
||||
|
||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||
if not hasattr(self, "_listeners") or not self._listeners:
|
||||
logger.debug("No listeners available for get_addrs()")
|
||||
return ()
|
||||
|
||||
# Handle WebSocketServer objects
|
||||
if hasattr(self._listeners, "port"):
|
||||
# This is a WebSocketServer object
|
||||
port = self._listeners.port
|
||||
# Create a multiaddr from the port with correct WSS/WS protocol
|
||||
protocol = "wss" if self._is_wss else "ws"
|
||||
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),)
|
||||
else:
|
||||
# This is a list of listeners (like TCP)
|
||||
listeners = self._listeners
|
||||
# Get addresses from listeners like TCP does
|
||||
return tuple(
|
||||
_multiaddr_from_socket(listener.socket, self._is_wss)
|
||||
for listener in listeners
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket listener and stop accepting new connections"""
|
||||
logger.debug("WebsocketListener.close called")
|
||||
if hasattr(self, "_listeners") and self._listeners:
|
||||
# Signal shutdown
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Close the WebSocket server
|
||||
if hasattr(self._listeners, "aclose"):
|
||||
# This is a WebSocketServer object
|
||||
logger.debug("Closing WebSocket server")
|
||||
await self._listeners.aclose()
|
||||
logger.debug("WebSocket server closed")
|
||||
elif isinstance(self._listeners, (list, tuple)):
|
||||
# This is a list of listeners (like TCP)
|
||||
logger.debug("Closing TCP listeners")
|
||||
for listener in self._listeners:
|
||||
await listener.aclose()
|
||||
logger.debug("TCP listeners closed")
|
||||
else:
|
||||
# Unknown type, try to close it directly
|
||||
logger.debug("Closing unknown listener type")
|
||||
if hasattr(self._listeners, "close"):
|
||||
self._listeners.close()
|
||||
logger.debug("Unknown listener closed")
|
||||
|
||||
# Clear the listeners reference
|
||||
self._listeners = None
|
||||
logger.debug("WebsocketListener.close completed")
|
||||
|
||||
|
||||
def _multiaddr_from_socket(
|
||||
socket: trio.socket.SocketType, is_wss: bool = False
|
||||
) -> Multiaddr:
|
||||
"""Convert socket to multiaddr"""
|
||||
ip, port = socket.getsockname()
|
||||
protocol = "wss" if is_wss else "ws"
|
||||
return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}")
|
||||
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
@ -0,0 +1,202 @@
|
||||
"""
|
||||
WebSocket multiaddr parsing utilities.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
|
||||
class ParsedWebSocketMultiaddr(NamedTuple):
|
||||
"""Parsed WebSocket multiaddr information."""
|
||||
|
||||
is_wss: bool
|
||||
sni: str | None
|
||||
rest_multiaddr: Multiaddr
|
||||
|
||||
|
||||
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
|
||||
"""
|
||||
Parse a WebSocket multiaddr and extract security information.
|
||||
|
||||
:param maddr: The multiaddr to parse
|
||||
:return: Parsed WebSocket multiaddr information
|
||||
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
|
||||
"""
|
||||
# First validate that this is a valid WebSocket multiaddr
|
||||
if not is_valid_websocket_multiaddr(maddr):
|
||||
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
|
||||
|
||||
protocols = list(maddr.protocols())
|
||||
|
||||
# Find the WebSocket protocol and check for security
|
||||
is_wss = False
|
||||
sni = None
|
||||
ws_index = -1
|
||||
tls_index = -1
|
||||
sni_index = -1
|
||||
|
||||
# Find protocol indices
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name == "ws":
|
||||
ws_index = i
|
||||
elif protocol.name == "wss":
|
||||
ws_index = i
|
||||
is_wss = True
|
||||
elif protocol.name == "tls":
|
||||
tls_index = i
|
||||
elif protocol.name == "sni":
|
||||
sni_index = i
|
||||
sni = protocol.value
|
||||
|
||||
if ws_index == -1:
|
||||
raise ValueError("Not a WebSocket multiaddr")
|
||||
|
||||
# Handle /wss protocol (convert to /tls/ws internally)
|
||||
if is_wss and tls_index == -1:
|
||||
# Convert /wss to /tls/ws format
|
||||
# Remove /wss to get the base multiaddr
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=True, sni=None, rest_multiaddr=without_wss
|
||||
)
|
||||
|
||||
# Handle /tls/ws and /tls/sni/.../ws formats
|
||||
if tls_index != -1:
|
||||
is_wss = True
|
||||
# Extract the base multiaddr (everything before /tls)
|
||||
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
|
||||
# Use multiaddr methods to properly extract the base
|
||||
rest_multiaddr = maddr
|
||||
# Remove /tls/ws or /tls/sni/.../ws from the end
|
||||
if sni_index != -1:
|
||||
# /tls/sni/example.com/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
else:
|
||||
# /tls/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
# Regular /ws multiaddr - remove /ws and any additional protocols
|
||||
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
|
||||
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid WebSocket structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid WebSocket structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# WebSocket multiaddr should have structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
|
||||
if len(protocols) < 3:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Check for valid WebSocket protocols
|
||||
ws_protocols = ["ws", "wss"]
|
||||
tls_protocols = ["tls"]
|
||||
sni_protocols = ["sni"]
|
||||
|
||||
# Find the WebSocket protocol
|
||||
ws_protocol_found = False
|
||||
tls_found = False
|
||||
# sni_found = False # Not used currently
|
||||
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name in ws_protocols:
|
||||
ws_protocol_found = True
|
||||
break
|
||||
elif protocol.name in tls_protocols:
|
||||
tls_found = True
|
||||
elif protocol.name in sni_protocols:
|
||||
pass # sni_found = True # Not used in current implementation
|
||||
|
||||
if not ws_protocol_found:
|
||||
return False
|
||||
|
||||
# Validate protocol sequence
|
||||
# For /ws: network + tcp + ws
|
||||
# For /wss: network + tcp + wss
|
||||
# For /tls/ws: network + tcp + tls + ws
|
||||
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
|
||||
|
||||
# Check if it's a simple /ws or /wss
|
||||
if len(protocols) == 3:
|
||||
return protocols[2].name in ["ws", "wss"]
|
||||
|
||||
# Check for /tls/ws or /tls/sni/.../ws patterns
|
||||
if tls_found:
|
||||
# Must end with /ws (not /wss when using /tls)
|
||||
if protocols[-1].name != "ws":
|
||||
return False
|
||||
|
||||
# Check for valid TLS sequence
|
||||
tls_index = None
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name == "tls":
|
||||
tls_index = i
|
||||
break
|
||||
|
||||
if tls_index is None:
|
||||
return False
|
||||
|
||||
# After tls, we can have sni, then ws
|
||||
remaining_protocols = protocols[tls_index + 1 :]
|
||||
if len(remaining_protocols) == 1:
|
||||
# /tls/ws
|
||||
return remaining_protocols[0].name == "ws"
|
||||
elif len(remaining_protocols) == 2:
|
||||
# /tls/sni/example.com/ws
|
||||
return (
|
||||
remaining_protocols[0].name == "sni"
|
||||
and remaining_protocols[1].name == "ws"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
# If we have more than 3 protocols but no TLS, check for valid continuations
|
||||
# Allow additional protocols after the WebSocket protocol (like /p2p)
|
||||
valid_continuations = ["p2p"]
|
||||
|
||||
# Find the WebSocket protocol index
|
||||
ws_index = None
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name in ["ws", "wss"]:
|
||||
ws_index = i
|
||||
break
|
||||
|
||||
if ws_index is not None:
|
||||
# Check protocols after the WebSocket protocol
|
||||
for i in range(ws_index + 1, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
229
libp2p/transport/websocket/transport.py
Normal file
229
libp2p/transport/websocket/transport.py
Normal file
@ -0,0 +1,229 @@
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.abc import IListener, ITransport
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
from .listener import WebsocketListener
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketTransport(ITransport):
|
||||
"""
|
||||
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
||||
|
||||
Implements production-ready WebSocket transport with:
|
||||
- Flow control and buffer management
|
||||
- Connection limits and rate limiting
|
||||
- Proper error handling and cleanup
|
||||
- Support for both WS and WSS protocols
|
||||
- TLS configuration and handshake timeout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
):
|
||||
self._upgrader = upgrader
|
||||
self._tls_client_config = tls_client_config
|
||||
self._tls_server_config = tls_server_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._connection_count = 0
|
||||
self._max_connections = 1000 # Production limit
|
||||
|
||||
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
||||
"""Dial a WebSocket connection to the given multiaddr."""
|
||||
logger.debug(f"WebsocketTransport.dial called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
# Build WebSocket URL based on security
|
||||
if parsed.is_wss:
|
||||
ws_url = f"wss://{host}:{port}/"
|
||||
else:
|
||||
ws_url = f"ws://{host}:{port}/"
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check connection limits
|
||||
if self._connection_count >= self._max_connections:
|
||||
raise OpenConnectionError(
|
||||
f"Maximum connections reached: {self._max_connections}"
|
||||
)
|
||||
|
||||
# Prepare SSL context for WSS connections
|
||||
ssl_context = None
|
||||
if parsed.is_wss:
|
||||
if self._tls_client_config:
|
||||
ssl_context = self._tls_client_config
|
||||
else:
|
||||
# Create default SSL context for client
|
||||
ssl_context = ssl.create_default_context()
|
||||
# Set SNI if available
|
||||
if parsed.sni:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
|
||||
|
||||
# Use a different approach: start background nursery that will persist
|
||||
logger.debug("WebsocketTransport.dial establishing connection")
|
||||
|
||||
# Import trio-websocket functions
|
||||
from trio_websocket import connect_websocket
|
||||
from trio_websocket._impl import _url_to_host
|
||||
|
||||
# Parse the WebSocket URL to get host, port, resource
|
||||
# like trio-websocket does
|
||||
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
|
||||
ws_url, ssl_context
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
|
||||
f"port={ws_port}, resource={ws_resource}"
|
||||
)
|
||||
|
||||
# Create a background task manager for this connection
|
||||
import trio
|
||||
|
||||
nursery_manager = trio.lowlevel.current_task().parent_nursery
|
||||
if nursery_manager is None:
|
||||
raise OpenConnectionError(
|
||||
f"No parent nursery available for WebSocket connection to {maddr}"
|
||||
)
|
||||
|
||||
# Apply timeout to the connection process
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
logger.debug("WebsocketTransport.dial connecting WebSocket")
|
||||
ws = await connect_websocket(
|
||||
nursery_manager, # Use the existing nursery from libp2p
|
||||
ws_host,
|
||||
ws_port,
|
||||
ws_resource,
|
||||
use_ssl=ws_ssl_context,
|
||||
message_queue_size=1024, # Reasonable defaults
|
||||
max_message_size=16 * 1024 * 1024, # 16MB max message
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
||||
|
||||
# Create our connection wrapper with both WSS support and flow control
|
||||
conn = P2PWebSocketConnection(
|
||||
ws,
|
||||
None,
|
||||
is_secure=parsed.is_wss,
|
||||
max_buffered_amount=self._max_buffered_amount,
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
||||
|
||||
self._connection_count += 1
|
||||
logger.debug(f"Total connections: {self._connection_count}")
|
||||
|
||||
return RawConnection(conn, initiator=True)
|
||||
except trio.TooSlowError as e:
|
||||
raise OpenConnectionError(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s "
|
||||
f"for {maddr}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
|
||||
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
||||
|
||||
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
||||
"""
|
||||
The type checker is incorrectly reporting this as an inconsistent override.
|
||||
"""
|
||||
logger.debug("WebsocketTransport.create_listener called")
|
||||
return WebsocketListener(
|
||||
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
|
||||
)
|
||||
|
||||
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
|
||||
"""
|
||||
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
|
||||
Similar to Go's Resolve() method.
|
||||
|
||||
:param maddr: The multiaddr to resolve
|
||||
:return: List of resolved multiaddrs
|
||||
"""
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
|
||||
return [maddr] # Return original if not a valid WebSocket multiaddr
|
||||
|
||||
logger.debug(
|
||||
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
|
||||
)
|
||||
|
||||
if not parsed.is_wss:
|
||||
# No /tls/ws component, this isn't a secure websocket multiaddr
|
||||
return [maddr]
|
||||
|
||||
if parsed.sni is not None:
|
||||
# Already has SNI, return as-is
|
||||
return [maddr]
|
||||
|
||||
# Try to extract DNS name from the base multiaddr
|
||||
dns_name = None
|
||||
for protocol_name in ["dns", "dns4", "dns6"]:
|
||||
try:
|
||||
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if dns_name is None:
|
||||
# No DNS name found, return original
|
||||
return [maddr]
|
||||
|
||||
# Create new multiaddr with SNI
|
||||
# For /dns/example.com/tcp/8080/wss ->
|
||||
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
|
||||
try:
|
||||
# Remove /wss and add /tls/sni/example.com/ws
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
sni_component = Multiaddr(f"/sni/{dns_name}")
|
||||
resolved = (
|
||||
without_wss.encapsulate(Multiaddr("/tls"))
|
||||
.encapsulate(sni_component)
|
||||
.encapsulate(Multiaddr("/ws"))
|
||||
)
|
||||
logger.debug(f"Resolved {maddr} to {resolved}")
|
||||
return [resolved]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
|
||||
return [maddr]
|
||||
@ -3,38 +3,24 @@ from __future__ import annotations
|
||||
import socket
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
try:
|
||||
from multiaddr.utils import ( # type: ignore
|
||||
get_network_addrs,
|
||||
get_thin_waist_addresses,
|
||||
)
|
||||
|
||||
_HAS_THIN_WAIST = True
|
||||
except ImportError: # pragma: no cover - only executed in older environments
|
||||
_HAS_THIN_WAIST = False
|
||||
get_thin_waist_addresses = None # type: ignore
|
||||
get_network_addrs = None # type: ignore
|
||||
from multiaddr.utils import get_network_addrs, get_thin_waist_addresses
|
||||
|
||||
|
||||
def _safe_get_network_addrs(ip_version: int) -> list[str]:
|
||||
"""
|
||||
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
|
||||
Falls back to minimal defaults when Thin Waist helpers are missing.
|
||||
|
||||
:param ip_version: 4 or 6
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_network_addrs:
|
||||
try:
|
||||
return get_network_addrs(ip_version) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return []
|
||||
# Fallback behavior (very conservative)
|
||||
if ip_version == 4:
|
||||
return ["127.0.0.1"]
|
||||
if ip_version == 6:
|
||||
return ["::1"]
|
||||
return []
|
||||
try:
|
||||
return get_network_addrs(ip_version) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
# Fallback behavior (very conservative)
|
||||
if ip_version == 4:
|
||||
return ["127.0.0.1"]
|
||||
if ip_version == 6:
|
||||
return ["::1"]
|
||||
return []
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
@ -47,16 +33,13 @@ def find_free_port() -> int:
|
||||
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
|
||||
"""
|
||||
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
|
||||
If Thin Waist isn't available, returns [addr] (identity).
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_thin_waist_addresses:
|
||||
try:
|
||||
if port is not None:
|
||||
return get_thin_waist_addresses(addr, port=port) or []
|
||||
return get_thin_waist_addresses(addr) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return [addr]
|
||||
return [addr]
|
||||
try:
|
||||
if port is not None:
|
||||
return get_thin_waist_addresses(addr, port=port) or []
|
||||
return get_thin_waist_addresses(addr) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return [addr]
|
||||
|
||||
|
||||
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
|
||||
@ -73,8 +56,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
||||
seen_v4: set[str] = set()
|
||||
|
||||
for ip in _safe_get_network_addrs(4):
|
||||
seen_v4.add(ip)
|
||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||
if ip not in seen_v4: # Avoid duplicates
|
||||
seen_v4.add(ip)
|
||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||
|
||||
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
|
||||
if seen_v4 and "127.0.0.1" not in seen_v4:
|
||||
@ -89,8 +73,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
||||
#
|
||||
# seen_v6: set[str] = set()
|
||||
# for ip in _safe_get_network_addrs(6):
|
||||
# seen_v6.add(ip)
|
||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||
# if ip not in seen_v6: # Avoid duplicates
|
||||
# seen_v6.add(ip)
|
||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||
#
|
||||
# # Always include IPv6 loopback for testing purposes when IPv6 is available
|
||||
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
|
||||
@ -99,7 +84,7 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
||||
|
||||
# Fallback if nothing discovered
|
||||
if not addrs:
|
||||
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
|
||||
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
|
||||
|
||||
return addrs
|
||||
|
||||
@ -120,6 +105,20 @@ def expand_wildcard_address(
|
||||
return expanded
|
||||
|
||||
|
||||
def get_wildcard_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
"""
|
||||
Get wildcard address (0.0.0.0) when explicitly needed.
|
||||
|
||||
This function provides access to wildcard binding as a feature when
|
||||
explicitly required, preserving the ability to bind to all interfaces.
|
||||
|
||||
:param port: Port number.
|
||||
:param protocol: Transport protocol.
|
||||
:return: A Multiaddr with wildcard binding (0.0.0.0).
|
||||
"""
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
|
||||
|
||||
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
"""
|
||||
Choose an optimal address for an example to bind to:
|
||||
@ -148,13 +147,14 @@ def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
|
||||
return c
|
||||
|
||||
# As a final fallback, produce a wildcard
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
# As a final fallback, produce a loopback address
|
||||
return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"get_wildcard_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user