Merge upstream/main into add-ws-transport

Resolved conflicts in:
- .gitignore: Combined JavaScript interop and Sphinx build ignores
- libp2p/__init__.py: Integrated QUIC transport support with WebSocket transport
- libp2p/network/swarm.py: Used upstream's improved listener handling
- pyproject.toml: Kept both WebSocket and QUIC dependencies

This merge brings in:
- QUIC transport implementation
- Enhanced swarm functionality
- Improved peer discovery
- Better error handling
- Updated dependencies and documentation

WebSocket transport implementation remains intact and functional.
This commit is contained in:
acul71
2025-09-07 23:47:41 +02:00
105 changed files with 13904 additions and 730 deletions

View File

@ -1,3 +1,11 @@
"""Libp2p Python implementation."""
import logging
from libp2p.transport.quic.utils import is_quic_multiaddr
from typing import Any
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.quic.config import QUICTransportConfig
from collections.abc import (
Mapping,
Sequence,
@ -6,15 +14,12 @@ from importlib.metadata import version as __version
from typing import (
Literal,
Optional,
Type,
cast,
)
import multiaddr
from libp2p.abc import (
IHost,
IMuxedConn,
INetworkService,
IPeerRouting,
IPeerStore,
@ -33,9 +38,6 @@ from libp2p.custom_types import (
TProtocol,
TSecurityOptions,
)
from libp2p.discovery.mdns.mdns import (
MDNSDiscovery,
)
from libp2p.host.basic_host import (
BasicHost,
)
@ -45,27 +47,34 @@ from libp2p.host.routed_host import (
from libp2p.network.swarm import (
Swarm,
)
from libp2p.network.config import (
ConnectionConfig,
RetryConfig
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import (
PeerStore,
create_signed_peer_record,
)
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
InsecureTransport,
)
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
import libp2p.security.secio.transport as secio
from libp2p.stream_muxer.mplex.mplex import (
MPLEX_PROTOCOL_ID,
Mplex,
)
from libp2p.stream_muxer.yamux.yamux import (
PROTOCOL_ID as YAMUX_PROTOCOL_ID,
Yamux,
)
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
from libp2p.transport.tcp.tcp import (
TCP,
)
@ -91,7 +100,7 @@ MUXER_YAMUX = "YAMUX"
MUXER_MPLEX = "MPLEX"
DEFAULT_NEGOTIATE_TIMEOUT = 5
logger = logging.getLogger(__name__)
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
"""
@ -160,7 +169,6 @@ def get_default_muxer_options() -> TMuxerOptions:
else: # YAMUX is default
return create_yamux_muxer_option()
def new_swarm(
key_pair: KeyPair | None = None,
muxer_opt: TMuxerOptions | None = None,
@ -168,6 +176,9 @@ def new_swarm(
peerstore_opt: IPeerStore | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
enable_quic: bool = False,
retry_config: Optional["RetryConfig"] = None,
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
) -> INetworkService:
"""
Create a swarm instance based on the parameters.
@ -178,6 +189,8 @@ def new_swarm(
:param peerstore_opt: optional peerstore
:param muxer_preference: optional explicit muxer preference
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_quic: enable quic for transport
:param quic_transport_opt: options for transport
:return: return a default swarm instance
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
@ -190,8 +203,6 @@ def new_swarm(
id_opt = generate_peer_id_from(key_pair)
# Generate X25519 keypair for Noise
noise_key_pair = create_new_x25519_key_pair()
@ -255,36 +266,18 @@ def new_swarm(
else:
transport = transport_maybe
# Use given muxer preference if provided, otherwise use global default
if muxer_preference is not None:
temp_pref = muxer_preference.upper()
if temp_pref not in [MUXER_YAMUX, MUXER_MPLEX]:
raise ValueError(
f"Unknown muxer: {muxer_preference}. Use 'YAMUX' or 'MPLEX'."
)
active_preference = temp_pref
else:
active_preference = DEFAULT_MUXER
# Use provided muxer options if given, otherwise create based on preference
if muxer_opt is not None:
muxer_transports_by_protocol = muxer_opt
else:
if active_preference == MUXER_MPLEX:
muxer_transports_by_protocol = create_mplex_muxer_option()
else: # YAMUX is default
muxer_transports_by_protocol = create_yamux_muxer_option()
upgrader = TransportUpgrader(
secure_transports_by_protocol=secure_transports_by_protocol,
muxer_transports_by_protocol=muxer_transports_by_protocol,
)
peerstore = peerstore_opt or PeerStore()
# Store our key pair in peerstore
peerstore.add_key_pair(id_opt, key_pair)
return Swarm(id_opt, peerstore, upgrader, transport)
return Swarm(
id_opt,
peerstore,
upgrader,
transport,
retry_config=retry_config,
connection_config=connection_config
)
def new_host(
@ -298,6 +291,8 @@ def new_host(
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
enable_quic: bool = False,
quic_transport_opt: QUICTransportConfig | None = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.
@ -311,19 +306,33 @@ def new_host(
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_mDNS: whether to enable mDNS discovery
:param bootstrap: optional list of bootstrap peer addresses as strings
:param enable_quic: optinal choice to use QUIC for transport
:param transport_opt: optional configuration for quic transport
:return: return a host instance
"""
if not enable_quic and quic_transport_opt is not None:
logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config")
swarm = new_swarm(
enable_quic=enable_quic,
key_pair=key_pair,
muxer_opt=muxer_opt,
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
muxer_preference=muxer_preference,
listen_addrs=listen_addrs,
connection_config=quic_transport_opt if enable_quic else None
)
if disc_opt is not None:
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout)
return BasicHost(
network=swarm,
enable_mDNS=enable_mDNS,
bootstrap=bootstrap,
negotitate_timeout=negotiate_timeout
)
__version__ = __version("libp2p")

View File

@ -970,6 +970,14 @@ class IPeerStore(
# --------CERTIFIED-ADDR-BOOK----------
@abstractmethod
def get_local_record(self) -> Optional["Envelope"]:
"""Get the local-peer-record wrapped in Envelope"""
@abstractmethod
def set_local_record(self, envelope: "Envelope") -> None:
"""Set the local-peer-record wrapped in Envelope"""
@abstractmethod
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
"""
@ -1404,15 +1412,16 @@ class INetwork(ABC):
----------
peerstore : IPeerStore
The peer store for managing peer information.
connections : dict[ID, INetConn]
A mapping of peer IDs to network connections.
connections : dict[ID, list[INetConn]]
A mapping of peer IDs to lists of network connections
(multiple connections per peer).
listeners : dict[str, IListener]
A mapping of listener identifiers to listener instances.
"""
peerstore: IPeerStore
connections: dict[ID, INetConn]
connections: dict[ID, list[INetConn]]
listeners: dict[str, IListener]
@abstractmethod
@ -1428,9 +1437,56 @@ class INetwork(ABC):
"""
@abstractmethod
async def dial_peer(self, peer_id: ID) -> INetConn:
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
"""
Create a connection to the specified peer.
Get connections for peer (like JS getConnections, Go ConnsToPeer).
Parameters
----------
peer_id : ID | None
The peer ID to get connections for. If None, returns all connections.
Returns
-------
list[INetConn]
List of connections to the specified peer, or all connections
if peer_id is None.
"""
@abstractmethod
def get_connections_map(self) -> dict[ID, list[INetConn]]:
"""
Get all connections map (like JS getConnectionsMap).
Returns
-------
dict[ID, list[INetConn]]
The complete mapping of peer IDs to their connection lists.
"""
@abstractmethod
def get_connection(self, peer_id: ID) -> INetConn | None:
"""
Get single connection for backward compatibility.
Parameters
----------
peer_id : ID
The peer ID to get a connection for.
Returns
-------
INetConn | None
The first available connection, or None if no connections exist.
"""
@abstractmethod
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
"""
Create connections to the specified peer with load balancing.
Parameters
----------
@ -1439,8 +1495,8 @@ class INetwork(ABC):
Returns
-------
INetConn
The network connection instance to the specified peer.
list[INetConn]
List of established connections to the peer.
Raises
------

View File

@ -5,17 +5,17 @@ from collections.abc import (
)
from typing import TYPE_CHECKING, NewType, Union, cast
from libp2p.transport.quic.stream import QUICStream
if TYPE_CHECKING:
from libp2p.abc import (
IMuxedConn,
INetStream,
ISecureTransport,
)
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
from libp2p.transport.quic.connection import QUICConnection
else:
IMuxedConn = cast(type, object)
INetStream = cast(type, object)
ISecureTransport = cast(type, object)
IMuxedStream = cast(type, object)
QUICConnection = cast(type, object)
from libp2p.io.abc import (
ReadWriteCloser,
@ -37,3 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]]
MessageID = NewType("MessageID", str)

View File

@ -2,15 +2,20 @@ import logging
from multiaddr import Multiaddr
from multiaddr.resolvers import DNSResolver
import trio
from libp2p.abc import ID, INetworkService, PeerInfo
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
from libp2p.discovery.events.peerDiscovery import peerDiscovery
from libp2p.network.exceptions import SwarmException
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.peer.peerstore import PERMANENT_ADDR_TTL
logger = logging.getLogger("libp2p.discovery.bootstrap")
resolver = DNSResolver()
DEFAULT_CONNECTION_TIMEOUT = 10
class BootstrapDiscovery:
"""
@ -19,68 +24,147 @@ class BootstrapDiscovery:
"""
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
"""
Initialize BootstrapDiscovery.
Args:
swarm: The network service (swarm) instance
bootstrap_addrs: List of bootstrap peer multiaddresses
"""
self.swarm = swarm
self.peerstore = swarm.peerstore
self.bootstrap_addrs = bootstrap_addrs or []
self.discovered_peers: set[str] = set()
self.connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT
async def start(self) -> None:
"""Process bootstrap addresses and emit peer discovery events."""
logger.debug(
"""Process bootstrap addresses and emit peer discovery events in parallel."""
logger.info(
f"Starting bootstrap discovery with "
f"{len(self.bootstrap_addrs)} bootstrap addresses"
)
# Show all bootstrap addresses being processed
for i, addr in enumerate(self.bootstrap_addrs):
logger.debug(f"{i + 1}. {addr}")
# Validate and filter bootstrap addresses
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
logger.info(f"Valid addresses after validation: {len(self.bootstrap_addrs)}")
for addr_str in self.bootstrap_addrs:
try:
await self._process_bootstrap_addr(addr_str)
except Exception as e:
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
# Use Trio nursery for PARALLEL address processing
try:
async with trio.open_nursery() as nursery:
logger.debug(
f"Starting {len(self.bootstrap_addrs)} parallel address "
f"processing tasks"
)
# Start all bootstrap address processing tasks in parallel
for addr_str in self.bootstrap_addrs:
logger.debug(f"Starting parallel task for: {addr_str}")
nursery.start_soon(self._process_bootstrap_addr, addr_str)
# The nursery will wait for all address processing tasks to complete
logger.debug(
"Nursery active - waiting for address processing tasks to complete"
)
except trio.Cancelled:
logger.debug("Bootstrap address processing cancelled - cleaning up tasks")
raise
except Exception as e:
logger.error(f"Bootstrap address processing failed: {e}")
raise
logger.info("Bootstrap discovery startup complete - all tasks finished")
def stop(self) -> None:
"""Clean up bootstrap discovery resources."""
logger.debug("Stopping bootstrap discovery")
logger.info("Stopping bootstrap discovery and cleaning up tasks")
# Clear discovered peers
self.discovered_peers.clear()
logger.debug("Bootstrap discovery cleanup completed")
async def _process_bootstrap_addr(self, addr_str: str) -> None:
"""Convert string address to PeerInfo and add to peerstore."""
try:
multiaddr = Multiaddr(addr_str)
try:
multiaddr = Multiaddr(addr_str)
except Exception as e:
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
return
if self.is_dns_addr(multiaddr):
resolved_addrs = await resolver.resolve(multiaddr)
if resolved_addrs is None:
logger.warning(f"DNS resolution returned None for: {addr_str}")
return
peer_id_str = multiaddr.get_peer_id()
if peer_id_str is None:
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
return
peer_id = ID.from_base58(peer_id_str)
addrs = [addr for addr in resolved_addrs]
if not addrs:
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
return
peer_info = PeerInfo(peer_id, addrs)
await self.add_addr(peer_info)
else:
peer_info = info_from_p2p_addr(multiaddr)
await self.add_addr(peer_info)
except Exception as e:
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
return
if self.is_dns_addr(multiaddr):
resolved_addrs = await resolver.resolve(multiaddr)
peer_id_str = multiaddr.get_peer_id()
if peer_id_str is None:
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
return
peer_id = ID.from_base58(peer_id_str)
addrs = [addr for addr in resolved_addrs]
if not addrs:
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
return
peer_info = PeerInfo(peer_id, addrs)
self.add_addr(peer_info)
else:
self.add_addr(info_from_p2p_addr(multiaddr))
logger.warning(f"Failed to process bootstrap address {addr_str}: {e}")
def is_dns_addr(self, addr: Multiaddr) -> bool:
"""Check if the address is a DNS address."""
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
def add_addr(self, peer_info: PeerInfo) -> None:
"""Add a peer to the peerstore and emit discovery event."""
async def add_addr(self, peer_info: PeerInfo) -> None:
"""
Add a peer to the peerstore, emit discovery event,
and attempt connection in parallel.
"""
logger.debug(
f"Adding peer {peer_info.peer_id} with {len(peer_info.addrs)} addresses"
)
# Skip if it's our own peer
if peer_info.peer_id == self.swarm.get_peer_id():
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
return
# Always add addresses to peerstore (allows multiple addresses for same peer)
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
# Filter addresses to only include IPv4+TCP (only supported protocol)
ipv4_tcp_addrs = []
filtered_out_addrs = []
for addr in peer_info.addrs:
if self._is_ipv4_tcp_addr(addr):
ipv4_tcp_addrs.append(addr)
else:
filtered_out_addrs.append(addr)
# Log filtering results
logger.debug(
f"Address filtering for {peer_info.peer_id}: "
f"{len(ipv4_tcp_addrs)} IPv4+TCP, {len(filtered_out_addrs)} filtered"
)
# Skip peer if no IPv4+TCP addresses available
if not ipv4_tcp_addrs:
logger.warning(
f"❌ No IPv4+TCP addresses for {peer_info.peer_id} - "
f"skipping connection attempts"
)
return
# Add only IPv4+TCP addresses to peerstore
self.peerstore.add_addrs(peer_info.peer_id, ipv4_tcp_addrs, PERMANENT_ADDR_TTL)
# Only emit discovery event if this is the first time we see this peer
peer_id_str = str(peer_info.peer_id)
@ -89,6 +173,140 @@ class BootstrapDiscovery:
self.discovered_peers.add(peer_id_str)
# Emit peer discovery event
peerDiscovery.emit_peer_discovered(peer_info)
logger.debug(f"Peer discovered: {peer_info.peer_id}")
logger.info(f"Peer discovered: {peer_info.peer_id}")
# Connect to peer (parallel across different bootstrap addresses)
logger.debug("Connecting to discovered peer...")
await self._connect_to_peer(peer_info.peer_id)
else:
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")
logger.debug(
f"Additional addresses added for existing peer: {peer_info.peer_id}"
)
# Even for existing peers, try to connect if not already connected
if peer_info.peer_id not in self.swarm.connections:
logger.debug("Connecting to existing peer...")
await self._connect_to_peer(peer_info.peer_id)
async def _connect_to_peer(self, peer_id: ID) -> None:
"""
Attempt to establish a connection to a peer with timeout.
Uses swarm.dial_peer to connect using addresses stored in peerstore.
Times out after self.connection_timeout seconds to prevent hanging.
"""
logger.debug(f"Connection attempt for peer: {peer_id}")
# Pre-connection validation: Check if already connected
if peer_id in self.swarm.connections:
logger.debug(
f"Already connected to {peer_id} - skipping connection attempt"
)
return
# Check available addresses before attempting connection
available_addrs = self.peerstore.addrs(peer_id)
logger.debug(f"Connecting to {peer_id} ({len(available_addrs)} addresses)")
if not available_addrs:
logger.error(f"❌ No addresses available for {peer_id} - cannot connect")
return
# Record start time for connection attempt monitoring
connection_start_time = trio.current_time()
try:
with trio.move_on_after(self.connection_timeout):
# Log connection attempt
logger.debug(
f"Attempting connection to {peer_id} using "
f"{len(available_addrs)} addresses"
)
# Use swarm.dial_peer to connect using stored addresses
await self.swarm.dial_peer(peer_id)
# Calculate connection time
connection_time = trio.current_time() - connection_start_time
# Post-connection validation: Verify connection was actually established
if peer_id in self.swarm.connections:
logger.info(
f"✅ Connected to {peer_id} (took {connection_time:.2f}s)"
)
else:
logger.warning(
f"Dial succeeded but connection not found for {peer_id}"
)
except trio.TooSlowError:
logger.warning(
f"❌ Connection to {peer_id} timed out after {self.connection_timeout}s"
)
except SwarmException as e:
# Calculate failed connection time
failed_connection_time = trio.current_time() - connection_start_time
# Enhanced error logging
error_msg = str(e)
if "no addresses established a successful connection" in error_msg:
logger.warning(
f"❌ Failed to connect to {peer_id} after trying all "
f"{len(available_addrs)} addresses "
f"(took {failed_connection_time:.2f}s)"
)
# Log individual address failures if this is a MultiError
if (
e.__cause__ is not None
and hasattr(e.__cause__, "exceptions")
and getattr(e.__cause__, "exceptions", None) is not None
):
exceptions_list = getattr(e.__cause__, "exceptions")
logger.debug("📋 Individual address failure details:")
for i, addr_exception in enumerate(exceptions_list, 1):
logger.debug(f"Address {i}: {addr_exception}")
# Also log the actual address that failed
if i <= len(available_addrs):
logger.debug(f"Failed address: {available_addrs[i - 1]}")
else:
logger.warning("No detailed exception information available")
else:
logger.warning(
f"❌ Failed to connect to {peer_id}: {e} "
f"(took {failed_connection_time:.2f}s)"
)
except Exception as e:
# Handle unexpected errors that aren't swarm-specific
failed_connection_time = trio.current_time() - connection_start_time
logger.error(
f"❌ Unexpected error connecting to {peer_id}: "
f"{e} (took {failed_connection_time:.2f}s)"
)
# Don't re-raise to prevent killing the nursery and other parallel tasks
def _is_ipv4_tcp_addr(self, addr: Multiaddr) -> bool:
"""
Check if address is IPv4 with TCP protocol only.
Filters out IPv6, UDP, QUIC, WebSocket, and other unsupported protocols.
Only IPv4+TCP addresses are supported by the current transport.
"""
try:
protocols = addr.protocols()
# Must have IPv4 protocol
has_ipv4 = any(p.name == "ip4" for p in protocols)
if not has_ipv4:
return False
# Must have TCP protocol
has_tcp = any(p.name == "tcp" for p in protocols)
if not has_tcp:
return False
return True
except Exception:
# If we can't parse the address, don't use it
return False

View File

@ -43,6 +43,7 @@ from libp2p.peer.id import (
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import create_signed_peer_record
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
@ -110,6 +111,14 @@ class BasicHost(IHost):
if bootstrap:
self.bootstrap = BootstrapDiscovery(network, bootstrap)
# Cache a signed-record if the local-node in the PeerStore
envelope = create_signed_peer_record(
self.get_id(),
self.get_addrs(),
self.get_private_key(),
)
self.get_peerstore().set_local_record(envelope)
def get_id(self) -> ID:
"""
:return: peer_id of host
@ -288,6 +297,11 @@ class BasicHost(IHost):
protocol, handler = await self.multiselect.negotiate(
MultiselectCommunicator(net_stream), self.negotiate_timeout
)
if protocol is None:
await net_stream.reset()
raise StreamFailure(
"Failed to negotiate protocol: no protocol selected"
)
except MultiselectError as error:
peer_id = net_stream.muxed_conn.peer_id
logger.debug(
@ -329,7 +343,7 @@ class BasicHost(IHost):
:param peer_id: ID of the peer to check
:return: True if peer has an active connection, False otherwise
"""
return peer_id in self._network.connections
return len(self._network.get_connections(peer_id)) > 0
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
"""
@ -338,4 +352,4 @@ class BasicHost(IHost):
:param peer_id: ID of the peer to get info for
:return: Connection object if peer is connected, None otherwise
"""
return self._network.connections.get(peer_id)
return self._network.get_connection(peer_id)

View File

@ -15,8 +15,7 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.peer.envelope import seal_record
from libp2p.peer.peer_record import PeerRecord
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.utils import (
decode_varint_with_size,
get_agent_version,
@ -66,9 +65,7 @@ def _mk_identify_protobuf(
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
# Create a signed peer-record for the remote peer
record = PeerRecord(host.get_id(), host.get_addrs())
envelope = seal_record(record, host.get_private_key())
protobuf = envelope.marshal_envelope()
envelope_bytes, _ = env_to_send_in_RPC(host)
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
return Identify(
@ -78,7 +75,7 @@ def _mk_identify_protobuf(
listen_addrs=map(_multiaddr_to_bytes, laddrs),
observed_addr=observed_addr,
protocols=protocols,
signedPeerRecord=protobuf,
signedPeerRecord=envelope_bytes,
)

View File

@ -22,15 +22,18 @@ from libp2p.abc import (
IHost,
)
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.tools.async_service import (
Service,
)
@ -234,6 +237,9 @@ class KadDHT(Service):
await self.add_peer(peer_id)
logger.debug(f"Added peer {peer_id} to routing table")
closer_peer_envelope: Envelope | None = None
provider_peer_envelope: Envelope | None = None
try:
# Read varint-prefixed length for the message
length_prefix = b""
@ -274,6 +280,14 @@ class KadDHT(Service):
)
logger.debug(f"Found {len(closest_peers)} peers close to target")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Build response message with protobuf
response = Message()
response.type = Message.MessageType.FIND_NODE
@ -298,6 +312,21 @@ class KadDHT(Service):
except Exception:
pass
# Add the signed-peer-record for each peer in the peer-proto
# if cached in the peerstore
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -312,6 +341,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
# Consume the source signed-peer-record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Extract provider information
for provider_proto in message.providerPeers:
try:
@ -338,6 +375,17 @@ class KadDHT(Service):
logger.debug(
f"Added provider {provider_id} for key {key.hex()}"
)
# Process the signed-records of provider if sent
if not maybe_consume_signed_record(
provider_proto, self.host
):
logger.error(
"Received an invalid-signed-record,"
"dropping the stream"
)
await stream.close()
return
except Exception as e:
logger.warning(f"Failed to process provider info: {e}")
@ -346,6 +394,10 @@ class KadDHT(Service):
response.type = Message.MessageType.ADD_PROVIDER
response.key = key
# Add sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
@ -357,6 +409,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
# Find providers for the key
providers = self.provider_store.get_providers(key)
logger.debug(
@ -368,12 +428,28 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_PROVIDERS
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add provider information to response
for provider_info in providers:
provider_proto = response.providerPeers.add()
provider_proto.id = provider_info.peer_id.to_bytes()
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add provider signed-records if cached
provider_peer_envelope = (
self.host.get_peerstore().get_peer_record(
provider_info.peer_id
)
)
if provider_peer_envelope is not None:
provider_proto.signedRecord = (
provider_peer_envelope.marshal_envelope()
)
# Add addresses if available
for addr in provider_info.addrs:
provider_proto.addrs.append(addr.to_bytes())
@ -397,6 +473,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -417,6 +503,14 @@ class KadDHT(Service):
key = message.key
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
value = self.value_store.get(key)
if value:
logger.debug(f"Found value for key {key.hex()}")
@ -431,6 +525,10 @@ class KadDHT(Service):
response.record.value = value
response.record.timeReceived = str(time.time())
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
@ -444,6 +542,10 @@ class KadDHT(Service):
response.type = Message.MessageType.GET_VALUE
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
@ -462,6 +564,16 @@ class KadDHT(Service):
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add signed-records of closer-peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer)
)
if closer_peer_envelope is not None:
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
@ -484,6 +596,15 @@ class KadDHT(Service):
key = message.record.key
value = message.record.value
success = False
# Consume the source signed_peer_record if sent
if not maybe_consume_signed_record(message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
await stream.close()
return
try:
if not (key and value):
raise ValueError(
@ -504,6 +625,12 @@ class KadDHT(Service):
response.type = Message.MessageType.PUT_VALUE
if success:
response.key = key
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)

View File

@ -27,6 +27,7 @@ message Message {
bytes id = 1;
repeated bytes addrs = 2;
ConnectionType connection = 3;
optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded
}
MessageType type = 1;
@ -35,4 +36,6 @@ message Message {
Record record = 3;
repeated Peer closerPeers = 8;
repeated Peer providerPeers = 9;
optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
}

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RECORD._serialized_start=36
_RECORD._serialized_end=94
_MESSAGE._serialized_start=97
_MESSAGE._serialized_end=555
_MESSAGE_PEER._serialized_start=281
_MESSAGE_PEER._serialized_end=359
_MESSAGE_MESSAGETYPE._serialized_start=361
_MESSAGE_MESSAGETYPE._serialized_end=466
_MESSAGE_CONNECTIONTYPE._serialized_start=468
_MESSAGE_CONNECTIONTYPE._serialized_end=555
_globals['_RECORD']._serialized_start=36
_globals['_RECORD']._serialized_end=94
_globals['_MESSAGE']._serialized_start=97
_globals['_MESSAGE']._serialized_end=643
_globals['_MESSAGE_PEER']._serialized_start=308
_globals['_MESSAGE_PEER']._serialized_end=430
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=432
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=537
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626
# @@protoc_insertion_point(module_scope)

View File

@ -1,133 +1,70 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
DESCRIPTOR: _descriptor.FileDescriptor
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
class Record(_message.Message):
__slots__ = ("key", "value", "timeReceived")
KEY_FIELD_NUMBER: _ClassVar[int]
VALUE_FIELD_NUMBER: _ClassVar[int]
TIMERECEIVED_FIELD_NUMBER: _ClassVar[int]
key: bytes
value: bytes
timeReceived: str
def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ...
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class Record(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
TIMERECEIVED_FIELD_NUMBER: builtins.int
key: builtins.bytes
value: builtins.bytes
timeReceived: builtins.str
def __init__(
self,
*,
key: builtins.bytes = ...,
value: builtins.bytes = ...,
timeReceived: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
global___Record = Record
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _MessageType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
PUT_VALUE: Message._MessageType.ValueType # 0
GET_VALUE: Message._MessageType.ValueType # 1
ADD_PROVIDER: Message._MessageType.ValueType # 2
GET_PROVIDERS: Message._MessageType.ValueType # 3
FIND_NODE: Message._MessageType.ValueType # 4
PING: Message._MessageType.ValueType # 5
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
PUT_VALUE: Message.MessageType.ValueType # 0
GET_VALUE: Message.MessageType.ValueType # 1
ADD_PROVIDER: Message.MessageType.ValueType # 2
GET_PROVIDERS: Message.MessageType.ValueType # 3
FIND_NODE: Message.MessageType.ValueType # 4
PING: Message.MessageType.ValueType # 5
class _ConnectionType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
CONNECTED: Message._ConnectionType.ValueType # 1
CAN_CONNECT: Message._ConnectionType.ValueType # 2
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
CONNECTED: Message.ConnectionType.ValueType # 1
CAN_CONNECT: Message.ConnectionType.ValueType # 2
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
@typing.final
class Peer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
ADDRS_FIELD_NUMBER: builtins.int
CONNECTION_FIELD_NUMBER: builtins.int
id: builtins.bytes
connection: global___Message.ConnectionType.ValueType
@property
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
id: builtins.bytes = ...,
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
connection: global___Message.ConnectionType.ValueType = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
TYPE_FIELD_NUMBER: builtins.int
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
RECORD_FIELD_NUMBER: builtins.int
CLOSERPEERS_FIELD_NUMBER: builtins.int
PROVIDERPEERS_FIELD_NUMBER: builtins.int
type: global___Message.MessageType.ValueType
clusterLevelRaw: builtins.int
key: builtins.bytes
@property
def record(self) -> global___Record: ...
@property
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
@property
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
def __init__(
self,
*,
type: global___Message.MessageType.ValueType = ...,
clusterLevelRaw: builtins.int = ...,
key: builtins.bytes = ...,
record: global___Record | None = ...,
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
global___Message = Message
class Message(_message.Message):
__slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord")
class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
PUT_VALUE: _ClassVar[Message.MessageType]
GET_VALUE: _ClassVar[Message.MessageType]
ADD_PROVIDER: _ClassVar[Message.MessageType]
GET_PROVIDERS: _ClassVar[Message.MessageType]
FIND_NODE: _ClassVar[Message.MessageType]
PING: _ClassVar[Message.MessageType]
PUT_VALUE: Message.MessageType
GET_VALUE: Message.MessageType
ADD_PROVIDER: Message.MessageType
GET_PROVIDERS: Message.MessageType
FIND_NODE: Message.MessageType
PING: Message.MessageType
class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NOT_CONNECTED: _ClassVar[Message.ConnectionType]
CONNECTED: _ClassVar[Message.ConnectionType]
CAN_CONNECT: _ClassVar[Message.ConnectionType]
CANNOT_CONNECT: _ClassVar[Message.ConnectionType]
NOT_CONNECTED: Message.ConnectionType
CONNECTED: Message.ConnectionType
CAN_CONNECT: Message.ConnectionType
CANNOT_CONNECT: Message.ConnectionType
class Peer(_message.Message):
__slots__ = ("id", "addrs", "connection", "signedRecord")
ID_FIELD_NUMBER: _ClassVar[int]
ADDRS_FIELD_NUMBER: _ClassVar[int]
CONNECTION_FIELD_NUMBER: _ClassVar[int]
SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int]
id: bytes
addrs: _containers.RepeatedScalarFieldContainer[bytes]
connection: Message.ConnectionType
signedRecord: bytes
def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ...
TYPE_FIELD_NUMBER: _ClassVar[int]
CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
RECORD_FIELD_NUMBER: _ClassVar[int]
CLOSERPEERS_FIELD_NUMBER: _ClassVar[int]
PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int]
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
type: Message.MessageType
clusterLevelRaw: int
key: bytes
record: Record
closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
senderRecord: bytes
def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore

View File

@ -15,12 +15,14 @@ from libp2p.abc import (
INetStream,
IPeerRouting,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -33,6 +35,7 @@ from .routing_table import (
RoutingTable,
)
from .utils import (
maybe_consume_signed_record,
sort_peer_ids_by_distance,
)
@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting):
find_node_msg.type = Message.MessageType.FIND_NODE
find_node_msg.key = target_key # Set target key directly as bytes
# Create sender_signed_peer_record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
find_node_msg.senderRecord = envelope_bytes
# Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString()
logger.debug(
@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting):
# Process closest peers from response
if response_msg.type == Message.MessageType.FIND_NODE:
# Consume the sender_signed_peer_record
if not maybe_consume_signed_record(response_msg, self.host, peer):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
for peer_data in response_msg.closerPeers:
# Consume the received closer_peers signed-records, peer-id is
# sent with the peer-data
if not maybe_consume_signed_record(peer_data, self.host):
logger.error(
"Received an invalid-signed-record,ignoring the response"
)
return []
new_peer_id = ID(peer_data.id)
if new_peer_id not in results:
results.append(new_peer_id)
@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting):
"""
try:
# Read message length
peer_id = stream.muxed_conn.peer_id
length_bytes = await stream.read(4)
if not length_bytes:
return
@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting):
# Parse protobuf message
kad_message = Message()
closer_peer_envelope: Envelope | None = None
try:
kad_message.ParseFromString(message_bytes)
if kad_message.type == Message.MessageType.FIND_NODE:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(kad_message, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, dropping the stream"
)
return
# Get target key directly from protobuf message
target_key = kad_message.key
@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting):
response = Message()
response.type = Message.MessageType.FIND_NODE
# Create sender_signed_peer_record for the response
envelope_bytes, _ = env_to_send_in_RPC(self.host)
response.senderRecord = envelope_bytes
# Add peer information to response
for peer_id in closest_peers:
peer_proto = response.closerPeers.add()
peer_proto.id = peer_id.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add the signed-records of closest_peers if cached
closer_peer_envelope = (
self.host.get_peerstore().get_peer_record(peer_id)
)
if isinstance(closer_peer_envelope, Envelope):
peer_proto.signedRecord = (
closer_peer_envelope.marshal_envelope()
)
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)

View File

@ -22,12 +22,14 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
ALPHA,
@ -240,11 +242,18 @@ class ProviderStore:
message.type = Message.MessageType.ADD_PROVIDER
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Add our provider info
provider = message.providerPeers.add()
provider.id = self.local_peer_id.to_bytes()
provider.addrs.extend(addrs)
# Add the provider's signed-peer-record
provider.signedRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -276,10 +285,15 @@ class ProviderStore:
response = Message()
response.ParseFromString(response_bytes)
# Check response type
response.type == Message.MessageType.ADD_PROVIDER
if response.type:
result = True
if response.type == Message.MessageType.ADD_PROVIDER:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
result = False
else:
result = True
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
@ -380,6 +394,10 @@ class ProviderStore:
message.type = Message.MessageType.GET_PROVIDERS
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -414,10 +432,26 @@ class ProviderStore:
if response.type != Message.MessageType.GET_PROVIDERS:
return []
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return []
# Extract provider information
providers = []
for provider_proto in response.providerPeers:
try:
# Consume the provider's signed-peer-record if sent, peer-id
# already sent with the provider-proto
if not maybe_consume_signed_record(provider_proto, self.host):
logger.error(
"Received an invalid-signed-record, "
"ignoring the response"
)
return []
# Create peer ID from bytes
provider_id = ID(provider_proto.id)
@ -431,6 +465,7 @@ class ProviderStore:
# Create PeerInfo and add to result
providers.append(PeerInfo(provider_id, addrs))
except Exception as e:
logger.warning(f"Failed to parse provider info: {e}")

View File

@ -2,13 +2,93 @@
Utility functions for Kademlia DHT implementation.
"""
import logging
import base58
import multihash
from libp2p.abc import IHost
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import (
ID,
)
from .pb.kademlia_pb2 import (
Message,
)
logger = logging.getLogger("kademlia-example.utils")
def maybe_consume_signed_record(
msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None
) -> bool:
"""
Attempt to parse and store a signed-peer-record (Envelope) received during
DHT communication. If the record is invalid, the peer-id does not match, or
updating the peerstore fails, the function logs an error and returns False.
Parameters
----------
msg : Message | Message.Peer
The protobuf message received during DHT communication. Can either be a
top-level `Message` containing `senderRecord` or a `Message.Peer`
containing `signedRecord`.
host : IHost
The local host instance, providing access to the peerstore for storing
verified peer records.
peer_id : ID | None, optional
The expected peer ID for record validation. If provided, the peer ID
inside the record must match this value.
Returns
-------
bool
True if a valid signed peer record was successfully consumed and stored,
False otherwise.
"""
if isinstance(msg, Message):
if msg.HasField("senderRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.senderRecord,
"libp2p-peer-record",
)
if not (isinstance(peer_id, ID) and record.peer_id == peer_id):
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
else:
if msg.HasField("signedRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(
msg.signedRecord,
"libp2p-peer-record",
)
if not record.peer_id.to_bytes() == msg.id:
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error(
"Failed to update the Certified-Addr-Book: %s",
e,
)
return False
return True
def create_key_from_binary(binary_data: bytes) -> bytes:
"""

View File

@ -15,9 +15,11 @@ from libp2p.abc import (
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .common import (
DEFAULT_TTL,
@ -110,6 +112,10 @@ class ValueStore:
message = Message()
message.type = Message.MessageType.PUT_VALUE
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Set message fields
message.key = key
message.record.key = key
@ -155,7 +161,13 @@ class ValueStore:
# Check if response is valid
if response.type == Message.MessageType.PUT_VALUE:
if response.key:
# Consume the sender's signed-peer-record if sent
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return False
if response.key == key:
result = True
return result
@ -231,6 +243,10 @@ class ValueStore:
message.type = Message.MessageType.GET_VALUE
message.key = key
# Create sender's signed-peer-record
envelope_bytes, _ = env_to_send_in_RPC(self.host)
message.senderRecord = envelope_bytes
# Serialize and send the protobuf message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
@ -275,6 +291,13 @@ class ValueStore:
and response.HasField("record")
and response.record.value
):
# Consume the sender's signed-peer-record
if not maybe_consume_signed_record(response, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the response"
)
return None
logger.debug(
f"Received value for key {key.hex()} from peer {peer_id}"
)

70
libp2p/network/config.py Normal file
View File

@ -0,0 +1,70 @@
from dataclasses import dataclass
@dataclass
class RetryConfig:
"""
Configuration for retry logic with exponential backoff.
This configuration controls how connection attempts are retried when they fail.
The retry mechanism uses exponential backoff with jitter to prevent thundering
herd problems in distributed systems.
Attributes:
max_retries: Maximum number of retry attempts before giving up.
Default: 3 attempts
initial_delay: Initial delay in seconds before the first retry.
Default: 0.1 seconds (100ms)
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
Default: 30.0 seconds
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
the delay by this factor). Default: 2.0 (doubles each time)
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
and prevent synchronized retries. Default: 0.1 (10% jitter)
"""
max_retries: int = 3
initial_delay: float = 0.1
max_delay: float = 30.0
backoff_multiplier: float = 2.0
jitter_factor: float = 0.1
@dataclass
class ConnectionConfig:
"""
Configuration for multi-connection support.
This configuration controls how multiple connections per peer are managed,
including connection limits, timeouts, and load balancing strategies.
Attributes:
max_connections_per_peer: Maximum number of connections allowed to a single
peer. Default: 3 connections
connection_timeout: Timeout in seconds for establishing new connections.
Default: 30.0 seconds
load_balancing_strategy: Strategy for distributing streams across connections.
Options: "round_robin" (default) or "least_loaded"
"""
max_connections_per_peer: int = 3
connection_timeout: float = 30.0
load_balancing_strategy: str = "round_robin" # or "least_loaded"
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not (
self.load_balancing_strategy == "round_robin"
or self.load_balancing_strategy == "least_loaded"
):
raise ValueError(
"Load balancing strategy can only be 'round_robin' or 'least_loaded'"
)
if self.max_connections_per_peer < 1:
raise ValueError("Max connection per peer should be atleast 1")
if self.connection_timeout < 0:
raise ValueError("Connection timeout should be positive")

View File

@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamError,
MuxedStreamReset,
)
from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError
from .exceptions import (
StreamClosed,
@ -170,7 +171,7 @@ class NetStream(INetStream):
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_READ
raise StreamEOF() from error
except MuxedStreamReset as error:
except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
@ -199,7 +200,12 @@ class NetStream(INetStream):
try:
await self.muxed_stream.write(data)
except (MuxedStreamClosed, MuxedStreamError) as error:
except (
MuxedStreamClosed,
MuxedStreamError,
QUICStreamClosedError,
QUICStreamResetError,
) as error:
async with self._state_lock:
if self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE

View File

@ -3,6 +3,8 @@ from collections.abc import (
Callable,
)
import logging
import random
from typing import cast
from multiaddr import (
Multiaddr,
@ -25,6 +27,7 @@ from libp2p.custom_types import (
from libp2p.io.abc import (
ReadWriteCloser,
)
from libp2p.network.config import ConnectionConfig, RetryConfig
from libp2p.peer.id import (
ID,
)
@ -39,6 +42,9 @@ from libp2p.transport.exceptions import (
OpenConnectionError,
SecurityUpgradeFailure,
)
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.upgrader import (
TransportUpgrader,
)
@ -71,9 +77,7 @@ class Swarm(Service, INetworkService):
peerstore: IPeerStore
upgrader: TransportUpgrader
transport: ITransport
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
# whereas in Go one `peer_id` may point to multiple connections.
connections: dict[ID, INetConn]
connections: dict[ID, list[INetConn]]
listeners: dict[str, IListener]
common_stream_handler: StreamHandlerFn
listener_nursery: trio.Nursery | None
@ -81,18 +85,31 @@ class Swarm(Service, INetworkService):
notifees: list[INotifee]
# Enhanced: New configuration
retry_config: RetryConfig
connection_config: ConnectionConfig | QUICTransportConfig
_round_robin_index: dict[ID, int]
def __init__(
self,
peer_id: ID,
peerstore: IPeerStore,
upgrader: TransportUpgrader,
transport: ITransport,
retry_config: RetryConfig | None = None,
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
):
self.self_id = peer_id
self.peerstore = peerstore
self.upgrader = upgrader
self.transport = transport
self.connections = dict()
# Enhanced: Initialize retry and connection configuration
self.retry_config = retry_config or RetryConfig()
self.connection_config = connection_config or ConnectionConfig()
# Enhanced: Initialize connections as 1:many mapping
self.connections = {}
self.listeners = dict()
# Create Notifee array
@ -103,11 +120,19 @@ class Swarm(Service, INetworkService):
self.listener_nursery = None
self.event_listener_nursery_created = trio.Event()
# Load balancing state
self._round_robin_index = {}
async def run(self) -> None:
async with trio.open_nursery() as nursery:
# Create a nursery for listener tasks.
self.listener_nursery = nursery
self.event_listener_nursery_created.set()
if isinstance(self.transport, QUICTransport):
self.transport.set_background_nursery(nursery)
self.transport.set_swarm(self)
try:
await self.manager.wait_finished()
finally:
@ -122,18 +147,74 @@ class Swarm(Service, INetworkService):
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
self.common_stream_handler = stream_handler
async def dial_peer(self, peer_id: ID) -> INetConn:
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
"""
Try to create a connection to peer_id.
Get connections for peer (like JS getConnections, Go ConnsToPeer).
Parameters
----------
peer_id : ID | None
The peer ID to get connections for. If None, returns all connections.
Returns
-------
list[INetConn]
List of connections to the specified peer, or all connections
if peer_id is None.
"""
if peer_id is not None:
return self.connections.get(peer_id, [])
# Return all connections from all peers
all_conns = []
for conns in self.connections.values():
all_conns.extend(conns)
return all_conns
def get_connections_map(self) -> dict[ID, list[INetConn]]:
"""
Get all connections map (like JS getConnectionsMap).
Returns
-------
dict[ID, list[INetConn]]
The complete mapping of peer IDs to their connection lists.
"""
return self.connections.copy()
def get_connection(self, peer_id: ID) -> INetConn | None:
"""
Get single connection for backward compatibility.
Parameters
----------
peer_id : ID
The peer ID to get a connection for.
Returns
-------
INetConn | None
The first available connection, or None if no connections exist.
"""
conns = self.get_connections(peer_id)
return conns[0] if conns else None
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
"""
Try to create connections to peer_id with enhanced retry logic.
:param peer_id: peer if we want to dial
:raises SwarmException: raised when an error occurs
:return: muxed connection
:return: list of muxed connections
"""
if peer_id in self.connections:
# If muxed connection already exists for peer_id,
# set muxed connection equal to existing muxed connection
return self.connections[peer_id]
# Check if we already have connections
existing_connections = self.get_connections(peer_id)
if existing_connections:
logger.debug(f"Reusing existing connections to peer {peer_id}")
return existing_connections
logger.debug("attempting to dial peer %s", peer_id)
@ -146,12 +227,19 @@ class Swarm(Service, INetworkService):
if not addrs:
raise SwarmException(f"No known addresses to peer {peer_id}")
connections = []
exceptions: list[SwarmException] = []
# Try all known addresses
# Enhanced: Try all known addresses with retry logic
for multiaddr in addrs:
try:
return await self.dial_addr(multiaddr, peer_id)
connection = await self._dial_with_retry(multiaddr, peer_id)
connections.append(connection)
# Limit number of connections per peer
if len(connections) >= self.connection_config.max_connections_per_peer:
break
except SwarmException as e:
exceptions.append(e)
logger.debug(
@ -161,15 +249,73 @@ class Swarm(Service, INetworkService):
exc_info=e,
)
# Tried all addresses, raising exception.
raise SwarmException(
f"unable to connect to {peer_id}, no addresses established a successful "
"connection (with exceptions)"
) from MultiError(exceptions)
if not connections:
# Tried all addresses, raising exception.
raise SwarmException(
f"unable to connect to {peer_id}, no addresses established a "
"successful connection (with exceptions)"
) from MultiError(exceptions)
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
return connections
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
"""
Try to create a connection to peer_id with addr.
Enhanced: Dial with retry logic and exponential backoff.
:param addr: the address to dial
:param peer_id: the peer we want to connect to
:raises SwarmException: raised when all retry attempts fail
:return: network connection
"""
last_exception = None
for attempt in range(self.retry_config.max_retries + 1):
try:
return await self._dial_addr_single_attempt(addr, peer_id)
except Exception as e:
last_exception = e
if attempt < self.retry_config.max_retries:
delay = self._calculate_backoff_delay(attempt)
logger.debug(
f"Connection attempt {attempt + 1} failed, "
f"retrying in {delay:.2f}s: {e}"
)
await trio.sleep(delay)
else:
logger.debug(f"All {self.retry_config.max_retries} attempts failed")
# Convert the last exception to SwarmException for consistency
if last_exception is not None:
if isinstance(last_exception, SwarmException):
raise last_exception
else:
raise SwarmException(
f"Failed to connect after {self.retry_config.max_retries} attempts"
) from last_exception
# This should never be reached, but mypy requires it
raise SwarmException("Unexpected error in retry logic")
def _calculate_backoff_delay(self, attempt: int) -> float:
"""
Enhanced: Calculate backoff delay with jitter to prevent thundering herd.
:param attempt: the current attempt number (0-based)
:return: delay in seconds
"""
delay = min(
self.retry_config.initial_delay
* (self.retry_config.backoff_multiplier**attempt),
self.retry_config.max_delay,
)
# Add jitter to prevent synchronized retries
jitter = delay * self.retry_config.jitter_factor
return delay + random.uniform(-jitter, jitter)
async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn:
"""
Enhanced: Single attempt to dial an address (extracted from original dial_addr).
:param addr: the address we want to connect with
:param peer_id: the peer we want to connect to
@ -179,6 +325,7 @@ class Swarm(Service, INetworkService):
# Dial peer (connection to peer does not yet exist)
# Transport dials peer (gets back a raw conn)
try:
addr = Multiaddr(f"{addr}/p2p/{peer_id}")
raw_conn = await self.transport.dial(addr)
except OpenConnectionError as error:
logger.debug("fail to dial peer %s over base transport", peer_id)
@ -186,6 +333,15 @@ class Swarm(Service, INetworkService):
f"fail to open connection to peer {peer_id}"
) from error
if isinstance(self.transport, QUICTransport) and isinstance(
raw_conn, IMuxedConn
):
logger.info(
"Skipping upgrade for QUIC, QUIC connections are already multiplexed"
)
swarm_conn = await self.add_conn(raw_conn)
return swarm_conn
logger.debug("dialed peer %s over base transport", peer_id)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
@ -211,24 +367,103 @@ class Swarm(Service, INetworkService):
logger.debug("upgraded mux for peer %s", peer_id)
swarm_conn = await self.add_conn(muxed_conn)
logger.debug("successfully dialed peer %s", peer_id)
return swarm_conn
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
"""
Enhanced: Try to create a connection to peer_id with addr using retry logic.
:param addr: the address we want to connect with
:param peer_id: the peer we want to connect to
:raises SwarmException: raised when an error occurs
:return: network connection
"""
return await self._dial_with_retry(addr, peer_id)
async def new_stream(self, peer_id: ID) -> INetStream:
"""
Enhanced: Create a new stream with load balancing across multiple connections.
:param peer_id: peer_id of destination
:raises SwarmException: raised when an error occurs
:return: net stream instance
"""
logger.debug("attempting to open a stream to peer %s", peer_id)
# Get existing connections or dial new ones
connections = self.get_connections(peer_id)
if not connections:
connections = await self.dial_peer(peer_id)
swarm_conn = await self.dial_peer(peer_id)
# Load balancing strategy at interface level
connection = self._select_connection(connections, peer_id)
net_stream = await swarm_conn.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream
if isinstance(self.transport, QUICTransport) and connection is not None:
conn = cast(SwarmConn, connection)
return await conn.new_stream()
try:
net_stream = await connection.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
return net_stream
except Exception as e:
logger.debug(f"Failed to create stream on connection: {e}")
# Try other connections if available
for other_conn in connections:
if other_conn != connection:
try:
net_stream = await other_conn.new_stream()
logger.debug(
f"Successfully opened a stream to peer {peer_id} "
"using alternative connection"
)
return net_stream
except Exception:
continue
# All connections failed, raise exception
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
"""
Select connection based on load balancing strategy.
Parameters
----------
connections : list[INetConn]
List of available connections.
peer_id : ID
The peer ID for round-robin tracking.
strategy : str
Load balancing strategy ("round_robin", "least_loaded", etc.).
Returns
-------
INetConn
Selected connection.
"""
if not connections:
raise ValueError("No connections available")
strategy = self.connection_config.load_balancing_strategy
if strategy == "round_robin":
# Simple round-robin selection
if peer_id not in self._round_robin_index:
self._round_robin_index[peer_id] = 0
index = self._round_robin_index[peer_id] % len(connections)
self._round_robin_index[peer_id] += 1
return connections[index]
elif strategy == "least_loaded":
# Find connection with least streams
return min(connections, key=lambda c: len(c.get_streams()))
else:
# Default to first connection
return connections[0]
async def listen(self, *multiaddrs: Multiaddr) -> bool:
"""
@ -248,17 +483,35 @@ class Swarm(Service, INetworkService):
"""
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
# We need to wait until `self.listener_nursery` is created.
logger.debug("Starting to listen")
await self.event_listener_nursery_created.wait()
success_count = 0
for maddr in multiaddrs:
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
if str(maddr) in self.listeners:
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
return True
success_count += 1
continue
async def conn_handler(
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
) -> None:
# No need to upgrade QUIC Connection
if isinstance(self.transport, QUICTransport):
try:
quic_conn = cast(QUICConnection, read_write_closer)
await self.add_conn(quic_conn)
peer_id = quic_conn.peer_id
logger.debug(
f"successfully opened quic connection to peer {peer_id}"
)
# NOTE: This is a intentional barrier to prevent from the
# handler exiting and closing the connection.
await self.manager.wait_finished()
except Exception:
await read_write_closer.close()
return
raw_conn = RawConnection(read_write_closer, False)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
@ -309,13 +562,14 @@ class Swarm(Service, INetworkService):
# Call notifiers since event occurred
await self.notify_listen(maddr)
return True
success_count += 1
logger.debug("successfully started listening on: %s", maddr)
except OSError:
# Failed. Continue looping.
logger.debug("fail to listen on: %s", maddr)
# No maddr succeeded
return False
# Return true if at least one address succeeded
return success_count > 0
async def close(self) -> None:
"""
@ -328,9 +582,9 @@ class Swarm(Service, INetworkService):
# Perform alternative cleanup if the manager isn't initialized
# Close all connections manually
if hasattr(self, "connections"):
for conn_id in list(self.connections.keys()):
conn = self.connections[conn_id]
await conn.close()
for peer_id, conns in list(self.connections.items()):
for conn in conns:
await conn.close()
# Clear connection tracking dictionary
self.connections.clear()
@ -360,12 +614,28 @@ class Swarm(Service, INetworkService):
logger.debug("swarm successfully closed")
async def close_peer(self, peer_id: ID) -> None:
if peer_id not in self.connections:
"""
Close all connections to the specified peer.
Parameters
----------
peer_id : ID
The peer ID to close connections for.
"""
connections = self.get_connections(peer_id)
if not connections:
return
connection = self.connections[peer_id]
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
# and `notify_disconnected` for us.
await connection.close()
# Close all connections
for connection in connections:
try:
await connection.close()
except Exception as e:
logger.warning(f"Error closing connection to {peer_id}: {e}")
# Remove from connections dict
self.connections.pop(peer_id, None)
logger.debug("successfully close the connection to peer %s", peer_id)
@ -379,26 +649,77 @@ class Swarm(Service, INetworkService):
muxed_conn,
self,
)
logger.debug("Swarm::add_conn | starting muxed connection")
self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
logger.debug("Swarm::add_conn | starting swarm connection")
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()
# Store muxed_conn with peer id
self.connections[muxed_conn.peer_id] = swarm_conn
# Add to connections dict with deduplication
peer_id = muxed_conn.peer_id
if peer_id not in self.connections:
self.connections[peer_id] = []
# Check for duplicate connections by comparing the underlying muxed connection
for existing_conn in self.connections[peer_id]:
if existing_conn.muxed_conn == muxed_conn:
logger.debug(f"Connection already exists for peer {peer_id}")
# existing_conn is a SwarmConn since it's stored in the connections list
return existing_conn # type: ignore[return-value]
self.connections[peer_id].append(swarm_conn)
# Trim if we exceed max connections
max_conns = self.connection_config.max_connections_per_peer
if len(self.connections[peer_id]) > max_conns:
self._trim_connections(peer_id)
# Call notifiers since event occurred
await self.notify_connected(swarm_conn)
return swarm_conn
def _trim_connections(self, peer_id: ID) -> None:
"""
Remove oldest connections when limit is exceeded.
"""
connections = self.connections[peer_id]
if len(connections) <= self.connection_config.max_connections_per_peer:
return
# Sort by creation time and remove oldest
# For now, just keep the most recent connections
max_conns = self.connection_config.max_connections_per_peer
connections_to_remove = connections[:-max_conns]
for conn in connections_to_remove:
logger.debug(f"Trimming old connection for peer {peer_id}")
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
# Keep only the most recent connections
max_conns = self.connection_config.max_connections_per_peer
self.connections[peer_id] = connections[-max_conns:]
async def _close_connection_async(self, connection: INetConn) -> None:
"""Close a connection asynchronously."""
try:
await connection.close()
except Exception as e:
logger.warning(f"Error closing connection: {e}")
def remove_conn(self, swarm_conn: SwarmConn) -> None:
"""
Simply remove the connection from Swarm's records, without closing
the connection.
"""
peer_id = swarm_conn.muxed_conn.peer_id
if peer_id not in self.connections:
return
del self.connections[peer_id]
if peer_id in self.connections:
self.connections[peer_id] = [
conn for conn in self.connections[peer_id] if conn != swarm_conn
]
if not self.connections[peer_id]:
del self.connections[peer_id]
# Notifee
@ -444,3 +765,21 @@ class Swarm(Service, INetworkService):
async with trio.open_nursery() as nursery:
for notifee in self.notifees:
nursery.start_soon(notifier, notifee)
# Backward compatibility properties
@property
def connections_legacy(self) -> dict[ID, INetConn]:
"""
Legacy 1:1 mapping for backward compatibility.
Returns
-------
dict[ID, INetConn]
Legacy mapping with only the first connection per peer.
"""
legacy_conns = {}
for peer_id, conns in self.connections.items():
if conns:
legacy_conns[peer_id] = conns[0]
return legacy_conns

View File

@ -1,5 +1,7 @@
from typing import Any, cast
import multiaddr
from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.crypto.rsa import RSAPublicKey
@ -131,6 +133,9 @@ class Envelope:
)
return False
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
return {b for b in self.record().addrs}
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
"""

View File

@ -16,6 +16,7 @@ import trio
from trio import MemoryReceiveChannel, MemorySendChannel
from libp2p.abc import (
IHost,
IPeerStore,
)
from libp2p.crypto.keys import (
@ -23,7 +24,8 @@ from libp2p.crypto.keys import (
PrivateKey,
PublicKey,
)
from libp2p.peer.envelope import Envelope
from libp2p.peer.envelope import Envelope, seal_record
from libp2p.peer.peer_record import PeerRecord
from .id import (
ID,
@ -39,6 +41,86 @@ from .peerinfo import (
PERMANENT_ADDR_TTL = 0
def create_signed_peer_record(
peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey
) -> Envelope:
"""Creates a signed_peer_record wrapped in an Envelope"""
record = PeerRecord(peer_id, addrs)
envelope = seal_record(record, pvt_key)
return envelope
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
"""
Return the signed peer record (Envelope) to be sent in an RPC.
This function checks whether the host already has a cached signed peer record
(SPR). If one exists and its addresses match the host's current listen
addresses, the cached envelope is reused. Otherwise, a new signed peer record
is created, cached, and returned.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
tuple[bytes, bool]
A 2-tuple where the first element is the serialized envelope (bytes)
for the signed peer record, and the second element is a boolean flag
indicating whether a new record was created (True) or an existing cached
one was reused (False).
"""
listen_addrs_set = {addr for addr in host.get_addrs()}
local_env = host.get_peerstore().get_local_record()
if local_env is None:
# No cached SPR yet -> create one
return issue_and_cache_local_record(host), True
else:
record_addrs_set = local_env._env_addrs_set()
if record_addrs_set == listen_addrs_set:
# Perfect match -> reuse cached envelope
return local_env.marshal_envelope(), False
else:
# Addresses changed -> issue a new SPR and cache it
return issue_and_cache_local_record(host), True
def issue_and_cache_local_record(host: IHost) -> bytes:
"""
Create and cache a new signed peer record (Envelope) for the host.
This function generates a new signed peer record from the hosts peer ID,
listen addresses, and private key. The resulting envelope is stored in
the peerstore as the local record for future reuse.
Parameters
----------
host : IHost
The local host instance, providing access to peer ID, listen addresses,
private key, and the peerstore.
Returns
-------
bytes
The serialized envelope (bytes) representing the newly created signed
peer record.
"""
env = create_signed_peer_record(
host.get_id(),
host.get_addrs(),
host.get_private_key(),
)
# Cache it for next time use
host.get_peerstore().set_local_record(env)
return env.marshal_envelope()
class PeerRecordState:
envelope: Envelope
seq: int
@ -55,8 +137,17 @@ class PeerStore(IPeerStore):
self.peer_data_map = defaultdict(PeerData)
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
self.peer_record_map: dict[ID, PeerRecordState] = {}
self.local_peer_record: Envelope | None = None
self.max_records = max_records
def get_local_record(self) -> Envelope | None:
"""Get the local-signed-record wrapped in Envelope"""
return self.local_peer_record
def set_local_record(self, envelope: Envelope) -> None:
"""Set the local-signed-record wrapped in Envelope"""
self.local_peer_record = envelope
def peer_info(self, peer_id: ID) -> PeerInfo:
"""
:param peer_id: peer ID to get info for

View File

@ -1,3 +1,5 @@
from builtins import AssertionError
from libp2p.abc import (
IMultiselectCommunicator,
)
@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
msg_bytes = encode_delim(msg_str.encode())
try:
await self.read_writer.write(msg_bytes)
except IOException as error:
# Handle for connection close during ongoing negotiation in QUIC
except (IOException, AssertionError, ValueError) as error:
raise MultiselectCommunicatorError(
"fail to write to multiselect communicator"
) from error

View File

@ -15,6 +15,7 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from .exceptions import (
PubsubRouterError,
@ -103,6 +104,11 @@ class FloodSub(IPubsubRouter):
)
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
# Add the senderRecord of the peer in the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
rpc_msg.senderRecord = envelope_bytes
logger.debug("publishing message %s", pubsub_msg)
if self.pubsub is None:

View File

@ -1,6 +1,3 @@
from ast import (
literal_eval,
)
from collections import (
defaultdict,
)
@ -22,6 +19,7 @@ from libp2p.abc import (
IPubsubRouter,
)
from libp2p.custom_types import (
MessageID,
TProtocol,
)
from libp2p.peer.id import (
@ -34,10 +32,12 @@ from libp2p.peer.peerinfo import (
)
from libp2p.peer.peerstore import (
PERMANENT_ADDR_TTL,
env_to_send_in_RPC,
)
from libp2p.pubsub import (
floodsub,
)
from libp2p.pubsub.utils import maybe_consume_signed_record
from libp2p.tools.async_service import (
Service,
)
@ -54,6 +54,10 @@ from .pb import (
from .pubsub import (
Pubsub,
)
from .utils import (
parse_message_id_safe,
safe_parse_message_id,
)
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
@ -226,6 +230,12 @@ class GossipSub(IPubsubRouter, Service):
:param rpc: RPC message
:param sender_peer_id: id of the peer who sent the message
"""
# Process the senderRecord if sent
if isinstance(self.pubsub, Pubsub):
if not maybe_consume_signed_record(rpc, self.pubsub.host, sender_peer_id):
logger.error("Received an invalid-signed-record, ignoring the message")
return
control_message = rpc.control
# Relay each rpc control message to the appropriate handler
@ -253,6 +263,11 @@ class GossipSub(IPubsubRouter, Service):
)
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
# Add the senderRecord of the peer in the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
rpc_msg.senderRecord = envelope_bytes
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
@ -781,8 +796,8 @@ class GossipSub(IPubsubRouter, Service):
# Add all unknown message ids (ids that appear in ihave_msg but not in
# seen_seqnos) to list of messages we want to request
msg_ids_wanted: list[str] = [
msg_id
msg_ids_wanted: list[MessageID] = [
parse_message_id_safe(msg_id)
for msg_id in ihave_msg.messageIDs
if msg_id not in seen_seqnos_and_peers
]
@ -798,9 +813,9 @@ class GossipSub(IPubsubRouter, Service):
Forwards all request messages that are present in mcache to the
requesting peer.
"""
# FIXME: Update type of message ID
# FIXME: Find a better way to parse the msg ids
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
msg_ids: list[tuple[bytes, bytes]] = [
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
]
msgs_to_forward: list[rpc_pb2.Message] = []
for msg_id_iwant in msg_ids:
# Check if the wanted message ID is present in mcache
@ -818,6 +833,13 @@ class GossipSub(IPubsubRouter, Service):
# 1) Package these messages into a single packet
packet: rpc_pb2.RPC = rpc_pb2.RPC()
# Here the an RPC message is being created and published in response
# to the iwant control msg, so we will send a freshly created senderRecord
# with the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
packet.senderRecord = envelope_bytes
packet.publish.extend(msgs_to_forward)
if self.pubsub is None:
@ -973,6 +995,12 @@ class GossipSub(IPubsubRouter, Service):
raise NoPubsubAttached
# Add control message to packet
packet: rpc_pb2.RPC = rpc_pb2.RPC()
# Add the sender's peer-record in the RPC msg
if isinstance(self.pubsub, Pubsub):
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
packet.senderRecord = envelope_bytes
packet.control.CopyFrom(control_msg)
# Get stream for peer from pubsub

View File

@ -14,6 +14,7 @@ message RPC {
}
optional ControlMessage control = 3;
optional bytes senderRecord = 4;
}
message Message {

View File

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/pubsub/pb/rpc.proto
# Protobuf Python Version: 4.25.3
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,39 +14,39 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xca\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x12\x14\n\x0csenderRecord\x18\x04 \x01(\x0c\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_RPC._serialized_start=42
_RPC._serialized_end=222
_RPC_SUBOPTS._serialized_start=177
_RPC_SUBOPTS._serialized_end=222
_MESSAGE._serialized_start=224
_MESSAGE._serialized_end=329
_CONTROLMESSAGE._serialized_start=332
_CONTROLMESSAGE._serialized_end=508
_CONTROLIHAVE._serialized_start=510
_CONTROLIHAVE._serialized_end=561
_CONTROLIWANT._serialized_start=563
_CONTROLIWANT._serialized_end=597
_CONTROLGRAFT._serialized_start=599
_CONTROLGRAFT._serialized_end=630
_CONTROLPRUNE._serialized_start=632
_CONTROLPRUNE._serialized_end=716
_PEERINFO._serialized_start=718
_PEERINFO._serialized_end=770
_TOPICDESCRIPTOR._serialized_start=773
_TOPICDESCRIPTOR._serialized_end=1164
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=906
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1030
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=992
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1030
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1033
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1164
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1121
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1164
_globals['_RPC']._serialized_start=42
_globals['_RPC']._serialized_end=244
_globals['_RPC_SUBOPTS']._serialized_start=199
_globals['_RPC_SUBOPTS']._serialized_end=244
_globals['_MESSAGE']._serialized_start=246
_globals['_MESSAGE']._serialized_end=351
_globals['_CONTROLMESSAGE']._serialized_start=354
_globals['_CONTROLMESSAGE']._serialized_end=530
_globals['_CONTROLIHAVE']._serialized_start=532
_globals['_CONTROLIHAVE']._serialized_end=583
_globals['_CONTROLIWANT']._serialized_start=585
_globals['_CONTROLIWANT']._serialized_end=619
_globals['_CONTROLGRAFT']._serialized_start=621
_globals['_CONTROLGRAFT']._serialized_end=652
_globals['_CONTROLPRUNE']._serialized_start=654
_globals['_CONTROLPRUNE']._serialized_end=738
_globals['_PEERINFO']._serialized_start=740
_globals['_PEERINFO']._serialized_end=792
_globals['_TOPICDESCRIPTOR']._serialized_start=795
_globals['_TOPICDESCRIPTOR']._serialized_end=1186
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_start=928
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_end=1052
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_start=1014
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_end=1052
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_start=1055
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_end=1186
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_start=1143
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_end=1186
# @@protoc_insertion_point(module_scope)

View File

@ -1,323 +1,132 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
Modified from https://github.com/libp2p/go-libp2p-pubsub/blob/master/pb/rpc.proto"""
from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
DESCRIPTOR: _descriptor.FileDescriptor
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
class RPC(_message.Message):
__slots__ = ("subscriptions", "publish", "control", "senderRecord")
class SubOpts(_message.Message):
__slots__ = ("subscribe", "topicid")
SUBSCRIBE_FIELD_NUMBER: _ClassVar[int]
TOPICID_FIELD_NUMBER: _ClassVar[int]
subscribe: bool
topicid: str
def __init__(self, subscribe: bool = ..., topicid: _Optional[str] = ...) -> None: ...
SUBSCRIPTIONS_FIELD_NUMBER: _ClassVar[int]
PUBLISH_FIELD_NUMBER: _ClassVar[int]
CONTROL_FIELD_NUMBER: _ClassVar[int]
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
subscriptions: _containers.RepeatedCompositeFieldContainer[RPC.SubOpts]
publish: _containers.RepeatedCompositeFieldContainer[Message]
control: ControlMessage
senderRecord: bytes
def __init__(self, subscriptions: _Optional[_Iterable[_Union[RPC.SubOpts, _Mapping]]] = ..., publish: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., control: _Optional[_Union[ControlMessage, _Mapping]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
class Message(_message.Message):
__slots__ = ("from_id", "data", "seqno", "topicIDs", "signature", "key")
FROM_ID_FIELD_NUMBER: _ClassVar[int]
DATA_FIELD_NUMBER: _ClassVar[int]
SEQNO_FIELD_NUMBER: _ClassVar[int]
TOPICIDS_FIELD_NUMBER: _ClassVar[int]
SIGNATURE_FIELD_NUMBER: _ClassVar[int]
KEY_FIELD_NUMBER: _ClassVar[int]
from_id: bytes
data: bytes
seqno: bytes
topicIDs: _containers.RepeatedScalarFieldContainer[str]
signature: bytes
key: bytes
def __init__(self, from_id: _Optional[bytes] = ..., data: _Optional[bytes] = ..., seqno: _Optional[bytes] = ..., topicIDs: _Optional[_Iterable[str]] = ..., signature: _Optional[bytes] = ..., key: _Optional[bytes] = ...) -> None: ...
@typing.final
class RPC(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class ControlMessage(_message.Message):
__slots__ = ("ihave", "iwant", "graft", "prune")
IHAVE_FIELD_NUMBER: _ClassVar[int]
IWANT_FIELD_NUMBER: _ClassVar[int]
GRAFT_FIELD_NUMBER: _ClassVar[int]
PRUNE_FIELD_NUMBER: _ClassVar[int]
ihave: _containers.RepeatedCompositeFieldContainer[ControlIHave]
iwant: _containers.RepeatedCompositeFieldContainer[ControlIWant]
graft: _containers.RepeatedCompositeFieldContainer[ControlGraft]
prune: _containers.RepeatedCompositeFieldContainer[ControlPrune]
def __init__(self, ihave: _Optional[_Iterable[_Union[ControlIHave, _Mapping]]] = ..., iwant: _Optional[_Iterable[_Union[ControlIWant, _Mapping]]] = ..., graft: _Optional[_Iterable[_Union[ControlGraft, _Mapping]]] = ..., prune: _Optional[_Iterable[_Union[ControlPrune, _Mapping]]] = ...) -> None: ... # type: ignore
@typing.final
class SubOpts(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class ControlIHave(_message.Message):
__slots__ = ("topicID", "messageIDs")
TOPICID_FIELD_NUMBER: _ClassVar[int]
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
topicID: str
messageIDs: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, topicID: _Optional[str] = ..., messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
SUBSCRIBE_FIELD_NUMBER: builtins.int
TOPICID_FIELD_NUMBER: builtins.int
subscribe: builtins.bool
"""subscribe or unsubscribe"""
topicid: builtins.str
def __init__(
self,
*,
subscribe: builtins.bool | None = ...,
topicid: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> None: ...
class ControlIWant(_message.Message):
__slots__ = ("messageIDs",)
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
messageIDs: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
SUBSCRIPTIONS_FIELD_NUMBER: builtins.int
PUBLISH_FIELD_NUMBER: builtins.int
CONTROL_FIELD_NUMBER: builtins.int
@property
def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RPC.SubOpts]: ...
@property
def publish(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message]: ...
@property
def control(self) -> global___ControlMessage: ...
def __init__(
self,
*,
subscriptions: collections.abc.Iterable[global___RPC.SubOpts] | None = ...,
publish: collections.abc.Iterable[global___Message] | None = ...,
control: global___ControlMessage | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["control", b"control"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["control", b"control", "publish", b"publish", "subscriptions", b"subscriptions"]) -> None: ...
class ControlGraft(_message.Message):
__slots__ = ("topicID",)
TOPICID_FIELD_NUMBER: _ClassVar[int]
topicID: str
def __init__(self, topicID: _Optional[str] = ...) -> None: ...
global___RPC = RPC
class ControlPrune(_message.Message):
__slots__ = ("topicID", "peers", "backoff")
TOPICID_FIELD_NUMBER: _ClassVar[int]
PEERS_FIELD_NUMBER: _ClassVar[int]
BACKOFF_FIELD_NUMBER: _ClassVar[int]
topicID: str
peers: _containers.RepeatedCompositeFieldContainer[PeerInfo]
backoff: int
def __init__(self, topicID: _Optional[str] = ..., peers: _Optional[_Iterable[_Union[PeerInfo, _Mapping]]] = ..., backoff: _Optional[int] = ...) -> None: ... # type: ignore
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class PeerInfo(_message.Message):
__slots__ = ("peerID", "signedPeerRecord")
PEERID_FIELD_NUMBER: _ClassVar[int]
SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int]
peerID: bytes
signedPeerRecord: bytes
def __init__(self, peerID: _Optional[bytes] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ...
FROM_ID_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
SEQNO_FIELD_NUMBER: builtins.int
TOPICIDS_FIELD_NUMBER: builtins.int
SIGNATURE_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
from_id: builtins.bytes
data: builtins.bytes
seqno: builtins.bytes
signature: builtins.bytes
key: builtins.bytes
@property
def topicIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
from_id: builtins.bytes | None = ...,
data: builtins.bytes | None = ...,
seqno: builtins.bytes | None = ...,
topicIDs: collections.abc.Iterable[builtins.str] | None = ...,
signature: builtins.bytes | None = ...,
key: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature", "topicIDs", b"topicIDs"]) -> None: ...
global___Message = Message
@typing.final
class ControlMessage(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
IHAVE_FIELD_NUMBER: builtins.int
IWANT_FIELD_NUMBER: builtins.int
GRAFT_FIELD_NUMBER: builtins.int
PRUNE_FIELD_NUMBER: builtins.int
@property
def ihave(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIHave]: ...
@property
def iwant(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIWant]: ...
@property
def graft(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlGraft]: ...
@property
def prune(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlPrune]: ...
def __init__(
self,
*,
ihave: collections.abc.Iterable[global___ControlIHave] | None = ...,
iwant: collections.abc.Iterable[global___ControlIWant] | None = ...,
graft: collections.abc.Iterable[global___ControlGraft] | None = ...,
prune: collections.abc.Iterable[global___ControlPrune] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["graft", b"graft", "ihave", b"ihave", "iwant", b"iwant", "prune", b"prune"]) -> None: ...
global___ControlMessage = ControlMessage
@typing.final
class ControlIHave(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPICID_FIELD_NUMBER: builtins.int
MESSAGEIDS_FIELD_NUMBER: builtins.int
topicID: builtins.str
@property
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
topicID: builtins.str | None = ...,
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs", "topicID", b"topicID"]) -> None: ...
global___ControlIHave = ControlIHave
@typing.final
class ControlIWant(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
MESSAGEIDS_FIELD_NUMBER: builtins.int
@property
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
def __init__(
self,
*,
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs"]) -> None: ...
global___ControlIWant = ControlIWant
@typing.final
class ControlGraft(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPICID_FIELD_NUMBER: builtins.int
topicID: builtins.str
def __init__(
self,
*,
topicID: builtins.str | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ...
global___ControlGraft = ControlGraft
@typing.final
class ControlPrune(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
TOPICID_FIELD_NUMBER: builtins.int
PEERS_FIELD_NUMBER: builtins.int
BACKOFF_FIELD_NUMBER: builtins.int
topicID: builtins.str
backoff: builtins.int
@property
def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ...
def __init__(
self,
*,
topicID: builtins.str | None = ...,
peers: collections.abc.Iterable[global___PeerInfo] | None = ...,
backoff: builtins.int | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ...
global___ControlPrune = ControlPrune
@typing.final
class PeerInfo(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
PEERID_FIELD_NUMBER: builtins.int
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
peerID: builtins.bytes
signedPeerRecord: builtins.bytes
def __init__(
self,
*,
peerID: builtins.bytes | None = ...,
signedPeerRecord: builtins.bytes | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
global___PeerInfo = PeerInfo
@typing.final
class TopicDescriptor(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@typing.final
class AuthOpts(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _AuthMode:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _AuthModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.AuthOpts._AuthMode.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NONE: TopicDescriptor.AuthOpts._AuthMode.ValueType # 0
"""no authentication, anyone can publish"""
KEY: TopicDescriptor.AuthOpts._AuthMode.ValueType # 1
"""only messages signed by keys in the topic descriptor are accepted"""
WOT: TopicDescriptor.AuthOpts._AuthMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
class AuthMode(_AuthMode, metaclass=_AuthModeEnumTypeWrapper): ...
NONE: TopicDescriptor.AuthOpts.AuthMode.ValueType # 0
"""no authentication, anyone can publish"""
KEY: TopicDescriptor.AuthOpts.AuthMode.ValueType # 1
"""only messages signed by keys in the topic descriptor are accepted"""
WOT: TopicDescriptor.AuthOpts.AuthMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
MODE_FIELD_NUMBER: builtins.int
KEYS_FIELD_NUMBER: builtins.int
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType
@property
def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
"""root keys to trust"""
def __init__(
self,
*,
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType | None = ...,
keys: collections.abc.Iterable[builtins.bytes] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["keys", b"keys", "mode", b"mode"]) -> None: ...
@typing.final
class EncOpts(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _EncMode:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _EncModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.EncOpts._EncMode.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NONE: TopicDescriptor.EncOpts._EncMode.ValueType # 0
"""no encryption, anyone can read"""
SHAREDKEY: TopicDescriptor.EncOpts._EncMode.ValueType # 1
"""messages are encrypted with shared key"""
WOT: TopicDescriptor.EncOpts._EncMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
class EncMode(_EncMode, metaclass=_EncModeEnumTypeWrapper): ...
NONE: TopicDescriptor.EncOpts.EncMode.ValueType # 0
"""no encryption, anyone can read"""
SHAREDKEY: TopicDescriptor.EncOpts.EncMode.ValueType # 1
"""messages are encrypted with shared key"""
WOT: TopicDescriptor.EncOpts.EncMode.ValueType # 2
"""web of trust, certificates can allow publisher set to grow"""
MODE_FIELD_NUMBER: builtins.int
KEYHASHES_FIELD_NUMBER: builtins.int
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType
@property
def keyHashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
"""the hashes of the shared keys used (salted)"""
def __init__(
self,
*,
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType | None = ...,
keyHashes: collections.abc.Iterable[builtins.bytes] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["keyHashes", b"keyHashes", "mode", b"mode"]) -> None: ...
NAME_FIELD_NUMBER: builtins.int
AUTH_FIELD_NUMBER: builtins.int
ENC_FIELD_NUMBER: builtins.int
name: builtins.str
@property
def auth(self) -> global___TopicDescriptor.AuthOpts: ...
@property
def enc(self) -> global___TopicDescriptor.EncOpts: ...
def __init__(
self,
*,
name: builtins.str | None = ...,
auth: global___TopicDescriptor.AuthOpts | None = ...,
enc: global___TopicDescriptor.EncOpts | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> None: ...
global___TopicDescriptor = TopicDescriptor
class TopicDescriptor(_message.Message):
__slots__ = ("name", "auth", "enc")
class AuthOpts(_message.Message):
__slots__ = ("mode", "keys")
class AuthMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NONE: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
KEY: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
WOT: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
NONE: TopicDescriptor.AuthOpts.AuthMode
KEY: TopicDescriptor.AuthOpts.AuthMode
WOT: TopicDescriptor.AuthOpts.AuthMode
MODE_FIELD_NUMBER: _ClassVar[int]
KEYS_FIELD_NUMBER: _ClassVar[int]
mode: TopicDescriptor.AuthOpts.AuthMode
keys: _containers.RepeatedScalarFieldContainer[bytes]
def __init__(self, mode: _Optional[_Union[TopicDescriptor.AuthOpts.AuthMode, str]] = ..., keys: _Optional[_Iterable[bytes]] = ...) -> None: ...
class EncOpts(_message.Message):
__slots__ = ("mode", "keyHashes")
class EncMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
NONE: _ClassVar[TopicDescriptor.EncOpts.EncMode]
SHAREDKEY: _ClassVar[TopicDescriptor.EncOpts.EncMode]
WOT: _ClassVar[TopicDescriptor.EncOpts.EncMode]
NONE: TopicDescriptor.EncOpts.EncMode
SHAREDKEY: TopicDescriptor.EncOpts.EncMode
WOT: TopicDescriptor.EncOpts.EncMode
MODE_FIELD_NUMBER: _ClassVar[int]
KEYHASHES_FIELD_NUMBER: _ClassVar[int]
mode: TopicDescriptor.EncOpts.EncMode
keyHashes: _containers.RepeatedScalarFieldContainer[bytes]
def __init__(self, mode: _Optional[_Union[TopicDescriptor.EncOpts.EncMode, str]] = ..., keyHashes: _Optional[_Iterable[bytes]] = ...) -> None: ...
NAME_FIELD_NUMBER: _ClassVar[int]
AUTH_FIELD_NUMBER: _ClassVar[int]
ENC_FIELD_NUMBER: _ClassVar[int]
name: str
auth: TopicDescriptor.AuthOpts
enc: TopicDescriptor.EncOpts
def __init__(self, name: _Optional[str] = ..., auth: _Optional[_Union[TopicDescriptor.AuthOpts, _Mapping]] = ..., enc: _Optional[_Union[TopicDescriptor.EncOpts, _Mapping]] = ...) -> None: ... # type: ignore

View File

@ -56,6 +56,8 @@ from libp2p.peer.id import (
from libp2p.peer.peerdata import (
PeerDataError,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.pubsub.utils import maybe_consume_signed_record
from libp2p.tools.async_service import (
Service,
)
@ -247,6 +249,10 @@ class Pubsub(Service, IPubsub):
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
)
# Add the sender's signedRecord in the RPC message
envelope_bytes, _ = env_to_send_in_RPC(self.host)
packet.senderRecord = envelope_bytes
return packet
async def continuously_read_stream(self, stream: INetStream) -> None:
@ -263,6 +269,14 @@ class Pubsub(Service, IPubsub):
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
# Process the sender's signed-record if sent
if not maybe_consume_signed_record(rpc_incoming, self.host, peer_id):
logger.error(
"Received an invalid-signed-record, ignoring the incoming msg"
)
continue
if rpc_incoming.publish:
# deal with RPC.publish
for msg in rpc_incoming.publish:
@ -572,6 +586,9 @@ class Pubsub(Service, IPubsub):
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
)
# Add the senderRecord of the peer in the RPC msg
envelope_bytes, _ = env_to_send_in_RPC(self.host)
packet.senderRecord = envelope_bytes
# Send out subscribe message to all peers
await self.message_all_peers(packet.SerializeToString())
@ -604,6 +621,9 @@ class Pubsub(Service, IPubsub):
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)]
)
# Add the senderRecord of the peer in the RPC msg
envelope_bytes, _ = env_to_send_in_RPC(self.host)
packet.senderRecord = envelope_bytes
# Send out unsubscribe message to all peers
await self.message_all_peers(packet.SerializeToString())

80
libp2p/pubsub/utils.py Normal file
View File

@ -0,0 +1,80 @@
import ast
import logging
from libp2p.abc import IHost
from libp2p.custom_types import (
MessageID,
)
from libp2p.peer.envelope import consume_envelope
from libp2p.peer.id import ID
from libp2p.pubsub.pb.rpc_pb2 import RPC
logger = logging.getLogger("pubsub-example.utils")
def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
"""
Attempt to parse and store a signed-peer-record (Envelope) received during
PubSub communication. If the record is invalid, the peer-id does not match, or
updating the peerstore fails, the function logs an error and returns False.
Parameters
----------
msg : RPC
The protobuf message received during PubSub communication.
host : IHost
The local host instance, providing access to the peerstore for storing
verified peer records.
peer_id : ID | None, optional
The expected peer ID for record validation. If provided, the peer ID
inside the record must match this value.
Returns
-------
bool
True if a valid signed peer record was successfully consumed and stored,
False otherwise.
"""
if msg.HasField("senderRecord"):
try:
# Convert the signed-peer-record(Envelope) from
# protobuf bytes
envelope, record = consume_envelope(msg.senderRecord, "libp2p-peer-record")
if not record.peer_id == peer_id:
return False
# Use the default TTL of 2 hours (7200 seconds)
if not host.get_peerstore().consume_peer_record(envelope, 7200):
logger.error("Failed to update the Certified-Addr-Book")
return False
except Exception as e:
logger.error("Failed to update the Certified-Addr-Book: %s", e)
return False
return True
def parse_message_id_safe(msg_id_str: str) -> MessageID:
"""Safely handle message ID as string."""
return MessageID(msg_id_str)
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
"""
Safely parse message ID using ast.literal_eval with validation.
:param msg_id_str: String representation of message ID
:return: Tuple of (seqno, from_id) as bytes
:raises ValueError: If parsing fails
"""
try:
parsed = ast.literal_eval(msg_id_str)
if not isinstance(parsed, tuple) or len(parsed) != 2:
raise ValueError("Invalid message ID format")
seqno, from_id = parsed
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
raise ValueError("Message ID components must be bytes")
return (seqno, from_id)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Invalid message ID format: {e}")

View File

@ -118,6 +118,8 @@ class SecurityMultistream(ABC):
# Select protocol if non-initiator
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a security protocol")
raise MultiselectError(
"Failed to negotiate a security protocol: no protocol selected"
)
# Return transport from protocol
return self.transports[protocol]

View File

@ -85,7 +85,9 @@ class MuxerMultistream:
else:
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a stream muxer protocol")
raise MultiselectError(
"Fail to negotiate a stream muxer protocol: no protocol selected"
)
return self.transports[protocol]
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:

View File

View File

@ -0,0 +1,345 @@
"""
Configuration classes for QUIC transport.
"""
from dataclasses import (
dataclass,
field,
)
import ssl
from typing import Any, Literal, TypedDict
from libp2p.custom_types import TProtocol
from libp2p.network.config import ConnectionConfig
class QUICTransportKwargs(TypedDict, total=False):
"""Type definition for kwargs accepted by new_transport function."""
# Connection settings
idle_timeout: float
max_datagram_size: int
local_port: int | None
# Protocol version support
enable_draft29: bool
enable_v1: bool
# TLS settings
verify_mode: ssl.VerifyMode
alpn_protocols: list[str]
# Performance settings
max_concurrent_streams: int
connection_window: int
stream_window: int
# Logging and debugging
enable_qlog: bool
qlog_dir: str | None
# Connection management
max_connections: int
connection_timeout: float
# Protocol identifiers
PROTOCOL_QUIC_V1: TProtocol
PROTOCOL_QUIC_DRAFT29: TProtocol
@dataclass
class QUICTransportConfig(ConnectionConfig):
"""Configuration for QUIC transport."""
# Connection settings
idle_timeout: float = 30.0 # Seconds before an idle connection is closed.
max_datagram_size: int = (
1200 # Maximum size of UDP datagrams to avoid IP fragmentation.
)
local_port: int | None = (
None # Local port to bind to. If None, a random port is chosen.
)
# Protocol version support
enable_draft29: bool = True # Enable QUIC draft-29 for compatibility
enable_v1: bool = True # Enable QUIC v1 (RFC 9000)
# TLS settings
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
# Performance settings
max_concurrent_streams: int = 100 # Maximum concurrent streams per connection
connection_window: int = 1024 * 1024 # Connection flow control window
stream_window: int = 64 * 1024 # Stream flow control window
# Logging and debugging
enable_qlog: bool = False # Enable QUIC logging
qlog_dir: str | None = None # Directory for QUIC logs
# Connection management
max_connections: int = 1000 # Maximum number of connections
connection_timeout: float = 10.0 # Connection establishment timeout
MAX_CONCURRENT_STREAMS: int = 1000
"""Maximum number of concurrent streams per connection."""
MAX_INCOMING_STREAMS: int = 1000
"""Maximum number of incoming streams per connection."""
CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0
"""Timeout for connection handshake (seconds)."""
MAX_OUTGOING_STREAMS: int = 1000
"""Maximum number of outgoing streams per connection."""
CONNECTION_CLOSE_TIMEOUT: int = 10
"""Timeout for opening new connection (seconds)."""
# Stream timeouts
STREAM_OPEN_TIMEOUT: float = 5.0
"""Timeout for opening new streams (seconds)."""
STREAM_ACCEPT_TIMEOUT: float = 30.0
"""Timeout for accepting incoming streams (seconds)."""
STREAM_READ_TIMEOUT: float = 30.0
"""Default timeout for stream read operations (seconds)."""
STREAM_WRITE_TIMEOUT: float = 30.0
"""Default timeout for stream write operations (seconds)."""
STREAM_CLOSE_TIMEOUT: float = 10.0
"""Timeout for graceful stream close (seconds)."""
# Flow control configuration
STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB
"""Per-stream flow control window size."""
CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB
"""Connection-wide flow control window size."""
# Buffer management
MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB
"""Maximum receive buffer size per stream."""
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
"""Low watermark for stream receive buffer."""
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
"""High watermark for stream receive buffer."""
# Stream lifecycle configuration
ENABLE_STREAM_RESET_ON_ERROR: bool = True
"""Whether to automatically reset streams on errors."""
STREAM_RESET_ERROR_CODE: int = 1
"""Default error code for stream resets."""
ENABLE_STREAM_KEEP_ALIVE: bool = False
"""Whether to enable stream keep-alive mechanisms."""
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
"""Interval for stream keep-alive pings (seconds)."""
# Resource management
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
"""Whether to track stream resource usage."""
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
"""Memory limit per individual stream."""
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
"""Total memory limit for all streams per connection."""
# Concurrency and performance
ENABLE_STREAM_BATCHING: bool = True
"""Whether to batch multiple stream operations."""
STREAM_BATCH_SIZE: int = 10
"""Number of streams to process in a batch."""
STREAM_PROCESSING_CONCURRENCY: int = 100
"""Maximum concurrent stream processing tasks."""
# Debugging and monitoring
ENABLE_STREAM_METRICS: bool = True
"""Whether to collect stream metrics."""
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
"""Whether to track stream lifecycle timelines."""
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
"""Interval for collecting stream metrics (seconds)."""
# Error handling configuration
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
"""Number of retry attempts for recoverable stream errors."""
STREAM_ERROR_RETRY_DELAY: float = 1.0
"""Initial delay between stream error retries (seconds)."""
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
"""Backoff factor for stream error retries."""
# Protocol identifiers matching go-libp2p
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not (self.enable_draft29 or self.enable_v1):
raise ValueError("At least one QUIC version must be enabled")
if self.idle_timeout <= 0:
raise ValueError("Idle timeout must be positive")
if self.max_datagram_size < 1200:
raise ValueError("Max datagram size must be at least 1200 bytes")
# Validate timeouts
timeout_fields = [
"STREAM_OPEN_TIMEOUT",
"STREAM_ACCEPT_TIMEOUT",
"STREAM_READ_TIMEOUT",
"STREAM_WRITE_TIMEOUT",
"STREAM_CLOSE_TIMEOUT",
]
for timeout_field in timeout_fields:
if getattr(self, timeout_field) <= 0:
raise ValueError(f"{timeout_field} must be positive")
# Validate flow control windows
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
raise ValueError(
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
)
# Validate buffer sizes
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
raise ValueError(
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
"exceed MAX_STREAM_RECEIVE_BUFFER"
)
)
if (
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
):
raise ValueError(
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
)
# Validate memory limits
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
expected_stream_memory = (
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
)
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
# Allow some headroom, but warn if configuration seems inconsistent
import logging
logger = logging.getLogger(__name__)
logger.warning(
"Stream memory configuration may be inconsistent: "
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
"could exceed connection limit of"
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
)
def get_stream_config_dict(self) -> dict[str, Any]:
"""Get stream-specific configuration as dictionary."""
stream_config = {}
for attr_name in dir(self):
if attr_name.startswith(
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
):
stream_config[attr_name.lower()] = getattr(self, attr_name)
return stream_config
# Additional configuration classes for specific stream features
class QUICStreamFlowControlConfig:
"""Configuration for QUIC stream flow control."""
def __init__(
self,
initial_window_size: int = 512 * 1024,
max_window_size: int = 2 * 1024 * 1024,
window_update_threshold: float = 0.5,
enable_auto_tuning: bool = True,
):
self.initial_window_size = initial_window_size
self.max_window_size = max_window_size
self.window_update_threshold = window_update_threshold
self.enable_auto_tuning = enable_auto_tuning
def create_stream_config_for_use_case(
use_case: Literal[
"high_throughput", "low_latency", "many_streams", "memory_constrained"
],
) -> QUICTransportConfig:
"""
Create optimized stream configuration for specific use cases.
Args:
use_case: One of "high_throughput", "low_latency", "many_streams","
"memory_constrained"
Returns:
Optimized QUICTransportConfig
"""
base_config = QUICTransportConfig()
if use_case == "high_throughput":
# Optimize for high throughput
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
base_config.STREAM_PROCESSING_CONCURRENCY = 200
elif use_case == "low_latency":
# Optimize for low latency
base_config.STREAM_OPEN_TIMEOUT = 1.0
base_config.STREAM_READ_TIMEOUT = 5.0
base_config.STREAM_WRITE_TIMEOUT = 5.0
base_config.ENABLE_STREAM_BATCHING = False
base_config.STREAM_BATCH_SIZE = 1
elif use_case == "many_streams":
# Optimize for many concurrent streams
base_config.MAX_CONCURRENT_STREAMS = 5000
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
base_config.STREAM_PROCESSING_CONCURRENCY = 500
elif use_case == "memory_constrained":
# Optimize for low memory usage
base_config.MAX_CONCURRENT_STREAMS = 100
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
base_config.STREAM_PROCESSING_CONCURRENCY = 50
else:
raise ValueError(f"Unknown use case: {use_case}")
return base_config

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,391 @@
"""
QUIC Transport exceptions
"""
from typing import Any, Literal
class QUICError(Exception):
"""Base exception for all QUIC transport errors."""
def __init__(self, message: str, error_code: int | None = None):
super().__init__(message)
self.error_code = error_code
# Transport-level exceptions
class QUICTransportError(QUICError):
"""Base exception for QUIC transport operations."""
pass
class QUICDialError(QUICTransportError):
"""Error occurred during QUIC connection establishment."""
pass
class QUICListenError(QUICTransportError):
"""Error occurred during QUIC listener operations."""
pass
class QUICSecurityError(QUICTransportError):
"""Error related to QUIC security/TLS operations."""
pass
# Connection-level exceptions
class QUICConnectionError(QUICError):
"""Base exception for QUIC connection operations."""
pass
class QUICConnectionClosedError(QUICConnectionError):
"""QUIC connection has been closed."""
pass
class QUICConnectionTimeoutError(QUICConnectionError):
"""QUIC connection operation timed out."""
pass
class QUICHandshakeError(QUICConnectionError):
"""Error during QUIC handshake process."""
pass
class QUICPeerVerificationError(QUICConnectionError):
"""Error verifying peer identity during handshake."""
pass
# Stream-level exceptions
class QUICStreamError(QUICError):
"""Base exception for QUIC stream operations."""
def __init__(
self,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
):
super().__init__(message, error_code)
self.stream_id = stream_id
class QUICStreamClosedError(QUICStreamError):
"""Stream is closed and cannot be used for I/O operations."""
pass
class QUICStreamResetError(QUICStreamError):
"""Stream was reset by local or remote peer."""
def __init__(
self,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
reset_by_peer: bool = False,
):
super().__init__(message, stream_id, error_code)
self.reset_by_peer = reset_by_peer
class QUICStreamTimeoutError(QUICStreamError):
"""Stream operation timed out."""
pass
class QUICStreamBackpressureError(QUICStreamError):
"""Stream write blocked due to flow control."""
pass
class QUICStreamLimitError(QUICStreamError):
"""Stream limit reached (too many concurrent streams)."""
pass
class QUICStreamStateError(QUICStreamError):
"""Invalid operation for current stream state."""
def __init__(
self,
message: str,
stream_id: str | None = None,
current_state: str | None = None,
attempted_operation: str | None = None,
):
super().__init__(message, stream_id)
self.current_state = current_state
self.attempted_operation = attempted_operation
# Flow control exceptions
class QUICFlowControlError(QUICError):
"""Base exception for flow control related errors."""
pass
class QUICFlowControlViolationError(QUICFlowControlError):
"""Flow control limits were violated."""
pass
class QUICFlowControlDeadlockError(QUICFlowControlError):
"""Flow control deadlock detected."""
pass
# Resource management exceptions
class QUICResourceError(QUICError):
"""Base exception for resource management errors."""
pass
class QUICMemoryLimitError(QUICResourceError):
"""Memory limit exceeded."""
pass
class QUICConnectionLimitError(QUICResourceError):
"""Connection limit exceeded."""
pass
# Multiaddr and addressing exceptions
class QUICAddressError(QUICError):
"""Base exception for QUIC addressing errors."""
pass
class QUICInvalidMultiaddrError(QUICAddressError):
"""Invalid multiaddr format for QUIC transport."""
pass
class QUICAddressResolutionError(QUICAddressError):
"""Failed to resolve QUIC address."""
pass
class QUICProtocolError(QUICError):
"""Base exception for QUIC protocol errors."""
pass
class QUICVersionNegotiationError(QUICProtocolError):
"""QUIC version negotiation failed."""
pass
class QUICUnsupportedVersionError(QUICProtocolError):
"""Unsupported QUIC version."""
pass
# Configuration exceptions
class QUICConfigurationError(QUICError):
"""Base exception for QUIC configuration errors."""
pass
class QUICInvalidConfigError(QUICConfigurationError):
"""Invalid QUIC configuration parameters."""
pass
class QUICCertificateError(QUICConfigurationError):
"""Error with TLS certificate configuration."""
pass
def map_quic_error_code(error_code: int) -> str:
"""
Map QUIC error codes to human-readable descriptions.
Based on RFC 9000 Transport Error Codes.
"""
error_codes = {
0x00: "NO_ERROR",
0x01: "INTERNAL_ERROR",
0x02: "CONNECTION_REFUSED",
0x03: "FLOW_CONTROL_ERROR",
0x04: "STREAM_LIMIT_ERROR",
0x05: "STREAM_STATE_ERROR",
0x06: "FINAL_SIZE_ERROR",
0x07: "FRAME_ENCODING_ERROR",
0x08: "TRANSPORT_PARAMETER_ERROR",
0x09: "CONNECTION_ID_LIMIT_ERROR",
0x0A: "PROTOCOL_VIOLATION",
0x0B: "INVALID_TOKEN",
0x0C: "APPLICATION_ERROR",
0x0D: "CRYPTO_BUFFER_EXCEEDED",
0x0E: "KEY_UPDATE_ERROR",
0x0F: "AEAD_LIMIT_REACHED",
0x10: "NO_VIABLE_PATH",
}
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
def create_stream_error(
error_type: str,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
) -> QUICStreamError:
"""
Factory function to create appropriate stream error based on type.
Args:
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
message: Error message
stream_id: Stream identifier
error_code: QUIC error code
Returns:
Appropriate QUICStreamError subclass
"""
error_type = error_type.lower()
if error_type in ("closed", "close"):
return QUICStreamClosedError(message, stream_id, error_code)
elif error_type == "reset":
return QUICStreamResetError(message, stream_id, error_code)
elif error_type == "timeout":
return QUICStreamTimeoutError(message, stream_id, error_code)
elif error_type in ("backpressure", "flow_control"):
return QUICStreamBackpressureError(message, stream_id, error_code)
elif error_type in ("limit", "stream_limit"):
return QUICStreamLimitError(message, stream_id, error_code)
elif error_type == "state":
return QUICStreamStateError(message, stream_id)
else:
return QUICStreamError(message, stream_id, error_code)
def create_connection_error(
error_type: str, message: str, error_code: int | None = None
) -> QUICConnectionError:
"""
Factory function to create appropriate connection error based on type.
Args:
error_type: Type of error ("closed", "timeout", "handshake", etc.)
message: Error message
error_code: QUIC error code
Returns:
Appropriate QUICConnectionError subclass
"""
error_type = error_type.lower()
if error_type in ("closed", "close"):
return QUICConnectionClosedError(message, error_code)
elif error_type == "timeout":
return QUICConnectionTimeoutError(message, error_code)
elif error_type == "handshake":
return QUICHandshakeError(message, error_code)
elif error_type in ("peer_verification", "verification"):
return QUICPeerVerificationError(message, error_code)
else:
return QUICConnectionError(message, error_code)
class QUICErrorContext:
"""
Context manager for handling QUIC errors with automatic error mapping.
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
"""
def __init__(self, operation: str, component: str = "quic") -> None:
self.operation = operation
self.component = component
def __enter__(self) -> "QUICErrorContext":
return self
# TODO: Fix types for exc_type
def __exit__(
self,
exc_type: type[BaseException] | None | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> Literal[False]:
if exc_type is None:
return False
if exc_val is None:
return False
# Map common aioquic exceptions to our exceptions
if "ConnectionClosed" in str(exc_type):
raise QUICConnectionClosedError(
f"Connection closed during {self.operation}: {exc_val}"
) from exc_val
elif "StreamReset" in str(exc_type):
raise QUICStreamResetError(
f"Stream reset during {self.operation}: {exc_val}"
) from exc_val
elif "timeout" in str(exc_val).lower():
if "stream" in self.component.lower():
raise QUICStreamTimeoutError(
f"Timeout during {self.operation}: {exc_val}"
) from exc_val
else:
raise QUICConnectionTimeoutError(
f"Timeout during {self.operation}: {exc_val}"
) from exc_val
elif "flow control" in str(exc_val).lower():
raise QUICStreamBackpressureError(
f"Flow control error during {self.operation}: {exc_val}"
) from exc_val
# Let other exceptions propagate
return False

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,656 @@
"""
QUIC Stream implementation
Provides stream interface over QUIC's native multiplexing.
"""
from enum import Enum
import logging
import time
from types import TracebackType
from typing import TYPE_CHECKING, Any, cast
import trio
from .exceptions import (
QUICStreamBackpressureError,
QUICStreamClosedError,
QUICStreamResetError,
QUICStreamTimeoutError,
)
if TYPE_CHECKING:
from libp2p.abc import IMuxedStream
from libp2p.custom_types import TProtocol
from .connection import QUICConnection
else:
IMuxedStream = cast(type, object)
TProtocol = cast(type, object)
logger = logging.getLogger(__name__)
class StreamState(Enum):
"""Stream lifecycle states following libp2p patterns."""
OPEN = "open"
WRITE_CLOSED = "write_closed"
READ_CLOSED = "read_closed"
CLOSED = "closed"
RESET = "reset"
class StreamDirection(Enum):
"""Stream direction for tracking initiator."""
INBOUND = "inbound"
OUTBOUND = "outbound"
class StreamTimeline:
"""Track stream lifecycle events for debugging and monitoring."""
def __init__(self) -> None:
self.created_at = time.time()
self.opened_at: float | None = None
self.first_data_at: float | None = None
self.closed_at: float | None = None
self.reset_at: float | None = None
self.error_code: int | None = None
def record_open(self) -> None:
self.opened_at = time.time()
def record_first_data(self) -> None:
if self.first_data_at is None:
self.first_data_at = time.time()
def record_close(self) -> None:
self.closed_at = time.time()
def record_reset(self, error_code: int) -> None:
self.reset_at = time.time()
self.error_code = error_code
class QUICStream(IMuxedStream):
"""
QUIC Stream implementation following libp2p IMuxedStream interface.
Based on patterns from go-libp2p and js-libp2p, this implementation:
- Leverages QUIC's native multiplexing and flow control
- Integrates with libp2p resource management
- Provides comprehensive error handling with QUIC-specific codes
- Supports bidirectional communication with independent close semantics
- Implements proper stream lifecycle management
"""
def __init__(
self,
connection: "QUICConnection",
stream_id: int,
direction: StreamDirection,
remote_addr: tuple[str, int],
resource_scope: Any | None = None,
):
"""
Initialize QUIC stream.
Args:
connection: Parent QUIC connection
stream_id: QUIC stream identifier
direction: Stream direction (inbound/outbound)
resource_scope: Resource manager scope for memory accounting
remote_addr: Remote addr stream is connected to
"""
self._connection = connection
self._stream_id = stream_id
self._direction = direction
self._resource_scope = resource_scope
# libp2p interface compliance
self._protocol: TProtocol | None = None
self._metadata: dict[str, Any] = {}
self._remote_addr = remote_addr
# Stream state management
self._state = StreamState.OPEN
self._state_lock = trio.Lock()
# Flow control and buffering
self._receive_buffer = bytearray()
self._receive_buffer_lock = trio.Lock()
self._receive_event = trio.Event()
self._backpressure_event = trio.Event()
self._backpressure_event.set() # Initially no backpressure
# Close/reset state
self._write_closed = False
self._read_closed = False
self._close_event = trio.Event()
self._reset_error_code: int | None = None
# Lifecycle tracking
self._timeline = StreamTimeline()
self._timeline.record_open()
# Resource accounting
self._memory_reserved = 0
# Stream constant configurations
self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT
self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT
self.FLOW_CONTROL_WINDOW_SIZE = (
connection._transport._config.STREAM_FLOW_CONTROL_WINDOW
)
self.MAX_RECEIVE_BUFFER_SIZE = (
connection._transport._config.MAX_STREAM_RECEIVE_BUFFER
)
if self._resource_scope:
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
logger.debug(
f"Created QUIC stream {stream_id} "
f"({direction.value}, connection: {connection.remote_peer_id()})"
)
# Properties for libp2p interface compliance
@property
def protocol(self) -> TProtocol | None:
"""Get the protocol identifier for this stream."""
return self._protocol
@protocol.setter
def protocol(self, protocol_id: TProtocol) -> None:
"""Set the protocol identifier for this stream."""
self._protocol = protocol_id
self._metadata["protocol"] = protocol_id
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
@property
def stream_id(self) -> str:
"""Get stream ID as string for libp2p compatibility."""
return str(self._stream_id)
@property
def muxed_conn(self) -> "QUICConnection": # type: ignore
"""Get the parent muxed connection."""
return self._connection
@property
def state(self) -> StreamState:
"""Get current stream state."""
return self._state
@property
def direction(self) -> StreamDirection:
"""Get stream direction."""
return self._direction
@property
def is_initiator(self) -> bool:
"""Check if this stream was locally initiated."""
return self._direction == StreamDirection.OUTBOUND
# Core stream operations
async def read(self, n: int | None = None) -> bytes:
"""
Read data from the stream with QUIC flow control.
Args:
n: Maximum number of bytes to read. If None or -1, read all available.
Returns:
Data read from stream
Raises:
QUICStreamClosedError: Stream is closed
QUICStreamResetError: Stream was reset
QUICStreamTimeoutError: Read timeout exceeded
"""
if n is None:
n = -1
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
if self._read_closed:
# Return any remaining buffered data, then EOF
async with self._receive_buffer_lock:
if self._receive_buffer:
data = self._extract_data_from_buffer(n)
self._timeline.record_first_data()
return data
return b""
# Wait for data with timeout
timeout = self.READ_TIMEOUT
try:
with trio.move_on_after(timeout) as cancel_scope:
while True:
async with self._receive_buffer_lock:
if self._receive_buffer:
data = self._extract_data_from_buffer(n)
self._timeline.record_first_data()
return data
# Check if stream was closed while waiting
if self._read_closed:
return b""
# Wait for more data
await self._receive_event.wait()
self._receive_event = trio.Event() # Reset for next wait
if cancel_scope.cancelled_caught:
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
return b""
except QUICStreamResetError:
# Stream was reset while reading
raise
except Exception as e:
logger.error(f"Error reading from stream {self.stream_id}: {e}")
await self._handle_stream_error(e)
raise
async def write(self, data: bytes) -> None:
"""
Write data to the stream with QUIC flow control.
Args:
data: Data to write
Raises:
QUICStreamClosedError: Stream is closed for writing
QUICStreamBackpressureError: Flow control window exhausted
QUICStreamResetError: Stream was reset
"""
if not data:
return
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
if self._write_closed:
raise QUICStreamClosedError(
f"Stream {self.stream_id} write side is closed"
)
try:
# Handle flow control backpressure
await self._backpressure_event.wait()
# Send data through QUIC connection
self._connection._quic.send_stream_data(self._stream_id, data)
await self._connection._transmit()
self._timeline.record_first_data()
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
except Exception as e:
logger.error(f"Error writing to stream {self.stream_id}: {e}")
# Convert QUIC-specific errors
if "flow control" in str(e).lower():
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
await self._handle_stream_error(e)
raise
async def close(self) -> None:
"""
Close the stream gracefully (both read and write sides).
This implements proper close semantics where both sides
are closed and resources are cleaned up.
"""
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
return
logger.debug(f"Closing stream {self.stream_id}")
# Close both sides
if not self._write_closed:
await self.close_write()
if not self._read_closed:
await self.close_read()
# Update state and cleanup
async with self._state_lock:
self._state = StreamState.CLOSED
await self._cleanup_resources()
self._timeline.record_close()
self._close_event.set()
logger.debug(f"Stream {self.stream_id} closed")
async def close_write(self) -> None:
"""Close the write side of the stream."""
if self._write_closed:
return
try:
# Send FIN to close write side
self._connection._quic.send_stream_data(
self._stream_id, b"", end_stream=True
)
await self._connection._transmit()
self._write_closed = True
async with self._state_lock:
if self._read_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.WRITE_CLOSED
logger.debug(f"Stream {self.stream_id} write side closed")
except Exception as e:
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
async def close_read(self) -> None:
"""Close the read side of the stream."""
if self._read_closed:
return
try:
self._read_closed = True
async with self._state_lock:
if self._write_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.READ_CLOSED
# Wake up any pending reads
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} read side closed")
except Exception as e:
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
async def reset(self, error_code: int = 0) -> None:
"""
Reset the stream with the given error code.
Args:
error_code: QUIC error code for the reset
"""
async with self._state_lock:
if self._state == StreamState.RESET:
return
logger.debug(
f"Resetting stream {self.stream_id} with error code {error_code}"
)
self._state = StreamState.RESET
self._reset_error_code = error_code
try:
# Send QUIC reset frame
self._connection._quic.reset_stream(self._stream_id, error_code)
await self._connection._transmit()
except Exception as e:
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
finally:
# Always cleanup resources
await self._cleanup_resources()
self._timeline.record_reset(error_code)
self._close_event.set()
def is_closed(self) -> bool:
"""Check if stream is completely closed."""
return self._state in (StreamState.CLOSED, StreamState.RESET)
def is_reset(self) -> bool:
"""Check if stream was reset."""
return self._state == StreamState.RESET
def can_read(self) -> bool:
"""Check if stream can be read from."""
return not self._read_closed and self._state not in (
StreamState.CLOSED,
StreamState.RESET,
)
def can_write(self) -> bool:
"""Check if stream can be written to."""
return not self._write_closed and self._state not in (
StreamState.CLOSED,
StreamState.RESET,
)
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
"""
Handle data received from the QUIC connection.
Args:
data: Received data
end_stream: Whether this is the last data (FIN received)
"""
if self._state == StreamState.RESET:
return
if data:
async with self._receive_buffer_lock:
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
logger.warning(
f"Stream {self.stream_id} receive buffer overflow, "
f"dropping {len(data)} bytes"
)
return
self._receive_buffer.extend(data)
self._timeline.record_first_data()
# Notify waiting readers
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
if end_stream:
self._read_closed = True
async with self._state_lock:
if self._write_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.READ_CLOSED
# Wake up readers to process remaining data and EOF
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} received FIN")
async def handle_stop_sending(self, error_code: int) -> None:
"""
Handle STOP_SENDING frame from remote peer.
When a STOP_SENDING frame is received, the peer is requesting that we
stop sending data on this stream. We respond by resetting the stream.
Args:
error_code: Error code from the STOP_SENDING frame
"""
logger.debug(
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
)
self._write_closed = True
# Wake up any pending write operations
self._backpressure_event.set()
async with self._state_lock:
if self.direction == StreamDirection.OUTBOUND:
self._state = StreamState.CLOSED
elif self._read_closed:
self._state = StreamState.CLOSED
else:
# Only write side closed - add WRITE_CLOSED state if needed
self._state = StreamState.WRITE_CLOSED
# Send RESET_STREAM in response (QUIC protocol requirement)
try:
self._connection._quic.reset_stream(int(self.stream_id), error_code)
await self._connection._transmit()
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
except Exception as e:
logger.warning(
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
)
async def handle_reset(self, error_code: int) -> None:
"""
Handle stream reset from remote peer.
Args:
error_code: QUIC error code from reset frame
"""
logger.debug(
f"Stream {self.stream_id} reset by peer with error code {error_code}"
)
async with self._state_lock:
self._state = StreamState.RESET
self._reset_error_code = error_code
await self._cleanup_resources()
self._timeline.record_reset(error_code)
self._close_event.set()
# Wake up any pending operations
self._receive_event.set()
self._backpressure_event.set()
async def handle_flow_control_update(self, available_window: int) -> None:
"""
Handle flow control window updates.
Args:
available_window: Available flow control window size
"""
if available_window > 0:
self._backpressure_event.set()
logger.debug(
f"Stream {self.stream_id} flow control".__add__(
f"window updated: {available_window}"
)
)
else:
self._backpressure_event = trio.Event() # Reset to blocking state
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
def _extract_data_from_buffer(self, n: int) -> bytes:
"""Extract data from receive buffer with specified limit."""
if n == -1:
# Read all available data
data = bytes(self._receive_buffer)
self._receive_buffer.clear()
else:
# Read up to n bytes
data = bytes(self._receive_buffer[:n])
self._receive_buffer = self._receive_buffer[n:]
return data
async def _handle_stream_error(self, error: Exception) -> None:
"""Handle errors by resetting the stream."""
logger.error(f"Stream {self.stream_id} error: {error}")
await self.reset(error_code=1) # Generic error code
def _reserve_memory(self, size: int) -> None:
"""Reserve memory with resource manager."""
if self._resource_scope:
try:
self._resource_scope.reserve_memory(size)
self._memory_reserved += size
except Exception as e:
logger.warning(
f"Failed to reserve memory for stream {self.stream_id}: {e}"
)
def _release_memory(self, size: int) -> None:
"""Release memory with resource manager."""
if self._resource_scope and size > 0:
try:
self._resource_scope.release_memory(size)
self._memory_reserved = max(0, self._memory_reserved - size)
except Exception as e:
logger.warning(
f"Failed to release memory for stream {self.stream_id}: {e}"
)
async def _cleanup_resources(self) -> None:
"""Clean up stream resources."""
# Release all reserved memory
if self._memory_reserved > 0:
self._release_memory(self._memory_reserved)
# Clear receive buffer
async with self._receive_buffer_lock:
self._receive_buffer.clear()
# Remove from connection's stream registry
self._connection._remove_stream(self._stream_id)
logger.debug(f"Stream {self.stream_id} resources cleaned up")
# Abstact implementations
def get_remote_address(self) -> tuple[str, int]:
return self._remote_addr
async def __aenter__(self) -> "QUICStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager and close the stream."""
logger.debug("Exiting the context and closing the stream")
await self.close()
def set_deadline(self, ttl: int) -> bool:
"""
Set a deadline for the stream. QUIC does not support deadlines natively,
so this method always returns False to indicate the operation is unsupported.
:param ttl: Time-to-live in seconds (ignored).
:return: False, as deadlines are not supported.
"""
raise NotImplementedError("QUIC does not support setting read deadlines")
# String representation for debugging
def __repr__(self) -> str:
return (
f"QUICStream(id={self.stream_id}, "
f"state={self._state.value}, "
f"direction={self._direction.value}, "
f"protocol={self._protocol})"
)
def __str__(self) -> str:
return f"QUICStream({self.stream_id})"

View File

@ -0,0 +1,491 @@
"""
QUIC Transport implementation
"""
import copy
import logging
import ssl
from typing import TYPE_CHECKING, cast
from aioquic.quic.configuration import (
QuicConfiguration,
)
from aioquic.quic.connection import (
QuicConnection as NativeQUICConnection,
)
from aioquic.quic.logger import QuicLogger
import multiaddr
import trio
from libp2p.abc import (
ITransport,
)
from libp2p.crypto.keys import (
PrivateKey,
)
from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
from libp2p.peer.id import (
ID,
)
from libp2p.transport.quic.security import QUICTLSSecurityConfig
from libp2p.transport.quic.utils import (
create_client_config_from_base,
create_server_config_from_base,
get_alpn_protocols,
is_quic_multiaddr,
multiaddr_to_quic_version,
quic_multiaddr_to_endpoint,
quic_version_to_wire_format,
)
if TYPE_CHECKING:
from libp2p.network.swarm import Swarm
else:
Swarm = cast(type, object)
from .config import (
QUICTransportConfig,
)
from .connection import (
QUICConnection,
)
from .exceptions import (
QUICDialError,
QUICListenError,
QUICSecurityError,
)
from .listener import (
QUICListener,
)
from .security import (
QUICTLSConfigManager,
create_quic_security_transport,
)
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
logger = logging.getLogger(__name__)
class QUICTransport(ITransport):
"""
QUIC Stream implementation following libp2p IMuxedStream interface.
"""
def __init__(
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
) -> None:
"""
Initialize QUIC transport with security integration.
Args:
private_key: libp2p private key for identity and TLS cert generation
config: QUIC transport configuration options
"""
self._private_key = private_key
self._peer_id = ID.from_pubkey(private_key.get_public_key())
self._config = config or QUICTransportConfig()
# Connection management
self._connections: dict[str, QUICConnection] = {}
self._listeners: list[QUICListener] = []
# Security manager for TLS integration
self._security_manager = create_quic_security_transport(
self._private_key, self._peer_id
)
# QUIC configurations for different versions
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
self._setup_quic_configurations()
# Resource management
self._closed = False
self._nursery_manager = trio.CapacityLimiter(1)
self._background_nursery: trio.Nursery | None = None
self._swarm: Swarm | None = None
logger.debug(
f"Initialized QUIC transport with security for peer {self._peer_id}"
)
def set_background_nursery(self, nursery: trio.Nursery) -> None:
"""Set the nursery to use for background tasks (called by swarm)."""
self._background_nursery = nursery
logger.debug("Transport background nursery set")
def set_swarm(self, swarm: Swarm) -> None:
"""Set the swarm for adding incoming connections."""
self._swarm = swarm
def _setup_quic_configurations(self) -> None:
"""Setup QUIC configurations."""
try:
# Get TLS configuration from security manager
server_tls_config = self._security_manager.create_server_config()
client_tls_config = self._security_manager.create_client_config()
# Base server configuration
base_server_config = QuicConfiguration(
is_client=False,
alpn_protocols=get_alpn_protocols(),
verify_mode=self._config.verify_mode,
max_datagram_frame_size=self._config.max_datagram_size,
idle_timeout=self._config.idle_timeout,
)
# Base client configuration
base_client_config = QuicConfiguration(
is_client=True,
alpn_protocols=get_alpn_protocols(),
verify_mode=self._config.verify_mode,
max_datagram_frame_size=self._config.max_datagram_size,
idle_timeout=self._config.idle_timeout,
)
# Apply TLS configuration
self._apply_tls_configuration(base_server_config, server_tls_config)
self._apply_tls_configuration(base_client_config, client_tls_config)
# QUIC v1 (RFC 9000) configurations
if self._config.enable_v1:
quic_v1_server_config = create_server_config_from_base(
base_server_config, self._security_manager, self._config
)
quic_v1_server_config.supported_versions = [
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
]
quic_v1_client_config = create_client_config_from_base(
base_client_config, self._security_manager, self._config
)
quic_v1_client_config.supported_versions = [
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
]
# Store both server and client configs for v1
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = (
quic_v1_server_config
)
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = (
quic_v1_client_config
)
# QUIC draft-29 configurations for compatibility
if self._config.enable_draft29:
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
draft29_server_config.supported_versions = [
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
]
draft29_client_config = copy.copy(base_client_config)
draft29_client_config.supported_versions = [
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
]
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = (
draft29_server_config
)
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = (
draft29_client_config
)
logger.debug("QUIC configurations initialized with libp2p TLS security")
except Exception as e:
raise QUICSecurityError(
f"Failed to setup QUIC TLS configurations: {e}"
) from e
def _apply_tls_configuration(
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
) -> None:
"""
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
Args:
config: QuicConfiguration to update
tls_config: TLS configuration dictionary from security manager
"""
try:
config.certificate = tls_config.certificate
config.private_key = tls_config.private_key
config.certificate_chain = tls_config.certificate_chain
config.alpn_protocols = tls_config.alpn_protocols
config.verify_mode = ssl.CERT_NONE
logger.debug("Successfully applied TLS configuration to QUIC config")
except Exception as e:
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
async def dial(
self,
maddr: multiaddr.Multiaddr,
) -> QUICConnection:
"""
Dial a remote peer using QUIC transport with security verification.
Args:
maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1)
peer_id: Expected peer ID for verification
nursery: Nursery to execute the background tasks
Returns:
Raw connection interface to the remote peer
Raises:
QUICDialError: If dialing fails
QUICSecurityError: If security verification fails
"""
if self._closed:
raise QUICDialError("Transport is closed")
if not is_quic_multiaddr(maddr):
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
try:
# Extract connection details from multiaddr
host, port = quic_multiaddr_to_endpoint(maddr)
remote_peer_id = maddr.get_peer_id()
if remote_peer_id is not None:
remote_peer_id = ID.from_base58(remote_peer_id)
if remote_peer_id is None:
logger.error("Unable to derive peer id from multiaddr")
raise QUICDialError("Unable to derive peer id from multiaddr")
quic_version = multiaddr_to_quic_version(maddr)
# Get appropriate QUIC client configuration
config_key = TProtocol(f"{quic_version}_client")
logger.debug("config_key", config_key, self._quic_configs.keys())
config = self._quic_configs.get(config_key)
if not config:
raise QUICDialError(f"Unsupported QUIC version: {quic_version}")
config.is_client = True
config.quic_logger = QuicLogger()
# Ensure client certificate is properly set for mutual authentication
if not config.certificate or not config.private_key:
logger.warning(
"Client config missing certificate - applying TLS config"
)
client_tls_config = self._security_manager.create_client_config()
self._apply_tls_configuration(config, client_tls_config)
# Debug log to verify certificate is present
logger.info(
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
)
logger.debug("Starting QUIC Connection")
# Create QUIC connection using aioquic's sans-IO core
native_quic_connection = NativeQUICConnection(configuration=config)
# Create trio-based QUIC connection wrapper with security
connection = QUICConnection(
quic_connection=native_quic_connection,
remote_addr=(host, port),
remote_peer_id=remote_peer_id,
local_peer_id=self._peer_id,
is_initiator=True,
maddr=maddr,
transport=self,
security_manager=self._security_manager,
)
logger.debug("QUIC Connection Created")
if self._background_nursery is None:
logger.error("No nursery set to execute background tasks")
raise QUICDialError("No nursery found to execute tasks")
await connection.connect(self._background_nursery)
# Store connection for management
conn_id = f"{host}:{port}"
self._connections[conn_id] = connection
return connection
except Exception as e:
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
raise QUICDialError(f"Dial failed: {e}") from e
async def _verify_peer_identity(
self, connection: QUICConnection, expected_peer_id: ID
) -> None:
"""
Verify remote peer identity after TLS handshake.
Args:
connection: The established QUIC connection
expected_peer_id: Expected peer ID
Raises:
QUICSecurityError: If peer verification fails
"""
try:
# Get peer certificate from the connection
peer_certificate = await connection.get_peer_certificate()
if not peer_certificate:
raise QUICSecurityError("No peer certificate available")
# Verify peer identity using security manager
verified_peer_id = self._security_manager.verify_peer_identity(
peer_certificate, expected_peer_id
)
if verified_peer_id != expected_peer_id:
raise QUICSecurityError(
"Peer ID verification failed: expected "
f"{expected_peer_id}, got {verified_peer_id}"
)
logger.debug(f"Peer identity verified: {verified_peer_id}")
logger.debug(f"Peer identity verified: {verified_peer_id}")
except Exception as e:
raise QUICSecurityError(f"Peer identity verification failed: {e}") from e
def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener:
"""
Create a QUIC listener with integrated security.
Args:
handler_function: Function to handle new connections
Returns:
QUIC listener instance
Raises:
QUICListenError: If transport is closed
"""
if self._closed:
raise QUICListenError("Transport is closed")
# Get server configurations for the listener
server_configs = {
version: config
for version, config in self._quic_configs.items()
if version.endswith("_server")
}
listener = QUICListener(
transport=self,
handler_function=handler_function,
quic_configs=server_configs,
config=self._config,
security_manager=self._security_manager,
)
self._listeners.append(listener)
logger.debug("Created QUIC listener with security")
return listener
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
"""
Check if this transport can dial the given multiaddr.
Args:
maddr: Multiaddr to check
Returns:
True if this transport can dial the address
"""
return is_quic_multiaddr(maddr)
def protocols(self) -> list[TProtocol]:
"""
Get supported protocol identifiers.
Returns:
List of supported protocol strings
"""
protocols = [QUIC_V1_PROTOCOL]
if self._config.enable_draft29:
protocols.append(QUIC_DRAFT29_PROTOCOL)
return protocols
def listen_order(self) -> int:
"""
Get the listen order priority for this transport.
Matches go-libp2p's ListenOrder = 1 for QUIC.
Returns:
Priority order for listening (lower = higher priority)
"""
return 1
async def close(self) -> None:
"""Close the transport and cleanup resources."""
if self._closed:
return
self._closed = True
logger.debug("Closing QUIC transport")
# Close all active connections and listeners concurrently using trio nursery
async with trio.open_nursery() as nursery:
# Close all connections
for connection in self._connections.values():
nursery.start_soon(connection.close)
# Close all listeners
for listener in self._listeners:
nursery.start_soon(listener.close)
self._connections.clear()
self._listeners.clear()
logger.debug("QUIC transport closed")
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
"""Clean up a terminated connection from all listeners."""
try:
for listener in self._listeners:
await listener._remove_connection_by_object(connection)
logger.debug(
"✅ TRANSPORT: Cleaned up terminated connection from all listeners"
)
except Exception as e:
logger.error(f"❌ TRANSPORT: Error cleaning up terminated connection: {e}")
def get_stats(self) -> dict[str, int | list[str] | object]:
"""Get transport statistics including security info."""
return {
"active_connections": len(self._connections),
"active_listeners": len(self._listeners),
"supported_protocols": self.protocols(),
"local_peer_id": str(self._peer_id),
"security_enabled": True,
"tls_configured": True,
}
def get_security_manager(self) -> QUICTLSConfigManager:
"""
Get the security manager for this transport.
Returns:
The QUIC TLS configuration manager
"""
return self._security_manager
def get_listener_socket(self) -> trio.socket.SocketType | None:
"""Get the socket from the first active listener."""
for listener in self._listeners:
if listener.is_listening() and listener._socket:
return listener._socket
return None

View File

@ -0,0 +1,466 @@
"""
Multiaddr utilities for QUIC transport - Module 4.
Essential utilities required for QUIC transport implementation.
Based on go-libp2p and js-libp2p QUIC implementations.
"""
import ipaddress
import logging
import ssl
from aioquic.quic.configuration import QuicConfiguration
import multiaddr
from libp2p.custom_types import TProtocol
from libp2p.transport.quic.security import QUICTLSConfigManager
from .config import QUICTransportConfig
from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError
logger = logging.getLogger(__name__)
# Protocol constants
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
UDP_PROTOCOL = "udp"
IP4_PROTOCOL = "ip4"
IP6_PROTOCOL = "ip6"
SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server"
CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client"
SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server"
CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client"
CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = {
SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000
CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000
SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
}
# QUIC version to wire format mappings (required for aioquic)
QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = {
QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000
QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29
}
# ALPN protocols for libp2p over QUIC
LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"]
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
"""
Check if a multiaddr represents a QUIC address.
Valid QUIC multiaddrs:
- /ip4/127.0.0.1/udp/4001/quic-v1
- /ip4/127.0.0.1/udp/4001/quic
- /ip6/::1/udp/4001/quic-v1
- /ip6/::1/udp/4001/quic
Args:
maddr: Multiaddr to check
Returns:
True if the multiaddr represents a QUIC address
"""
try:
addr_str = str(maddr)
# Check for required components
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
has_quic = (
f"/{QUIC_V1_PROTOCOL}" in addr_str
or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str
or "/quic" in addr_str
)
return has_ip and has_udp and has_quic
except Exception:
return False
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
"""
Extract host and port from a QUIC multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
Tuple of (host, port)
Raises:
QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address
"""
if not is_quic_multiaddr(maddr):
raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}")
try:
host = None
port = None
# Try to get IPv4 address
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except Exception:
pass
# Try to get IPv6 address if IPv4 not found
if host is None:
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except Exception:
pass
# Get UDP port
try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
port = int(port_str)
except Exception:
pass
if host is None or port is None:
raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}")
return host, port
except Exception as e:
raise QUICInvalidMultiaddrError(
f"Failed to parse QUIC multiaddr {maddr}: {e}"
) from e
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
"""
Determine QUIC version from multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
QUIC version identifier ("quic-v1" or "quic")
Raises:
QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol
"""
try:
addr_str = str(maddr)
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
return QUIC_V1_PROTOCOL # RFC 9000
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
return QUIC_DRAFT29_PROTOCOL # draft-29
else:
raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}")
except Exception as e:
raise QUICInvalidMultiaddrError(
f"Failed to determine QUIC version from {maddr}: {e}"
) from e
def create_quic_multiaddr(
host: str, port: int, version: str = "quic-v1"
) -> multiaddr.Multiaddr:
"""
Create a QUIC multiaddr from host, port, and version.
Args:
host: IP address (IPv4 or IPv6)
port: UDP port number
version: QUIC version ("quic-v1" or "quic")
Returns:
QUIC multiaddr
Raises:
QUICInvalidMultiaddrError: If invalid parameters provided
"""
try:
# Determine IP version
try:
ip = ipaddress.ip_address(host)
if isinstance(ip, ipaddress.IPv4Address):
ip_proto = IP4_PROTOCOL
else:
ip_proto = IP6_PROTOCOL
except ValueError:
raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}")
# Validate port
if not (0 <= port <= 65535):
raise QUICInvalidMultiaddrError(f"Invalid port: {port}")
# Validate and normalize QUIC version
if version == "quic-v1" or version == "/quic-v1":
quic_proto = QUIC_V1_PROTOCOL
elif version == "quic" or version == "/quic":
quic_proto = QUIC_DRAFT29_PROTOCOL
else:
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
# Construct multiaddr
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
return multiaddr.Multiaddr(addr_str)
except Exception as e:
raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e
def quic_version_to_wire_format(version: TProtocol) -> int:
"""
Convert QUIC version string to wire format integer for aioquic.
Args:
version: QUIC version string ("quic-v1" or "quic")
Returns:
Wire format version number
Raises:
QUICUnsupportedVersionError: If version is not supported
"""
wire_version = QUIC_VERSION_MAPPINGS.get(version)
if wire_version is None:
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
return wire_version
def custom_quic_version_to_wire_format(version: TProtocol) -> int:
"""
Convert QUIC version string to wire format integer for aioquic.
Args:
version: QUIC version string ("quic-v1" or "quic")
Returns:
Wire format version number
Raises:
QUICUnsupportedVersionError: If version is not supported
"""
wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version)
if wire_version is None:
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
return wire_version
def get_alpn_protocols() -> list[str]:
"""
Get ALPN protocols for libp2p over QUIC.
Returns:
List of ALPN protocol identifiers
"""
return LIBP2P_ALPN_PROTOCOLS.copy()
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
"""
Normalize a QUIC multiaddr to canonical form.
Args:
maddr: Input QUIC multiaddr
Returns:
Normalized multiaddr
Raises:
QUICInvalidMultiaddrError: If not a valid QUIC multiaddr
"""
if not is_quic_multiaddr(maddr):
raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}")
host, port = quic_multiaddr_to_endpoint(maddr)
version = multiaddr_to_quic_version(maddr)
return create_quic_multiaddr(host, port, version)
def create_server_config_from_base(
base_config: QuicConfiguration,
security_manager: QUICTLSConfigManager | None = None,
transport_config: QUICTransportConfig | None = None,
) -> QuicConfiguration:
"""
Create a server configuration without using deepcopy.
Manually copies attributes while handling cryptography objects properly.
"""
try:
# Create new server configuration from scratch
server_config = QuicConfiguration(is_client=False)
server_config.verify_mode = ssl.CERT_NONE
# Copy basic configuration attributes (these are safe to copy)
copyable_attrs = [
"alpn_protocols",
"verify_mode",
"max_datagram_frame_size",
"idle_timeout",
"max_concurrent_streams",
"supported_versions",
"max_data",
"max_stream_data",
"stateless_retry",
"quantum_readiness_test",
]
for attr in copyable_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(server_config, attr, value)
# Handle cryptography objects - these need direct reference, not copying
crypto_attrs = [
"certificate",
"private_key",
"certificate_chain",
"ca_certs",
]
for attr in crypto_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(server_config, attr, value)
# Apply security manager configuration if available
if security_manager:
try:
server_tls_config = security_manager.create_server_config()
# Override with security manager's TLS configuration
if server_tls_config.certificate:
server_config.certificate = server_tls_config.certificate
if server_tls_config.private_key:
server_config.private_key = server_tls_config.private_key
if server_tls_config.certificate_chain:
server_config.certificate_chain = (
server_tls_config.certificate_chain
)
if server_tls_config.alpn_protocols:
server_config.alpn_protocols = server_tls_config.alpn_protocols
server_tls_config.request_client_certificate = True
if getattr(server_tls_config, "request_client_certificate", False):
server_config._libp2p_request_client_cert = True # type: ignore
else:
logger.error(
"🔧 Failed to set request_client_certificate in server config"
)
except Exception as e:
logger.warning(f"Failed to apply security manager config: {e}")
# Set transport-specific defaults if provided
if transport_config:
if server_config.idle_timeout == 0:
server_config.idle_timeout = getattr(
transport_config, "idle_timeout", 30.0
)
if server_config.max_datagram_frame_size is None:
server_config.max_datagram_frame_size = getattr(
transport_config, "max_datagram_size", 1200
)
# Ensure we have ALPN protocols
if not server_config.alpn_protocols:
server_config.alpn_protocols = ["libp2p"]
logger.debug("Successfully created server config without deepcopy")
return server_config
except Exception as e:
logger.error(f"Failed to create server config: {e}")
raise
def create_client_config_from_base(
base_config: QuicConfiguration,
security_manager: QUICTLSConfigManager | None = None,
transport_config: QUICTransportConfig | None = None,
) -> QuicConfiguration:
"""
Create a client configuration without using deepcopy.
"""
try:
# Create new client configuration from scratch
client_config = QuicConfiguration(is_client=True)
client_config.verify_mode = ssl.CERT_NONE
# Copy basic configuration attributes
copyable_attrs = [
"alpn_protocols",
"verify_mode",
"max_datagram_frame_size",
"idle_timeout",
"max_concurrent_streams",
"supported_versions",
"max_data",
"max_stream_data",
"quantum_readiness_test",
]
for attr in copyable_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(client_config, attr, value)
# Handle cryptography objects - these need direct reference, not copying
crypto_attrs = [
"certificate",
"private_key",
"certificate_chain",
"ca_certs",
]
for attr in crypto_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(client_config, attr, value)
# Apply security manager configuration if available
if security_manager:
try:
client_tls_config = security_manager.create_client_config()
# Override with security manager's TLS configuration
if client_tls_config.certificate:
client_config.certificate = client_tls_config.certificate
if client_tls_config.private_key:
client_config.private_key = client_tls_config.private_key
if client_tls_config.certificate_chain:
client_config.certificate_chain = (
client_tls_config.certificate_chain
)
if client_tls_config.alpn_protocols:
client_config.alpn_protocols = client_tls_config.alpn_protocols
except Exception as e:
logger.warning(f"Failed to apply security manager config: {e}")
# Ensure we have ALPN protocols
if not client_config.alpn_protocols:
client_config.alpn_protocols = ["libp2p"]
logger.debug("Successfully created client config without deepcopy")
return client_config
except Exception as e:
logger.error(f"Failed to create client config: {e}")
raise

View File

@ -1,9 +1,7 @@
from libp2p.abc import (
IListener,
IMuxedConn,
IRawConnection,
ISecureConn,
ITransport,
)
from libp2p.custom_types import (
TMuxerOptions,
@ -43,10 +41,6 @@ class TransportUpgrader:
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
def upgrade_listener(self, transport: ITransport, listeners: IListener) -> None:
"""Upgrade multiaddr listeners to libp2p-transport listeners."""
# TODO: Figure out what to do with this function.
async def upgrade_security(
self,
raw_conn: IRawConnection,

View File

@ -15,6 +15,13 @@ from libp2p.utils.version import (
get_agent_version,
)
from libp2p.utils.address_validation import (
get_available_interfaces,
get_optimal_binding_address,
expand_wildcard_address,
find_free_port,
)
__all__ = [
"decode_uvarint_from_stream",
"encode_delim",
@ -26,4 +33,8 @@ __all__ = [
"decode_varint_from_bytes",
"decode_varint_with_size",
"read_length_prefixed_protobuf",
"get_available_interfaces",
"get_optimal_binding_address",
"expand_wildcard_address",
"find_free_port",
]

View File

@ -0,0 +1,160 @@
from __future__ import annotations
import socket
from multiaddr import Multiaddr
try:
from multiaddr.utils import ( # type: ignore
get_network_addrs,
get_thin_waist_addresses,
)
_HAS_THIN_WAIST = True
except ImportError: # pragma: no cover - only executed in older environments
_HAS_THIN_WAIST = False
get_thin_waist_addresses = None # type: ignore
get_network_addrs = None # type: ignore
def _safe_get_network_addrs(ip_version: int) -> list[str]:
"""
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
Falls back to minimal defaults when Thin Waist helpers are missing.
:param ip_version: 4 or 6
"""
if _HAS_THIN_WAIST and get_network_addrs:
try:
return get_network_addrs(ip_version) or []
except Exception: # pragma: no cover - defensive
return []
# Fallback behavior (very conservative)
if ip_version == 4:
return ["127.0.0.1"]
if ip_version == 6:
return ["::1"]
return []
def find_free_port() -> int:
"""Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to a free port provided by the OS
return s.getsockname()[1]
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
"""
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
If Thin Waist isn't available, returns [addr] (identity).
"""
if _HAS_THIN_WAIST and get_thin_waist_addresses:
try:
if port is not None:
return get_thin_waist_addresses(addr, port=port) or []
return get_thin_waist_addresses(addr) or []
except Exception: # pragma: no cover - defensive
return [addr]
return [addr]
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
"""
Discover available network interfaces (IPv4 + IPv6 if supported) for binding.
:param port: Port number to bind to.
:param protocol: Transport protocol (e.g., "tcp" or "udp").
:return: List of Multiaddr objects representing candidate interface addresses.
"""
addrs: list[Multiaddr] = []
# IPv4 enumeration
seen_v4: set[str] = set()
for ip in _safe_get_network_addrs(4):
seen_v4.add(ip)
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
if seen_v4 and "127.0.0.1" not in seen_v4:
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
# TODO: IPv6 support temporarily disabled due to libp2p handshake issues
# IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure)
# Re-enable IPv6 support once the following issues are resolved:
# - libp2p security handshake over IPv6
# - multiselect protocol over IPv6
# - connection establishment over IPv6
#
# seen_v6: set[str] = set()
# for ip in _safe_get_network_addrs(6):
# seen_v6.add(ip)
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
#
# # Always include IPv6 loopback for testing purposes when IPv6 is available
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
# if "::1" not in seen_v6:
# addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}"))
# Fallback if nothing discovered
if not addrs:
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
return addrs
def expand_wildcard_address(
addr: Multiaddr, port: int | None = None
) -> list[Multiaddr]:
"""
Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces.
:param addr: Multiaddr to expand.
:param port: Optional override for port selection.
:return: List of concrete Multiaddr instances.
"""
expanded = _safe_expand(addr, port=port)
if not expanded: # Safety fallback
return [addr]
return expanded
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
"""
Choose an optimal address for an example to bind to:
- Prefer non-loopback IPv4
- Then non-loopback IPv6
- Fallback to loopback
- Fallback to wildcard
:param port: Port number.
:param protocol: Transport protocol.
:return: A single Multiaddr chosen heuristically.
"""
candidates = get_available_interfaces(port, protocol)
def is_non_loopback(ma: Multiaddr) -> bool:
s = str(ma)
return not ("/ip4/127." in s or "/ip6/::1" in s)
for c in candidates:
if "/ip4/" in str(c) and is_non_loopback(c):
return c
for c in candidates:
if "/ip6/" in str(c) and is_non_loopback(c):
return c
for c in candidates:
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
return c
# As a final fallback, produce a wildcard
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
__all__ = [
"get_available_interfaces",
"get_optimal_binding_address",
"expand_wildcard_address",
"find_free_port",
]

View File

@ -1,7 +1,4 @@
import atexit
from datetime import (
datetime,
)
import logging
import logging.handlers
import os
@ -21,6 +18,9 @@ log_queue: "queue.Queue[Any]" = queue.Queue()
# Store the current listener to stop it on exit
_current_listener: logging.handlers.QueueListener | None = None
# Store the handlers for proper cleanup
_current_handlers: list[logging.Handler] = []
# Event to track when the listener is ready
_listener_ready = threading.Event()
@ -95,7 +95,7 @@ def setup_logging() -> None:
- Child loggers inherit their parent's level unless explicitly set
- The root libp2p logger controls the default level
"""
global _current_listener, _listener_ready
global _current_listener, _listener_ready, _current_handlers
# Reset the event
_listener_ready.clear()
@ -105,6 +105,12 @@ def setup_logging() -> None:
_current_listener.stop()
_current_listener = None
# Close and clear existing handlers
for handler in _current_handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
_current_handlers.clear()
# Get the log level from environment variable
debug_str = os.environ.get("LIBP2P_DEBUG", "")
@ -148,13 +154,10 @@ def setup_logging() -> None:
log_path = Path(log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
else:
# Default log file with timestamp and unique identifier
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
unique_id = os.urandom(4).hex() # Add a unique identifier to prevent collisions
if os.name == "nt": # Windows
log_file = f"C:\\Windows\\Temp\\py-libp2p_{timestamp}_{unique_id}.log"
else: # Unix-like
log_file = f"/tmp/py-libp2p_{timestamp}_{unique_id}.log"
# Use cross-platform temp file creation
from libp2p.utils.paths import create_temp_file
log_file = str(create_temp_file(prefix="py-libp2p_", suffix=".log"))
# Print the log file path so users know where to find it
print(f"Logging to: {log_file}", file=sys.stderr)
@ -195,6 +198,9 @@ def setup_logging() -> None:
logger.setLevel(level)
logger.propagate = False # Prevent message duplication
# Store handlers globally for cleanup
_current_handlers.extend(handlers)
# Start the listener AFTER configuring all loggers
_current_listener = logging.handlers.QueueListener(
log_queue, *handlers, respect_handler_level=True
@ -209,7 +215,13 @@ def setup_logging() -> None:
@atexit.register
def cleanup_logging() -> None:
"""Clean up logging resources on exit."""
global _current_listener
global _current_listener, _current_handlers
if _current_listener is not None:
_current_listener.stop()
_current_listener = None
# Close all file handlers to ensure proper cleanup on Windows
for handler in _current_handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
_current_handlers.clear()

267
libp2p/utils/paths.py Normal file
View File

@ -0,0 +1,267 @@
"""
Cross-platform path utilities for py-libp2p.
This module provides standardized path operations to ensure consistent
behavior across Windows, macOS, and Linux platforms.
"""
import os
from pathlib import Path
import sys
import tempfile
from typing import Union
PathLike = Union[str, Path]
def get_temp_dir() -> Path:
"""
Get cross-platform temporary directory.
Returns:
Path: Platform-specific temporary directory path
"""
return Path(tempfile.gettempdir())
def get_project_root() -> Path:
"""
Get the project root directory.
Returns:
Path: Path to the py-libp2p project root
"""
# Navigate from libp2p/utils/paths.py to project root
return Path(__file__).parent.parent.parent
def join_paths(*parts: PathLike) -> Path:
"""
Cross-platform path joining.
Args:
*parts: Path components to join
Returns:
Path: Joined path using platform-appropriate separator
"""
return Path(*parts)
def ensure_dir_exists(path: PathLike) -> Path:
"""
Ensure directory exists, create if needed.
Args:
path: Directory path to ensure exists
Returns:
Path: Path object for the directory
"""
path_obj = Path(path)
path_obj.mkdir(parents=True, exist_ok=True)
return path_obj
def get_config_dir() -> Path:
"""
Get user config directory (cross-platform).
Returns:
Path: Platform-specific config directory
"""
if os.name == "nt": # Windows
appdata = os.environ.get("APPDATA", "")
if appdata:
return Path(appdata) / "py-libp2p"
else:
# Fallback to user home directory
return Path.home() / "AppData" / "Roaming" / "py-libp2p"
else: # Unix-like (Linux, macOS)
return Path.home() / ".config" / "py-libp2p"
def get_script_dir(script_path: PathLike | None = None) -> Path:
"""
Get the directory containing a script file.
Args:
script_path: Path to the script file. If None, uses __file__
Returns:
Path: Directory containing the script
Raises:
RuntimeError: If script path cannot be determined
"""
if script_path is None:
# This will be the directory of the calling script
import inspect
frame = inspect.currentframe()
if frame and frame.f_back:
script_path = frame.f_back.f_globals.get("__file__")
else:
raise RuntimeError("Could not determine script path")
if script_path is None:
raise RuntimeError("Script path is None")
return Path(script_path).parent.absolute()
def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path:
"""
Create a temporary file with a unique name.
Args:
prefix: File name prefix
suffix: File name suffix
Returns:
Path: Path to the created temporary file
"""
temp_dir = get_temp_dir()
# Create a unique filename using timestamp and random bytes
import secrets
import time
timestamp = time.strftime("%Y%m%d_%H%M%S")
microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string
unique_id = secrets.token_hex(4)
filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}"
temp_file = temp_dir / filename
# Create the file by touching it
temp_file.touch()
return temp_file
def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path:
"""
Resolve a relative path from a base path.
Args:
base_path: Base directory path
relative_path: Relative path to resolve
Returns:
Path: Resolved absolute path
"""
base = Path(base_path).resolve()
relative = Path(relative_path)
if relative.is_absolute():
return relative
else:
return (base / relative).resolve()
def normalize_path(path: PathLike) -> Path:
"""
Normalize a path, resolving any symbolic links and relative components.
Args:
path: Path to normalize
Returns:
Path: Normalized absolute path
"""
return Path(path).resolve()
def get_venv_path() -> Path | None:
"""
Get virtual environment path if active.
Returns:
Path: Virtual environment path if active, None otherwise
"""
venv_path = os.environ.get("VIRTUAL_ENV")
if venv_path:
return Path(venv_path)
return None
def get_python_executable() -> Path:
"""
Get current Python executable path.
Returns:
Path: Path to the current Python executable
"""
return Path(sys.executable)
def find_executable(name: str) -> Path | None:
"""
Find executable in system PATH.
Args:
name: Name of the executable to find
Returns:
Path: Path to executable if found, None otherwise
"""
# Check if name already contains path
if os.path.dirname(name):
path = Path(name)
if path.exists() and os.access(path, os.X_OK):
return path
return None
# Search in PATH
for path_dir in os.environ.get("PATH", "").split(os.pathsep):
if not path_dir:
continue
path = Path(path_dir) / name
if path.exists() and os.access(path, os.X_OK):
return path
return None
def get_script_binary_path() -> Path:
"""
Get path to script's binary directory.
Returns:
Path: Directory containing the script's binary
"""
return get_python_executable().parent
def get_binary_path(binary_name: str) -> Path | None:
"""
Find binary in PATH or virtual environment.
Args:
binary_name: Name of the binary to find
Returns:
Path: Path to binary if found, None otherwise
"""
# First check in virtual environment if active
venv_path = get_venv_path()
if venv_path:
venv_bin = venv_path / "bin" if os.name != "nt" else venv_path / "Scripts"
binary_path = venv_bin / binary_name
if binary_path.exists() and os.access(binary_path, os.X_OK):
return binary_path
# Fall back to system PATH
return find_executable(binary_name)