feat: Implement Random walk in py-libp2p (#822)

* Implementing random walk in py libp2p

* Add documentation for Random Walk module implementation in py-libp2p

* Add Random Walk example for py-libp2p Kademlia DHT

* refactor: peer eviction from routing table stopped

* refactored location of random walk

* add nodesin routing table  from peerstore

* random walk working as expected

* removed extra functions

* Removed all manual triggers

* added newsfragments

* fix linting issues

* refacored logs and cleaned example file

* refactor: update RandomWalk and RTRefreshManager to use query function for peer discovery

* docs: added Random Walk example docs

* added optional argument to use random walk in kademlia DHT

* enabled random walk in example file

* Added tests for RandomWalk module

* fixed lint issues

* Update refresh interval and some more tests are added.

* Removed Random Walk module documentation file

* Extra parentheses have been removed from the random walk logs.

Co-authored-by: Paul Robinson <5199899+pacrob@users.noreply.github.com>

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
Co-authored-by: Paul Robinson <5199899+pacrob@users.noreply.github.com>
This commit is contained in:
Sumanjeet
2025-08-20 16:40:06 +05:30
committed by GitHub
parent dabb3a0962
commit 94d695c6bc
16 changed files with 1516 additions and 3 deletions

View File

@ -0,0 +1,17 @@
"""Random walk discovery modules for py-libp2p."""
from .rt_refresh_manager import RTRefreshManager
from .random_walk import RandomWalk
from .exceptions import (
RoutingTableRefreshError,
RandomWalkError,
PeerValidationError,
)
__all__ = [
"RTRefreshManager",
"RandomWalk",
"RoutingTableRefreshError",
"RandomWalkError",
"PeerValidationError",
]

View File

@ -0,0 +1,16 @@
from typing import Final
# Timing constants (matching go-libp2p)
PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds
REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds
REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes
SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute
# Routing table thresholds
MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh
MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try
# Random walk specific
RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks
RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks
RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback

View File

@ -0,0 +1,19 @@
from libp2p.exceptions import BaseLibp2pError
class RoutingTableRefreshError(BaseLibp2pError):
"""Base exception for routing table refresh operations."""
pass
class RandomWalkError(RoutingTableRefreshError):
"""Exception raised during random walk operations."""
pass
class PeerValidationError(RoutingTableRefreshError):
"""Exception raised when peer validation fails."""
pass

View File

@ -0,0 +1,218 @@
from collections.abc import Awaitable, Callable
import logging
import secrets
import trio
from libp2p.abc import IHost
from libp2p.discovery.random_walk.config import (
RANDOM_WALK_CONCURRENCY,
RANDOM_WALK_RT_THRESHOLD,
REFRESH_QUERY_TIMEOUT,
)
from libp2p.discovery.random_walk.exceptions import RandomWalkError
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
logger = logging.getLogger("libp2p.discovery.random_walk")
class RandomWalk:
"""
Random Walk implementation for peer discovery in Kademlia DHT.
Generates random peer IDs and performs FIND_NODE queries to discover
new peers and populate the routing table.
"""
def __init__(
self,
host: IHost,
local_peer_id: ID,
query_function: Callable[[bytes], Awaitable[list[ID]]],
):
"""
Initialize Random Walk module.
Args:
host: The libp2p host instance
local_peer_id: Local peer ID
query_function: Function to query for closest peers given target key bytes
"""
self.host = host
self.local_peer_id = local_peer_id
self.query_function = query_function
def generate_random_peer_id(self) -> str:
"""
Generate a completely random peer ID
for random walk queries.
Returns:
Random peer ID as string
"""
# Generate 32 random bytes (256 bits) - same as go-libp2p
random_bytes = secrets.token_bytes(32)
# Convert to hex string for query
return random_bytes.hex()
async def perform_random_walk(self) -> list[PeerInfo]:
"""
Perform a single random walk operation.
Returns:
List of validated peers discovered during the walk
"""
try:
# Generate random peer ID
random_peer_id = self.generate_random_peer_id()
logger.info(f"Starting random walk for peer ID: {random_peer_id}")
# Perform FIND_NODE query
discovered_peer_ids: list[ID] = []
with trio.move_on_after(REFRESH_QUERY_TIMEOUT):
# Call the query function with target key bytes
target_key = bytes.fromhex(random_peer_id)
discovered_peer_ids = await self.query_function(target_key) or []
if not discovered_peer_ids:
logger.debug(f"No peers discovered in random walk for {random_peer_id}")
return []
logger.info(
f"Discovered {len(discovered_peer_ids)} peers in random walk "
f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity
)
# Convert peer IDs to PeerInfo objects and validate
validated_peers: list[PeerInfo] = []
for peer_id in discovered_peer_ids:
try:
# Get addresses from peerstore
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
peer_info = PeerInfo(peer_id, addrs)
validated_peers.append(peer_info)
except Exception as e:
logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}")
continue
return validated_peers
except Exception as e:
logger.error(f"Random walk failed: {e}")
raise RandomWalkError(f"Random walk operation failed: {e}") from e
async def run_concurrent_random_walks(
self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0
) -> list[PeerInfo]:
"""
Run multiple random walks concurrently.
Args:
count: Number of concurrent random walks to perform
current_routing_table_size: Current size of routing table (for optimization)
Returns:
Combined list of all validated peers discovered
"""
all_validated_peers: list[PeerInfo] = []
logger.info(f"Starting {count} concurrent random walks")
# First, try to add peers from peerstore if routing table is small
if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD:
try:
peerstore_peers = self._get_peerstore_peers()
if peerstore_peers:
logger.debug(
f"RT size ({current_routing_table_size}) below threshold, "
f"adding {len(peerstore_peers)} peerstore peers"
)
all_validated_peers.extend(peerstore_peers)
except Exception as e:
logger.warning(f"Error processing peerstore peers: {e}")
async def single_walk() -> None:
try:
peers = await self.perform_random_walk()
all_validated_peers.extend(peers)
except Exception as e:
logger.warning(f"Concurrent random walk failed: {e}")
return
# Run concurrent random walks
async with trio.open_nursery() as nursery:
for _ in range(count):
nursery.start_soon(single_walk)
# Remove duplicates based on peer ID
unique_peers = {}
for peer in all_validated_peers:
unique_peers[peer.peer_id] = peer
result = list(unique_peers.values())
logger.info(
f"Concurrent random walks completed: {len(result)} unique peers discovered"
)
return result
def _get_peerstore_peers(self) -> list[PeerInfo]:
"""
Get peer info objects from the host's peerstore.
Returns:
List of PeerInfo objects from peerstore
"""
try:
peerstore = self.host.get_peerstore()
peer_ids = peerstore.peers_with_addrs()
peer_infos = []
for peer_id in peer_ids:
try:
# Skip local peer
if peer_id == self.local_peer_id:
continue
peer_info = peerstore.peer_info(peer_id)
if peer_info and peer_info.addrs:
# Filter for compatible addresses (TCP + IPv4)
if self._has_compatible_addresses(peer_info):
peer_infos.append(peer_info)
except Exception as e:
logger.debug(f"Error getting peer info for {peer_id}: {e}")
return peer_infos
except Exception as e:
logger.warning(f"Error accessing peerstore: {e}")
return []
def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool:
"""
Check if a peer has TCP+IPv4 compatible addresses.
Args:
peer_info: PeerInfo to check
Returns:
True if peer has compatible addresses
"""
if not peer_info.addrs:
return False
for addr in peer_info.addrs:
addr_str = str(addr)
# Check for TCP and IPv4 compatibility, avoid QUIC
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
return True
return False

