mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge remote-tracking branch 'origin/main' into fix_pubsub_msg_id_type_inconsistency
This commit is contained in:
17
libp2p/discovery/random_walk/__init__.py
Normal file
17
libp2p/discovery/random_walk/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Random walk discovery modules for py-libp2p."""
|
||||
|
||||
from .rt_refresh_manager import RTRefreshManager
|
||||
from .random_walk import RandomWalk
|
||||
from .exceptions import (
|
||||
RoutingTableRefreshError,
|
||||
RandomWalkError,
|
||||
PeerValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RTRefreshManager",
|
||||
"RandomWalk",
|
||||
"RoutingTableRefreshError",
|
||||
"RandomWalkError",
|
||||
"PeerValidationError",
|
||||
]
|
||||
16
libp2p/discovery/random_walk/config.py
Normal file
16
libp2p/discovery/random_walk/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Final
|
||||
|
||||
# Timing constants (matching go-libp2p)
|
||||
PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds
|
||||
REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds
|
||||
REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes
|
||||
SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute
|
||||
|
||||
# Routing table thresholds
|
||||
MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh
|
||||
MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try
|
||||
|
||||
# Random walk specific
|
||||
RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks
|
||||
RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks
|
||||
RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback
|
||||
19
libp2p/discovery/random_walk/exceptions.py
Normal file
19
libp2p/discovery/random_walk/exceptions.py
Normal file
@ -0,0 +1,19 @@
|
||||
from libp2p.exceptions import BaseLibp2pError
|
||||
|
||||
|
||||
class RoutingTableRefreshError(BaseLibp2pError):
|
||||
"""Base exception for routing table refresh operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RandomWalkError(RoutingTableRefreshError):
|
||||
"""Exception raised during random walk operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PeerValidationError(RoutingTableRefreshError):
|
||||
"""Exception raised when peer validation fails."""
|
||||
|
||||
pass
|
||||
218
libp2p/discovery/random_walk/random_walk.py
Normal file
218
libp2p/discovery/random_walk/random_walk.py
Normal file
@ -0,0 +1,218 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_RT_THRESHOLD,
|
||||
REFRESH_QUERY_TIMEOUT,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RandomWalkError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk")
|
||||
|
||||
|
||||
class RandomWalk:
|
||||
"""
|
||||
Random Walk implementation for peer discovery in Kademlia DHT.
|
||||
|
||||
Generates random peer IDs and performs FIND_NODE queries to discover
|
||||
new peers and populate the routing table.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
):
|
||||
"""
|
||||
Initialize Random Walk module.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
def generate_random_peer_id(self) -> str:
|
||||
"""
|
||||
Generate a completely random peer ID
|
||||
for random walk queries.
|
||||
|
||||
Returns:
|
||||
Random peer ID as string
|
||||
|
||||
"""
|
||||
# Generate 32 random bytes (256 bits) - same as go-libp2p
|
||||
random_bytes = secrets.token_bytes(32)
|
||||
# Convert to hex string for query
|
||||
return random_bytes.hex()
|
||||
|
||||
async def perform_random_walk(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Perform a single random walk operation.
|
||||
|
||||
Returns:
|
||||
List of validated peers discovered during the walk
|
||||
|
||||
"""
|
||||
try:
|
||||
# Generate random peer ID
|
||||
random_peer_id = self.generate_random_peer_id()
|
||||
logger.info(f"Starting random walk for peer ID: {random_peer_id}")
|
||||
|
||||
# Perform FIND_NODE query
|
||||
discovered_peer_ids: list[ID] = []
|
||||
|
||||
with trio.move_on_after(REFRESH_QUERY_TIMEOUT):
|
||||
# Call the query function with target key bytes
|
||||
target_key = bytes.fromhex(random_peer_id)
|
||||
discovered_peer_ids = await self.query_function(target_key) or []
|
||||
|
||||
if not discovered_peer_ids:
|
||||
logger.debug(f"No peers discovered in random walk for {random_peer_id}")
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(discovered_peer_ids)} peers in random walk "
|
||||
f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity
|
||||
)
|
||||
|
||||
# Convert peer IDs to PeerInfo objects and validate
|
||||
validated_peers: list[PeerInfo] = []
|
||||
|
||||
for peer_id in discovered_peer_ids:
|
||||
try:
|
||||
# Get addresses from peerstore
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
validated_peers.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}")
|
||||
continue
|
||||
|
||||
return validated_peers
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Random walk failed: {e}")
|
||||
raise RandomWalkError(f"Random walk operation failed: {e}") from e
|
||||
|
||||
async def run_concurrent_random_walks(
|
||||
self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0
|
||||
) -> list[PeerInfo]:
|
||||
"""
|
||||
Run multiple random walks concurrently.
|
||||
|
||||
Args:
|
||||
count: Number of concurrent random walks to perform
|
||||
current_routing_table_size: Current size of routing table (for optimization)
|
||||
|
||||
Returns:
|
||||
Combined list of all validated peers discovered
|
||||
|
||||
"""
|
||||
all_validated_peers: list[PeerInfo] = []
|
||||
logger.info(f"Starting {count} concurrent random walks")
|
||||
|
||||
# First, try to add peers from peerstore if routing table is small
|
||||
if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD:
|
||||
try:
|
||||
peerstore_peers = self._get_peerstore_peers()
|
||||
if peerstore_peers:
|
||||
logger.debug(
|
||||
f"RT size ({current_routing_table_size}) below threshold, "
|
||||
f"adding {len(peerstore_peers)} peerstore peers"
|
||||
)
|
||||
all_validated_peers.extend(peerstore_peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing peerstore peers: {e}")
|
||||
|
||||
async def single_walk() -> None:
|
||||
try:
|
||||
peers = await self.perform_random_walk()
|
||||
all_validated_peers.extend(peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Concurrent random walk failed: {e}")
|
||||
return
|
||||
|
||||
# Run concurrent random walks
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(count):
|
||||
nursery.start_soon(single_walk)
|
||||
|
||||
# Remove duplicates based on peer ID
|
||||
unique_peers = {}
|
||||
for peer in all_validated_peers:
|
||||
unique_peers[peer.peer_id] = peer
|
||||
|
||||
result = list(unique_peers.values())
|
||||
logger.info(
|
||||
f"Concurrent random walks completed: {len(result)} unique peers discovered"
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_peerstore_peers(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Get peer info objects from the host's peerstore.
|
||||
|
||||
Returns:
|
||||
List of PeerInfo objects from peerstore
|
||||
|
||||
"""
|
||||
try:
|
||||
peerstore = self.host.get_peerstore()
|
||||
peer_ids = peerstore.peers_with_addrs()
|
||||
|
||||
peer_infos = []
|
||||
for peer_id in peer_ids:
|
||||
try:
|
||||
# Skip local peer
|
||||
if peer_id == self.local_peer_id:
|
||||
continue
|
||||
|
||||
peer_info = peerstore.peer_info(peer_id)
|
||||
if peer_info and peer_info.addrs:
|
||||
# Filter for compatible addresses (TCP + IPv4)
|
||||
if self._has_compatible_addresses(peer_info):
|
||||
peer_infos.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting peer info for {peer_id}: {e}")
|
||||
|
||||
return peer_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error accessing peerstore: {e}")
|
||||
return []
|
||||
|
||||
def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool:
|
||||
"""
|
||||
Check if a peer has TCP+IPv4 compatible addresses.
|
||||
|
||||
Args:
|
||||
peer_info: PeerInfo to check
|
||||
|
||||
Returns:
|
||||
True if peer has compatible addresses
|
||||
|
||||
"""
|
||||
if not peer_info.addrs:
|
||||
return False
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
addr_str = str(addr)
|
||||
# Check for TCP and IPv4 compatibility, avoid QUIC
|
||||
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
@ -0,0 +1,208 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import time
|
||||
from typing import Protocol
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_ENABLED,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class RoutingTableProtocol(Protocol):
|
||||
"""Protocol for routing table operations needed by RT refresh manager."""
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the current size of the routing table."""
|
||||
...
|
||||
|
||||
async def add_peer(self, peer_obj: PeerInfo) -> bool:
|
||||
"""Add a peer to the routing table."""
|
||||
...
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
|
||||
|
||||
|
||||
class RTRefreshManager:
|
||||
"""
|
||||
Routing Table Refresh Manager for py-libp2p.
|
||||
|
||||
Manages periodic routing table refreshes and random walk operations
|
||||
to maintain routing table health and discover new peers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
routing_table: RoutingTableProtocol,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
|
||||
refresh_interval: float = REFRESH_INTERVAL,
|
||||
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
|
||||
):
|
||||
"""
|
||||
Initialize RT Refresh Manager.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
routing_table: Routing table of host
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
enable_auto_refresh: Whether to enable automatic refresh
|
||||
refresh_interval: Interval between refreshes in seconds
|
||||
min_refresh_threshold: Minimum RT size before triggering refresh
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.routing_table = routing_table
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
self.enable_auto_refresh = enable_auto_refresh
|
||||
self.refresh_interval = refresh_interval
|
||||
self.min_refresh_threshold = min_refresh_threshold
|
||||
|
||||
# Initialize random walk module
|
||||
self.random_walk = RandomWalk(
|
||||
host=host,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=query_function,
|
||||
)
|
||||
|
||||
# Control variables
|
||||
self._running = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
# Tracking
|
||||
self._last_refresh_time = 0.0
|
||||
self._refresh_done_callbacks: list[Callable[[], None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the RT Refresh Manager."""
|
||||
if self._running:
|
||||
logger.warning("RT Refresh Manager is already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
logger.info("Starting RT Refresh Manager")
|
||||
|
||||
# Start the main loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
self._nursery = nursery
|
||||
nursery.start_soon(self._main_loop)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the RT Refresh Manager."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info("Stopping RT Refresh Manager")
|
||||
self._running = False
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
"""Main loop for the RT Refresh Manager."""
|
||||
logger.info("RT Refresh Manager main loop started")
|
||||
|
||||
# Initial refresh if auto-refresh is enabled
|
||||
if self.enable_auto_refresh:
|
||||
await self._do_refresh(force=True)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Schedule periodic refresh if enabled
|
||||
if self.enable_auto_refresh:
|
||||
nursery.start_soon(self._periodic_refresh_task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RT Refresh Manager main loop error: {e}")
|
||||
finally:
|
||||
logger.info("RT Refresh Manager main loop stopped")
|
||||
|
||||
async def _periodic_refresh_task(self) -> None:
|
||||
"""Task for periodic refreshes."""
|
||||
while self._running:
|
||||
await trio.sleep(self.refresh_interval)
|
||||
if self._running:
|
||||
await self._do_refresh()
|
||||
|
||||
async def _do_refresh(self, force: bool = False) -> None:
|
||||
"""
|
||||
Perform routing table refresh operation.
|
||||
|
||||
Args:
|
||||
force: Whether to force refresh regardless of timing
|
||||
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Check if refresh is needed
|
||||
if not force:
|
||||
if current_time - self._last_refresh_time < self.refresh_interval:
|
||||
logger.debug("Skipping refresh: interval not elapsed")
|
||||
return
|
||||
|
||||
if self.routing_table.size() >= self.min_refresh_threshold:
|
||||
logger.debug("Skipping refresh: routing table size above threshold")
|
||||
return
|
||||
|
||||
logger.info(f"Starting routing table refresh (force={force})")
|
||||
start_time = current_time
|
||||
|
||||
# Perform random walks to discover new peers
|
||||
logger.info("Running concurrent random walks to discover new peers")
|
||||
current_rt_size = self.routing_table.size()
|
||||
discovered_peers = await self.random_walk.run_concurrent_random_walks(
|
||||
count=RANDOM_WALK_CONCURRENCY,
|
||||
current_routing_table_size=current_rt_size,
|
||||
)
|
||||
|
||||
# Add discovered peers to routing table
|
||||
added_count = 0
|
||||
for peer_info in discovered_peers:
|
||||
result = await self.routing_table.add_peer(peer_info)
|
||||
if result:
|
||||
added_count += 1
|
||||
|
||||
self._last_refresh_time = current_time
|
||||
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
f"Routing table refresh completed: "
|
||||
f"{added_count}/{len(discovered_peers)} peers added, "
|
||||
f"RT size: {self.routing_table.size()}, "
|
||||
f"duration: {duration:.2f}s"
|
||||
)
|
||||
|
||||
# Notify refresh completion
|
||||
for callback in self._refresh_done_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.warning(f"Refresh callback error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Routing table refresh failed: {e}")
|
||||
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
|
||||
|
||||
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Add a callback to be called when refresh completes."""
|
||||
self._refresh_done_callbacks.append(callback)
|
||||
|
||||
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Remove a refresh completion callback."""
|
||||
if callback in self._refresh_done_callbacks:
|
||||
self._refresh_done_callbacks.remove(callback)
|
||||
@ -295,6 +295,13 @@ class BasicHost(IHost):
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
if protocol is None:
|
||||
logger.debug(
|
||||
"no protocol negotiated, closing stream from peer %s",
|
||||
net_stream.muxed_conn.peer_id,
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
net_stream.set_protocol(protocol)
|
||||
if handler is None:
|
||||
logger.debug(
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT)
|
||||
implementation based on the Kademlia algorithm and protocol.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from enum import (
|
||||
Enum,
|
||||
)
|
||||
@ -20,6 +21,7 @@ import varint
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
@ -73,14 +75,27 @@ class KadDHT(Service):
|
||||
|
||||
This class provides a DHT implementation that combines routing table management,
|
||||
peer discovery, content routing, and value storage.
|
||||
|
||||
Optional Random Walk feature enhances peer discovery by automatically
|
||||
performing periodic random queries to discover new peers and maintain
|
||||
routing table health.
|
||||
|
||||
Example:
|
||||
# Basic DHT without random walk (default)
|
||||
dht = KadDHT(host, DHTMode.SERVER)
|
||||
|
||||
# DHT with random walk enabled for enhanced peer discovery
|
||||
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, mode: DHTMode):
|
||||
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
|
||||
"""
|
||||
Initialize a new Kademlia DHT node.
|
||||
|
||||
:param host: The libp2p host.
|
||||
:param mode: The mode of host (Client or Server) - must be DHTMode enum
|
||||
:param enable_random_walk: Whether to enable automatic random walk
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -92,6 +107,7 @@ class KadDHT(Service):
|
||||
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
|
||||
|
||||
self.mode = mode
|
||||
self.enable_random_walk = enable_random_walk
|
||||
|
||||
# Initialize the routing table
|
||||
self.routing_table = RoutingTable(self.local_peer_id, self.host)
|
||||
@ -108,13 +124,56 @@ class KadDHT(Service):
|
||||
# Last time we republished provider records
|
||||
self._last_provider_republish = time.time()
|
||||
|
||||
# Initialize RT Refresh Manager (only if random walk is enabled)
|
||||
self.rt_refresh_manager: RTRefreshManager | None = None
|
||||
if self.enable_random_walk:
|
||||
self.rt_refresh_manager = RTRefreshManager(
|
||||
host=self.host,
|
||||
routing_table=self.routing_table,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=self._create_query_function(),
|
||||
enable_auto_refresh=True,
|
||||
)
|
||||
|
||||
# Set protocol handlers
|
||||
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
|
||||
|
||||
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
|
||||
"""
|
||||
Create a query function that wraps peer_routing.find_closest_peers_network.
|
||||
|
||||
This function is used by the RandomWalk module to query for peers without
|
||||
directly importing PeerRouting, avoiding circular import issues.
|
||||
|
||||
Returns:
|
||||
Callable that takes target_key bytes and returns list of peer IDs
|
||||
|
||||
"""
|
||||
|
||||
async def query_function(target_key: bytes) -> list[ID]:
|
||||
"""Query for closest peers to target key."""
|
||||
return await self.peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
return query_function
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the DHT service."""
|
||||
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
|
||||
|
||||
# Start the RT Refresh Manager in parallel with the main DHT service
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start the RT Refresh Manager only if random walk is enabled
|
||||
if self.rt_refresh_manager is not None:
|
||||
nursery.start_soon(self.rt_refresh_manager.start)
|
||||
logger.info("RT Refresh Manager started - Random Walk is now active")
|
||||
else:
|
||||
logger.info("Random Walk is disabled - RT Refresh Manager not started")
|
||||
|
||||
# Start the main DHT service loop
|
||||
nursery.start_soon(self._run_main_loop)
|
||||
|
||||
async def _run_main_loop(self) -> None:
|
||||
"""Run the main DHT service loop."""
|
||||
# Main service loop
|
||||
while self.manager.is_running:
|
||||
# Periodically refresh the routing table
|
||||
@ -135,6 +194,17 @@ class KadDHT(Service):
|
||||
# Wait before next maintenance cycle
|
||||
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DHT service and cleanup resources."""
|
||||
logger.info("Stopping Kademlia DHT")
|
||||
|
||||
# Stop the RT Refresh Manager only if it was started
|
||||
if self.rt_refresh_manager is not None:
|
||||
await self.rt_refresh_manager.stop()
|
||||
logger.info("RT Refresh Manager stopped")
|
||||
else:
|
||||
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
|
||||
|
||||
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
|
||||
"""
|
||||
Switch the DHT mode.
|
||||
@ -614,3 +684,15 @@ class KadDHT(Service):
|
||||
|
||||
"""
|
||||
return self.value_store.size()
|
||||
|
||||
def is_random_walk_enabled(self) -> bool:
|
||||
"""
|
||||
Check if random walk peer discovery is enabled.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if random walk is enabled, False otherwise.
|
||||
|
||||
"""
|
||||
return self.enable_random_walk
|
||||
|
||||
@ -170,7 +170,7 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Return early if we have no peers to start with
|
||||
if not closest_peers:
|
||||
logger.warning("No local peers available for network lookup")
|
||||
logger.debug("No local peers available for network lookup")
|
||||
return []
|
||||
|
||||
# Iterative lookup until convergence
|
||||
|
||||
@ -249,9 +249,11 @@ class Swarm(Service, INetworkService):
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
for maddr in multiaddrs:
|
||||
if str(maddr) in self.listeners:
|
||||
return True
|
||||
success_count += 1
|
||||
continue
|
||||
|
||||
async def conn_handler(
|
||||
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
||||
@ -302,13 +304,14 @@ class Swarm(Service, INetworkService):
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_listen(maddr)
|
||||
|
||||
return True
|
||||
success_count += 1
|
||||
logger.debug("successfully started listening on: %s", maddr)
|
||||
except OSError:
|
||||
# Failed. Continue looping.
|
||||
logger.debug("fail to listen on: %s", maddr)
|
||||
|
||||
# No maddr succeeded
|
||||
return False
|
||||
# Return true if at least one address succeeded
|
||||
return success_count > 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
|
||||
@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
|
||||
"""
|
||||
self.handlers[protocol] = handler
|
||||
|
||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||
async def negotiate(
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
|
||||
"""
|
||||
Negotiate performs protocol selection.
|
||||
|
||||
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
protocol_to_check = None if not command else TProtocol(command)
|
||||
if protocol_to_check in self.handlers:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(command)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
return protocol_to_check, self.handlers[protocol_to_check]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
|
||||
@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
:return: selected protocol
|
||||
"""
|
||||
# Represent `None` protocol as an empty string.
|
||||
protocol_str = protocol if protocol is not None else ""
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(protocol_str)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
if response == protocol:
|
||||
if response == protocol_str:
|
||||
return protocol
|
||||
if response == PROTOCOL_NOT_FOUND_MSG:
|
||||
raise MultiselectClientError("protocol not supported")
|
||||
|
||||
@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
||||
""" # noqa: E501
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
if msg_str is None:
|
||||
msg_bytes = encode_delim(b"")
|
||||
else:
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
|
||||
@ -777,14 +777,18 @@ class GossipSub(IPubsubRouter, Service):
|
||||
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in
|
||||
# seen_messages cache
|
||||
seen_seqnos_and_peers = [
|
||||
seqno_and_from for seqno_and_from in self.pubsub.seen_messages.cache.keys()
|
||||
str(seqno_and_from)
|
||||
for seqno_and_from in self.pubsub.seen_messages.cache.keys()
|
||||
]
|
||||
|
||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||
# seen_seqnos) to list of messages we want to request
|
||||
msg_ids_wanted: list[str] = [
|
||||
msg_id
|
||||
msg_ids_wanted: list[MessageID] = [
|
||||
parse_message_id_safe(msg_id)
|
||||
for msg_id in ihave_msg.messageIDs
|
||||
if msg_id not in seen_seqnos_and_peers
|
||||
if msg_id not in str(seen_seqnos_and_peers)
|
||||
]
|
||||
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
|
||||
:param is_initiator: true if we are the initiator, false otherwise
|
||||
:return: selected secure transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if is_initiator:
|
||||
# Select protocol if initiator
|
||||
@ -114,5 +117,7 @@ class SecurityMultistream(ABC):
|
||||
else:
|
||||
# Select protocol if non-initiator
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a security protocol")
|
||||
# Return transport from protocol
|
||||
return self.transports[protocol]
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -73,7 +76,7 @@ class MuxerMultistream:
|
||||
:param conn: conn to choose a transport over
|
||||
:return: selected muxer transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
@ -81,6 +84,8 @@ class MuxerMultistream:
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a stream muxer protocol")
|
||||
return self.transports[protocol]
|
||||
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
|
||||
@ -15,6 +15,13 @@ from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
)
|
||||
|
||||
from libp2p.utils.address_validation import (
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
expand_wildcard_address,
|
||||
find_free_port,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"decode_uvarint_from_stream",
|
||||
"encode_delim",
|
||||
@ -26,4 +33,8 @@ __all__ = [
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
"read_length_prefixed_protobuf",
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
|
||||
160
libp2p/utils/address_validation.py
Normal file
160
libp2p/utils/address_validation.py
Normal file
@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
try:
|
||||
from multiaddr.utils import ( # type: ignore
|
||||
get_network_addrs,
|
||||
get_thin_waist_addresses,
|
||||
)
|
||||
|
||||
_HAS_THIN_WAIST = True
|
||||
except ImportError: # pragma: no cover - only executed in older environments
|
||||
_HAS_THIN_WAIST = False
|
||||
get_thin_waist_addresses = None # type: ignore
|
||||
get_network_addrs = None # type: ignore
|
||||
|
||||
|
||||
def _safe_get_network_addrs(ip_version: int) -> list[str]:
|
||||
"""
|
||||
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
|
||||
Falls back to minimal defaults when Thin Waist helpers are missing.
|
||||
|
||||
:param ip_version: 4 or 6
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_network_addrs:
|
||||
try:
|
||||
return get_network_addrs(ip_version) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return []
|
||||
# Fallback behavior (very conservative)
|
||||
if ip_version == 4:
|
||||
return ["127.0.0.1"]
|
||||
if ip_version == 6:
|
||||
return ["::1"]
|
||||
return []
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to a free port provided by the OS
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
|
||||
"""
|
||||
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
|
||||
If Thin Waist isn't available, returns [addr] (identity).
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_thin_waist_addresses:
|
||||
try:
|
||||
if port is not None:
|
||||
return get_thin_waist_addresses(addr, port=port) or []
|
||||
return get_thin_waist_addresses(addr) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return [addr]
|
||||
return [addr]
|
||||
|
||||
|
||||
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
|
||||
"""
|
||||
Discover available network interfaces (IPv4 + IPv6 if supported) for binding.
|
||||
|
||||
:param port: Port number to bind to.
|
||||
:param protocol: Transport protocol (e.g., "tcp" or "udp").
|
||||
:return: List of Multiaddr objects representing candidate interface addresses.
|
||||
"""
|
||||
addrs: list[Multiaddr] = []
|
||||
|
||||
# IPv4 enumeration
|
||||
seen_v4: set[str] = set()
|
||||
|
||||
for ip in _safe_get_network_addrs(4):
|
||||
seen_v4.add(ip)
|
||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||
|
||||
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
|
||||
if seen_v4 and "127.0.0.1" not in seen_v4:
|
||||
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
|
||||
|
||||
# TODO: IPv6 support temporarily disabled due to libp2p handshake issues
|
||||
# IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure)
|
||||
# Re-enable IPv6 support once the following issues are resolved:
|
||||
# - libp2p security handshake over IPv6
|
||||
# - multiselect protocol over IPv6
|
||||
# - connection establishment over IPv6
|
||||
#
|
||||
# seen_v6: set[str] = set()
|
||||
# for ip in _safe_get_network_addrs(6):
|
||||
# seen_v6.add(ip)
|
||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||
#
|
||||
# # Always include IPv6 loopback for testing purposes when IPv6 is available
|
||||
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
|
||||
# if "::1" not in seen_v6:
|
||||
# addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}"))
|
||||
|
||||
# Fallback if nothing discovered
|
||||
if not addrs:
|
||||
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
|
||||
|
||||
return addrs
|
||||
|
||||
|
||||
def expand_wildcard_address(
|
||||
addr: Multiaddr, port: int | None = None
|
||||
) -> list[Multiaddr]:
|
||||
"""
|
||||
Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces.
|
||||
|
||||
:param addr: Multiaddr to expand.
|
||||
:param port: Optional override for port selection.
|
||||
:return: List of concrete Multiaddr instances.
|
||||
"""
|
||||
expanded = _safe_expand(addr, port=port)
|
||||
if not expanded: # Safety fallback
|
||||
return [addr]
|
||||
return expanded
|
||||
|
||||
|
||||
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
"""
|
||||
Choose an optimal address for an example to bind to:
|
||||
- Prefer non-loopback IPv4
|
||||
- Then non-loopback IPv6
|
||||
- Fallback to loopback
|
||||
- Fallback to wildcard
|
||||
|
||||
:param port: Port number.
|
||||
:param protocol: Transport protocol.
|
||||
:return: A single Multiaddr chosen heuristically.
|
||||
"""
|
||||
candidates = get_available_interfaces(port, protocol)
|
||||
|
||||
def is_non_loopback(ma: Multiaddr) -> bool:
|
||||
s = str(ma)
|
||||
return not ("/ip4/127." in s or "/ip6/::1" in s)
|
||||
|
||||
for c in candidates:
|
||||
if "/ip4/" in str(c) and is_non_loopback(c):
|
||||
return c
|
||||
for c in candidates:
|
||||
if "/ip6/" in str(c) and is_non_loopback(c):
|
||||
return c
|
||||
for c in candidates:
|
||||
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
|
||||
return c
|
||||
|
||||
# As a final fallback, produce a wildcard
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
Reference in New Issue
Block a user