mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge upstream/main into add-ws-transport
Resolved conflicts in: - .gitignore: Combined JavaScript interop and Sphinx build ignores - libp2p/__init__.py: Integrated QUIC transport support with WebSocket transport - libp2p/network/swarm.py: Used upstream's improved listener handling - pyproject.toml: Kept both WebSocket and QUIC dependencies This merge brings in: - QUIC transport implementation - Enhanced swarm functionality - Improved peer discovery - Better error handling - Updated dependencies and documentation WebSocket transport implementation remains intact and functional.
This commit is contained in:
@ -1,3 +1,11 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
import logging
|
||||
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
from typing import Any
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
@ -6,15 +14,12 @@ from importlib.metadata import version as __version
|
||||
from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
IMuxedConn,
|
||||
INetworkService,
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
@ -33,9 +38,6 @@ from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.discovery.mdns.mdns import (
|
||||
MDNSDiscovery,
|
||||
)
|
||||
from libp2p.host.basic_host import (
|
||||
BasicHost,
|
||||
)
|
||||
@ -45,27 +47,34 @@ from libp2p.host.routed_host import (
|
||||
from libp2p.network.swarm import (
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.network.config import (
|
||||
ConnectionConfig,
|
||||
RetryConfig
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerstore import (
|
||||
PeerStore,
|
||||
create_signed_peer_record,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
PLAINTEXT_PROTOCOL_ID,
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
Mplex,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
PROTOCOL_ID as YAMUX_PROTOCOL_ID,
|
||||
Yamux,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
||||
from libp2p.transport.tcp.tcp import (
|
||||
TCP,
|
||||
)
|
||||
@ -91,7 +100,7 @@ MUXER_YAMUX = "YAMUX"
|
||||
MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
"""
|
||||
@ -160,7 +169,6 @@ def get_default_muxer_options() -> TMuxerOptions:
|
||||
else: # YAMUX is default
|
||||
return create_yamux_muxer_option()
|
||||
|
||||
|
||||
def new_swarm(
|
||||
key_pair: KeyPair | None = None,
|
||||
muxer_opt: TMuxerOptions | None = None,
|
||||
@ -168,6 +176,9 @@ def new_swarm(
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
enable_quic: bool = False,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
) -> INetworkService:
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
@ -178,6 +189,8 @@ def new_swarm(
|
||||
:param peerstore_opt: optional peerstore
|
||||
:param muxer_preference: optional explicit muxer preference
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_quic: enable quic for transport
|
||||
:param quic_transport_opt: options for transport
|
||||
:return: return a default swarm instance
|
||||
|
||||
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
||||
@ -190,8 +203,6 @@ def new_swarm(
|
||||
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
|
||||
|
||||
# Generate X25519 keypair for Noise
|
||||
noise_key_pair = create_new_x25519_key_pair()
|
||||
|
||||
@ -255,36 +266,18 @@ def new_swarm(
|
||||
else:
|
||||
transport = transport_maybe
|
||||
|
||||
# Use given muxer preference if provided, otherwise use global default
|
||||
if muxer_preference is not None:
|
||||
temp_pref = muxer_preference.upper()
|
||||
if temp_pref not in [MUXER_YAMUX, MUXER_MPLEX]:
|
||||
raise ValueError(
|
||||
f"Unknown muxer: {muxer_preference}. Use 'YAMUX' or 'MPLEX'."
|
||||
)
|
||||
active_preference = temp_pref
|
||||
else:
|
||||
active_preference = DEFAULT_MUXER
|
||||
|
||||
# Use provided muxer options if given, otherwise create based on preference
|
||||
if muxer_opt is not None:
|
||||
muxer_transports_by_protocol = muxer_opt
|
||||
else:
|
||||
if active_preference == MUXER_MPLEX:
|
||||
muxer_transports_by_protocol = create_mplex_muxer_option()
|
||||
else: # YAMUX is default
|
||||
muxer_transports_by_protocol = create_yamux_muxer_option()
|
||||
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol=secure_transports_by_protocol,
|
||||
muxer_transports_by_protocol=muxer_transports_by_protocol,
|
||||
)
|
||||
|
||||
peerstore = peerstore_opt or PeerStore()
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
|
||||
return Swarm(id_opt, peerstore, upgrader, transport)
|
||||
return Swarm(
|
||||
id_opt,
|
||||
peerstore,
|
||||
upgrader,
|
||||
transport,
|
||||
retry_config=retry_config,
|
||||
connection_config=connection_config
|
||||
)
|
||||
|
||||
|
||||
def new_host(
|
||||
@ -298,6 +291,8 @@ def new_host(
|
||||
enable_mDNS: bool = False,
|
||||
bootstrap: list[str] | None = None,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
enable_quic: bool = False,
|
||||
quic_transport_opt: QUICTransportConfig | None = None,
|
||||
) -> IHost:
|
||||
"""
|
||||
Create a new libp2p host based on the given parameters.
|
||||
@ -311,19 +306,33 @@ def new_host(
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||
:param enable_quic: optinal choice to use QUIC for transport
|
||||
:param transport_opt: optional configuration for quic transport
|
||||
:return: return a host instance
|
||||
"""
|
||||
|
||||
if not enable_quic and quic_transport_opt is not None:
|
||||
logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config")
|
||||
|
||||
swarm = new_swarm(
|
||||
enable_quic=enable_quic,
|
||||
key_pair=key_pair,
|
||||
muxer_opt=muxer_opt,
|
||||
sec_opt=sec_opt,
|
||||
peerstore_opt=peerstore_opt,
|
||||
muxer_preference=muxer_preference,
|
||||
listen_addrs=listen_addrs,
|
||||
connection_config=quic_transport_opt if enable_quic else None
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap)
|
||||
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout)
|
||||
return BasicHost(
|
||||
network=swarm,
|
||||
enable_mDNS=enable_mDNS,
|
||||
bootstrap=bootstrap,
|
||||
negotitate_timeout=negotiate_timeout
|
||||
)
|
||||
|
||||
|
||||
__version__ = __version("libp2p")
|
||||
|
||||
@ -970,6 +970,14 @@ class IPeerStore(
|
||||
|
||||
# --------CERTIFIED-ADDR-BOOK----------
|
||||
|
||||
@abstractmethod
|
||||
def get_local_record(self) -> Optional["Envelope"]:
|
||||
"""Get the local-peer-record wrapped in Envelope"""
|
||||
|
||||
@abstractmethod
|
||||
def set_local_record(self, envelope: "Envelope") -> None:
|
||||
"""Set the local-peer-record wrapped in Envelope"""
|
||||
|
||||
@abstractmethod
|
||||
def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool:
|
||||
"""
|
||||
@ -1404,15 +1412,16 @@ class INetwork(ABC):
|
||||
----------
|
||||
peerstore : IPeerStore
|
||||
The peer store for managing peer information.
|
||||
connections : dict[ID, INetConn]
|
||||
A mapping of peer IDs to network connections.
|
||||
connections : dict[ID, list[INetConn]]
|
||||
A mapping of peer IDs to lists of network connections
|
||||
(multiple connections per peer).
|
||||
listeners : dict[str, IListener]
|
||||
A mapping of listener identifiers to listener instances.
|
||||
|
||||
"""
|
||||
|
||||
peerstore: IPeerStore
|
||||
connections: dict[ID, INetConn]
|
||||
connections: dict[ID, list[INetConn]]
|
||||
listeners: dict[str, IListener]
|
||||
|
||||
@abstractmethod
|
||||
@ -1428,9 +1437,56 @@ class INetwork(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Create a connection to the specified peer.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Create connections to the specified peer with load balancing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1439,8 +1495,8 @@ class INetwork(ABC):
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
The network connection instance to the specified peer.
|
||||
list[INetConn]
|
||||
List of established connections to the peer.
|
||||
|
||||
Raises
|
||||
------
|
||||
|
||||
@ -5,17 +5,17 @@ from collections.abc import (
|
||||
)
|
||||
from typing import TYPE_CHECKING, NewType, Union, cast
|
||||
|
||||
from libp2p.transport.quic.stream import QUICStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import (
|
||||
IMuxedConn,
|
||||
INetStream,
|
||||
ISecureTransport,
|
||||
)
|
||||
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
else:
|
||||
IMuxedConn = cast(type, object)
|
||||
INetStream = cast(type, object)
|
||||
ISecureTransport = cast(type, object)
|
||||
|
||||
IMuxedStream = cast(type, object)
|
||||
QUICConnection = cast(type, object)
|
||||
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
@ -37,3 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
|
||||
TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]]
|
||||
MessageID = NewType("MessageID", str)
|
||||
|
||||
@ -2,15 +2,20 @@ import logging
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.resolvers import DNSResolver
|
||||
import trio
|
||||
|
||||
from libp2p.abc import ID, INetworkService, PeerInfo
|
||||
from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses
|
||||
from libp2p.discovery.events.peerDiscovery import peerDiscovery
|
||||
from libp2p.network.exceptions import SwarmException
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.peer.peerstore import PERMANENT_ADDR_TTL
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.bootstrap")
|
||||
resolver = DNSResolver()
|
||||
|
||||
DEFAULT_CONNECTION_TIMEOUT = 10
|
||||
|
||||
|
||||
class BootstrapDiscovery:
|
||||
"""
|
||||
@ -19,68 +24,147 @@ class BootstrapDiscovery:
|
||||
"""
|
||||
|
||||
def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]):
|
||||
"""
|
||||
Initialize BootstrapDiscovery.
|
||||
|
||||
Args:
|
||||
swarm: The network service (swarm) instance
|
||||
bootstrap_addrs: List of bootstrap peer multiaddresses
|
||||
|
||||
"""
|
||||
self.swarm = swarm
|
||||
self.peerstore = swarm.peerstore
|
||||
self.bootstrap_addrs = bootstrap_addrs or []
|
||||
self.discovered_peers: set[str] = set()
|
||||
self.connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Process bootstrap addresses and emit peer discovery events."""
|
||||
logger.debug(
|
||||
"""Process bootstrap addresses and emit peer discovery events in parallel."""
|
||||
logger.info(
|
||||
f"Starting bootstrap discovery with "
|
||||
f"{len(self.bootstrap_addrs)} bootstrap addresses"
|
||||
)
|
||||
|
||||
# Show all bootstrap addresses being processed
|
||||
for i, addr in enumerate(self.bootstrap_addrs):
|
||||
logger.debug(f"{i + 1}. {addr}")
|
||||
|
||||
# Validate and filter bootstrap addresses
|
||||
self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs)
|
||||
logger.info(f"Valid addresses after validation: {len(self.bootstrap_addrs)}")
|
||||
|
||||
for addr_str in self.bootstrap_addrs:
|
||||
try:
|
||||
await self._process_bootstrap_addr(addr_str)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to process bootstrap address {addr_str}: {e}")
|
||||
# Use Trio nursery for PARALLEL address processing
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
logger.debug(
|
||||
f"Starting {len(self.bootstrap_addrs)} parallel address "
|
||||
f"processing tasks"
|
||||
)
|
||||
|
||||
# Start all bootstrap address processing tasks in parallel
|
||||
for addr_str in self.bootstrap_addrs:
|
||||
logger.debug(f"Starting parallel task for: {addr_str}")
|
||||
nursery.start_soon(self._process_bootstrap_addr, addr_str)
|
||||
|
||||
# The nursery will wait for all address processing tasks to complete
|
||||
logger.debug(
|
||||
"Nursery active - waiting for address processing tasks to complete"
|
||||
)
|
||||
|
||||
except trio.Cancelled:
|
||||
logger.debug("Bootstrap address processing cancelled - cleaning up tasks")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Bootstrap address processing failed: {e}")
|
||||
raise
|
||||
|
||||
logger.info("Bootstrap discovery startup complete - all tasks finished")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Clean up bootstrap discovery resources."""
|
||||
logger.debug("Stopping bootstrap discovery")
|
||||
logger.info("Stopping bootstrap discovery and cleaning up tasks")
|
||||
|
||||
# Clear discovered peers
|
||||
self.discovered_peers.clear()
|
||||
|
||||
logger.debug("Bootstrap discovery cleanup completed")
|
||||
|
||||
async def _process_bootstrap_addr(self, addr_str: str) -> None:
|
||||
"""Convert string address to PeerInfo and add to peerstore."""
|
||||
try:
|
||||
multiaddr = Multiaddr(addr_str)
|
||||
try:
|
||||
multiaddr = Multiaddr(addr_str)
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
|
||||
return
|
||||
|
||||
if self.is_dns_addr(multiaddr):
|
||||
resolved_addrs = await resolver.resolve(multiaddr)
|
||||
if resolved_addrs is None:
|
||||
logger.warning(f"DNS resolution returned None for: {addr_str}")
|
||||
return
|
||||
|
||||
peer_id_str = multiaddr.get_peer_id()
|
||||
if peer_id_str is None:
|
||||
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
|
||||
return
|
||||
peer_id = ID.from_base58(peer_id_str)
|
||||
addrs = [addr for addr in resolved_addrs]
|
||||
if not addrs:
|
||||
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
|
||||
return
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
await self.add_addr(peer_info)
|
||||
else:
|
||||
peer_info = info_from_p2p_addr(multiaddr)
|
||||
await self.add_addr(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Invalid multiaddr format '{addr_str}': {e}")
|
||||
return
|
||||
if self.is_dns_addr(multiaddr):
|
||||
resolved_addrs = await resolver.resolve(multiaddr)
|
||||
peer_id_str = multiaddr.get_peer_id()
|
||||
if peer_id_str is None:
|
||||
logger.warning(f"Missing peer ID in DNS address: {addr_str}")
|
||||
return
|
||||
peer_id = ID.from_base58(peer_id_str)
|
||||
addrs = [addr for addr in resolved_addrs]
|
||||
if not addrs:
|
||||
logger.warning(f"No addresses resolved for DNS address: {addr_str}")
|
||||
return
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
self.add_addr(peer_info)
|
||||
else:
|
||||
self.add_addr(info_from_p2p_addr(multiaddr))
|
||||
logger.warning(f"Failed to process bootstrap address {addr_str}: {e}")
|
||||
|
||||
def is_dns_addr(self, addr: Multiaddr) -> bool:
|
||||
"""Check if the address is a DNS address."""
|
||||
return any(protocol.name == "dnsaddr" for protocol in addr.protocols())
|
||||
|
||||
def add_addr(self, peer_info: PeerInfo) -> None:
|
||||
"""Add a peer to the peerstore and emit discovery event."""
|
||||
async def add_addr(self, peer_info: PeerInfo) -> None:
|
||||
"""
|
||||
Add a peer to the peerstore, emit discovery event,
|
||||
and attempt connection in parallel.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding peer {peer_info.peer_id} with {len(peer_info.addrs)} addresses"
|
||||
)
|
||||
|
||||
# Skip if it's our own peer
|
||||
if peer_info.peer_id == self.swarm.get_peer_id():
|
||||
logger.debug(f"Skipping own peer ID: {peer_info.peer_id}")
|
||||
return
|
||||
|
||||
# Always add addresses to peerstore (allows multiple addresses for same peer)
|
||||
self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10)
|
||||
# Filter addresses to only include IPv4+TCP (only supported protocol)
|
||||
ipv4_tcp_addrs = []
|
||||
filtered_out_addrs = []
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
if self._is_ipv4_tcp_addr(addr):
|
||||
ipv4_tcp_addrs.append(addr)
|
||||
else:
|
||||
filtered_out_addrs.append(addr)
|
||||
|
||||
# Log filtering results
|
||||
logger.debug(
|
||||
f"Address filtering for {peer_info.peer_id}: "
|
||||
f"{len(ipv4_tcp_addrs)} IPv4+TCP, {len(filtered_out_addrs)} filtered"
|
||||
)
|
||||
|
||||
# Skip peer if no IPv4+TCP addresses available
|
||||
if not ipv4_tcp_addrs:
|
||||
logger.warning(
|
||||
f"❌ No IPv4+TCP addresses for {peer_info.peer_id} - "
|
||||
f"skipping connection attempts"
|
||||
)
|
||||
return
|
||||
|
||||
# Add only IPv4+TCP addresses to peerstore
|
||||
self.peerstore.add_addrs(peer_info.peer_id, ipv4_tcp_addrs, PERMANENT_ADDR_TTL)
|
||||
|
||||
# Only emit discovery event if this is the first time we see this peer
|
||||
peer_id_str = str(peer_info.peer_id)
|
||||
@ -89,6 +173,140 @@ class BootstrapDiscovery:
|
||||
self.discovered_peers.add(peer_id_str)
|
||||
# Emit peer discovery event
|
||||
peerDiscovery.emit_peer_discovered(peer_info)
|
||||
logger.debug(f"Peer discovered: {peer_info.peer_id}")
|
||||
logger.info(f"Peer discovered: {peer_info.peer_id}")
|
||||
|
||||
# Connect to peer (parallel across different bootstrap addresses)
|
||||
logger.debug("Connecting to discovered peer...")
|
||||
await self._connect_to_peer(peer_info.peer_id)
|
||||
|
||||
else:
|
||||
logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}")
|
||||
logger.debug(
|
||||
f"Additional addresses added for existing peer: {peer_info.peer_id}"
|
||||
)
|
||||
# Even for existing peers, try to connect if not already connected
|
||||
if peer_info.peer_id not in self.swarm.connections:
|
||||
logger.debug("Connecting to existing peer...")
|
||||
await self._connect_to_peer(peer_info.peer_id)
|
||||
|
||||
async def _connect_to_peer(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Attempt to establish a connection to a peer with timeout.
|
||||
|
||||
Uses swarm.dial_peer to connect using addresses stored in peerstore.
|
||||
Times out after self.connection_timeout seconds to prevent hanging.
|
||||
"""
|
||||
logger.debug(f"Connection attempt for peer: {peer_id}")
|
||||
|
||||
# Pre-connection validation: Check if already connected
|
||||
if peer_id in self.swarm.connections:
|
||||
logger.debug(
|
||||
f"Already connected to {peer_id} - skipping connection attempt"
|
||||
)
|
||||
return
|
||||
|
||||
# Check available addresses before attempting connection
|
||||
available_addrs = self.peerstore.addrs(peer_id)
|
||||
logger.debug(f"Connecting to {peer_id} ({len(available_addrs)} addresses)")
|
||||
|
||||
if not available_addrs:
|
||||
logger.error(f"❌ No addresses available for {peer_id} - cannot connect")
|
||||
return
|
||||
|
||||
# Record start time for connection attempt monitoring
|
||||
connection_start_time = trio.current_time()
|
||||
|
||||
try:
|
||||
with trio.move_on_after(self.connection_timeout):
|
||||
# Log connection attempt
|
||||
logger.debug(
|
||||
f"Attempting connection to {peer_id} using "
|
||||
f"{len(available_addrs)} addresses"
|
||||
)
|
||||
|
||||
# Use swarm.dial_peer to connect using stored addresses
|
||||
await self.swarm.dial_peer(peer_id)
|
||||
|
||||
# Calculate connection time
|
||||
connection_time = trio.current_time() - connection_start_time
|
||||
|
||||
# Post-connection validation: Verify connection was actually established
|
||||
if peer_id in self.swarm.connections:
|
||||
logger.info(
|
||||
f"✅ Connected to {peer_id} (took {connection_time:.2f}s)"
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Dial succeeded but connection not found for {peer_id}"
|
||||
)
|
||||
except trio.TooSlowError:
|
||||
logger.warning(
|
||||
f"❌ Connection to {peer_id} timed out after {self.connection_timeout}s"
|
||||
)
|
||||
except SwarmException as e:
|
||||
# Calculate failed connection time
|
||||
failed_connection_time = trio.current_time() - connection_start_time
|
||||
|
||||
# Enhanced error logging
|
||||
error_msg = str(e)
|
||||
if "no addresses established a successful connection" in error_msg:
|
||||
logger.warning(
|
||||
f"❌ Failed to connect to {peer_id} after trying all "
|
||||
f"{len(available_addrs)} addresses "
|
||||
f"(took {failed_connection_time:.2f}s)"
|
||||
)
|
||||
# Log individual address failures if this is a MultiError
|
||||
if (
|
||||
e.__cause__ is not None
|
||||
and hasattr(e.__cause__, "exceptions")
|
||||
and getattr(e.__cause__, "exceptions", None) is not None
|
||||
):
|
||||
exceptions_list = getattr(e.__cause__, "exceptions")
|
||||
logger.debug("📋 Individual address failure details:")
|
||||
for i, addr_exception in enumerate(exceptions_list, 1):
|
||||
logger.debug(f"Address {i}: {addr_exception}")
|
||||
# Also log the actual address that failed
|
||||
if i <= len(available_addrs):
|
||||
logger.debug(f"Failed address: {available_addrs[i - 1]}")
|
||||
else:
|
||||
logger.warning("No detailed exception information available")
|
||||
else:
|
||||
logger.warning(
|
||||
f"❌ Failed to connect to {peer_id}: {e} "
|
||||
f"(took {failed_connection_time:.2f}s)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Handle unexpected errors that aren't swarm-specific
|
||||
failed_connection_time = trio.current_time() - connection_start_time
|
||||
logger.error(
|
||||
f"❌ Unexpected error connecting to {peer_id}: "
|
||||
f"{e} (took {failed_connection_time:.2f}s)"
|
||||
)
|
||||
# Don't re-raise to prevent killing the nursery and other parallel tasks
|
||||
|
||||
def _is_ipv4_tcp_addr(self, addr: Multiaddr) -> bool:
|
||||
"""
|
||||
Check if address is IPv4 with TCP protocol only.
|
||||
|
||||
Filters out IPv6, UDP, QUIC, WebSocket, and other unsupported protocols.
|
||||
Only IPv4+TCP addresses are supported by the current transport.
|
||||
"""
|
||||
try:
|
||||
protocols = addr.protocols()
|
||||
|
||||
# Must have IPv4 protocol
|
||||
has_ipv4 = any(p.name == "ip4" for p in protocols)
|
||||
if not has_ipv4:
|
||||
return False
|
||||
|
||||
# Must have TCP protocol
|
||||
has_tcp = any(p.name == "tcp" for p in protocols)
|
||||
if not has_tcp:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
# If we can't parse the address, don't use it
|
||||
return False
|
||||
|
||||
@ -43,6 +43,7 @@ from libp2p.peer.id import (
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import create_signed_peer_record
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
MultiselectError,
|
||||
@ -110,6 +111,14 @@ class BasicHost(IHost):
|
||||
if bootstrap:
|
||||
self.bootstrap = BootstrapDiscovery(network, bootstrap)
|
||||
|
||||
# Cache a signed-record if the local-node in the PeerStore
|
||||
envelope = create_signed_peer_record(
|
||||
self.get_id(),
|
||||
self.get_addrs(),
|
||||
self.get_private_key(),
|
||||
)
|
||||
self.get_peerstore().set_local_record(envelope)
|
||||
|
||||
def get_id(self) -> ID:
|
||||
"""
|
||||
:return: peer_id of host
|
||||
@ -288,6 +297,11 @@ class BasicHost(IHost):
|
||||
protocol, handler = await self.multiselect.negotiate(
|
||||
MultiselectCommunicator(net_stream), self.negotiate_timeout
|
||||
)
|
||||
if protocol is None:
|
||||
await net_stream.reset()
|
||||
raise StreamFailure(
|
||||
"Failed to negotiate protocol: no protocol selected"
|
||||
)
|
||||
except MultiselectError as error:
|
||||
peer_id = net_stream.muxed_conn.peer_id
|
||||
logger.debug(
|
||||
@ -329,7 +343,7 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to check
|
||||
:return: True if peer has an active connection, False otherwise
|
||||
"""
|
||||
return peer_id in self._network.connections
|
||||
return len(self._network.get_connections(peer_id)) > 0
|
||||
|
||||
def get_peer_connection_info(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
@ -338,4 +352,4 @@ class BasicHost(IHost):
|
||||
:param peer_id: ID of the peer to get info for
|
||||
:return: Connection object if peer is connected, None otherwise
|
||||
"""
|
||||
return self._network.connections.get(peer_id)
|
||||
return self._network.get_connection(peer_id)
|
||||
|
||||
@ -15,8 +15,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamClosed,
|
||||
)
|
||||
from libp2p.peer.envelope import seal_record
|
||||
from libp2p.peer.peer_record import PeerRecord
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
from libp2p.utils import (
|
||||
decode_varint_with_size,
|
||||
get_agent_version,
|
||||
@ -66,9 +65,7 @@ def _mk_identify_protobuf(
|
||||
protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None)
|
||||
|
||||
# Create a signed peer-record for the remote peer
|
||||
record = PeerRecord(host.get_id(), host.get_addrs())
|
||||
envelope = seal_record(record, host.get_private_key())
|
||||
protobuf = envelope.marshal_envelope()
|
||||
envelope_bytes, _ = env_to_send_in_RPC(host)
|
||||
|
||||
observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b""
|
||||
return Identify(
|
||||
@ -78,7 +75,7 @@ def _mk_identify_protobuf(
|
||||
listen_addrs=map(_multiaddr_to_bytes, laddrs),
|
||||
observed_addr=observed_addr,
|
||||
protocols=protocols,
|
||||
signedPeerRecord=protobuf,
|
||||
signedPeerRecord=envelope_bytes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -22,15 +22,18 @@ from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.kad_dht.utils import maybe_consume_signed_record
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
@ -234,6 +237,9 @@ class KadDHT(Service):
|
||||
await self.add_peer(peer_id)
|
||||
logger.debug(f"Added peer {peer_id} to routing table")
|
||||
|
||||
closer_peer_envelope: Envelope | None = None
|
||||
provider_peer_envelope: Envelope | None = None
|
||||
|
||||
try:
|
||||
# Read varint-prefixed length for the message
|
||||
length_prefix = b""
|
||||
@ -274,6 +280,14 @@ class KadDHT(Service):
|
||||
)
|
||||
logger.debug(f"Found {len(closest_peers)} peers close to target")
|
||||
|
||||
# Consume the source signed_peer_record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Build response message with protobuf
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
@ -298,6 +312,21 @@ class KadDHT(Service):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add the signed-peer-record for each peer in the peer-proto
|
||||
# if cached in the peerstore
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer)
|
||||
)
|
||||
|
||||
if closer_peer_envelope is not None:
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Create sender_signed_peer_record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
@ -312,6 +341,14 @@ class KadDHT(Service):
|
||||
key = message.key
|
||||
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
|
||||
|
||||
# Consume the source signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Extract provider information
|
||||
for provider_proto in message.providerPeers:
|
||||
try:
|
||||
@ -338,6 +375,17 @@ class KadDHT(Service):
|
||||
logger.debug(
|
||||
f"Added provider {provider_id} for key {key.hex()}"
|
||||
)
|
||||
|
||||
# Process the signed-records of provider if sent
|
||||
if not maybe_consume_signed_record(
|
||||
provider_proto, self.host
|
||||
):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record,"
|
||||
"dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process provider info: {e}")
|
||||
|
||||
@ -346,6 +394,10 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.ADD_PROVIDER
|
||||
response.key = key
|
||||
|
||||
# Add sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
@ -357,6 +409,14 @@ class KadDHT(Service):
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
|
||||
|
||||
# Consume the source signed_peer_record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
# Find providers for the key
|
||||
providers = self.provider_store.get_providers(key)
|
||||
logger.debug(
|
||||
@ -368,12 +428,28 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.GET_PROVIDERS
|
||||
response.key = key
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Add provider information to response
|
||||
for provider_info in providers:
|
||||
provider_proto = response.providerPeers.add()
|
||||
provider_proto.id = provider_info.peer_id.to_bytes()
|
||||
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add provider signed-records if cached
|
||||
provider_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(
|
||||
provider_info.peer_id
|
||||
)
|
||||
)
|
||||
|
||||
if provider_peer_envelope is not None:
|
||||
provider_proto.signedRecord = (
|
||||
provider_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
for addr in provider_info.addrs:
|
||||
provider_proto.addrs.append(addr.to_bytes())
|
||||
@ -397,6 +473,16 @@ class KadDHT(Service):
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add the signed-records of closest_peers if cached
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer)
|
||||
)
|
||||
|
||||
if closer_peer_envelope is not None:
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
@ -417,6 +503,14 @@ class KadDHT(Service):
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
|
||||
|
||||
# Consume the sender_signed_peer_record
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
value = self.value_store.get(key)
|
||||
if value:
|
||||
logger.debug(f"Found value for key {key.hex()}")
|
||||
@ -431,6 +525,10 @@ class KadDHT(Service):
|
||||
response.record.value = value
|
||||
response.record.timeReceived = str(time.time())
|
||||
|
||||
# Create sender_signed_peer_record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
@ -444,6 +542,10 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.GET_VALUE
|
||||
response.key = key
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Add closest peers to key
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
key, 20
|
||||
@ -462,6 +564,16 @@ class KadDHT(Service):
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add signed-records of closer-peers if cached
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer)
|
||||
)
|
||||
|
||||
if closer_peer_envelope is not None:
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
@ -484,6 +596,15 @@ class KadDHT(Service):
|
||||
key = message.record.key
|
||||
value = message.record.value
|
||||
success = False
|
||||
|
||||
# Consume the source signed_peer_record if sent
|
||||
if not maybe_consume_signed_record(message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
try:
|
||||
if not (key and value):
|
||||
raise ValueError(
|
||||
@ -504,6 +625,12 @@ class KadDHT(Service):
|
||||
response.type = Message.MessageType.PUT_VALUE
|
||||
if success:
|
||||
response.key = key
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
|
||||
@ -27,6 +27,7 @@ message Message {
|
||||
bytes id = 1;
|
||||
repeated bytes addrs = 2;
|
||||
ConnectionType connection = 3;
|
||||
optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded
|
||||
}
|
||||
|
||||
MessageType type = 1;
|
||||
@ -35,4 +36,6 @@ message Message {
|
||||
Record record = 3;
|
||||
repeated Peer closerPeers = 8;
|
||||
repeated Peer providerPeers = 9;
|
||||
|
||||
optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
|
||||
}
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/kad_dht/pb/kademlia.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals())
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_RECORD._serialized_start=36
|
||||
_RECORD._serialized_end=94
|
||||
_MESSAGE._serialized_start=97
|
||||
_MESSAGE._serialized_end=555
|
||||
_MESSAGE_PEER._serialized_start=281
|
||||
_MESSAGE_PEER._serialized_end=359
|
||||
_MESSAGE_MESSAGETYPE._serialized_start=361
|
||||
_MESSAGE_MESSAGETYPE._serialized_end=466
|
||||
_MESSAGE_CONNECTIONTYPE._serialized_start=468
|
||||
_MESSAGE_CONNECTIONTYPE._serialized_end=555
|
||||
_globals['_RECORD']._serialized_start=36
|
||||
_globals['_RECORD']._serialized_end=94
|
||||
_globals['_MESSAGE']._serialized_start=97
|
||||
_globals['_MESSAGE']._serialized_end=643
|
||||
_globals['_MESSAGE_PEER']._serialized_start=308
|
||||
_globals['_MESSAGE_PEER']._serialized_end=430
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=432
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=537
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -1,133 +1,70 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
class Record(_message.Message):
|
||||
__slots__ = ("key", "value", "timeReceived")
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
VALUE_FIELD_NUMBER: _ClassVar[int]
|
||||
TIMERECEIVED_FIELD_NUMBER: _ClassVar[int]
|
||||
key: bytes
|
||||
value: bytes
|
||||
timeReceived: str
|
||||
def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ...
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class Record(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
TIMERECEIVED_FIELD_NUMBER: builtins.int
|
||||
key: builtins.bytes
|
||||
value: builtins.bytes
|
||||
timeReceived: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.bytes = ...,
|
||||
value: builtins.bytes = ...,
|
||||
timeReceived: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
|
||||
|
||||
global___Record = Record
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _MessageType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
PUT_VALUE: Message._MessageType.ValueType # 0
|
||||
GET_VALUE: Message._MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message._MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message._MessageType.ValueType # 3
|
||||
FIND_NODE: Message._MessageType.ValueType # 4
|
||||
PING: Message._MessageType.ValueType # 5
|
||||
|
||||
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
|
||||
PUT_VALUE: Message.MessageType.ValueType # 0
|
||||
GET_VALUE: Message.MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message.MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message.MessageType.ValueType # 3
|
||||
FIND_NODE: Message.MessageType.ValueType # 4
|
||||
PING: Message.MessageType.ValueType # 5
|
||||
|
||||
class _ConnectionType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
|
||||
CONNECTED: Message._ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message._ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
|
||||
|
||||
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
|
||||
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
|
||||
CONNECTED: Message.ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message.ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
|
||||
|
||||
@typing.final
|
||||
class Peer(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
ADDRS_FIELD_NUMBER: builtins.int
|
||||
CONNECTION_FIELD_NUMBER: builtins.int
|
||||
id: builtins.bytes
|
||||
connection: global___Message.ConnectionType.ValueType
|
||||
@property
|
||||
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.bytes = ...,
|
||||
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
connection: global___Message.ConnectionType.ValueType = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
RECORD_FIELD_NUMBER: builtins.int
|
||||
CLOSERPEERS_FIELD_NUMBER: builtins.int
|
||||
PROVIDERPEERS_FIELD_NUMBER: builtins.int
|
||||
type: global___Message.MessageType.ValueType
|
||||
clusterLevelRaw: builtins.int
|
||||
key: builtins.bytes
|
||||
@property
|
||||
def record(self) -> global___Record: ...
|
||||
@property
|
||||
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
@property
|
||||
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___Message.MessageType.ValueType = ...,
|
||||
clusterLevelRaw: builtins.int = ...,
|
||||
key: builtins.bytes = ...,
|
||||
record: global___Record | None = ...,
|
||||
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
|
||||
|
||||
global___Message = Message
|
||||
class Message(_message.Message):
|
||||
__slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord")
|
||||
class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
PUT_VALUE: _ClassVar[Message.MessageType]
|
||||
GET_VALUE: _ClassVar[Message.MessageType]
|
||||
ADD_PROVIDER: _ClassVar[Message.MessageType]
|
||||
GET_PROVIDERS: _ClassVar[Message.MessageType]
|
||||
FIND_NODE: _ClassVar[Message.MessageType]
|
||||
PING: _ClassVar[Message.MessageType]
|
||||
PUT_VALUE: Message.MessageType
|
||||
GET_VALUE: Message.MessageType
|
||||
ADD_PROVIDER: Message.MessageType
|
||||
GET_PROVIDERS: Message.MessageType
|
||||
FIND_NODE: Message.MessageType
|
||||
PING: Message.MessageType
|
||||
class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
NOT_CONNECTED: _ClassVar[Message.ConnectionType]
|
||||
CONNECTED: _ClassVar[Message.ConnectionType]
|
||||
CAN_CONNECT: _ClassVar[Message.ConnectionType]
|
||||
CANNOT_CONNECT: _ClassVar[Message.ConnectionType]
|
||||
NOT_CONNECTED: Message.ConnectionType
|
||||
CONNECTED: Message.ConnectionType
|
||||
CAN_CONNECT: Message.ConnectionType
|
||||
CANNOT_CONNECT: Message.ConnectionType
|
||||
class Peer(_message.Message):
|
||||
__slots__ = ("id", "addrs", "connection", "signedRecord")
|
||||
ID_FIELD_NUMBER: _ClassVar[int]
|
||||
ADDRS_FIELD_NUMBER: _ClassVar[int]
|
||||
CONNECTION_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
id: bytes
|
||||
addrs: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
connection: Message.ConnectionType
|
||||
signedRecord: bytes
|
||||
def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ...
|
||||
TYPE_FIELD_NUMBER: _ClassVar[int]
|
||||
CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int]
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
RECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
CLOSERPEERS_FIELD_NUMBER: _ClassVar[int]
|
||||
PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int]
|
||||
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
type: Message.MessageType
|
||||
clusterLevelRaw: int
|
||||
key: bytes
|
||||
record: Record
|
||||
closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
|
||||
providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer]
|
||||
senderRecord: bytes
|
||||
def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore
|
||||
|
||||
@ -15,12 +15,14 @@ from libp2p.abc import (
|
||||
INetStream,
|
||||
IPeerRouting,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .common import (
|
||||
ALPHA,
|
||||
@ -33,6 +35,7 @@ from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .utils import (
|
||||
maybe_consume_signed_record,
|
||||
sort_peer_ids_by_distance,
|
||||
)
|
||||
|
||||
@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting):
|
||||
find_node_msg.type = Message.MessageType.FIND_NODE
|
||||
find_node_msg.key = target_key # Set target key directly as bytes
|
||||
|
||||
# Create sender_signed_peer_record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
find_node_msg.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the protobuf message with varint length prefix
|
||||
proto_bytes = find_node_msg.SerializeToString()
|
||||
logger.debug(
|
||||
@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Process closest peers from response
|
||||
if response_msg.type == Message.MessageType.FIND_NODE:
|
||||
# Consume the sender_signed_peer_record
|
||||
if not maybe_consume_signed_record(response_msg, self.host, peer):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record,ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
for peer_data in response_msg.closerPeers:
|
||||
# Consume the received closer_peers signed-records, peer-id is
|
||||
# sent with the peer-data
|
||||
if not maybe_consume_signed_record(peer_data, self.host):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record,ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
new_peer_id = ID(peer_data.id)
|
||||
if new_peer_id not in results:
|
||||
results.append(new_peer_id)
|
||||
@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting):
|
||||
"""
|
||||
try:
|
||||
# Read message length
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
length_bytes = await stream.read(4)
|
||||
if not length_bytes:
|
||||
return
|
||||
@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Parse protobuf message
|
||||
kad_message = Message()
|
||||
closer_peer_envelope: Envelope | None = None
|
||||
try:
|
||||
kad_message.ParseFromString(message_bytes)
|
||||
|
||||
if kad_message.type == Message.MessageType.FIND_NODE:
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(kad_message, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, dropping the stream"
|
||||
)
|
||||
return
|
||||
|
||||
# Get target key directly from protobuf message
|
||||
target_key = kad_message.key
|
||||
|
||||
@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting):
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Create sender_signed_peer_record for the response
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
response.senderRecord = envelope_bytes
|
||||
|
||||
# Add peer information to response
|
||||
for peer_id in closest_peers:
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer_id.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add the signed-records of closest_peers if cached
|
||||
closer_peer_envelope = (
|
||||
self.host.get_peerstore().get_peer_record(peer_id)
|
||||
)
|
||||
|
||||
if isinstance(closer_peer_envelope, Envelope):
|
||||
peer_proto.signedRecord = (
|
||||
closer_peer_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
|
||||
@ -22,12 +22,14 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.kad_dht.utils import maybe_consume_signed_record
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .common import (
|
||||
ALPHA,
|
||||
@ -240,11 +242,18 @@ class ProviderStore:
|
||||
message.type = Message.MessageType.ADD_PROVIDER
|
||||
message.key = key
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Add our provider info
|
||||
provider = message.providerPeers.add()
|
||||
provider.id = self.local_peer_id.to_bytes()
|
||||
provider.addrs.extend(addrs)
|
||||
|
||||
# Add the provider's signed-peer-record
|
||||
provider.signedRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
@ -276,10 +285,15 @@ class ProviderStore:
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check response type
|
||||
response.type == Message.MessageType.ADD_PROVIDER
|
||||
if response.type:
|
||||
result = True
|
||||
if response.type == Message.MessageType.ADD_PROVIDER:
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
result = False
|
||||
else:
|
||||
result = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
|
||||
@ -380,6 +394,10 @@ class ProviderStore:
|
||||
message.type = Message.MessageType.GET_PROVIDERS
|
||||
message.key = key
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
@ -414,10 +432,26 @@ class ProviderStore:
|
||||
if response.type != Message.MessageType.GET_PROVIDERS:
|
||||
return []
|
||||
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
# Extract provider information
|
||||
providers = []
|
||||
for provider_proto in response.providerPeers:
|
||||
try:
|
||||
# Consume the provider's signed-peer-record if sent, peer-id
|
||||
# already sent with the provider-proto
|
||||
if not maybe_consume_signed_record(provider_proto, self.host):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, "
|
||||
"ignoring the response"
|
||||
)
|
||||
return []
|
||||
|
||||
# Create peer ID from bytes
|
||||
provider_id = ID(provider_proto.id)
|
||||
|
||||
@ -431,6 +465,7 @@ class ProviderStore:
|
||||
|
||||
# Create PeerInfo and add to result
|
||||
providers.append(PeerInfo(provider_id, addrs))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse provider info: {e}")
|
||||
|
||||
|
||||
@ -2,13 +2,93 @@
|
||||
Utility functions for Kademlia DHT implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import base58
|
||||
import multihash
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.peer.envelope import consume_envelope
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("kademlia-example.utils")
|
||||
|
||||
|
||||
def maybe_consume_signed_record(
|
||||
msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to parse and store a signed-peer-record (Envelope) received during
|
||||
DHT communication. If the record is invalid, the peer-id does not match, or
|
||||
updating the peerstore fails, the function logs an error and returns False.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : Message | Message.Peer
|
||||
The protobuf message received during DHT communication. Can either be a
|
||||
top-level `Message` containing `senderRecord` or a `Message.Peer`
|
||||
containing `signedRecord`.
|
||||
host : IHost
|
||||
The local host instance, providing access to the peerstore for storing
|
||||
verified peer records.
|
||||
peer_id : ID | None, optional
|
||||
The expected peer ID for record validation. If provided, the peer ID
|
||||
inside the record must match this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if a valid signed peer record was successfully consumed and stored,
|
||||
False otherwise.
|
||||
|
||||
"""
|
||||
if isinstance(msg, Message):
|
||||
if msg.HasField("senderRecord"):
|
||||
try:
|
||||
# Convert the signed-peer-record(Envelope) from
|
||||
# protobuf bytes
|
||||
envelope, record = consume_envelope(
|
||||
msg.senderRecord,
|
||||
"libp2p-peer-record",
|
||||
)
|
||||
if not (isinstance(peer_id, ID) and record.peer_id == peer_id):
|
||||
return False
|
||||
# Use the default TTL of 2 hours (7200 seconds)
|
||||
if not host.get_peerstore().consume_peer_record(envelope, 7200):
|
||||
logger.error("Failed to update the Certified-Addr-Book")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||
return False
|
||||
else:
|
||||
if msg.HasField("signedRecord"):
|
||||
try:
|
||||
# Convert the signed-peer-record(Envelope) from
|
||||
# protobuf bytes
|
||||
envelope, record = consume_envelope(
|
||||
msg.signedRecord,
|
||||
"libp2p-peer-record",
|
||||
)
|
||||
if not record.peer_id.to_bytes() == msg.id:
|
||||
return False
|
||||
# Use the default TTL of 2 hours (7200 seconds)
|
||||
if not host.get_peerstore().consume_peer_record(envelope, 7200):
|
||||
logger.error("Failed to update the Certified-Addr-Book")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to update the Certified-Addr-Book: %s",
|
||||
e,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def create_key_from_binary(binary_data: bytes) -> bytes:
|
||||
"""
|
||||
|
||||
@ -15,9 +15,11 @@ from libp2p.abc import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.kad_dht.utils import maybe_consume_signed_record
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .common import (
|
||||
DEFAULT_TTL,
|
||||
@ -110,6 +112,10 @@ class ValueStore:
|
||||
message = Message()
|
||||
message.type = Message.MessageType.PUT_VALUE
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Set message fields
|
||||
message.key = key
|
||||
message.record.key = key
|
||||
@ -155,7 +161,13 @@ class ValueStore:
|
||||
|
||||
# Check if response is valid
|
||||
if response.type == Message.MessageType.PUT_VALUE:
|
||||
if response.key:
|
||||
# Consume the sender's signed-peer-record if sent
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
return False
|
||||
if response.key == key:
|
||||
result = True
|
||||
return result
|
||||
|
||||
@ -231,6 +243,10 @@ class ValueStore:
|
||||
message.type = Message.MessageType.GET_VALUE
|
||||
message.key = key
|
||||
|
||||
# Create sender's signed-peer-record
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
message.senderRecord = envelope_bytes
|
||||
|
||||
# Serialize and send the protobuf message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
@ -275,6 +291,13 @@ class ValueStore:
|
||||
and response.HasField("record")
|
||||
and response.record.value
|
||||
):
|
||||
# Consume the sender's signed-peer-record
|
||||
if not maybe_consume_signed_record(response, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the response"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
f"Received value for key {key.hex()} from peer {peer_id}"
|
||||
)
|
||||
|
||||
70
libp2p/network/config.py
Normal file
70
libp2p/network/config.py
Normal file
@ -0,0 +1,70 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""
|
||||
Configuration for retry logic with exponential backoff.
|
||||
|
||||
This configuration controls how connection attempts are retried when they fail.
|
||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||
herd problems in distributed systems.
|
||||
|
||||
Attributes:
|
||||
max_retries: Maximum number of retry attempts before giving up.
|
||||
Default: 3 attempts
|
||||
initial_delay: Initial delay in seconds before the first retry.
|
||||
Default: 0.1 seconds (100ms)
|
||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||
Default: 30.0 seconds
|
||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||
the delay by this factor). Default: 2.0 (doubles each time)
|
||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""
|
||||
Configuration for multi-connection support.
|
||||
|
||||
This configuration controls how multiple connections per peer are managed,
|
||||
including connection limits, timeouts, and load balancing strategies.
|
||||
|
||||
Attributes:
|
||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||
peer. Default: 3 connections
|
||||
connection_timeout: Timeout in seconds for establishing new connections.
|
||||
Default: 30.0 seconds
|
||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||
Options: "round_robin" (default) or "least_loaded"
|
||||
|
||||
"""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration after initialization."""
|
||||
if not (
|
||||
self.load_balancing_strategy == "round_robin"
|
||||
or self.load_balancing_strategy == "least_loaded"
|
||||
):
|
||||
raise ValueError(
|
||||
"Load balancing strategy can only be 'round_robin' or 'least_loaded'"
|
||||
)
|
||||
|
||||
if self.max_connections_per_peer < 1:
|
||||
raise ValueError("Max connection per peer should be atleast 1")
|
||||
|
||||
if self.connection_timeout < 0:
|
||||
raise ValueError("Connection timeout should be positive")
|
||||
@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamError,
|
||||
MuxedStreamReset,
|
||||
)
|
||||
from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError
|
||||
|
||||
from .exceptions import (
|
||||
StreamClosed,
|
||||
@ -170,7 +171,7 @@ class NetStream(INetStream):
|
||||
elif self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_READ
|
||||
raise StreamEOF() from error
|
||||
except MuxedStreamReset as error:
|
||||
except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.OPEN,
|
||||
@ -199,7 +200,12 @@ class NetStream(INetStream):
|
||||
|
||||
try:
|
||||
await self.muxed_stream.write(data)
|
||||
except (MuxedStreamClosed, MuxedStreamError) as error:
|
||||
except (
|
||||
MuxedStreamClosed,
|
||||
MuxedStreamError,
|
||||
QUICStreamClosedError,
|
||||
QUICStreamResetError,
|
||||
) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_WRITE
|
||||
|
||||
@ -3,6 +3,8 @@ from collections.abc import (
|
||||
Callable,
|
||||
)
|
||||
import logging
|
||||
import random
|
||||
from typing import cast
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -25,6 +27,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
)
|
||||
from libp2p.network.config import ConnectionConfig, RetryConfig
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -39,6 +42,9 @@ from libp2p.transport.exceptions import (
|
||||
OpenConnectionError,
|
||||
SecurityUpgradeFailure,
|
||||
)
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
@ -71,9 +77,7 @@ class Swarm(Service, INetworkService):
|
||||
peerstore: IPeerStore
|
||||
upgrader: TransportUpgrader
|
||||
transport: ITransport
|
||||
# TODO: Connection and `peer_id` are 1-1 mapping in our implementation,
|
||||
# whereas in Go one `peer_id` may point to multiple connections.
|
||||
connections: dict[ID, INetConn]
|
||||
connections: dict[ID, list[INetConn]]
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
@ -81,18 +85,31 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
notifees: list[INotifee]
|
||||
|
||||
# Enhanced: New configuration
|
||||
retry_config: RetryConfig
|
||||
connection_config: ConnectionConfig | QUICTransportConfig
|
||||
_round_robin_index: dict[ID, int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
peer_id: ID,
|
||||
peerstore: IPeerStore,
|
||||
upgrader: TransportUpgrader,
|
||||
transport: ITransport,
|
||||
retry_config: RetryConfig | None = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
):
|
||||
self.self_id = peer_id
|
||||
self.peerstore = peerstore
|
||||
self.upgrader = upgrader
|
||||
self.transport = transport
|
||||
self.connections = dict()
|
||||
|
||||
# Enhanced: Initialize retry and connection configuration
|
||||
self.retry_config = retry_config or RetryConfig()
|
||||
self.connection_config = connection_config or ConnectionConfig()
|
||||
|
||||
# Enhanced: Initialize connections as 1:many mapping
|
||||
self.connections = {}
|
||||
self.listeners = dict()
|
||||
|
||||
# Create Notifee array
|
||||
@ -103,11 +120,19 @@ class Swarm(Service, INetworkService):
|
||||
self.listener_nursery = None
|
||||
self.event_listener_nursery_created = trio.Event()
|
||||
|
||||
# Load balancing state
|
||||
self._round_robin_index = {}
|
||||
|
||||
async def run(self) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Create a nursery for listener tasks.
|
||||
self.listener_nursery = nursery
|
||||
self.event_listener_nursery_created.set()
|
||||
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
self.transport.set_background_nursery(nursery)
|
||||
self.transport.set_swarm(self)
|
||||
|
||||
try:
|
||||
await self.manager.wait_finished()
|
||||
finally:
|
||||
@ -122,18 +147,74 @@ class Swarm(Service, INetworkService):
|
||||
def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None:
|
||||
self.common_stream_handler = stream_handler
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> INetConn:
|
||||
def get_connections(self, peer_id: ID | None = None) -> list[INetConn]:
|
||||
"""
|
||||
Try to create a connection to peer_id.
|
||||
Get connections for peer (like JS getConnections, Go ConnsToPeer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID | None
|
||||
The peer ID to get connections for. If None, returns all connections.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[INetConn]
|
||||
List of connections to the specified peer, or all connections
|
||||
if peer_id is None.
|
||||
|
||||
"""
|
||||
if peer_id is not None:
|
||||
return self.connections.get(peer_id, [])
|
||||
|
||||
# Return all connections from all peers
|
||||
all_conns = []
|
||||
for conns in self.connections.values():
|
||||
all_conns.extend(conns)
|
||||
return all_conns
|
||||
|
||||
def get_connections_map(self) -> dict[ID, list[INetConn]]:
|
||||
"""
|
||||
Get all connections map (like JS getConnectionsMap).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, list[INetConn]]
|
||||
The complete mapping of peer IDs to their connection lists.
|
||||
|
||||
"""
|
||||
return self.connections.copy()
|
||||
|
||||
def get_connection(self, peer_id: ID) -> INetConn | None:
|
||||
"""
|
||||
Get single connection for backward compatibility.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to get a connection for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn | None
|
||||
The first available connection, or None if no connections exist.
|
||||
|
||||
"""
|
||||
conns = self.get_connections(peer_id)
|
||||
return conns[0] if conns else None
|
||||
|
||||
async def dial_peer(self, peer_id: ID) -> list[INetConn]:
|
||||
"""
|
||||
Try to create connections to peer_id with enhanced retry logic.
|
||||
|
||||
:param peer_id: peer if we want to dial
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: muxed connection
|
||||
:return: list of muxed connections
|
||||
"""
|
||||
if peer_id in self.connections:
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
return self.connections[peer_id]
|
||||
# Check if we already have connections
|
||||
existing_connections = self.get_connections(peer_id)
|
||||
if existing_connections:
|
||||
logger.debug(f"Reusing existing connections to peer {peer_id}")
|
||||
return existing_connections
|
||||
|
||||
logger.debug("attempting to dial peer %s", peer_id)
|
||||
|
||||
@ -146,12 +227,19 @@ class Swarm(Service, INetworkService):
|
||||
if not addrs:
|
||||
raise SwarmException(f"No known addresses to peer {peer_id}")
|
||||
|
||||
connections = []
|
||||
exceptions: list[SwarmException] = []
|
||||
|
||||
# Try all known addresses
|
||||
# Enhanced: Try all known addresses with retry logic
|
||||
for multiaddr in addrs:
|
||||
try:
|
||||
return await self.dial_addr(multiaddr, peer_id)
|
||||
connection = await self._dial_with_retry(multiaddr, peer_id)
|
||||
connections.append(connection)
|
||||
|
||||
# Limit number of connections per peer
|
||||
if len(connections) >= self.connection_config.max_connections_per_peer:
|
||||
break
|
||||
|
||||
except SwarmException as e:
|
||||
exceptions.append(e)
|
||||
logger.debug(
|
||||
@ -161,15 +249,73 @@ class Swarm(Service, INetworkService):
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a successful "
|
||||
"connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
if not connections:
|
||||
# Tried all addresses, raising exception.
|
||||
raise SwarmException(
|
||||
f"unable to connect to {peer_id}, no addresses established a "
|
||||
"successful connection (with exceptions)"
|
||||
) from MultiError(exceptions)
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
return connections
|
||||
|
||||
async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Try to create a connection to peer_id with addr.
|
||||
Enhanced: Dial with retry logic and exponential backoff.
|
||||
|
||||
:param addr: the address to dial
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when all retry attempts fail
|
||||
:return: network connection
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.retry_config.max_retries + 1):
|
||||
try:
|
||||
return await self._dial_addr_single_attempt(addr, peer_id)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < self.retry_config.max_retries:
|
||||
delay = self._calculate_backoff_delay(attempt)
|
||||
logger.debug(
|
||||
f"Connection attempt {attempt + 1} failed, "
|
||||
f"retrying in {delay:.2f}s: {e}"
|
||||
)
|
||||
await trio.sleep(delay)
|
||||
else:
|
||||
logger.debug(f"All {self.retry_config.max_retries} attempts failed")
|
||||
|
||||
# Convert the last exception to SwarmException for consistency
|
||||
if last_exception is not None:
|
||||
if isinstance(last_exception, SwarmException):
|
||||
raise last_exception
|
||||
else:
|
||||
raise SwarmException(
|
||||
f"Failed to connect after {self.retry_config.max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
# This should never be reached, but mypy requires it
|
||||
raise SwarmException("Unexpected error in retry logic")
|
||||
|
||||
def _calculate_backoff_delay(self, attempt: int) -> float:
|
||||
"""
|
||||
Enhanced: Calculate backoff delay with jitter to prevent thundering herd.
|
||||
|
||||
:param attempt: the current attempt number (0-based)
|
||||
:return: delay in seconds
|
||||
"""
|
||||
delay = min(
|
||||
self.retry_config.initial_delay
|
||||
* (self.retry_config.backoff_multiplier**attempt),
|
||||
self.retry_config.max_delay,
|
||||
)
|
||||
|
||||
# Add jitter to prevent synchronized retries
|
||||
jitter = delay * self.retry_config.jitter_factor
|
||||
return delay + random.uniform(-jitter, jitter)
|
||||
|
||||
async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Single attempt to dial an address (extracted from original dial_addr).
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
@ -179,6 +325,7 @@ class Swarm(Service, INetworkService):
|
||||
# Dial peer (connection to peer does not yet exist)
|
||||
# Transport dials peer (gets back a raw conn)
|
||||
try:
|
||||
addr = Multiaddr(f"{addr}/p2p/{peer_id}")
|
||||
raw_conn = await self.transport.dial(addr)
|
||||
except OpenConnectionError as error:
|
||||
logger.debug("fail to dial peer %s over base transport", peer_id)
|
||||
@ -186,6 +333,15 @@ class Swarm(Service, INetworkService):
|
||||
f"fail to open connection to peer {peer_id}"
|
||||
) from error
|
||||
|
||||
if isinstance(self.transport, QUICTransport) and isinstance(
|
||||
raw_conn, IMuxedConn
|
||||
):
|
||||
logger.info(
|
||||
"Skipping upgrade for QUIC, QUIC connections are already multiplexed"
|
||||
)
|
||||
swarm_conn = await self.add_conn(raw_conn)
|
||||
return swarm_conn
|
||||
|
||||
logger.debug("dialed peer %s over base transport", peer_id)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
@ -211,24 +367,103 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("upgraded mux for peer %s", peer_id)
|
||||
|
||||
swarm_conn = await self.add_conn(muxed_conn)
|
||||
|
||||
logger.debug("successfully dialed peer %s", peer_id)
|
||||
|
||||
return swarm_conn
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Enhanced: Try to create a connection to peer_id with addr using retry logic.
|
||||
|
||||
:param addr: the address we want to connect with
|
||||
:param peer_id: the peer we want to connect to
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: network connection
|
||||
"""
|
||||
return await self._dial_with_retry(addr, peer_id)
|
||||
|
||||
async def new_stream(self, peer_id: ID) -> INetStream:
|
||||
"""
|
||||
Enhanced: Create a new stream with load balancing across multiple connections.
|
||||
|
||||
:param peer_id: peer_id of destination
|
||||
:raises SwarmException: raised when an error occurs
|
||||
:return: net stream instance
|
||||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
# Get existing connections or dial new ones
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
connections = await self.dial_peer(peer_id)
|
||||
|
||||
swarm_conn = await self.dial_peer(peer_id)
|
||||
# Load balancing strategy at interface level
|
||||
connection = self._select_connection(connections, peer_id)
|
||||
|
||||
net_stream = await swarm_conn.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
if isinstance(self.transport, QUICTransport) and connection is not None:
|
||||
conn = cast(SwarmConn, connection)
|
||||
return await conn.new_stream()
|
||||
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create stream on connection: {e}")
|
||||
# Try other connections if available
|
||||
for other_conn in connections:
|
||||
if other_conn != connection:
|
||||
try:
|
||||
net_stream = await other_conn.new_stream()
|
||||
logger.debug(
|
||||
f"Successfully opened a stream to peer {peer_id} "
|
||||
"using alternative connection"
|
||||
)
|
||||
return net_stream
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# All connections failed, raise exception
|
||||
raise SwarmException(f"Failed to create stream to peer {peer_id}") from e
|
||||
|
||||
def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn:
|
||||
"""
|
||||
Select connection based on load balancing strategy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
connections : list[INetConn]
|
||||
List of available connections.
|
||||
peer_id : ID
|
||||
The peer ID for round-robin tracking.
|
||||
strategy : str
|
||||
Load balancing strategy ("round_robin", "least_loaded", etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
INetConn
|
||||
Selected connection.
|
||||
|
||||
"""
|
||||
if not connections:
|
||||
raise ValueError("No connections available")
|
||||
|
||||
strategy = self.connection_config.load_balancing_strategy
|
||||
|
||||
if strategy == "round_robin":
|
||||
# Simple round-robin selection
|
||||
if peer_id not in self._round_robin_index:
|
||||
self._round_robin_index[peer_id] = 0
|
||||
|
||||
index = self._round_robin_index[peer_id] % len(connections)
|
||||
self._round_robin_index[peer_id] += 1
|
||||
return connections[index]
|
||||
|
||||
elif strategy == "least_loaded":
|
||||
# Find connection with least streams
|
||||
return min(connections, key=lambda c: len(c.get_streams()))
|
||||
|
||||
else:
|
||||
# Default to first connection
|
||||
return connections[0]
|
||||
|
||||
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||
"""
|
||||
@ -248,17 +483,35 @@ class Swarm(Service, INetworkService):
|
||||
"""
|
||||
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
logger.debug("Starting to listen")
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
for maddr in multiaddrs:
|
||||
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
|
||||
if str(maddr) in self.listeners:
|
||||
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
|
||||
return True
|
||||
success_count += 1
|
||||
continue
|
||||
|
||||
async def conn_handler(
|
||||
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
||||
) -> None:
|
||||
# No need to upgrade QUIC Connection
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
try:
|
||||
quic_conn = cast(QUICConnection, read_write_closer)
|
||||
await self.add_conn(quic_conn)
|
||||
peer_id = quic_conn.peer_id
|
||||
logger.debug(
|
||||
f"successfully opened quic connection to peer {peer_id}"
|
||||
)
|
||||
# NOTE: This is a intentional barrier to prevent from the
|
||||
# handler exiting and closing the connection.
|
||||
await self.manager.wait_finished()
|
||||
except Exception:
|
||||
await read_write_closer.close()
|
||||
return
|
||||
|
||||
raw_conn = RawConnection(read_write_closer, False)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
|
||||
@ -309,13 +562,14 @@ class Swarm(Service, INetworkService):
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_listen(maddr)
|
||||
|
||||
return True
|
||||
success_count += 1
|
||||
logger.debug("successfully started listening on: %s", maddr)
|
||||
except OSError:
|
||||
# Failed. Continue looping.
|
||||
logger.debug("fail to listen on: %s", maddr)
|
||||
|
||||
# No maddr succeeded
|
||||
return False
|
||||
# Return true if at least one address succeeded
|
||||
return success_count > 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
@ -328,9 +582,9 @@ class Swarm(Service, INetworkService):
|
||||
# Perform alternative cleanup if the manager isn't initialized
|
||||
# Close all connections manually
|
||||
if hasattr(self, "connections"):
|
||||
for conn_id in list(self.connections.keys()):
|
||||
conn = self.connections[conn_id]
|
||||
await conn.close()
|
||||
for peer_id, conns in list(self.connections.items()):
|
||||
for conn in conns:
|
||||
await conn.close()
|
||||
|
||||
# Clear connection tracking dictionary
|
||||
self.connections.clear()
|
||||
@ -360,12 +614,28 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("swarm successfully closed")
|
||||
|
||||
async def close_peer(self, peer_id: ID) -> None:
|
||||
if peer_id not in self.connections:
|
||||
"""
|
||||
Close all connections to the specified peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to close connections for.
|
||||
|
||||
"""
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
return
|
||||
connection = self.connections[peer_id]
|
||||
# NOTE: `connection.close` will delete `peer_id` from `self.connections`
|
||||
# and `notify_disconnected` for us.
|
||||
await connection.close()
|
||||
|
||||
# Close all connections
|
||||
for connection in connections:
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection to {peer_id}: {e}")
|
||||
|
||||
# Remove from connections dict
|
||||
self.connections.pop(peer_id, None)
|
||||
|
||||
logger.debug("successfully close the connection to peer %s", peer_id)
|
||||
|
||||
@ -379,26 +649,77 @@ class Swarm(Service, INetworkService):
|
||||
muxed_conn,
|
||||
self,
|
||||
)
|
||||
|
||||
logger.debug("Swarm::add_conn | starting muxed connection")
|
||||
self.manager.run_task(muxed_conn.start)
|
||||
await muxed_conn.event_started.wait()
|
||||
logger.debug("Swarm::add_conn | starting swarm connection")
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
# Store muxed_conn with peer id
|
||||
self.connections[muxed_conn.peer_id] = swarm_conn
|
||||
|
||||
# Add to connections dict with deduplication
|
||||
peer_id = muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
self.connections[peer_id] = []
|
||||
|
||||
# Check for duplicate connections by comparing the underlying muxed connection
|
||||
for existing_conn in self.connections[peer_id]:
|
||||
if existing_conn.muxed_conn == muxed_conn:
|
||||
logger.debug(f"Connection already exists for peer {peer_id}")
|
||||
# existing_conn is a SwarmConn since it's stored in the connections list
|
||||
return existing_conn # type: ignore[return-value]
|
||||
|
||||
self.connections[peer_id].append(swarm_conn)
|
||||
|
||||
# Trim if we exceed max connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
if len(self.connections[peer_id]) > max_conns:
|
||||
self._trim_connections(peer_id)
|
||||
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_connected(swarm_conn)
|
||||
return swarm_conn
|
||||
|
||||
def _trim_connections(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove oldest connections when limit is exceeded.
|
||||
"""
|
||||
connections = self.connections[peer_id]
|
||||
if len(connections) <= self.connection_config.max_connections_per_peer:
|
||||
return
|
||||
|
||||
# Sort by creation time and remove oldest
|
||||
# For now, just keep the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
connections_to_remove = connections[:-max_conns]
|
||||
|
||||
for conn in connections_to_remove:
|
||||
logger.debug(f"Trimming old connection for peer {peer_id}")
|
||||
trio.lowlevel.spawn_system_task(self._close_connection_async, conn)
|
||||
|
||||
# Keep only the most recent connections
|
||||
max_conns = self.connection_config.max_connections_per_peer
|
||||
self.connections[peer_id] = connections[-max_conns:]
|
||||
|
||||
async def _close_connection_async(self, connection: INetConn) -> None:
|
||||
"""Close a connection asynchronously."""
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing connection: {e}")
|
||||
|
||||
def remove_conn(self, swarm_conn: SwarmConn) -> None:
|
||||
"""
|
||||
Simply remove the connection from Swarm's records, without closing
|
||||
the connection.
|
||||
"""
|
||||
peer_id = swarm_conn.muxed_conn.peer_id
|
||||
if peer_id not in self.connections:
|
||||
return
|
||||
del self.connections[peer_id]
|
||||
|
||||
if peer_id in self.connections:
|
||||
self.connections[peer_id] = [
|
||||
conn for conn in self.connections[peer_id] if conn != swarm_conn
|
||||
]
|
||||
if not self.connections[peer_id]:
|
||||
del self.connections[peer_id]
|
||||
|
||||
# Notifee
|
||||
|
||||
@ -444,3 +765,21 @@ class Swarm(Service, INetworkService):
|
||||
async with trio.open_nursery() as nursery:
|
||||
for notifee in self.notifees:
|
||||
nursery.start_soon(notifier, notifee)
|
||||
|
||||
# Backward compatibility properties
|
||||
@property
|
||||
def connections_legacy(self) -> dict[ID, INetConn]:
|
||||
"""
|
||||
Legacy 1:1 mapping for backward compatibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[ID, INetConn]
|
||||
Legacy mapping with only the first connection per peer.
|
||||
|
||||
"""
|
||||
legacy_conns = {}
|
||||
for peer_id, conns in self.connections.items():
|
||||
if conns:
|
||||
legacy_conns[peer_id] = conns[0]
|
||||
return legacy_conns
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from typing import Any, cast
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.crypto.ed25519 import Ed25519PublicKey
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
from libp2p.crypto.rsa import RSAPublicKey
|
||||
@ -131,6 +133,9 @@ class Envelope:
|
||||
)
|
||||
return False
|
||||
|
||||
def _env_addrs_set(self) -> set[multiaddr.Multiaddr]:
|
||||
return {b for b in self.record().addrs}
|
||||
|
||||
|
||||
def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey:
|
||||
"""
|
||||
|
||||
@ -16,6 +16,7 @@ import trio
|
||||
from trio import MemoryReceiveChannel, MemorySendChannel
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
IPeerStore,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
@ -23,7 +24,8 @@ from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
PublicKey,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope
|
||||
from libp2p.peer.envelope import Envelope, seal_record
|
||||
from libp2p.peer.peer_record import PeerRecord
|
||||
|
||||
from .id import (
|
||||
ID,
|
||||
@ -39,6 +41,86 @@ from .peerinfo import (
|
||||
PERMANENT_ADDR_TTL = 0
|
||||
|
||||
|
||||
def create_signed_peer_record(
|
||||
peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey
|
||||
) -> Envelope:
|
||||
"""Creates a signed_peer_record wrapped in an Envelope"""
|
||||
record = PeerRecord(peer_id, addrs)
|
||||
envelope = seal_record(record, pvt_key)
|
||||
return envelope
|
||||
|
||||
|
||||
def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]:
|
||||
"""
|
||||
Return the signed peer record (Envelope) to be sent in an RPC.
|
||||
|
||||
This function checks whether the host already has a cached signed peer record
|
||||
(SPR). If one exists and its addresses match the host's current listen
|
||||
addresses, the cached envelope is reused. Otherwise, a new signed peer record
|
||||
is created, cached, and returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The local host instance, providing access to peer ID, listen addresses,
|
||||
private key, and the peerstore.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bytes, bool]
|
||||
A 2-tuple where the first element is the serialized envelope (bytes)
|
||||
for the signed peer record, and the second element is a boolean flag
|
||||
indicating whether a new record was created (True) or an existing cached
|
||||
one was reused (False).
|
||||
|
||||
"""
|
||||
listen_addrs_set = {addr for addr in host.get_addrs()}
|
||||
local_env = host.get_peerstore().get_local_record()
|
||||
|
||||
if local_env is None:
|
||||
# No cached SPR yet -> create one
|
||||
return issue_and_cache_local_record(host), True
|
||||
else:
|
||||
record_addrs_set = local_env._env_addrs_set()
|
||||
if record_addrs_set == listen_addrs_set:
|
||||
# Perfect match -> reuse cached envelope
|
||||
return local_env.marshal_envelope(), False
|
||||
else:
|
||||
# Addresses changed -> issue a new SPR and cache it
|
||||
return issue_and_cache_local_record(host), True
|
||||
|
||||
|
||||
def issue_and_cache_local_record(host: IHost) -> bytes:
|
||||
"""
|
||||
Create and cache a new signed peer record (Envelope) for the host.
|
||||
|
||||
This function generates a new signed peer record from the host’s peer ID,
|
||||
listen addresses, and private key. The resulting envelope is stored in
|
||||
the peerstore as the local record for future reuse.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The local host instance, providing access to peer ID, listen addresses,
|
||||
private key, and the peerstore.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bytes
|
||||
The serialized envelope (bytes) representing the newly created signed
|
||||
peer record.
|
||||
|
||||
"""
|
||||
env = create_signed_peer_record(
|
||||
host.get_id(),
|
||||
host.get_addrs(),
|
||||
host.get_private_key(),
|
||||
)
|
||||
# Cache it for next time use
|
||||
host.get_peerstore().set_local_record(env)
|
||||
return env.marshal_envelope()
|
||||
|
||||
|
||||
class PeerRecordState:
|
||||
envelope: Envelope
|
||||
seq: int
|
||||
@ -55,8 +137,17 @@ class PeerStore(IPeerStore):
|
||||
self.peer_data_map = defaultdict(PeerData)
|
||||
self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {}
|
||||
self.peer_record_map: dict[ID, PeerRecordState] = {}
|
||||
self.local_peer_record: Envelope | None = None
|
||||
self.max_records = max_records
|
||||
|
||||
def get_local_record(self) -> Envelope | None:
|
||||
"""Get the local-signed-record wrapped in Envelope"""
|
||||
return self.local_peer_record
|
||||
|
||||
def set_local_record(self, envelope: Envelope) -> None:
|
||||
"""Set the local-signed-record wrapped in Envelope"""
|
||||
self.local_peer_record = envelope
|
||||
|
||||
def peer_info(self, peer_id: ID) -> PeerInfo:
|
||||
"""
|
||||
:param peer_id: peer ID to get info for
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from builtins import AssertionError
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
# Handle for connection close during ongoing negotiation in QUIC
|
||||
except (IOException, AssertionError, ValueError) as error:
|
||||
raise MultiselectCommunicatorError(
|
||||
"fail to write to multiselect communicator"
|
||||
) from error
|
||||
|
||||
@ -15,6 +15,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
|
||||
from .exceptions import (
|
||||
PubsubRouterError,
|
||||
@ -103,6 +104,11 @@ class FloodSub(IPubsubRouter):
|
||||
)
|
||||
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
|
||||
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
rpc_msg.senderRecord = envelope_bytes
|
||||
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
if self.pubsub is None:
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
from ast import (
|
||||
literal_eval,
|
||||
)
|
||||
from collections import (
|
||||
defaultdict,
|
||||
)
|
||||
@ -22,6 +19,7 @@ from libp2p.abc import (
|
||||
IPubsubRouter,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
MessageID,
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
@ -34,10 +32,12 @@ from libp2p.peer.peerinfo import (
|
||||
)
|
||||
from libp2p.peer.peerstore import (
|
||||
PERMANENT_ADDR_TTL,
|
||||
env_to_send_in_RPC,
|
||||
)
|
||||
from libp2p.pubsub import (
|
||||
floodsub,
|
||||
)
|
||||
from libp2p.pubsub.utils import maybe_consume_signed_record
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
@ -54,6 +54,10 @@ from .pb import (
|
||||
from .pubsub import (
|
||||
Pubsub,
|
||||
)
|
||||
from .utils import (
|
||||
parse_message_id_safe,
|
||||
safe_parse_message_id,
|
||||
)
|
||||
|
||||
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")
|
||||
PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0")
|
||||
@ -226,6 +230,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
:param rpc: RPC message
|
||||
:param sender_peer_id: id of the peer who sent the message
|
||||
"""
|
||||
# Process the senderRecord if sent
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
if not maybe_consume_signed_record(rpc, self.pubsub.host, sender_peer_id):
|
||||
logger.error("Received an invalid-signed-record, ignoring the message")
|
||||
return
|
||||
|
||||
control_message = rpc.control
|
||||
|
||||
# Relay each rpc control message to the appropriate handler
|
||||
@ -253,6 +263,11 @@ class GossipSub(IPubsubRouter, Service):
|
||||
)
|
||||
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
|
||||
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
rpc_msg.senderRecord = envelope_bytes
|
||||
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
for peer_id in peers_gen:
|
||||
@ -781,8 +796,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||
# seen_seqnos) to list of messages we want to request
|
||||
msg_ids_wanted: list[str] = [
|
||||
msg_id
|
||||
msg_ids_wanted: list[MessageID] = [
|
||||
parse_message_id_safe(msg_id)
|
||||
for msg_id in ihave_msg.messageIDs
|
||||
if msg_id not in seen_seqnos_and_peers
|
||||
]
|
||||
@ -798,9 +813,9 @@ class GossipSub(IPubsubRouter, Service):
|
||||
Forwards all request messages that are present in mcache to the
|
||||
requesting peer.
|
||||
"""
|
||||
# FIXME: Update type of message ID
|
||||
# FIXME: Find a better way to parse the msg ids
|
||||
msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
|
||||
msg_ids: list[tuple[bytes, bytes]] = [
|
||||
safe_parse_message_id(msg) for msg in iwant_msg.messageIDs
|
||||
]
|
||||
msgs_to_forward: list[rpc_pb2.Message] = []
|
||||
for msg_id_iwant in msg_ids:
|
||||
# Check if the wanted message ID is present in mcache
|
||||
@ -818,6 +833,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
# 1) Package these messages into a single packet
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
||||
# Here the an RPC message is being created and published in response
|
||||
# to the iwant control msg, so we will send a freshly created senderRecord
|
||||
# with the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
packet.publish.extend(msgs_to_forward)
|
||||
|
||||
if self.pubsub is None:
|
||||
@ -973,6 +995,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
raise NoPubsubAttached
|
||||
# Add control message to packet
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
|
||||
# Add the sender's peer-record in the RPC msg
|
||||
if isinstance(self.pubsub, Pubsub):
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
packet.control.CopyFrom(control_msg)
|
||||
|
||||
# Get stream for peer from pubsub
|
||||
|
||||
@ -14,6 +14,7 @@ message RPC {
|
||||
}
|
||||
|
||||
optional ControlMessage control = 3;
|
||||
optional bytes senderRecord = 4;
|
||||
}
|
||||
|
||||
message Message {
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/pubsub/pb/rpc.proto
|
||||
# Protobuf Python Version: 4.25.3
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
@ -13,39 +14,39 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xca\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x12\x14\n\x0csenderRecord\x18\x04 \x01(\x0c\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals())
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_RPC._serialized_start=42
|
||||
_RPC._serialized_end=222
|
||||
_RPC_SUBOPTS._serialized_start=177
|
||||
_RPC_SUBOPTS._serialized_end=222
|
||||
_MESSAGE._serialized_start=224
|
||||
_MESSAGE._serialized_end=329
|
||||
_CONTROLMESSAGE._serialized_start=332
|
||||
_CONTROLMESSAGE._serialized_end=508
|
||||
_CONTROLIHAVE._serialized_start=510
|
||||
_CONTROLIHAVE._serialized_end=561
|
||||
_CONTROLIWANT._serialized_start=563
|
||||
_CONTROLIWANT._serialized_end=597
|
||||
_CONTROLGRAFT._serialized_start=599
|
||||
_CONTROLGRAFT._serialized_end=630
|
||||
_CONTROLPRUNE._serialized_start=632
|
||||
_CONTROLPRUNE._serialized_end=716
|
||||
_PEERINFO._serialized_start=718
|
||||
_PEERINFO._serialized_end=770
|
||||
_TOPICDESCRIPTOR._serialized_start=773
|
||||
_TOPICDESCRIPTOR._serialized_end=1164
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_start=906
|
||||
_TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1030
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=992
|
||||
_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1030
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_start=1033
|
||||
_TOPICDESCRIPTOR_ENCOPTS._serialized_end=1164
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1121
|
||||
_TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1164
|
||||
_globals['_RPC']._serialized_start=42
|
||||
_globals['_RPC']._serialized_end=244
|
||||
_globals['_RPC_SUBOPTS']._serialized_start=199
|
||||
_globals['_RPC_SUBOPTS']._serialized_end=244
|
||||
_globals['_MESSAGE']._serialized_start=246
|
||||
_globals['_MESSAGE']._serialized_end=351
|
||||
_globals['_CONTROLMESSAGE']._serialized_start=354
|
||||
_globals['_CONTROLMESSAGE']._serialized_end=530
|
||||
_globals['_CONTROLIHAVE']._serialized_start=532
|
||||
_globals['_CONTROLIHAVE']._serialized_end=583
|
||||
_globals['_CONTROLIWANT']._serialized_start=585
|
||||
_globals['_CONTROLIWANT']._serialized_end=619
|
||||
_globals['_CONTROLGRAFT']._serialized_start=621
|
||||
_globals['_CONTROLGRAFT']._serialized_end=652
|
||||
_globals['_CONTROLPRUNE']._serialized_start=654
|
||||
_globals['_CONTROLPRUNE']._serialized_end=738
|
||||
_globals['_PEERINFO']._serialized_start=740
|
||||
_globals['_PEERINFO']._serialized_end=792
|
||||
_globals['_TOPICDESCRIPTOR']._serialized_start=795
|
||||
_globals['_TOPICDESCRIPTOR']._serialized_end=1186
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_start=928
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_end=1052
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_start=1014
|
||||
_globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_end=1052
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_start=1055
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_end=1186
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_start=1143
|
||||
_globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_end=1186
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@ -1,323 +1,132 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
Modified from https://github.com/libp2p/go-libp2p-pubsub/blob/master/pb/rpc.proto"""
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
class RPC(_message.Message):
|
||||
__slots__ = ("subscriptions", "publish", "control", "senderRecord")
|
||||
class SubOpts(_message.Message):
|
||||
__slots__ = ("subscribe", "topicid")
|
||||
SUBSCRIBE_FIELD_NUMBER: _ClassVar[int]
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
subscribe: bool
|
||||
topicid: str
|
||||
def __init__(self, subscribe: bool = ..., topicid: _Optional[str] = ...) -> None: ...
|
||||
SUBSCRIPTIONS_FIELD_NUMBER: _ClassVar[int]
|
||||
PUBLISH_FIELD_NUMBER: _ClassVar[int]
|
||||
CONTROL_FIELD_NUMBER: _ClassVar[int]
|
||||
SENDERRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
subscriptions: _containers.RepeatedCompositeFieldContainer[RPC.SubOpts]
|
||||
publish: _containers.RepeatedCompositeFieldContainer[Message]
|
||||
control: ControlMessage
|
||||
senderRecord: bytes
|
||||
def __init__(self, subscriptions: _Optional[_Iterable[_Union[RPC.SubOpts, _Mapping]]] = ..., publish: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., control: _Optional[_Union[ControlMessage, _Mapping]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
class Message(_message.Message):
|
||||
__slots__ = ("from_id", "data", "seqno", "topicIDs", "signature", "key")
|
||||
FROM_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
SEQNO_FIELD_NUMBER: _ClassVar[int]
|
||||
TOPICIDS_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNATURE_FIELD_NUMBER: _ClassVar[int]
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
from_id: bytes
|
||||
data: bytes
|
||||
seqno: bytes
|
||||
topicIDs: _containers.RepeatedScalarFieldContainer[str]
|
||||
signature: bytes
|
||||
key: bytes
|
||||
def __init__(self, from_id: _Optional[bytes] = ..., data: _Optional[bytes] = ..., seqno: _Optional[bytes] = ..., topicIDs: _Optional[_Iterable[str]] = ..., signature: _Optional[bytes] = ..., key: _Optional[bytes] = ...) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class RPC(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
class ControlMessage(_message.Message):
|
||||
__slots__ = ("ihave", "iwant", "graft", "prune")
|
||||
IHAVE_FIELD_NUMBER: _ClassVar[int]
|
||||
IWANT_FIELD_NUMBER: _ClassVar[int]
|
||||
GRAFT_FIELD_NUMBER: _ClassVar[int]
|
||||
PRUNE_FIELD_NUMBER: _ClassVar[int]
|
||||
ihave: _containers.RepeatedCompositeFieldContainer[ControlIHave]
|
||||
iwant: _containers.RepeatedCompositeFieldContainer[ControlIWant]
|
||||
graft: _containers.RepeatedCompositeFieldContainer[ControlGraft]
|
||||
prune: _containers.RepeatedCompositeFieldContainer[ControlPrune]
|
||||
def __init__(self, ihave: _Optional[_Iterable[_Union[ControlIHave, _Mapping]]] = ..., iwant: _Optional[_Iterable[_Union[ControlIWant, _Mapping]]] = ..., graft: _Optional[_Iterable[_Union[ControlGraft, _Mapping]]] = ..., prune: _Optional[_Iterable[_Union[ControlPrune, _Mapping]]] = ...) -> None: ... # type: ignore
|
||||
|
||||
@typing.final
|
||||
class SubOpts(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
class ControlIHave(_message.Message):
|
||||
__slots__ = ("topicID", "messageIDs")
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
|
||||
topicID: str
|
||||
messageIDs: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, topicID: _Optional[str] = ..., messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
SUBSCRIBE_FIELD_NUMBER: builtins.int
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
subscribe: builtins.bool
|
||||
"""subscribe or unsubscribe"""
|
||||
topicid: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
subscribe: builtins.bool | None = ...,
|
||||
topicid: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> None: ...
|
||||
class ControlIWant(_message.Message):
|
||||
__slots__ = ("messageIDs",)
|
||||
MESSAGEIDS_FIELD_NUMBER: _ClassVar[int]
|
||||
messageIDs: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, messageIDs: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
SUBSCRIPTIONS_FIELD_NUMBER: builtins.int
|
||||
PUBLISH_FIELD_NUMBER: builtins.int
|
||||
CONTROL_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RPC.SubOpts]: ...
|
||||
@property
|
||||
def publish(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message]: ...
|
||||
@property
|
||||
def control(self) -> global___ControlMessage: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
subscriptions: collections.abc.Iterable[global___RPC.SubOpts] | None = ...,
|
||||
publish: collections.abc.Iterable[global___Message] | None = ...,
|
||||
control: global___ControlMessage | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["control", b"control"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["control", b"control", "publish", b"publish", "subscriptions", b"subscriptions"]) -> None: ...
|
||||
class ControlGraft(_message.Message):
|
||||
__slots__ = ("topicID",)
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
topicID: str
|
||||
def __init__(self, topicID: _Optional[str] = ...) -> None: ...
|
||||
|
||||
global___RPC = RPC
|
||||
class ControlPrune(_message.Message):
|
||||
__slots__ = ("topicID", "peers", "backoff")
|
||||
TOPICID_FIELD_NUMBER: _ClassVar[int]
|
||||
PEERS_FIELD_NUMBER: _ClassVar[int]
|
||||
BACKOFF_FIELD_NUMBER: _ClassVar[int]
|
||||
topicID: str
|
||||
peers: _containers.RepeatedCompositeFieldContainer[PeerInfo]
|
||||
backoff: int
|
||||
def __init__(self, topicID: _Optional[str] = ..., peers: _Optional[_Iterable[_Union[PeerInfo, _Mapping]]] = ..., backoff: _Optional[int] = ...) -> None: ... # type: ignore
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
class PeerInfo(_message.Message):
|
||||
__slots__ = ("peerID", "signedPeerRecord")
|
||||
PEERID_FIELD_NUMBER: _ClassVar[int]
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int]
|
||||
peerID: bytes
|
||||
signedPeerRecord: bytes
|
||||
def __init__(self, peerID: _Optional[bytes] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ...
|
||||
|
||||
FROM_ID_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
SEQNO_FIELD_NUMBER: builtins.int
|
||||
TOPICIDS_FIELD_NUMBER: builtins.int
|
||||
SIGNATURE_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
from_id: builtins.bytes
|
||||
data: builtins.bytes
|
||||
seqno: builtins.bytes
|
||||
signature: builtins.bytes
|
||||
key: builtins.bytes
|
||||
@property
|
||||
def topicIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
from_id: builtins.bytes | None = ...,
|
||||
data: builtins.bytes | None = ...,
|
||||
seqno: builtins.bytes | None = ...,
|
||||
topicIDs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
signature: builtins.bytes | None = ...,
|
||||
key: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature", "topicIDs", b"topicIDs"]) -> None: ...
|
||||
|
||||
global___Message = Message
|
||||
|
||||
@typing.final
|
||||
class ControlMessage(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
IHAVE_FIELD_NUMBER: builtins.int
|
||||
IWANT_FIELD_NUMBER: builtins.int
|
||||
GRAFT_FIELD_NUMBER: builtins.int
|
||||
PRUNE_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def ihave(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIHave]: ...
|
||||
@property
|
||||
def iwant(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIWant]: ...
|
||||
@property
|
||||
def graft(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlGraft]: ...
|
||||
@property
|
||||
def prune(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlPrune]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ihave: collections.abc.Iterable[global___ControlIHave] | None = ...,
|
||||
iwant: collections.abc.Iterable[global___ControlIWant] | None = ...,
|
||||
graft: collections.abc.Iterable[global___ControlGraft] | None = ...,
|
||||
prune: collections.abc.Iterable[global___ControlPrune] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["graft", b"graft", "ihave", b"ihave", "iwant", b"iwant", "prune", b"prune"]) -> None: ...
|
||||
|
||||
global___ControlMessage = ControlMessage
|
||||
|
||||
@typing.final
|
||||
class ControlIHave(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
MESSAGEIDS_FIELD_NUMBER: builtins.int
|
||||
topicID: builtins.str
|
||||
@property
|
||||
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topicID: builtins.str | None = ...,
|
||||
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs", "topicID", b"topicID"]) -> None: ...
|
||||
|
||||
global___ControlIHave = ControlIHave
|
||||
|
||||
@typing.final
|
||||
class ControlIWant(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
MESSAGEIDS_FIELD_NUMBER: builtins.int
|
||||
@property
|
||||
def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
messageIDs: collections.abc.Iterable[builtins.str] | None = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs"]) -> None: ...
|
||||
|
||||
global___ControlIWant = ControlIWant
|
||||
|
||||
@typing.final
|
||||
class ControlGraft(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
topicID: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topicID: builtins.str | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ...
|
||||
|
||||
global___ControlGraft = ControlGraft
|
||||
|
||||
@typing.final
|
||||
class ControlPrune(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
TOPICID_FIELD_NUMBER: builtins.int
|
||||
PEERS_FIELD_NUMBER: builtins.int
|
||||
BACKOFF_FIELD_NUMBER: builtins.int
|
||||
topicID: builtins.str
|
||||
backoff: builtins.int
|
||||
@property
|
||||
def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
topicID: builtins.str | None = ...,
|
||||
peers: collections.abc.Iterable[global___PeerInfo] | None = ...,
|
||||
backoff: builtins.int | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ...
|
||||
|
||||
global___ControlPrune = ControlPrune
|
||||
|
||||
@typing.final
|
||||
class PeerInfo(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
PEERID_FIELD_NUMBER: builtins.int
|
||||
SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int
|
||||
peerID: builtins.bytes
|
||||
signedPeerRecord: builtins.bytes
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
peerID: builtins.bytes | None = ...,
|
||||
signedPeerRecord: builtins.bytes | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ...
|
||||
|
||||
global___PeerInfo = PeerInfo
|
||||
|
||||
@typing.final
|
||||
class TopicDescriptor(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
@typing.final
|
||||
class AuthOpts(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _AuthMode:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _AuthModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.AuthOpts._AuthMode.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NONE: TopicDescriptor.AuthOpts._AuthMode.ValueType # 0
|
||||
"""no authentication, anyone can publish"""
|
||||
KEY: TopicDescriptor.AuthOpts._AuthMode.ValueType # 1
|
||||
"""only messages signed by keys in the topic descriptor are accepted"""
|
||||
WOT: TopicDescriptor.AuthOpts._AuthMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
class AuthMode(_AuthMode, metaclass=_AuthModeEnumTypeWrapper): ...
|
||||
NONE: TopicDescriptor.AuthOpts.AuthMode.ValueType # 0
|
||||
"""no authentication, anyone can publish"""
|
||||
KEY: TopicDescriptor.AuthOpts.AuthMode.ValueType # 1
|
||||
"""only messages signed by keys in the topic descriptor are accepted"""
|
||||
WOT: TopicDescriptor.AuthOpts.AuthMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
MODE_FIELD_NUMBER: builtins.int
|
||||
KEYS_FIELD_NUMBER: builtins.int
|
||||
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType
|
||||
@property
|
||||
def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
|
||||
"""root keys to trust"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType | None = ...,
|
||||
keys: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["keys", b"keys", "mode", b"mode"]) -> None: ...
|
||||
|
||||
@typing.final
|
||||
class EncOpts(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _EncMode:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _EncModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.EncOpts._EncMode.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NONE: TopicDescriptor.EncOpts._EncMode.ValueType # 0
|
||||
"""no encryption, anyone can read"""
|
||||
SHAREDKEY: TopicDescriptor.EncOpts._EncMode.ValueType # 1
|
||||
"""messages are encrypted with shared key"""
|
||||
WOT: TopicDescriptor.EncOpts._EncMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
class EncMode(_EncMode, metaclass=_EncModeEnumTypeWrapper): ...
|
||||
NONE: TopicDescriptor.EncOpts.EncMode.ValueType # 0
|
||||
"""no encryption, anyone can read"""
|
||||
SHAREDKEY: TopicDescriptor.EncOpts.EncMode.ValueType # 1
|
||||
"""messages are encrypted with shared key"""
|
||||
WOT: TopicDescriptor.EncOpts.EncMode.ValueType # 2
|
||||
"""web of trust, certificates can allow publisher set to grow"""
|
||||
|
||||
MODE_FIELD_NUMBER: builtins.int
|
||||
KEYHASHES_FIELD_NUMBER: builtins.int
|
||||
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType
|
||||
@property
|
||||
def keyHashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
|
||||
"""the hashes of the shared keys used (salted)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mode: global___TopicDescriptor.EncOpts.EncMode.ValueType | None = ...,
|
||||
keyHashes: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["keyHashes", b"keyHashes", "mode", b"mode"]) -> None: ...
|
||||
|
||||
NAME_FIELD_NUMBER: builtins.int
|
||||
AUTH_FIELD_NUMBER: builtins.int
|
||||
ENC_FIELD_NUMBER: builtins.int
|
||||
name: builtins.str
|
||||
@property
|
||||
def auth(self) -> global___TopicDescriptor.AuthOpts: ...
|
||||
@property
|
||||
def enc(self) -> global___TopicDescriptor.EncOpts: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: builtins.str | None = ...,
|
||||
auth: global___TopicDescriptor.AuthOpts | None = ...,
|
||||
enc: global___TopicDescriptor.EncOpts | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> None: ...
|
||||
|
||||
global___TopicDescriptor = TopicDescriptor
|
||||
class TopicDescriptor(_message.Message):
|
||||
__slots__ = ("name", "auth", "enc")
|
||||
class AuthOpts(_message.Message):
|
||||
__slots__ = ("mode", "keys")
|
||||
class AuthMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
NONE: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
|
||||
KEY: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
|
||||
WOT: _ClassVar[TopicDescriptor.AuthOpts.AuthMode]
|
||||
NONE: TopicDescriptor.AuthOpts.AuthMode
|
||||
KEY: TopicDescriptor.AuthOpts.AuthMode
|
||||
WOT: TopicDescriptor.AuthOpts.AuthMode
|
||||
MODE_FIELD_NUMBER: _ClassVar[int]
|
||||
KEYS_FIELD_NUMBER: _ClassVar[int]
|
||||
mode: TopicDescriptor.AuthOpts.AuthMode
|
||||
keys: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
def __init__(self, mode: _Optional[_Union[TopicDescriptor.AuthOpts.AuthMode, str]] = ..., keys: _Optional[_Iterable[bytes]] = ...) -> None: ...
|
||||
class EncOpts(_message.Message):
|
||||
__slots__ = ("mode", "keyHashes")
|
||||
class EncMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
NONE: _ClassVar[TopicDescriptor.EncOpts.EncMode]
|
||||
SHAREDKEY: _ClassVar[TopicDescriptor.EncOpts.EncMode]
|
||||
WOT: _ClassVar[TopicDescriptor.EncOpts.EncMode]
|
||||
NONE: TopicDescriptor.EncOpts.EncMode
|
||||
SHAREDKEY: TopicDescriptor.EncOpts.EncMode
|
||||
WOT: TopicDescriptor.EncOpts.EncMode
|
||||
MODE_FIELD_NUMBER: _ClassVar[int]
|
||||
KEYHASHES_FIELD_NUMBER: _ClassVar[int]
|
||||
mode: TopicDescriptor.EncOpts.EncMode
|
||||
keyHashes: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
def __init__(self, mode: _Optional[_Union[TopicDescriptor.EncOpts.EncMode, str]] = ..., keyHashes: _Optional[_Iterable[bytes]] = ...) -> None: ...
|
||||
NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
AUTH_FIELD_NUMBER: _ClassVar[int]
|
||||
ENC_FIELD_NUMBER: _ClassVar[int]
|
||||
name: str
|
||||
auth: TopicDescriptor.AuthOpts
|
||||
enc: TopicDescriptor.EncOpts
|
||||
def __init__(self, name: _Optional[str] = ..., auth: _Optional[_Union[TopicDescriptor.AuthOpts, _Mapping]] = ..., enc: _Optional[_Union[TopicDescriptor.EncOpts, _Mapping]] = ...) -> None: ... # type: ignore
|
||||
|
||||
@ -56,6 +56,8 @@ from libp2p.peer.id import (
|
||||
from libp2p.peer.peerdata import (
|
||||
PeerDataError,
|
||||
)
|
||||
from libp2p.peer.peerstore import env_to_send_in_RPC
|
||||
from libp2p.pubsub.utils import maybe_consume_signed_record
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
@ -247,6 +249,10 @@ class Pubsub(Service, IPubsub):
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
# Add the sender's signedRecord in the RPC message
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
return packet
|
||||
|
||||
async def continuously_read_stream(self, stream: INetStream) -> None:
|
||||
@ -263,6 +269,14 @@ class Pubsub(Service, IPubsub):
|
||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
|
||||
# Process the sender's signed-record if sent
|
||||
if not maybe_consume_signed_record(rpc_incoming, self.host, peer_id):
|
||||
logger.error(
|
||||
"Received an invalid-signed-record, ignoring the incoming msg"
|
||||
)
|
||||
continue
|
||||
|
||||
if rpc_incoming.publish:
|
||||
# deal with RPC.publish
|
||||
for msg in rpc_incoming.publish:
|
||||
@ -572,6 +586,9 @@ class Pubsub(Service, IPubsub):
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
# Send out subscribe message to all peers
|
||||
await self.message_all_peers(packet.SerializeToString())
|
||||
|
||||
@ -604,6 +621,9 @@ class Pubsub(Service, IPubsub):
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)]
|
||||
)
|
||||
# Add the senderRecord of the peer in the RPC msg
|
||||
envelope_bytes, _ = env_to_send_in_RPC(self.host)
|
||||
packet.senderRecord = envelope_bytes
|
||||
|
||||
# Send out unsubscribe message to all peers
|
||||
await self.message_all_peers(packet.SerializeToString())
|
||||
|
||||
80
libp2p/pubsub/utils.py
Normal file
80
libp2p/pubsub/utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
import ast
|
||||
import logging
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.custom_types import (
|
||||
MessageID,
|
||||
)
|
||||
from libp2p.peer.envelope import consume_envelope
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.pubsub.pb.rpc_pb2 import RPC
|
||||
|
||||
logger = logging.getLogger("pubsub-example.utils")
|
||||
|
||||
|
||||
def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool:
|
||||
"""
|
||||
Attempt to parse and store a signed-peer-record (Envelope) received during
|
||||
PubSub communication. If the record is invalid, the peer-id does not match, or
|
||||
updating the peerstore fails, the function logs an error and returns False.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : RPC
|
||||
The protobuf message received during PubSub communication.
|
||||
host : IHost
|
||||
The local host instance, providing access to the peerstore for storing
|
||||
verified peer records.
|
||||
peer_id : ID | None, optional
|
||||
The expected peer ID for record validation. If provided, the peer ID
|
||||
inside the record must match this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if a valid signed peer record was successfully consumed and stored,
|
||||
False otherwise.
|
||||
|
||||
"""
|
||||
if msg.HasField("senderRecord"):
|
||||
try:
|
||||
# Convert the signed-peer-record(Envelope) from
|
||||
# protobuf bytes
|
||||
envelope, record = consume_envelope(msg.senderRecord, "libp2p-peer-record")
|
||||
if not record.peer_id == peer_id:
|
||||
return False
|
||||
|
||||
# Use the default TTL of 2 hours (7200 seconds)
|
||||
if not host.get_peerstore().consume_peer_record(envelope, 7200):
|
||||
logger.error("Failed to update the Certified-Addr-Book")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Failed to update the Certified-Addr-Book: %s", e)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_message_id_safe(msg_id_str: str) -> MessageID:
|
||||
"""Safely handle message ID as string."""
|
||||
return MessageID(msg_id_str)
|
||||
|
||||
|
||||
def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]:
|
||||
"""
|
||||
Safely parse message ID using ast.literal_eval with validation.
|
||||
:param msg_id_str: String representation of message ID
|
||||
:return: Tuple of (seqno, from_id) as bytes
|
||||
:raises ValueError: If parsing fails
|
||||
"""
|
||||
try:
|
||||
parsed = ast.literal_eval(msg_id_str)
|
||||
if not isinstance(parsed, tuple) or len(parsed) != 2:
|
||||
raise ValueError("Invalid message ID format")
|
||||
|
||||
seqno, from_id = parsed
|
||||
if not isinstance(seqno, bytes) or not isinstance(from_id, bytes):
|
||||
raise ValueError("Message ID components must be bytes")
|
||||
|
||||
return (seqno, from_id)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Invalid message ID format: {e}")
|
||||
@ -118,6 +118,8 @@ class SecurityMultistream(ABC):
|
||||
# Select protocol if non-initiator
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a security protocol")
|
||||
raise MultiselectError(
|
||||
"Failed to negotiate a security protocol: no protocol selected"
|
||||
)
|
||||
# Return transport from protocol
|
||||
return self.transports[protocol]
|
||||
|
||||
@ -85,7 +85,9 @@ class MuxerMultistream:
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a stream muxer protocol")
|
||||
raise MultiselectError(
|
||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||
)
|
||||
return self.transports[protocol]
|
||||
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
|
||||
0
libp2p/transport/quic/__init__.py
Normal file
0
libp2p/transport/quic/__init__.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""
|
||||
Configuration classes for QUIC transport.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
import ssl
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.network.config import ConnectionConfig
|
||||
|
||||
|
||||
class QUICTransportKwargs(TypedDict, total=False):
|
||||
"""Type definition for kwargs accepted by new_transport function."""
|
||||
|
||||
# Connection settings
|
||||
idle_timeout: float
|
||||
max_datagram_size: int
|
||||
local_port: int | None
|
||||
|
||||
# Protocol version support
|
||||
enable_draft29: bool
|
||||
enable_v1: bool
|
||||
|
||||
# TLS settings
|
||||
verify_mode: ssl.VerifyMode
|
||||
alpn_protocols: list[str]
|
||||
|
||||
# Performance settings
|
||||
max_concurrent_streams: int
|
||||
connection_window: int
|
||||
stream_window: int
|
||||
|
||||
# Logging and debugging
|
||||
enable_qlog: bool
|
||||
qlog_dir: str | None
|
||||
|
||||
# Connection management
|
||||
max_connections: int
|
||||
connection_timeout: float
|
||||
|
||||
# Protocol identifiers
|
||||
PROTOCOL_QUIC_V1: TProtocol
|
||||
PROTOCOL_QUIC_DRAFT29: TProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class QUICTransportConfig(ConnectionConfig):
|
||||
"""Configuration for QUIC transport."""
|
||||
|
||||
# Connection settings
|
||||
idle_timeout: float = 30.0 # Seconds before an idle connection is closed.
|
||||
max_datagram_size: int = (
|
||||
1200 # Maximum size of UDP datagrams to avoid IP fragmentation.
|
||||
)
|
||||
local_port: int | None = (
|
||||
None # Local port to bind to. If None, a random port is chosen.
|
||||
)
|
||||
|
||||
# Protocol version support
|
||||
enable_draft29: bool = True # Enable QUIC draft-29 for compatibility
|
||||
enable_v1: bool = True # Enable QUIC v1 (RFC 9000)
|
||||
|
||||
# TLS settings
|
||||
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
|
||||
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
|
||||
|
||||
# Performance settings
|
||||
max_concurrent_streams: int = 100 # Maximum concurrent streams per connection
|
||||
connection_window: int = 1024 * 1024 # Connection flow control window
|
||||
stream_window: int = 64 * 1024 # Stream flow control window
|
||||
|
||||
# Logging and debugging
|
||||
enable_qlog: bool = False # Enable QUIC logging
|
||||
qlog_dir: str | None = None # Directory for QUIC logs
|
||||
|
||||
# Connection management
|
||||
max_connections: int = 1000 # Maximum number of connections
|
||||
connection_timeout: float = 10.0 # Connection establishment timeout
|
||||
|
||||
MAX_CONCURRENT_STREAMS: int = 1000
|
||||
"""Maximum number of concurrent streams per connection."""
|
||||
|
||||
MAX_INCOMING_STREAMS: int = 1000
|
||||
"""Maximum number of incoming streams per connection."""
|
||||
|
||||
CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0
|
||||
"""Timeout for connection handshake (seconds)."""
|
||||
|
||||
MAX_OUTGOING_STREAMS: int = 1000
|
||||
"""Maximum number of outgoing streams per connection."""
|
||||
|
||||
CONNECTION_CLOSE_TIMEOUT: int = 10
|
||||
"""Timeout for opening new connection (seconds)."""
|
||||
|
||||
# Stream timeouts
|
||||
STREAM_OPEN_TIMEOUT: float = 5.0
|
||||
"""Timeout for opening new streams (seconds)."""
|
||||
|
||||
STREAM_ACCEPT_TIMEOUT: float = 30.0
|
||||
"""Timeout for accepting incoming streams (seconds)."""
|
||||
|
||||
STREAM_READ_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream read operations (seconds)."""
|
||||
|
||||
STREAM_WRITE_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream write operations (seconds)."""
|
||||
|
||||
STREAM_CLOSE_TIMEOUT: float = 10.0
|
||||
"""Timeout for graceful stream close (seconds)."""
|
||||
|
||||
# Flow control configuration
|
||||
STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB
|
||||
"""Per-stream flow control window size."""
|
||||
|
||||
CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB
|
||||
"""Connection-wide flow control window size."""
|
||||
|
||||
# Buffer management
|
||||
MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB
|
||||
"""Maximum receive buffer size per stream."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
|
||||
"""Low watermark for stream receive buffer."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
|
||||
"""High watermark for stream receive buffer."""
|
||||
|
||||
# Stream lifecycle configuration
|
||||
ENABLE_STREAM_RESET_ON_ERROR: bool = True
|
||||
"""Whether to automatically reset streams on errors."""
|
||||
|
||||
STREAM_RESET_ERROR_CODE: int = 1
|
||||
"""Default error code for stream resets."""
|
||||
|
||||
ENABLE_STREAM_KEEP_ALIVE: bool = False
|
||||
"""Whether to enable stream keep-alive mechanisms."""
|
||||
|
||||
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
|
||||
"""Interval for stream keep-alive pings (seconds)."""
|
||||
|
||||
# Resource management
|
||||
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
|
||||
"""Whether to track stream resource usage."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
|
||||
"""Memory limit per individual stream."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
|
||||
"""Total memory limit for all streams per connection."""
|
||||
|
||||
# Concurrency and performance
|
||||
ENABLE_STREAM_BATCHING: bool = True
|
||||
"""Whether to batch multiple stream operations."""
|
||||
|
||||
STREAM_BATCH_SIZE: int = 10
|
||||
"""Number of streams to process in a batch."""
|
||||
|
||||
STREAM_PROCESSING_CONCURRENCY: int = 100
|
||||
"""Maximum concurrent stream processing tasks."""
|
||||
|
||||
# Debugging and monitoring
|
||||
ENABLE_STREAM_METRICS: bool = True
|
||||
"""Whether to collect stream metrics."""
|
||||
|
||||
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
|
||||
"""Whether to track stream lifecycle timelines."""
|
||||
|
||||
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
|
||||
"""Interval for collecting stream metrics (seconds)."""
|
||||
|
||||
# Error handling configuration
|
||||
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
|
||||
"""Number of retry attempts for recoverable stream errors."""
|
||||
|
||||
STREAM_ERROR_RETRY_DELAY: float = 1.0
|
||||
"""Initial delay between stream error retries (seconds)."""
|
||||
|
||||
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
|
||||
"""Backoff factor for stream error retries."""
|
||||
|
||||
# Protocol identifiers matching go-libp2p
|
||||
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000
|
||||
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration after initialization."""
|
||||
if not (self.enable_draft29 or self.enable_v1):
|
||||
raise ValueError("At least one QUIC version must be enabled")
|
||||
|
||||
if self.idle_timeout <= 0:
|
||||
raise ValueError("Idle timeout must be positive")
|
||||
|
||||
if self.max_datagram_size < 1200:
|
||||
raise ValueError("Max datagram size must be at least 1200 bytes")
|
||||
|
||||
# Validate timeouts
|
||||
timeout_fields = [
|
||||
"STREAM_OPEN_TIMEOUT",
|
||||
"STREAM_ACCEPT_TIMEOUT",
|
||||
"STREAM_READ_TIMEOUT",
|
||||
"STREAM_WRITE_TIMEOUT",
|
||||
"STREAM_CLOSE_TIMEOUT",
|
||||
]
|
||||
for timeout_field in timeout_fields:
|
||||
if getattr(self, timeout_field) <= 0:
|
||||
raise ValueError(f"{timeout_field} must be positive")
|
||||
|
||||
# Validate flow control windows
|
||||
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
|
||||
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
|
||||
|
||||
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
|
||||
raise ValueError(
|
||||
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
|
||||
)
|
||||
|
||||
# Validate buffer sizes
|
||||
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
|
||||
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
|
||||
|
||||
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
|
||||
"exceed MAX_STREAM_RECEIVE_BUFFER"
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
|
||||
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
|
||||
):
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
|
||||
)
|
||||
|
||||
# Validate memory limits
|
||||
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
|
||||
|
||||
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
|
||||
|
||||
expected_stream_memory = (
|
||||
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
|
||||
)
|
||||
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
|
||||
# Allow some headroom, but warn if configuration seems inconsistent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Stream memory configuration may be inconsistent: "
|
||||
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
|
||||
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
|
||||
"could exceed connection limit of"
|
||||
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
|
||||
)
|
||||
|
||||
def get_stream_config_dict(self) -> dict[str, Any]:
|
||||
"""Get stream-specific configuration as dictionary."""
|
||||
stream_config = {}
|
||||
for attr_name in dir(self):
|
||||
if attr_name.startswith(
|
||||
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
|
||||
):
|
||||
stream_config[attr_name.lower()] = getattr(self, attr_name)
|
||||
return stream_config
|
||||
|
||||
|
||||
# Additional configuration classes for specific stream features
|
||||
|
||||
|
||||
class QUICStreamFlowControlConfig:
|
||||
"""Configuration for QUIC stream flow control."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_window_size: int = 512 * 1024,
|
||||
max_window_size: int = 2 * 1024 * 1024,
|
||||
window_update_threshold: float = 0.5,
|
||||
enable_auto_tuning: bool = True,
|
||||
):
|
||||
self.initial_window_size = initial_window_size
|
||||
self.max_window_size = max_window_size
|
||||
self.window_update_threshold = window_update_threshold
|
||||
self.enable_auto_tuning = enable_auto_tuning
|
||||
|
||||
|
||||
def create_stream_config_for_use_case(
|
||||
use_case: Literal[
|
||||
"high_throughput", "low_latency", "many_streams", "memory_constrained"
|
||||
],
|
||||
) -> QUICTransportConfig:
|
||||
"""
|
||||
Create optimized stream configuration for specific use cases.
|
||||
|
||||
Args:
|
||||
use_case: One of "high_throughput", "low_latency", "many_streams","
|
||||
"memory_constrained"
|
||||
|
||||
Returns:
|
||||
Optimized QUICTransportConfig
|
||||
|
||||
"""
|
||||
base_config = QUICTransportConfig()
|
||||
|
||||
if use_case == "high_throughput":
|
||||
# Optimize for high throughput
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 200
|
||||
|
||||
elif use_case == "low_latency":
|
||||
# Optimize for low latency
|
||||
base_config.STREAM_OPEN_TIMEOUT = 1.0
|
||||
base_config.STREAM_READ_TIMEOUT = 5.0
|
||||
base_config.STREAM_WRITE_TIMEOUT = 5.0
|
||||
base_config.ENABLE_STREAM_BATCHING = False
|
||||
base_config.STREAM_BATCH_SIZE = 1
|
||||
|
||||
elif use_case == "many_streams":
|
||||
# Optimize for many concurrent streams
|
||||
base_config.MAX_CONCURRENT_STREAMS = 5000
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 500
|
||||
|
||||
elif use_case == "memory_constrained":
|
||||
# Optimize for low memory usage
|
||||
base_config.MAX_CONCURRENT_STREAMS = 100
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
|
||||
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 50
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown use case: {use_case}")
|
||||
|
||||
return base_config
|
||||
1487
libp2p/transport/quic/connection.py
Normal file
1487
libp2p/transport/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
391
libp2p/transport/quic/exceptions.py
Normal file
391
libp2p/transport/quic/exceptions.py
Normal file
@ -0,0 +1,391 @@
|
||||
"""
|
||||
QUIC Transport exceptions
|
||||
"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
class QUICError(Exception):
|
||||
"""Base exception for all QUIC transport errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
# Transport-level exceptions
|
||||
|
||||
|
||||
class QUICTransportError(QUICError):
|
||||
"""Base exception for QUIC transport operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICDialError(QUICTransportError):
|
||||
"""Error occurred during QUIC connection establishment."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICListenError(QUICTransportError):
|
||||
"""Error occurred during QUIC listener operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICSecurityError(QUICTransportError):
|
||||
"""Error related to QUIC security/TLS operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Connection-level exceptions
|
||||
|
||||
|
||||
class QUICConnectionError(QUICError):
|
||||
"""Base exception for QUIC connection operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionClosedError(QUICConnectionError):
|
||||
"""QUIC connection has been closed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionTimeoutError(QUICConnectionError):
|
||||
"""QUIC connection operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICHandshakeError(QUICConnectionError):
|
||||
"""Error during QUIC handshake process."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICPeerVerificationError(QUICConnectionError):
|
||||
"""Error verifying peer identity during handshake."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Stream-level exceptions
|
||||
|
||||
|
||||
class QUICStreamError(QUICError):
|
||||
"""Base exception for QUIC stream operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
):
|
||||
super().__init__(message, error_code)
|
||||
self.stream_id = stream_id
|
||||
|
||||
|
||||
class QUICStreamClosedError(QUICStreamError):
|
||||
"""Stream is closed and cannot be used for I/O operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamResetError(QUICStreamError):
|
||||
"""Stream was reset by local or remote peer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
reset_by_peer: bool = False,
|
||||
):
|
||||
super().__init__(message, stream_id, error_code)
|
||||
self.reset_by_peer = reset_by_peer
|
||||
|
||||
|
||||
class QUICStreamTimeoutError(QUICStreamError):
|
||||
"""Stream operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamBackpressureError(QUICStreamError):
|
||||
"""Stream write blocked due to flow control."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamLimitError(QUICStreamError):
|
||||
"""Stream limit reached (too many concurrent streams)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamStateError(QUICStreamError):
|
||||
"""Invalid operation for current stream state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
current_state: str | None = None,
|
||||
attempted_operation: str | None = None,
|
||||
):
|
||||
super().__init__(message, stream_id)
|
||||
self.current_state = current_state
|
||||
self.attempted_operation = attempted_operation
|
||||
|
||||
|
||||
# Flow control exceptions
|
||||
|
||||
|
||||
class QUICFlowControlError(QUICError):
|
||||
"""Base exception for flow control related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlViolationError(QUICFlowControlError):
|
||||
"""Flow control limits were violated."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlDeadlockError(QUICFlowControlError):
|
||||
"""Flow control deadlock detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Resource management exceptions
|
||||
|
||||
|
||||
class QUICResourceError(QUICError):
|
||||
"""Base exception for resource management errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICMemoryLimitError(QUICResourceError):
|
||||
"""Memory limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionLimitError(QUICResourceError):
|
||||
"""Connection limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Multiaddr and addressing exceptions
|
||||
|
||||
|
||||
class QUICAddressError(QUICError):
|
||||
"""Base exception for QUIC addressing errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICInvalidMultiaddrError(QUICAddressError):
|
||||
"""Invalid multiaddr format for QUIC transport."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICAddressResolutionError(QUICAddressError):
|
||||
"""Failed to resolve QUIC address."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICProtocolError(QUICError):
|
||||
"""Base exception for QUIC protocol errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICVersionNegotiationError(QUICProtocolError):
|
||||
"""QUIC version negotiation failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICUnsupportedVersionError(QUICProtocolError):
|
||||
"""Unsupported QUIC version."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Configuration exceptions
|
||||
|
||||
|
||||
class QUICConfigurationError(QUICError):
|
||||
"""Base exception for QUIC configuration errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICInvalidConfigError(QUICConfigurationError):
|
||||
"""Invalid QUIC configuration parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICCertificateError(QUICConfigurationError):
|
||||
"""Error with TLS certificate configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def map_quic_error_code(error_code: int) -> str:
|
||||
"""
|
||||
Map QUIC error codes to human-readable descriptions.
|
||||
Based on RFC 9000 Transport Error Codes.
|
||||
"""
|
||||
error_codes = {
|
||||
0x00: "NO_ERROR",
|
||||
0x01: "INTERNAL_ERROR",
|
||||
0x02: "CONNECTION_REFUSED",
|
||||
0x03: "FLOW_CONTROL_ERROR",
|
||||
0x04: "STREAM_LIMIT_ERROR",
|
||||
0x05: "STREAM_STATE_ERROR",
|
||||
0x06: "FINAL_SIZE_ERROR",
|
||||
0x07: "FRAME_ENCODING_ERROR",
|
||||
0x08: "TRANSPORT_PARAMETER_ERROR",
|
||||
0x09: "CONNECTION_ID_LIMIT_ERROR",
|
||||
0x0A: "PROTOCOL_VIOLATION",
|
||||
0x0B: "INVALID_TOKEN",
|
||||
0x0C: "APPLICATION_ERROR",
|
||||
0x0D: "CRYPTO_BUFFER_EXCEEDED",
|
||||
0x0E: "KEY_UPDATE_ERROR",
|
||||
0x0F: "AEAD_LIMIT_REACHED",
|
||||
0x10: "NO_VIABLE_PATH",
|
||||
}
|
||||
|
||||
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
|
||||
|
||||
|
||||
def create_stream_error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
) -> QUICStreamError:
|
||||
"""
|
||||
Factory function to create appropriate stream error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
|
||||
message: Error message
|
||||
stream_id: Stream identifier
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICStreamError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICStreamClosedError(message, stream_id, error_code)
|
||||
elif error_type == "reset":
|
||||
return QUICStreamResetError(message, stream_id, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICStreamTimeoutError(message, stream_id, error_code)
|
||||
elif error_type in ("backpressure", "flow_control"):
|
||||
return QUICStreamBackpressureError(message, stream_id, error_code)
|
||||
elif error_type in ("limit", "stream_limit"):
|
||||
return QUICStreamLimitError(message, stream_id, error_code)
|
||||
elif error_type == "state":
|
||||
return QUICStreamStateError(message, stream_id)
|
||||
else:
|
||||
return QUICStreamError(message, stream_id, error_code)
|
||||
|
||||
|
||||
def create_connection_error(
|
||||
error_type: str, message: str, error_code: int | None = None
|
||||
) -> QUICConnectionError:
|
||||
"""
|
||||
Factory function to create appropriate connection error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "timeout", "handshake", etc.)
|
||||
message: Error message
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICConnectionError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICConnectionClosedError(message, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICConnectionTimeoutError(message, error_code)
|
||||
elif error_type == "handshake":
|
||||
return QUICHandshakeError(message, error_code)
|
||||
elif error_type in ("peer_verification", "verification"):
|
||||
return QUICPeerVerificationError(message, error_code)
|
||||
else:
|
||||
return QUICConnectionError(message, error_code)
|
||||
|
||||
|
||||
class QUICErrorContext:
|
||||
"""
|
||||
Context manager for handling QUIC errors with automatic error mapping.
|
||||
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
|
||||
"""
|
||||
|
||||
def __init__(self, operation: str, component: str = "quic") -> None:
|
||||
self.operation = operation
|
||||
self.component = component
|
||||
|
||||
def __enter__(self) -> "QUICErrorContext":
|
||||
return self
|
||||
|
||||
# TODO: Fix types for exc_type
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> Literal[False]:
|
||||
if exc_type is None:
|
||||
return False
|
||||
|
||||
if exc_val is None:
|
||||
return False
|
||||
|
||||
# Map common aioquic exceptions to our exceptions
|
||||
if "ConnectionClosed" in str(exc_type):
|
||||
raise QUICConnectionClosedError(
|
||||
f"Connection closed during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "StreamReset" in str(exc_type):
|
||||
raise QUICStreamResetError(
|
||||
f"Stream reset during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "timeout" in str(exc_val).lower():
|
||||
if "stream" in self.component.lower():
|
||||
raise QUICStreamTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
else:
|
||||
raise QUICConnectionTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "flow control" in str(exc_val).lower():
|
||||
raise QUICStreamBackpressureError(
|
||||
f"Flow control error during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
|
||||
# Let other exceptions propagate
|
||||
return False
|
||||
1041
libp2p/transport/quic/listener.py
Normal file
1041
libp2p/transport/quic/listener.py
Normal file
File diff suppressed because it is too large
Load Diff
1165
libp2p/transport/quic/security.py
Normal file
1165
libp2p/transport/quic/security.py
Normal file
File diff suppressed because it is too large
Load Diff
656
libp2p/transport/quic/stream.py
Normal file
656
libp2p/transport/quic/stream.py
Normal file
@ -0,0 +1,656 @@
|
||||
"""
|
||||
QUIC Stream implementation
|
||||
Provides stream interface over QUIC's native multiplexing.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import trio
|
||||
|
||||
from .exceptions import (
|
||||
QUICStreamBackpressureError,
|
||||
QUICStreamClosedError,
|
||||
QUICStreamResetError,
|
||||
QUICStreamTimeoutError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import IMuxedStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
|
||||
from .connection import QUICConnection
|
||||
else:
|
||||
IMuxedStream = cast(type, object)
|
||||
TProtocol = cast(type, object)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamState(Enum):
|
||||
"""Stream lifecycle states following libp2p patterns."""
|
||||
|
||||
OPEN = "open"
|
||||
WRITE_CLOSED = "write_closed"
|
||||
READ_CLOSED = "read_closed"
|
||||
CLOSED = "closed"
|
||||
RESET = "reset"
|
||||
|
||||
|
||||
class StreamDirection(Enum):
|
||||
"""Stream direction for tracking initiator."""
|
||||
|
||||
INBOUND = "inbound"
|
||||
OUTBOUND = "outbound"
|
||||
|
||||
|
||||
class StreamTimeline:
|
||||
"""Track stream lifecycle events for debugging and monitoring."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.created_at = time.time()
|
||||
self.opened_at: float | None = None
|
||||
self.first_data_at: float | None = None
|
||||
self.closed_at: float | None = None
|
||||
self.reset_at: float | None = None
|
||||
self.error_code: int | None = None
|
||||
|
||||
def record_open(self) -> None:
|
||||
self.opened_at = time.time()
|
||||
|
||||
def record_first_data(self) -> None:
|
||||
if self.first_data_at is None:
|
||||
self.first_data_at = time.time()
|
||||
|
||||
def record_close(self) -> None:
|
||||
self.closed_at = time.time()
|
||||
|
||||
def record_reset(self, error_code: int) -> None:
|
||||
self.reset_at = time.time()
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class QUICStream(IMuxedStream):
|
||||
"""
|
||||
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||
|
||||
Based on patterns from go-libp2p and js-libp2p, this implementation:
|
||||
- Leverages QUIC's native multiplexing and flow control
|
||||
- Integrates with libp2p resource management
|
||||
- Provides comprehensive error handling with QUIC-specific codes
|
||||
- Supports bidirectional communication with independent close semantics
|
||||
- Implements proper stream lifecycle management
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "QUICConnection",
|
||||
stream_id: int,
|
||||
direction: StreamDirection,
|
||||
remote_addr: tuple[str, int],
|
||||
resource_scope: Any | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize QUIC stream.
|
||||
|
||||
Args:
|
||||
connection: Parent QUIC connection
|
||||
stream_id: QUIC stream identifier
|
||||
direction: Stream direction (inbound/outbound)
|
||||
resource_scope: Resource manager scope for memory accounting
|
||||
remote_addr: Remote addr stream is connected to
|
||||
|
||||
"""
|
||||
self._connection = connection
|
||||
self._stream_id = stream_id
|
||||
self._direction = direction
|
||||
self._resource_scope = resource_scope
|
||||
|
||||
# libp2p interface compliance
|
||||
self._protocol: TProtocol | None = None
|
||||
self._metadata: dict[str, Any] = {}
|
||||
self._remote_addr = remote_addr
|
||||
|
||||
# Stream state management
|
||||
self._state = StreamState.OPEN
|
||||
self._state_lock = trio.Lock()
|
||||
|
||||
# Flow control and buffering
|
||||
self._receive_buffer = bytearray()
|
||||
self._receive_buffer_lock = trio.Lock()
|
||||
self._receive_event = trio.Event()
|
||||
self._backpressure_event = trio.Event()
|
||||
self._backpressure_event.set() # Initially no backpressure
|
||||
|
||||
# Close/reset state
|
||||
self._write_closed = False
|
||||
self._read_closed = False
|
||||
self._close_event = trio.Event()
|
||||
self._reset_error_code: int | None = None
|
||||
|
||||
# Lifecycle tracking
|
||||
self._timeline = StreamTimeline()
|
||||
self._timeline.record_open()
|
||||
|
||||
# Resource accounting
|
||||
self._memory_reserved = 0
|
||||
|
||||
# Stream constant configurations
|
||||
self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT
|
||||
self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT
|
||||
self.FLOW_CONTROL_WINDOW_SIZE = (
|
||||
connection._transport._config.STREAM_FLOW_CONTROL_WINDOW
|
||||
)
|
||||
self.MAX_RECEIVE_BUFFER_SIZE = (
|
||||
connection._transport._config.MAX_STREAM_RECEIVE_BUFFER
|
||||
)
|
||||
|
||||
if self._resource_scope:
|
||||
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
|
||||
|
||||
logger.debug(
|
||||
f"Created QUIC stream {stream_id} "
|
||||
f"({direction.value}, connection: {connection.remote_peer_id()})"
|
||||
)
|
||||
|
||||
# Properties for libp2p interface compliance
|
||||
|
||||
@property
|
||||
def protocol(self) -> TProtocol | None:
|
||||
"""Get the protocol identifier for this stream."""
|
||||
return self._protocol
|
||||
|
||||
@protocol.setter
|
||||
def protocol(self, protocol_id: TProtocol) -> None:
|
||||
"""Set the protocol identifier for this stream."""
|
||||
self._protocol = protocol_id
|
||||
self._metadata["protocol"] = protocol_id
|
||||
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str:
|
||||
"""Get stream ID as string for libp2p compatibility."""
|
||||
return str(self._stream_id)
|
||||
|
||||
@property
|
||||
def muxed_conn(self) -> "QUICConnection": # type: ignore
|
||||
"""Get the parent muxed connection."""
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def state(self) -> StreamState:
|
||||
"""Get current stream state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def direction(self) -> StreamDirection:
|
||||
"""Get stream direction."""
|
||||
return self._direction
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
"""Check if this stream was locally initiated."""
|
||||
return self._direction == StreamDirection.OUTBOUND
|
||||
|
||||
# Core stream operations
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read data from the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
n: Maximum number of bytes to read. If None or -1, read all available.
|
||||
|
||||
Returns:
|
||||
Data read from stream
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed
|
||||
QUICStreamResetError: Stream was reset
|
||||
QUICStreamTimeoutError: Read timeout exceeded
|
||||
|
||||
"""
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._read_closed:
|
||||
# Return any remaining buffered data, then EOF
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
return b""
|
||||
|
||||
# Wait for data with timeout
|
||||
timeout = self.READ_TIMEOUT
|
||||
try:
|
||||
with trio.move_on_after(timeout) as cancel_scope:
|
||||
while True:
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
|
||||
# Check if stream was closed while waiting
|
||||
if self._read_closed:
|
||||
return b""
|
||||
|
||||
# Wait for more data
|
||||
await self._receive_event.wait()
|
||||
self._receive_event = trio.Event() # Reset for next wait
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
|
||||
|
||||
return b""
|
||||
except QUICStreamResetError:
|
||||
# Stream was reset while reading
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from stream {self.stream_id}: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
data: Data to write
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed for writing
|
||||
QUICStreamBackpressureError: Flow control window exhausted
|
||||
QUICStreamResetError: Stream was reset
|
||||
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._write_closed:
|
||||
raise QUICStreamClosedError(
|
||||
f"Stream {self.stream_id} write side is closed"
|
||||
)
|
||||
|
||||
try:
|
||||
# Handle flow control backpressure
|
||||
await self._backpressure_event.wait()
|
||||
|
||||
# Send data through QUIC connection
|
||||
self._connection._quic.send_stream_data(self._stream_id, data)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._timeline.record_first_data()
|
||||
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to stream {self.stream_id}: {e}")
|
||||
# Convert QUIC-specific errors
|
||||
if "flow control" in str(e).lower():
|
||||
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the stream gracefully (both read and write sides).
|
||||
|
||||
This implements proper close semantics where both sides
|
||||
are closed and resources are cleaned up.
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
return
|
||||
|
||||
logger.debug(f"Closing stream {self.stream_id}")
|
||||
|
||||
# Close both sides
|
||||
if not self._write_closed:
|
||||
await self.close_write()
|
||||
if not self._read_closed:
|
||||
await self.close_read()
|
||||
|
||||
# Update state and cleanup
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.CLOSED
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_close()
|
||||
self._close_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} closed")
|
||||
|
||||
async def close_write(self) -> None:
|
||||
"""Close the write side of the stream."""
|
||||
if self._write_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
# Send FIN to close write side
|
||||
self._connection._quic.send_stream_data(
|
||||
self._stream_id, b"", end_stream=True
|
||||
)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} write side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def close_read(self) -> None:
|
||||
"""Close the read side of the stream."""
|
||||
if self._read_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
self._read_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up any pending reads
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} read side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def reset(self, error_code: int = 0) -> None:
|
||||
"""
|
||||
Reset the stream with the given error code.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code for the reset
|
||||
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Resetting stream {self.stream_id} with error code {error_code}"
|
||||
)
|
||||
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
try:
|
||||
# Send QUIC reset frame
|
||||
self._connection._quic.reset_stream(self._stream_id, error_code)
|
||||
await self._connection._transmit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
|
||||
finally:
|
||||
# Always cleanup resources
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if stream is completely closed."""
|
||||
return self._state in (StreamState.CLOSED, StreamState.RESET)
|
||||
|
||||
def is_reset(self) -> bool:
|
||||
"""Check if stream was reset."""
|
||||
return self._state == StreamState.RESET
|
||||
|
||||
def can_read(self) -> bool:
|
||||
"""Check if stream can be read from."""
|
||||
return not self._read_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
def can_write(self) -> bool:
|
||||
"""Check if stream can be written to."""
|
||||
return not self._write_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
||||
"""
|
||||
Handle data received from the QUIC connection.
|
||||
|
||||
Args:
|
||||
data: Received data
|
||||
end_stream: Whether this is the last data (FIN received)
|
||||
|
||||
"""
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
if data:
|
||||
async with self._receive_buffer_lock:
|
||||
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
|
||||
logger.warning(
|
||||
f"Stream {self.stream_id} receive buffer overflow, "
|
||||
f"dropping {len(data)} bytes"
|
||||
)
|
||||
return
|
||||
|
||||
self._receive_buffer.extend(data)
|
||||
self._timeline.record_first_data()
|
||||
|
||||
# Notify waiting readers
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
|
||||
|
||||
if end_stream:
|
||||
self._read_closed = True
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up readers to process remaining data and EOF
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received FIN")
|
||||
|
||||
async def handle_stop_sending(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle STOP_SENDING frame from remote peer.
|
||||
|
||||
When a STOP_SENDING frame is received, the peer is requesting that we
|
||||
stop sending data on this stream. We respond by resetting the stream.
|
||||
|
||||
Args:
|
||||
error_code: Error code from the STOP_SENDING frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
|
||||
)
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
# Wake up any pending write operations
|
||||
self._backpressure_event.set()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.direction == StreamDirection.OUTBOUND:
|
||||
self._state = StreamState.CLOSED
|
||||
elif self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
# Only write side closed - add WRITE_CLOSED state if needed
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
# Send RESET_STREAM in response (QUIC protocol requirement)
|
||||
try:
|
||||
self._connection._quic.reset_stream(int(self.stream_id), error_code)
|
||||
await self._connection._transmit()
|
||||
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def handle_reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle stream reset from remote peer.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code from reset frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} reset by peer with error code {error_code}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
# Wake up any pending operations
|
||||
self._receive_event.set()
|
||||
self._backpressure_event.set()
|
||||
|
||||
async def handle_flow_control_update(self, available_window: int) -> None:
|
||||
"""
|
||||
Handle flow control window updates.
|
||||
|
||||
Args:
|
||||
available_window: Available flow control window size
|
||||
|
||||
"""
|
||||
if available_window > 0:
|
||||
self._backpressure_event.set()
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} flow control".__add__(
|
||||
f"window updated: {available_window}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._backpressure_event = trio.Event() # Reset to blocking state
|
||||
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
|
||||
|
||||
def _extract_data_from_buffer(self, n: int) -> bytes:
|
||||
"""Extract data from receive buffer with specified limit."""
|
||||
if n == -1:
|
||||
# Read all available data
|
||||
data = bytes(self._receive_buffer)
|
||||
self._receive_buffer.clear()
|
||||
else:
|
||||
# Read up to n bytes
|
||||
data = bytes(self._receive_buffer[:n])
|
||||
self._receive_buffer = self._receive_buffer[n:]
|
||||
|
||||
return data
|
||||
|
||||
async def _handle_stream_error(self, error: Exception) -> None:
|
||||
"""Handle errors by resetting the stream."""
|
||||
logger.error(f"Stream {self.stream_id} error: {error}")
|
||||
await self.reset(error_code=1) # Generic error code
|
||||
|
||||
def _reserve_memory(self, size: int) -> None:
|
||||
"""Reserve memory with resource manager."""
|
||||
if self._resource_scope:
|
||||
try:
|
||||
self._resource_scope.reserve_memory(size)
|
||||
self._memory_reserved += size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to reserve memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
def _release_memory(self, size: int) -> None:
|
||||
"""Release memory with resource manager."""
|
||||
if self._resource_scope and size > 0:
|
||||
try:
|
||||
self._resource_scope.release_memory(size)
|
||||
self._memory_reserved = max(0, self._memory_reserved - size)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to release memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def _cleanup_resources(self) -> None:
|
||||
"""Clean up stream resources."""
|
||||
# Release all reserved memory
|
||||
if self._memory_reserved > 0:
|
||||
self._release_memory(self._memory_reserved)
|
||||
|
||||
# Clear receive buffer
|
||||
async with self._receive_buffer_lock:
|
||||
self._receive_buffer.clear()
|
||||
|
||||
# Remove from connection's stream registry
|
||||
self._connection._remove_stream(self._stream_id)
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} resources cleaned up")
|
||||
|
||||
# Abstact implementations
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int]:
|
||||
return self._remote_addr
|
||||
|
||||
async def __aenter__(self) -> "QUICStream":
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the stream."""
|
||||
logger.debug("Exiting the context and closing the stream")
|
||||
await self.close()
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
Set a deadline for the stream. QUIC does not support deadlines natively,
|
||||
so this method always returns False to indicate the operation is unsupported.
|
||||
|
||||
:param ttl: Time-to-live in seconds (ignored).
|
||||
:return: False, as deadlines are not supported.
|
||||
"""
|
||||
raise NotImplementedError("QUIC does not support setting read deadlines")
|
||||
|
||||
# String representation for debugging
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"QUICStream(id={self.stream_id}, "
|
||||
f"state={self._state.value}, "
|
||||
f"direction={self._direction.value}, "
|
||||
f"protocol={self._protocol})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"QUICStream({self.stream_id})"
|
||||
491
libp2p/transport/quic/transport.py
Normal file
491
libp2p/transport/quic/transport.py
Normal file
@ -0,0 +1,491 @@
|
||||
"""
|
||||
QUIC Transport implementation
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import ssl
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from aioquic.quic.configuration import (
|
||||
QuicConfiguration,
|
||||
)
|
||||
from aioquic.quic.connection import (
|
||||
QuicConnection as NativeQUICConnection,
|
||||
)
|
||||
from aioquic.quic.logger import QuicLogger
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
)
|
||||
from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.transport.quic.security import QUICTLSSecurityConfig
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_client_config_from_base,
|
||||
create_server_config_from_base,
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
quic_multiaddr_to_endpoint,
|
||||
quic_version_to_wire_format,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.network.swarm import Swarm
|
||||
else:
|
||||
Swarm = cast(type, object)
|
||||
|
||||
from .config import (
|
||||
QUICTransportConfig,
|
||||
)
|
||||
from .connection import (
|
||||
QUICConnection,
|
||||
)
|
||||
from .exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
QUICSecurityError,
|
||||
)
|
||||
from .listener import (
|
||||
QUICListener,
|
||||
)
|
||||
from .security import (
|
||||
QUICTLSConfigManager,
|
||||
create_quic_security_transport,
|
||||
)
|
||||
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QUICTransport(ITransport):
|
||||
"""
|
||||
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Initialize QUIC transport with security integration.
|
||||
|
||||
Args:
|
||||
private_key: libp2p private key for identity and TLS cert generation
|
||||
config: QUIC transport configuration options
|
||||
|
||||
"""
|
||||
self._private_key = private_key
|
||||
self._peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
self._config = config or QUICTransportConfig()
|
||||
|
||||
# Connection management
|
||||
self._connections: dict[str, QUICConnection] = {}
|
||||
self._listeners: list[QUICListener] = []
|
||||
|
||||
# Security manager for TLS integration
|
||||
self._security_manager = create_quic_security_transport(
|
||||
self._private_key, self._peer_id
|
||||
)
|
||||
|
||||
# QUIC configurations for different versions
|
||||
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
|
||||
self._setup_quic_configurations()
|
||||
|
||||
# Resource management
|
||||
self._closed = False
|
||||
self._nursery_manager = trio.CapacityLimiter(1)
|
||||
self._background_nursery: trio.Nursery | None = None
|
||||
|
||||
self._swarm: Swarm | None = None
|
||||
|
||||
logger.debug(
|
||||
f"Initialized QUIC transport with security for peer {self._peer_id}"
|
||||
)
|
||||
|
||||
def set_background_nursery(self, nursery: trio.Nursery) -> None:
|
||||
"""Set the nursery to use for background tasks (called by swarm)."""
|
||||
self._background_nursery = nursery
|
||||
logger.debug("Transport background nursery set")
|
||||
|
||||
def set_swarm(self, swarm: Swarm) -> None:
|
||||
"""Set the swarm for adding incoming connections."""
|
||||
self._swarm = swarm
|
||||
|
||||
def _setup_quic_configurations(self) -> None:
|
||||
"""Setup QUIC configurations."""
|
||||
try:
|
||||
# Get TLS configuration from security manager
|
||||
server_tls_config = self._security_manager.create_server_config()
|
||||
client_tls_config = self._security_manager.create_client_config()
|
||||
|
||||
# Base server configuration
|
||||
base_server_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=get_alpn_protocols(),
|
||||
verify_mode=self._config.verify_mode,
|
||||
max_datagram_frame_size=self._config.max_datagram_size,
|
||||
idle_timeout=self._config.idle_timeout,
|
||||
)
|
||||
|
||||
# Base client configuration
|
||||
base_client_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=get_alpn_protocols(),
|
||||
verify_mode=self._config.verify_mode,
|
||||
max_datagram_frame_size=self._config.max_datagram_size,
|
||||
idle_timeout=self._config.idle_timeout,
|
||||
)
|
||||
|
||||
# Apply TLS configuration
|
||||
self._apply_tls_configuration(base_server_config, server_tls_config)
|
||||
self._apply_tls_configuration(base_client_config, client_tls_config)
|
||||
|
||||
# QUIC v1 (RFC 9000) configurations
|
||||
if self._config.enable_v1:
|
||||
quic_v1_server_config = create_server_config_from_base(
|
||||
base_server_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
quic_v1_client_config = create_client_config_from_base(
|
||||
base_client_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
# Store both server and client configs for v1
|
||||
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = (
|
||||
quic_v1_server_config
|
||||
)
|
||||
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = (
|
||||
quic_v1_client_config
|
||||
)
|
||||
|
||||
# QUIC draft-29 configurations for compatibility
|
||||
if self._config.enable_draft29:
|
||||
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
|
||||
draft29_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
|
||||
draft29_client_config = copy.copy(base_client_config)
|
||||
draft29_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
|
||||
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = (
|
||||
draft29_server_config
|
||||
)
|
||||
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = (
|
||||
draft29_client_config
|
||||
)
|
||||
|
||||
logger.debug("QUIC configurations initialized with libp2p TLS security")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(
|
||||
f"Failed to setup QUIC TLS configurations: {e}"
|
||||
) from e
|
||||
|
||||
def _apply_tls_configuration(
|
||||
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
|
||||
) -> None:
|
||||
"""
|
||||
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
||||
|
||||
Args:
|
||||
config: QuicConfiguration to update
|
||||
tls_config: TLS configuration dictionary from security manager
|
||||
|
||||
"""
|
||||
try:
|
||||
config.certificate = tls_config.certificate
|
||||
config.private_key = tls_config.private_key
|
||||
config.certificate_chain = tls_config.certificate_chain
|
||||
config.alpn_protocols = tls_config.alpn_protocols
|
||||
config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.debug("Successfully applied TLS configuration to QUIC config")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
|
||||
|
||||
async def dial(
|
||||
self,
|
||||
maddr: multiaddr.Multiaddr,
|
||||
) -> QUICConnection:
|
||||
"""
|
||||
Dial a remote peer using QUIC transport with security verification.
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1)
|
||||
peer_id: Expected peer ID for verification
|
||||
nursery: Nursery to execute the background tasks
|
||||
|
||||
Returns:
|
||||
Raw connection interface to the remote peer
|
||||
|
||||
Raises:
|
||||
QUICDialError: If dialing fails
|
||||
QUICSecurityError: If security verification fails
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise QUICDialError("Transport is closed")
|
||||
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
|
||||
|
||||
try:
|
||||
# Extract connection details from multiaddr
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
remote_peer_id = maddr.get_peer_id()
|
||||
if remote_peer_id is not None:
|
||||
remote_peer_id = ID.from_base58(remote_peer_id)
|
||||
|
||||
if remote_peer_id is None:
|
||||
logger.error("Unable to derive peer id from multiaddr")
|
||||
raise QUICDialError("Unable to derive peer id from multiaddr")
|
||||
quic_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
# Get appropriate QUIC client configuration
|
||||
config_key = TProtocol(f"{quic_version}_client")
|
||||
logger.debug("config_key", config_key, self._quic_configs.keys())
|
||||
config = self._quic_configs.get(config_key)
|
||||
if not config:
|
||||
raise QUICDialError(f"Unsupported QUIC version: {quic_version}")
|
||||
|
||||
config.is_client = True
|
||||
config.quic_logger = QuicLogger()
|
||||
|
||||
# Ensure client certificate is properly set for mutual authentication
|
||||
if not config.certificate or not config.private_key:
|
||||
logger.warning(
|
||||
"Client config missing certificate - applying TLS config"
|
||||
)
|
||||
client_tls_config = self._security_manager.create_client_config()
|
||||
self._apply_tls_configuration(config, client_tls_config)
|
||||
|
||||
# Debug log to verify certificate is present
|
||||
logger.info(
|
||||
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
|
||||
)
|
||||
|
||||
logger.debug("Starting QUIC Connection")
|
||||
# Create QUIC connection using aioquic's sans-IO core
|
||||
native_quic_connection = NativeQUICConnection(configuration=config)
|
||||
|
||||
# Create trio-based QUIC connection wrapper with security
|
||||
connection = QUICConnection(
|
||||
quic_connection=native_quic_connection,
|
||||
remote_addr=(host, port),
|
||||
remote_peer_id=remote_peer_id,
|
||||
local_peer_id=self._peer_id,
|
||||
is_initiator=True,
|
||||
maddr=maddr,
|
||||
transport=self,
|
||||
security_manager=self._security_manager,
|
||||
)
|
||||
logger.debug("QUIC Connection Created")
|
||||
|
||||
if self._background_nursery is None:
|
||||
logger.error("No nursery set to execute background tasks")
|
||||
raise QUICDialError("No nursery found to execute tasks")
|
||||
|
||||
await connection.connect(self._background_nursery)
|
||||
|
||||
# Store connection for management
|
||||
conn_id = f"{host}:{port}"
|
||||
self._connections[conn_id] = connection
|
||||
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
|
||||
raise QUICDialError(f"Dial failed: {e}") from e
|
||||
|
||||
async def _verify_peer_identity(
|
||||
self, connection: QUICConnection, expected_peer_id: ID
|
||||
) -> None:
|
||||
"""
|
||||
Verify remote peer identity after TLS handshake.
|
||||
|
||||
Args:
|
||||
connection: The established QUIC connection
|
||||
expected_peer_id: Expected peer ID
|
||||
|
||||
Raises:
|
||||
QUICSecurityError: If peer verification fails
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get peer certificate from the connection
|
||||
peer_certificate = await connection.get_peer_certificate()
|
||||
|
||||
if not peer_certificate:
|
||||
raise QUICSecurityError("No peer certificate available")
|
||||
|
||||
# Verify peer identity using security manager
|
||||
verified_peer_id = self._security_manager.verify_peer_identity(
|
||||
peer_certificate, expected_peer_id
|
||||
)
|
||||
|
||||
if verified_peer_id != expected_peer_id:
|
||||
raise QUICSecurityError(
|
||||
"Peer ID verification failed: expected "
|
||||
f"{expected_peer_id}, got {verified_peer_id}"
|
||||
)
|
||||
|
||||
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Peer identity verification failed: {e}") from e
|
||||
|
||||
def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener:
|
||||
"""
|
||||
Create a QUIC listener with integrated security.
|
||||
|
||||
Args:
|
||||
handler_function: Function to handle new connections
|
||||
|
||||
Returns:
|
||||
QUIC listener instance
|
||||
|
||||
Raises:
|
||||
QUICListenError: If transport is closed
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise QUICListenError("Transport is closed")
|
||||
|
||||
# Get server configurations for the listener
|
||||
server_configs = {
|
||||
version: config
|
||||
for version, config in self._quic_configs.items()
|
||||
if version.endswith("_server")
|
||||
}
|
||||
|
||||
listener = QUICListener(
|
||||
transport=self,
|
||||
handler_function=handler_function,
|
||||
quic_configs=server_configs,
|
||||
config=self._config,
|
||||
security_manager=self._security_manager,
|
||||
)
|
||||
|
||||
self._listeners.append(listener)
|
||||
logger.debug("Created QUIC listener with security")
|
||||
return listener
|
||||
|
||||
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""
|
||||
Check if this transport can dial the given multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to check
|
||||
|
||||
Returns:
|
||||
True if this transport can dial the address
|
||||
|
||||
"""
|
||||
return is_quic_multiaddr(maddr)
|
||||
|
||||
def protocols(self) -> list[TProtocol]:
|
||||
"""
|
||||
Get supported protocol identifiers.
|
||||
|
||||
Returns:
|
||||
List of supported protocol strings
|
||||
|
||||
"""
|
||||
protocols = [QUIC_V1_PROTOCOL]
|
||||
if self._config.enable_draft29:
|
||||
protocols.append(QUIC_DRAFT29_PROTOCOL)
|
||||
return protocols
|
||||
|
||||
def listen_order(self) -> int:
|
||||
"""
|
||||
Get the listen order priority for this transport.
|
||||
Matches go-libp2p's ListenOrder = 1 for QUIC.
|
||||
|
||||
Returns:
|
||||
Priority order for listening (lower = higher priority)
|
||||
|
||||
"""
|
||||
return 1
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
logger.debug("Closing QUIC transport")
|
||||
|
||||
# Close all active connections and listeners concurrently using trio nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Close all connections
|
||||
for connection in self._connections.values():
|
||||
nursery.start_soon(connection.close)
|
||||
|
||||
# Close all listeners
|
||||
for listener in self._listeners:
|
||||
nursery.start_soon(listener.close)
|
||||
|
||||
self._connections.clear()
|
||||
self._listeners.clear()
|
||||
|
||||
logger.debug("QUIC transport closed")
|
||||
|
||||
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
|
||||
"""Clean up a terminated connection from all listeners."""
|
||||
try:
|
||||
for listener in self._listeners:
|
||||
await listener._remove_connection_by_object(connection)
|
||||
logger.debug(
|
||||
"✅ TRANSPORT: Cleaned up terminated connection from all listeners"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ TRANSPORT: Error cleaning up terminated connection: {e}")
|
||||
|
||||
def get_stats(self) -> dict[str, int | list[str] | object]:
|
||||
"""Get transport statistics including security info."""
|
||||
return {
|
||||
"active_connections": len(self._connections),
|
||||
"active_listeners": len(self._listeners),
|
||||
"supported_protocols": self.protocols(),
|
||||
"local_peer_id": str(self._peer_id),
|
||||
"security_enabled": True,
|
||||
"tls_configured": True,
|
||||
}
|
||||
|
||||
def get_security_manager(self) -> QUICTLSConfigManager:
|
||||
"""
|
||||
Get the security manager for this transport.
|
||||
|
||||
Returns:
|
||||
The QUIC TLS configuration manager
|
||||
|
||||
"""
|
||||
return self._security_manager
|
||||
|
||||
def get_listener_socket(self) -> trio.socket.SocketType | None:
|
||||
"""Get the socket from the first active listener."""
|
||||
for listener in self._listeners:
|
||||
if listener.is_listening() and listener._socket:
|
||||
return listener._socket
|
||||
return None
|
||||
466
libp2p/transport/quic/utils.py
Normal file
466
libp2p/transport/quic/utils.py
Normal file
@ -0,0 +1,466 @@
|
||||
"""
|
||||
Multiaddr utilities for QUIC transport - Module 4.
|
||||
Essential utilities required for QUIC transport implementation.
|
||||
Based on go-libp2p and js-libp2p QUIC implementations.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
import multiaddr
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||
|
||||
from .config import QUICTransportConfig
|
||||
from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Protocol constants
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
UDP_PROTOCOL = "udp"
|
||||
IP4_PROTOCOL = "ip4"
|
||||
IP6_PROTOCOL = "ip6"
|
||||
|
||||
SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server"
|
||||
CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client"
|
||||
|
||||
SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server"
|
||||
CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client"
|
||||
|
||||
CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = {
|
||||
SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000
|
||||
CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000
|
||||
SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||
CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||
}
|
||||
|
||||
# QUIC version to wire format mappings (required for aioquic)
|
||||
QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = {
|
||||
QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000
|
||||
QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29
|
||||
}
|
||||
|
||||
# ALPN protocols for libp2p over QUIC
|
||||
LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"]
|
||||
|
||||
|
||||
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""
|
||||
Check if a multiaddr represents a QUIC address.
|
||||
|
||||
Valid QUIC multiaddrs:
|
||||
- /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
- /ip4/127.0.0.1/udp/4001/quic
|
||||
- /ip6/::1/udp/4001/quic-v1
|
||||
- /ip6/::1/udp/4001/quic
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to check
|
||||
|
||||
Returns:
|
||||
True if the multiaddr represents a QUIC address
|
||||
|
||||
"""
|
||||
try:
|
||||
addr_str = str(maddr)
|
||||
|
||||
# Check for required components
|
||||
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
|
||||
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
|
||||
has_quic = (
|
||||
f"/{QUIC_V1_PROTOCOL}" in addr_str
|
||||
or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str
|
||||
or "/quic" in addr_str
|
||||
)
|
||||
|
||||
return has_ip and has_udp and has_quic
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
|
||||
"""
|
||||
Extract host and port from a QUIC multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Tuple of (host, port)
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}")
|
||||
|
||||
try:
|
||||
host = None
|
||||
port = None
|
||||
|
||||
# Try to get IPv4 address
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to get IPv6 address if IPv4 not found
|
||||
if host is None:
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get UDP port
|
||||
try:
|
||||
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
|
||||
port = int(port_str)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if host is None or port is None:
|
||||
raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}")
|
||||
|
||||
return host, port
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(
|
||||
f"Failed to parse QUIC multiaddr {maddr}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
|
||||
"""
|
||||
Determine QUIC version from multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
QUIC version identifier ("quic-v1" or "quic")
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol
|
||||
|
||||
"""
|
||||
try:
|
||||
addr_str = str(maddr)
|
||||
|
||||
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
|
||||
return QUIC_V1_PROTOCOL # RFC 9000
|
||||
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
|
||||
return QUIC_DRAFT29_PROTOCOL # draft-29
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(
|
||||
f"Failed to determine QUIC version from {maddr}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def create_quic_multiaddr(
|
||||
host: str, port: int, version: str = "quic-v1"
|
||||
) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Create a QUIC multiaddr from host, port, and version.
|
||||
|
||||
Args:
|
||||
host: IP address (IPv4 or IPv6)
|
||||
port: UDP port number
|
||||
version: QUIC version ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
QUIC multiaddr
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If invalid parameters provided
|
||||
|
||||
"""
|
||||
try:
|
||||
# Determine IP version
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
ip_proto = IP4_PROTOCOL
|
||||
else:
|
||||
ip_proto = IP6_PROTOCOL
|
||||
except ValueError:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}")
|
||||
|
||||
# Validate port
|
||||
if not (0 <= port <= 65535):
|
||||
raise QUICInvalidMultiaddrError(f"Invalid port: {port}")
|
||||
|
||||
# Validate and normalize QUIC version
|
||||
if version == "quic-v1" or version == "/quic-v1":
|
||||
quic_proto = QUIC_V1_PROTOCOL
|
||||
elif version == "quic" or version == "/quic":
|
||||
quic_proto = QUIC_DRAFT29_PROTOCOL
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
|
||||
|
||||
# Construct multiaddr
|
||||
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
|
||||
return multiaddr.Multiaddr(addr_str)
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e
|
||||
|
||||
|
||||
def quic_version_to_wire_format(version: TProtocol) -> int:
|
||||
"""
|
||||
Convert QUIC version string to wire format integer for aioquic.
|
||||
|
||||
Args:
|
||||
version: QUIC version string ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
Wire format version number
|
||||
|
||||
Raises:
|
||||
QUICUnsupportedVersionError: If version is not supported
|
||||
|
||||
"""
|
||||
wire_version = QUIC_VERSION_MAPPINGS.get(version)
|
||||
if wire_version is None:
|
||||
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||
|
||||
return wire_version
|
||||
|
||||
|
||||
def custom_quic_version_to_wire_format(version: TProtocol) -> int:
|
||||
"""
|
||||
Convert QUIC version string to wire format integer for aioquic.
|
||||
|
||||
Args:
|
||||
version: QUIC version string ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
Wire format version number
|
||||
|
||||
Raises:
|
||||
QUICUnsupportedVersionError: If version is not supported
|
||||
|
||||
"""
|
||||
wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version)
|
||||
if wire_version is None:
|
||||
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||
|
||||
return wire_version
|
||||
|
||||
|
||||
def get_alpn_protocols() -> list[str]:
|
||||
"""
|
||||
Get ALPN protocols for libp2p over QUIC.
|
||||
|
||||
Returns:
|
||||
List of ALPN protocol identifiers
|
||||
|
||||
"""
|
||||
return LIBP2P_ALPN_PROTOCOLS.copy()
|
||||
|
||||
|
||||
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Normalize a QUIC multiaddr to canonical form.
|
||||
|
||||
Args:
|
||||
maddr: Input QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Normalized multiaddr
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If not a valid QUIC multiaddr
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}")
|
||||
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
return create_quic_multiaddr(host, port, version)
|
||||
|
||||
|
||||
def create_server_config_from_base(
|
||||
base_config: QuicConfiguration,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
transport_config: QUICTransportConfig | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""
|
||||
Create a server configuration without using deepcopy.
|
||||
Manually copies attributes while handling cryptography objects properly.
|
||||
"""
|
||||
try:
|
||||
# Create new server configuration from scratch
|
||||
server_config = QuicConfiguration(is_client=False)
|
||||
server_config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Copy basic configuration attributes (these are safe to copy)
|
||||
copyable_attrs = [
|
||||
"alpn_protocols",
|
||||
"verify_mode",
|
||||
"max_datagram_frame_size",
|
||||
"idle_timeout",
|
||||
"max_concurrent_streams",
|
||||
"supported_versions",
|
||||
"max_data",
|
||||
"max_stream_data",
|
||||
"stateless_retry",
|
||||
"quantum_readiness_test",
|
||||
]
|
||||
|
||||
for attr in copyable_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(server_config, attr, value)
|
||||
|
||||
# Handle cryptography objects - these need direct reference, not copying
|
||||
crypto_attrs = [
|
||||
"certificate",
|
||||
"private_key",
|
||||
"certificate_chain",
|
||||
"ca_certs",
|
||||
]
|
||||
|
||||
for attr in crypto_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(server_config, attr, value)
|
||||
|
||||
# Apply security manager configuration if available
|
||||
if security_manager:
|
||||
try:
|
||||
server_tls_config = security_manager.create_server_config()
|
||||
|
||||
# Override with security manager's TLS configuration
|
||||
if server_tls_config.certificate:
|
||||
server_config.certificate = server_tls_config.certificate
|
||||
if server_tls_config.private_key:
|
||||
server_config.private_key = server_tls_config.private_key
|
||||
if server_tls_config.certificate_chain:
|
||||
server_config.certificate_chain = (
|
||||
server_tls_config.certificate_chain
|
||||
)
|
||||
if server_tls_config.alpn_protocols:
|
||||
server_config.alpn_protocols = server_tls_config.alpn_protocols
|
||||
server_tls_config.request_client_certificate = True
|
||||
if getattr(server_tls_config, "request_client_certificate", False):
|
||||
server_config._libp2p_request_client_cert = True # type: ignore
|
||||
else:
|
||||
logger.error(
|
||||
"🔧 Failed to set request_client_certificate in server config"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
|
||||
# Set transport-specific defaults if provided
|
||||
if transport_config:
|
||||
if server_config.idle_timeout == 0:
|
||||
server_config.idle_timeout = getattr(
|
||||
transport_config, "idle_timeout", 30.0
|
||||
)
|
||||
if server_config.max_datagram_frame_size is None:
|
||||
server_config.max_datagram_frame_size = getattr(
|
||||
transport_config, "max_datagram_size", 1200
|
||||
)
|
||||
# Ensure we have ALPN protocols
|
||||
if not server_config.alpn_protocols:
|
||||
server_config.alpn_protocols = ["libp2p"]
|
||||
|
||||
logger.debug("Successfully created server config without deepcopy")
|
||||
return server_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create server config: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_client_config_from_base(
|
||||
base_config: QuicConfiguration,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
transport_config: QUICTransportConfig | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""
|
||||
Create a client configuration without using deepcopy.
|
||||
"""
|
||||
try:
|
||||
# Create new client configuration from scratch
|
||||
client_config = QuicConfiguration(is_client=True)
|
||||
client_config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Copy basic configuration attributes
|
||||
copyable_attrs = [
|
||||
"alpn_protocols",
|
||||
"verify_mode",
|
||||
"max_datagram_frame_size",
|
||||
"idle_timeout",
|
||||
"max_concurrent_streams",
|
||||
"supported_versions",
|
||||
"max_data",
|
||||
"max_stream_data",
|
||||
"quantum_readiness_test",
|
||||
]
|
||||
|
||||
for attr in copyable_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Handle cryptography objects - these need direct reference, not copying
|
||||
crypto_attrs = [
|
||||
"certificate",
|
||||
"private_key",
|
||||
"certificate_chain",
|
||||
"ca_certs",
|
||||
]
|
||||
|
||||
for attr in crypto_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Apply security manager configuration if available
|
||||
if security_manager:
|
||||
try:
|
||||
client_tls_config = security_manager.create_client_config()
|
||||
|
||||
# Override with security manager's TLS configuration
|
||||
if client_tls_config.certificate:
|
||||
client_config.certificate = client_tls_config.certificate
|
||||
if client_tls_config.private_key:
|
||||
client_config.private_key = client_tls_config.private_key
|
||||
if client_tls_config.certificate_chain:
|
||||
client_config.certificate_chain = (
|
||||
client_tls_config.certificate_chain
|
||||
)
|
||||
if client_tls_config.alpn_protocols:
|
||||
client_config.alpn_protocols = client_tls_config.alpn_protocols
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
|
||||
# Ensure we have ALPN protocols
|
||||
if not client_config.alpn_protocols:
|
||||
client_config.alpn_protocols = ["libp2p"]
|
||||
|
||||
logger.debug("Successfully created client config without deepcopy")
|
||||
return client_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create client config: {e}")
|
||||
raise
|
||||
@ -1,9 +1,7 @@
|
||||
from libp2p.abc import (
|
||||
IListener,
|
||||
IMuxedConn,
|
||||
IRawConnection,
|
||||
ISecureConn,
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TMuxerOptions,
|
||||
@ -43,10 +41,6 @@ class TransportUpgrader:
|
||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
||||
|
||||
def upgrade_listener(self, transport: ITransport, listeners: IListener) -> None:
|
||||
"""Upgrade multiaddr listeners to libp2p-transport listeners."""
|
||||
# TODO: Figure out what to do with this function.
|
||||
|
||||
async def upgrade_security(
|
||||
self,
|
||||
raw_conn: IRawConnection,
|
||||
|
||||
@ -15,6 +15,13 @@ from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
)
|
||||
|
||||
from libp2p.utils.address_validation import (
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
expand_wildcard_address,
|
||||
find_free_port,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"decode_uvarint_from_stream",
|
||||
"encode_delim",
|
||||
@ -26,4 +33,8 @@ __all__ = [
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
"read_length_prefixed_protobuf",
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
|
||||
160
libp2p/utils/address_validation.py
Normal file
160
libp2p/utils/address_validation.py
Normal file
@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
try:
|
||||
from multiaddr.utils import ( # type: ignore
|
||||
get_network_addrs,
|
||||
get_thin_waist_addresses,
|
||||
)
|
||||
|
||||
_HAS_THIN_WAIST = True
|
||||
except ImportError: # pragma: no cover - only executed in older environments
|
||||
_HAS_THIN_WAIST = False
|
||||
get_thin_waist_addresses = None # type: ignore
|
||||
get_network_addrs = None # type: ignore
|
||||
|
||||
|
||||
def _safe_get_network_addrs(ip_version: int) -> list[str]:
|
||||
"""
|
||||
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
|
||||
Falls back to minimal defaults when Thin Waist helpers are missing.
|
||||
|
||||
:param ip_version: 4 or 6
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_network_addrs:
|
||||
try:
|
||||
return get_network_addrs(ip_version) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return []
|
||||
# Fallback behavior (very conservative)
|
||||
if ip_version == 4:
|
||||
return ["127.0.0.1"]
|
||||
if ip_version == 6:
|
||||
return ["::1"]
|
||||
return []
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to a free port provided by the OS
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
|
||||
"""
|
||||
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
|
||||
If Thin Waist isn't available, returns [addr] (identity).
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_thin_waist_addresses:
|
||||
try:
|
||||
if port is not None:
|
||||
return get_thin_waist_addresses(addr, port=port) or []
|
||||
return get_thin_waist_addresses(addr) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return [addr]
|
||||
return [addr]
|
||||
|
||||
|
||||
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
|
||||
"""
|
||||
Discover available network interfaces (IPv4 + IPv6 if supported) for binding.
|
||||
|
||||
:param port: Port number to bind to.
|
||||
:param protocol: Transport protocol (e.g., "tcp" or "udp").
|
||||
:return: List of Multiaddr objects representing candidate interface addresses.
|
||||
"""
|
||||
addrs: list[Multiaddr] = []
|
||||
|
||||
# IPv4 enumeration
|
||||
seen_v4: set[str] = set()
|
||||
|
||||
for ip in _safe_get_network_addrs(4):
|
||||
seen_v4.add(ip)
|
||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||
|
||||
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
|
||||
if seen_v4 and "127.0.0.1" not in seen_v4:
|
||||
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
|
||||
|
||||
# TODO: IPv6 support temporarily disabled due to libp2p handshake issues
|
||||
# IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure)
|
||||
# Re-enable IPv6 support once the following issues are resolved:
|
||||
# - libp2p security handshake over IPv6
|
||||
# - multiselect protocol over IPv6
|
||||
# - connection establishment over IPv6
|
||||
#
|
||||
# seen_v6: set[str] = set()
|
||||
# for ip in _safe_get_network_addrs(6):
|
||||
# seen_v6.add(ip)
|
||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||
#
|
||||
# # Always include IPv6 loopback for testing purposes when IPv6 is available
|
||||
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
|
||||
# if "::1" not in seen_v6:
|
||||
# addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}"))
|
||||
|
||||
# Fallback if nothing discovered
|
||||
if not addrs:
|
||||
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
|
||||
|
||||
return addrs
|
||||
|
||||
|
||||
def expand_wildcard_address(
|
||||
addr: Multiaddr, port: int | None = None
|
||||
) -> list[Multiaddr]:
|
||||
"""
|
||||
Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces.
|
||||
|
||||
:param addr: Multiaddr to expand.
|
||||
:param port: Optional override for port selection.
|
||||
:return: List of concrete Multiaddr instances.
|
||||
"""
|
||||
expanded = _safe_expand(addr, port=port)
|
||||
if not expanded: # Safety fallback
|
||||
return [addr]
|
||||
return expanded
|
||||
|
||||
|
||||
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
"""
|
||||
Choose an optimal address for an example to bind to:
|
||||
- Prefer non-loopback IPv4
|
||||
- Then non-loopback IPv6
|
||||
- Fallback to loopback
|
||||
- Fallback to wildcard
|
||||
|
||||
:param port: Port number.
|
||||
:param protocol: Transport protocol.
|
||||
:return: A single Multiaddr chosen heuristically.
|
||||
"""
|
||||
candidates = get_available_interfaces(port, protocol)
|
||||
|
||||
def is_non_loopback(ma: Multiaddr) -> bool:
|
||||
s = str(ma)
|
||||
return not ("/ip4/127." in s or "/ip6/::1" in s)
|
||||
|
||||
for c in candidates:
|
||||
if "/ip4/" in str(c) and is_non_loopback(c):
|
||||
return c
|
||||
for c in candidates:
|
||||
if "/ip6/" in str(c) and is_non_loopback(c):
|
||||
return c
|
||||
for c in candidates:
|
||||
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
|
||||
return c
|
||||
|
||||
# As a final fallback, produce a wildcard
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
@ -1,7 +1,4 @@
|
||||
import atexit
|
||||
from datetime import (
|
||||
datetime,
|
||||
)
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
@ -21,6 +18,9 @@ log_queue: "queue.Queue[Any]" = queue.Queue()
|
||||
# Store the current listener to stop it on exit
|
||||
_current_listener: logging.handlers.QueueListener | None = None
|
||||
|
||||
# Store the handlers for proper cleanup
|
||||
_current_handlers: list[logging.Handler] = []
|
||||
|
||||
# Event to track when the listener is ready
|
||||
_listener_ready = threading.Event()
|
||||
|
||||
@ -95,7 +95,7 @@ def setup_logging() -> None:
|
||||
- Child loggers inherit their parent's level unless explicitly set
|
||||
- The root libp2p logger controls the default level
|
||||
"""
|
||||
global _current_listener, _listener_ready
|
||||
global _current_listener, _listener_ready, _current_handlers
|
||||
|
||||
# Reset the event
|
||||
_listener_ready.clear()
|
||||
@ -105,6 +105,12 @@ def setup_logging() -> None:
|
||||
_current_listener.stop()
|
||||
_current_listener = None
|
||||
|
||||
# Close and clear existing handlers
|
||||
for handler in _current_handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
handler.close()
|
||||
_current_handlers.clear()
|
||||
|
||||
# Get the log level from environment variable
|
||||
debug_str = os.environ.get("LIBP2P_DEBUG", "")
|
||||
|
||||
@ -148,13 +154,10 @@ def setup_logging() -> None:
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
# Default log file with timestamp and unique identifier
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
unique_id = os.urandom(4).hex() # Add a unique identifier to prevent collisions
|
||||
if os.name == "nt": # Windows
|
||||
log_file = f"C:\\Windows\\Temp\\py-libp2p_{timestamp}_{unique_id}.log"
|
||||
else: # Unix-like
|
||||
log_file = f"/tmp/py-libp2p_{timestamp}_{unique_id}.log"
|
||||
# Use cross-platform temp file creation
|
||||
from libp2p.utils.paths import create_temp_file
|
||||
|
||||
log_file = str(create_temp_file(prefix="py-libp2p_", suffix=".log"))
|
||||
|
||||
# Print the log file path so users know where to find it
|
||||
print(f"Logging to: {log_file}", file=sys.stderr)
|
||||
@ -195,6 +198,9 @@ def setup_logging() -> None:
|
||||
logger.setLevel(level)
|
||||
logger.propagate = False # Prevent message duplication
|
||||
|
||||
# Store handlers globally for cleanup
|
||||
_current_handlers.extend(handlers)
|
||||
|
||||
# Start the listener AFTER configuring all loggers
|
||||
_current_listener = logging.handlers.QueueListener(
|
||||
log_queue, *handlers, respect_handler_level=True
|
||||
@ -209,7 +215,13 @@ def setup_logging() -> None:
|
||||
@atexit.register
|
||||
def cleanup_logging() -> None:
|
||||
"""Clean up logging resources on exit."""
|
||||
global _current_listener
|
||||
global _current_listener, _current_handlers
|
||||
if _current_listener is not None:
|
||||
_current_listener.stop()
|
||||
_current_listener = None
|
||||
|
||||
# Close all file handlers to ensure proper cleanup on Windows
|
||||
for handler in _current_handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
handler.close()
|
||||
_current_handlers.clear()
|
||||
|
||||
267
libp2p/utils/paths.py
Normal file
267
libp2p/utils/paths.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Cross-platform path utilities for py-libp2p.
|
||||
|
||||
This module provides standardized path operations to ensure consistent
|
||||
behavior across Windows, macOS, and Linux platforms.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Union
|
||||
|
||||
PathLike = Union[str, Path]
|
||||
|
||||
|
||||
def get_temp_dir() -> Path:
|
||||
"""
|
||||
Get cross-platform temporary directory.
|
||||
|
||||
Returns:
|
||||
Path: Platform-specific temporary directory path
|
||||
|
||||
"""
|
||||
return Path(tempfile.gettempdir())
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""
|
||||
Get the project root directory.
|
||||
|
||||
Returns:
|
||||
Path: Path to the py-libp2p project root
|
||||
|
||||
"""
|
||||
# Navigate from libp2p/utils/paths.py to project root
|
||||
return Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def join_paths(*parts: PathLike) -> Path:
|
||||
"""
|
||||
Cross-platform path joining.
|
||||
|
||||
Args:
|
||||
*parts: Path components to join
|
||||
|
||||
Returns:
|
||||
Path: Joined path using platform-appropriate separator
|
||||
|
||||
"""
|
||||
return Path(*parts)
|
||||
|
||||
|
||||
def ensure_dir_exists(path: PathLike) -> Path:
|
||||
"""
|
||||
Ensure directory exists, create if needed.
|
||||
|
||||
Args:
|
||||
path: Directory path to ensure exists
|
||||
|
||||
Returns:
|
||||
Path: Path object for the directory
|
||||
|
||||
"""
|
||||
path_obj = Path(path)
|
||||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
return path_obj
|
||||
|
||||
|
||||
def get_config_dir() -> Path:
|
||||
"""
|
||||
Get user config directory (cross-platform).
|
||||
|
||||
Returns:
|
||||
Path: Platform-specific config directory
|
||||
|
||||
"""
|
||||
if os.name == "nt": # Windows
|
||||
appdata = os.environ.get("APPDATA", "")
|
||||
if appdata:
|
||||
return Path(appdata) / "py-libp2p"
|
||||
else:
|
||||
# Fallback to user home directory
|
||||
return Path.home() / "AppData" / "Roaming" / "py-libp2p"
|
||||
else: # Unix-like (Linux, macOS)
|
||||
return Path.home() / ".config" / "py-libp2p"
|
||||
|
||||
|
||||
def get_script_dir(script_path: PathLike | None = None) -> Path:
|
||||
"""
|
||||
Get the directory containing a script file.
|
||||
|
||||
Args:
|
||||
script_path: Path to the script file. If None, uses __file__
|
||||
|
||||
Returns:
|
||||
Path: Directory containing the script
|
||||
|
||||
Raises:
|
||||
RuntimeError: If script path cannot be determined
|
||||
|
||||
"""
|
||||
if script_path is None:
|
||||
# This will be the directory of the calling script
|
||||
import inspect
|
||||
|
||||
frame = inspect.currentframe()
|
||||
if frame and frame.f_back:
|
||||
script_path = frame.f_back.f_globals.get("__file__")
|
||||
else:
|
||||
raise RuntimeError("Could not determine script path")
|
||||
|
||||
if script_path is None:
|
||||
raise RuntimeError("Script path is None")
|
||||
|
||||
return Path(script_path).parent.absolute()
|
||||
|
||||
|
||||
def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path:
|
||||
"""
|
||||
Create a temporary file with a unique name.
|
||||
|
||||
Args:
|
||||
prefix: File name prefix
|
||||
suffix: File name suffix
|
||||
|
||||
Returns:
|
||||
Path: Path to the created temporary file
|
||||
|
||||
"""
|
||||
temp_dir = get_temp_dir()
|
||||
# Create a unique filename using timestamp and random bytes
|
||||
import secrets
|
||||
import time
|
||||
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string
|
||||
unique_id = secrets.token_hex(4)
|
||||
filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}"
|
||||
|
||||
temp_file = temp_dir / filename
|
||||
# Create the file by touching it
|
||||
temp_file.touch()
|
||||
return temp_file
|
||||
|
||||
|
||||
def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path:
|
||||
"""
|
||||
Resolve a relative path from a base path.
|
||||
|
||||
Args:
|
||||
base_path: Base directory path
|
||||
relative_path: Relative path to resolve
|
||||
|
||||
Returns:
|
||||
Path: Resolved absolute path
|
||||
|
||||
"""
|
||||
base = Path(base_path).resolve()
|
||||
relative = Path(relative_path)
|
||||
|
||||
if relative.is_absolute():
|
||||
return relative
|
||||
else:
|
||||
return (base / relative).resolve()
|
||||
|
||||
|
||||
def normalize_path(path: PathLike) -> Path:
|
||||
"""
|
||||
Normalize a path, resolving any symbolic links and relative components.
|
||||
|
||||
Args:
|
||||
path: Path to normalize
|
||||
|
||||
Returns:
|
||||
Path: Normalized absolute path
|
||||
|
||||
"""
|
||||
return Path(path).resolve()
|
||||
|
||||
|
||||
def get_venv_path() -> Path | None:
|
||||
"""
|
||||
Get virtual environment path if active.
|
||||
|
||||
Returns:
|
||||
Path: Virtual environment path if active, None otherwise
|
||||
|
||||
"""
|
||||
venv_path = os.environ.get("VIRTUAL_ENV")
|
||||
if venv_path:
|
||||
return Path(venv_path)
|
||||
return None
|
||||
|
||||
|
||||
def get_python_executable() -> Path:
|
||||
"""
|
||||
Get current Python executable path.
|
||||
|
||||
Returns:
|
||||
Path: Path to the current Python executable
|
||||
|
||||
"""
|
||||
return Path(sys.executable)
|
||||
|
||||
|
||||
def find_executable(name: str) -> Path | None:
|
||||
"""
|
||||
Find executable in system PATH.
|
||||
|
||||
Args:
|
||||
name: Name of the executable to find
|
||||
|
||||
Returns:
|
||||
Path: Path to executable if found, None otherwise
|
||||
|
||||
"""
|
||||
# Check if name already contains path
|
||||
if os.path.dirname(name):
|
||||
path = Path(name)
|
||||
if path.exists() and os.access(path, os.X_OK):
|
||||
return path
|
||||
return None
|
||||
|
||||
# Search in PATH
|
||||
for path_dir in os.environ.get("PATH", "").split(os.pathsep):
|
||||
if not path_dir:
|
||||
continue
|
||||
path = Path(path_dir) / name
|
||||
if path.exists() and os.access(path, os.X_OK):
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_script_binary_path() -> Path:
|
||||
"""
|
||||
Get path to script's binary directory.
|
||||
|
||||
Returns:
|
||||
Path: Directory containing the script's binary
|
||||
|
||||
"""
|
||||
return get_python_executable().parent
|
||||
|
||||
|
||||
def get_binary_path(binary_name: str) -> Path | None:
|
||||
"""
|
||||
Find binary in PATH or virtual environment.
|
||||
|
||||
Args:
|
||||
binary_name: Name of the binary to find
|
||||
|
||||
Returns:
|
||||
Path: Path to binary if found, None otherwise
|
||||
|
||||
"""
|
||||
# First check in virtual environment if active
|
||||
venv_path = get_venv_path()
|
||||
if venv_path:
|
||||
venv_bin = venv_path / "bin" if os.name != "nt" else venv_path / "Scripts"
|
||||
binary_path = venv_bin / binary_name
|
||||
if binary_path.exists() and os.access(binary_path, os.X_OK):
|
||||
return binary_path
|
||||
|
||||
# Fall back to system PATH
|
||||
return find_executable(binary_name)
|
||||
Reference in New Issue
Block a user