View File

@ -0,0 +1,208 @@
from collections.abc import Awaitable, Callable
import logging
import time
from typing import Protocol
import trio
from libp2p.abc import IHost
from libp2p.discovery.random_walk.config import (
MIN_RT_REFRESH_THRESHOLD,
RANDOM_WALK_CONCURRENCY,
RANDOM_WALK_ENABLED,
REFRESH_INTERVAL,
)
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
class RoutingTableProtocol(Protocol):
"""Protocol for routing table operations needed by RT refresh manager."""
def size(self) -> int:
"""Return the current size of the routing table."""
...
async def add_peer(self, peer_obj: PeerInfo) -> bool:
"""Add a peer to the routing table."""
...
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
class RTRefreshManager:
"""
Routing Table Refresh Manager for py-libp2p.
Manages periodic routing table refreshes and random walk operations
to maintain routing table health and discover new peers.
"""
def __init__(
self,
host: IHost,
routing_table: RoutingTableProtocol,
local_peer_id: ID,
query_function: Callable[[bytes], Awaitable[list[ID]]],
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
refresh_interval: float = REFRESH_INTERVAL,
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
):
"""
Initialize RT Refresh Manager.
Args:
host: The libp2p host instance
routing_table: Routing table of host
local_peer_id: Local peer ID
query_function: Function to query for closest peers given target key bytes
enable_auto_refresh: Whether to enable automatic refresh
refresh_interval: Interval between refreshes in seconds
min_refresh_threshold: Minimum RT size before triggering refresh
"""
self.host = host
self.routing_table = routing_table
self.local_peer_id = local_peer_id
self.query_function = query_function
self.enable_auto_refresh = enable_auto_refresh
self.refresh_interval = refresh_interval
self.min_refresh_threshold = min_refresh_threshold
# Initialize random walk module
self.random_walk = RandomWalk(
host=host,
local_peer_id=self.local_peer_id,
query_function=query_function,
)
# Control variables
self._running = False
self._nursery: trio.Nursery | None = None
# Tracking
self._last_refresh_time = 0.0
self._refresh_done_callbacks: list[Callable[[], None]] = []
async def start(self) -> None:
"""Start the RT Refresh Manager."""
if self._running:
logger.warning("RT Refresh Manager is already running")
return
self._running = True
logger.info("Starting RT Refresh Manager")
# Start the main loop
async with trio.open_nursery() as nursery:
self._nursery = nursery
nursery.start_soon(self._main_loop)
async def stop(self) -> None:
"""Stop the RT Refresh Manager."""
if not self._running:
return
logger.info("Stopping RT Refresh Manager")
self._running = False
async def _main_loop(self) -> None:
"""Main loop for the RT Refresh Manager."""
logger.info("RT Refresh Manager main loop started")
# Initial refresh if auto-refresh is enabled
if self.enable_auto_refresh:
await self._do_refresh(force=True)
try:
while self._running:
async with trio.open_nursery() as nursery:
# Schedule periodic refresh if enabled
if self.enable_auto_refresh:
nursery.start_soon(self._periodic_refresh_task)
except Exception as e:
logger.error(f"RT Refresh Manager main loop error: {e}")
finally:
logger.info("RT Refresh Manager main loop stopped")
async def _periodic_refresh_task(self) -> None:
"""Task for periodic refreshes."""
while self._running:
await trio.sleep(self.refresh_interval)
if self._running:
await self._do_refresh()
async def _do_refresh(self, force: bool = False) -> None:
"""
Perform routing table refresh operation.
Args:
force: Whether to force refresh regardless of timing
"""
try:
current_time = time.time()
# Check if refresh is needed
if not force:
if current_time - self._last_refresh_time < self.refresh_interval:
logger.debug("Skipping refresh: interval not elapsed")
return
if self.routing_table.size() >= self.min_refresh_threshold:
logger.debug("Skipping refresh: routing table size above threshold")
return
logger.info(f"Starting routing table refresh (force={force})")
start_time = current_time
# Perform random walks to discover new peers
logger.info("Running concurrent random walks to discover new peers")
current_rt_size = self.routing_table.size()
discovered_peers = await self.random_walk.run_concurrent_random_walks(
count=RANDOM_WALK_CONCURRENCY,
current_routing_table_size=current_rt_size,
)
# Add discovered peers to routing table
added_count = 0
for peer_info in discovered_peers:
result = await self.routing_table.add_peer(peer_info)
if result:
added_count += 1
self._last_refresh_time = current_time
duration = time.time() - start_time
logger.info(
f"Routing table refresh completed: "
f"{added_count}/{len(discovered_peers)} peers added, "
f"RT size: {self.routing_table.size()}, "
f"duration: {duration:.2f}s"
)
# Notify refresh completion
for callback in self._refresh_done_callbacks:
try:
callback()
except Exception as e:
logger.warning(f"Refresh callback error: {e}")
except Exception as e:
logger.error(f"Routing table refresh failed: {e}")
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback to be called when refresh completes."""
self._refresh_done_callbacks.append(callback)
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
"""Remove a refresh completion callback."""
if callback in self._refresh_done_callbacks:
self._refresh_done_callbacks.remove(callback)

View File

@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT)
implementation based on the Kademlia algorithm and protocol.
"""
from collections.abc import Awaitable, Callable
from enum import (
Enum,
)
@ -20,6 +21,7 @@ import varint
from libp2p.abc import (
IHost,
)
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.network.stream.net_stream import (
INetStream,
)
@ -73,14 +75,27 @@ class KadDHT(Service):
This class provides a DHT implementation that combines routing table management,
peer discovery, content routing, and value storage.
Optional Random Walk feature enhances peer discovery by automatically
performing periodic random queries to discover new peers and maintain
routing table health.
Example:
# Basic DHT without random walk (default)
dht = KadDHT(host, DHTMode.SERVER)
# DHT with random walk enabled for enhanced peer discovery
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
"""
def __init__(self, host: IHost, mode: DHTMode):
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
"""
Initialize a new Kademlia DHT node.
:param host: The libp2p host.
:param mode: The mode of host (Client or Server) - must be DHTMode enum
:param enable_random_walk: Whether to enable automatic random walk
"""
super().__init__()
@ -92,6 +107,7 @@ class KadDHT(Service):
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
self.mode = mode
self.enable_random_walk = enable_random_walk
# Initialize the routing table
self.routing_table = RoutingTable(self.local_peer_id, self.host)
@ -108,13 +124,56 @@ class KadDHT(Service):
# Last time we republished provider records
self._last_provider_republish = time.time()
# Initialize RT Refresh Manager (only if random walk is enabled)
self.rt_refresh_manager: RTRefreshManager | None = None
if self.enable_random_walk:
self.rt_refresh_manager = RTRefreshManager(
host=self.host,
routing_table=self.routing_table,
local_peer_id=self.local_peer_id,
query_function=self._create_query_function(),
enable_auto_refresh=True,
)
# Set protocol handlers
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
"""
Create a query function that wraps peer_routing.find_closest_peers_network.
This function is used by the RandomWalk module to query for peers without
directly importing PeerRouting, avoiding circular import issues.
Returns:
Callable that takes target_key bytes and returns list of peer IDs
"""
async def query_function(target_key: bytes) -> list[ID]:
"""Query for closest peers to target key."""
return await self.peer_routing.find_closest_peers_network(target_key)
return query_function
async def run(self) -> None:
"""Run the DHT service."""
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
# Start the RT Refresh Manager in parallel with the main DHT service
async with trio.open_nursery() as nursery:
# Start the RT Refresh Manager only if random walk is enabled
if self.rt_refresh_manager is not None:
nursery.start_soon(self.rt_refresh_manager.start)
logger.info("RT Refresh Manager started - Random Walk is now active")
else:
logger.info("Random Walk is disabled - RT Refresh Manager not started")
# Start the main DHT service loop
nursery.start_soon(self._run_main_loop)
async def _run_main_loop(self) -> None:
"""Run the main DHT service loop."""
# Main service loop
while self.manager.is_running:
# Periodically refresh the routing table
@ -135,6 +194,17 @@ class KadDHT(Service):
# Wait before next maintenance cycle
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
async def stop(self) -> None:
"""Stop the DHT service and cleanup resources."""
logger.info("Stopping Kademlia DHT")
# Stop the RT Refresh Manager only if it was started
if self.rt_refresh_manager is not None:
await self.rt_refresh_manager.stop()
logger.info("RT Refresh Manager stopped")
else:
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
"""
Switch the DHT mode.
@ -614,3 +684,15 @@ class KadDHT(Service):
"""
return self.value_store.size()
def is_random_walk_enabled(self) -> bool:
"""
Check if random walk peer discovery is enabled.
Returns
-------
bool
True if random walk is enabled, False otherwise.
"""
return self.enable_random_walk

View File

@ -170,7 +170,7 @@ class PeerRouting(IPeerRouting):
# Return early if we have no peers to start with
if not closest_peers:
logger.warning("No local peers available for network lookup")
logger.debug("No local peers available for network lookup")
return []
# Iterative lookup until convergence