Merge branch 'libp2p:main' into tests/notifee-coverage

This commit is contained in:
Mercy Boma Naps Nkari
2025-08-21 08:07:53 +01:00
committed by GitHub
26 changed files with 1685 additions and 22 deletions

View File

@ -0,0 +1,131 @@
Random Walk Example
===================
This example demonstrates the Random Walk module's peer discovery capabilities using real libp2p hosts and Kademlia DHT.
It shows how the Random Walk module automatically discovers new peers and maintains routing table health.
The Random Walk implementation performs the following key operations:
* **Automatic Peer Discovery**: Generates random peer IDs and queries the DHT network to discover new peers
* **Routing Table Maintenance**: Periodically refreshes the routing table to maintain network connectivity
* **Connection Management**: Maintains optimal connections to healthy peers in the network
* **Real-time Statistics**: Displays routing table size, connected peers, and peerstore statistics
.. code-block:: console
$ python -m pip install libp2p
Collecting libp2p
...
Successfully installed libp2p-x.x.x
$ cd examples/random_walk
$ python random_walk.py --mode server
2025-08-12 19:51:25,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s
2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0
2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode
2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started
2025-08-12 19:51:55,432 - random-walk-example - INFO - --- Iteration 1 ---
2025-08-12 19:51:55,432 - random-walk-example - INFO - Routing table size: 15
2025-08-12 19:51:55,432 - random-walk-example - INFO - Connected peers: 8
2025-08-12 19:51:55,432 - random-walk-example - INFO - Peerstore size: 42
You can also run the example in client mode:
.. code-block:: console
$ python random_walk.py --mode client
2025-08-12 19:52:15,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
2025-08-12 19:52:15,424 - random-walk-example - INFO - Mode: client, Port: 0 Demo interval: 30s
2025-08-12 19:52:15,426 - random-walk-example - INFO - Starting client node on port 51234
2025-08-12 19:52:15,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAmAbc123xyz...
2025-08-12 19:52:15,427 - random-walk-example - INFO - DHT service started in CLIENT mode
2025-08-12 19:52:45,432 - random-walk-example - INFO - --- Iteration 1 ---
2025-08-12 19:52:45,432 - random-walk-example - INFO - Routing table size: 8
2025-08-12 19:52:45,432 - random-walk-example - INFO - Connected peers: 5
2025-08-12 19:52:45,432 - random-walk-example - INFO - Peerstore size: 25
Command Line Options
--------------------
The example supports several command-line options:
.. code-block:: console
$ python random_walk.py --help
usage: random_walk.py [-h] [--mode {server,client}] [--port PORT]
[--demo-interval DEMO_INTERVAL] [--verbose]
Random Walk Example for py-libp2p Kademlia DHT
optional arguments:
-h, --help show this help message and exit
--mode {server,client}
Node mode: server (DHT server), or client (DHT client)
--port PORT Port to listen on (0 for random)
--demo-interval DEMO_INTERVAL
Interval between random walk demonstrations in seconds
--verbose Enable verbose logging
Key Features Demonstrated
-------------------------
**Automatic Random Walk Discovery**
The example shows how the Random Walk module automatically:
* Generates random 256-bit peer IDs for discovery queries
* Performs concurrent random walks to maximize peer discovery
* Validates discovered peers and adds them to the routing table
* Maintains routing table health through periodic refreshes
**Real-time Network Statistics**
The example displays live statistics every 30 seconds (configurable):
* **Routing Table Size**: Number of peers in the Kademlia routing table
* **Connected Peers**: Number of actively connected peers
* **Peerstore Size**: Total number of known peers with addresses
**Connection Management**
The example includes sophisticated connection management:
* Automatically maintains connections to healthy peers
* Filters for compatible peers (TCP + IPv4 addresses)
* Reconnects to maintain optimal network connectivity
* Handles connection failures gracefully
**DHT Integration**
Shows seamless integration between Random Walk and Kademlia DHT:
* RT Refresh Manager coordinates with the DHT routing table
* Peer discovery feeds directly into DHT operations
* Both SERVER and CLIENT modes supported
* Bootstrap connectivity to public IPFS nodes
Understanding the Output
------------------------
When you run the example, you'll see periodic statistics that show how the Random Walk module is working:
* **Initial Phase**: Routing table starts empty and quickly discovers peers
* **Growth Phase**: Routing table size increases as more peers are discovered
* **Maintenance Phase**: Routing table size stabilizes as the system maintains optimal peer connections
The Random Walk module runs automatically in the background, performing peer discovery queries every few minutes to ensure the routing table remains populated with fresh, reachable peers.
Configuration
-------------
The Random Walk module can be configured through the following parameters in ``libp2p.discovery.random_walk.config``:
* ``RANDOM_WALK_ENABLED``: Enable/disable automatic random walks (default: True)
* ``REFRESH_INTERVAL``: Time between automatic refreshes in seconds (default: 300)
* ``RANDOM_WALK_CONCURRENCY``: Number of concurrent random walks (default: 3)
* ``MIN_RT_REFRESH_THRESHOLD``: Minimum routing table size before triggering refresh (default: 4)
See Also
--------
* :doc:`examples.kademlia` - Kademlia DHT value storage and content routing
* :doc:`libp2p.discovery.random_walk` - Random Walk module API documentation

View File

@ -14,3 +14,4 @@ Examples
examples.circuit_relay
examples.kademlia
examples.mDNS
examples.random_walk

View File

@ -0,0 +1,48 @@
libp2p.discovery.random_walk package
====================================
The Random Walk module implements a peer discovery mechanism.
It performs random walks through the DHT network to discover new peers and maintain routing table health through periodic refreshes.
Submodules
----------
libp2p.discovery.random_walk.config module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.config
:members:
:undoc-members:
:show-inheritance:
libp2p.discovery.random_walk.exceptions module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.exceptions
:members:
:undoc-members:
:show-inheritance:
libp2p.discovery.random_walk.random_walk module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.random_walk
:members:
:undoc-members:
:show-inheritance:
libp2p.discovery.random_walk.rt_refresh_manager module
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: libp2p.discovery.random_walk.rt_refresh_manager
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.discovery.random_walk
:members:
:undoc-members:
:show-inheritance:

View File

@ -10,6 +10,7 @@ Subpackages
libp2p.discovery.bootstrap
libp2p.discovery.events
libp2p.discovery.mdns
libp2p.discovery.random_walk
Submodules
----------

View File

@ -227,7 +227,7 @@ async def run_node(
# Keep the node running
while True:
logger.debug(
logger.info(
"Status - Connected peers: %d,"
"Peers in store: %d, Values in store: %d",
len(dht.host.get_connected_peers()),

View File

@ -0,0 +1,221 @@
"""
Random Walk Example for py-libp2p Kademlia DHT
This example demonstrates the Random Walk module's peer discovery capabilities
using real libp2p hosts and Kademlia DHT. It shows how the Random Walk module
automatically discovers new peers and maintains routing table health.
Usage:
# Start server nodes (they will discover peers via random walk)
python3 random_walk.py --mode server
"""
import argparse
import logging
import random
import secrets
import sys
from multiaddr import Multiaddr
import trio
from libp2p import new_host
from libp2p.abc import IHost
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.kad_dht.kad_dht import DHTMode, KadDHT
from libp2p.tools.async_service import background_trio_service
# Simple logging configuration
def setup_logging(verbose: bool = False):
"""Setup unified logging configuration."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
# Configure key module loggers
for module in ["libp2p.discovery.random_walk", "libp2p.kad_dht"]:
logging.getLogger(module).setLevel(level)
# Suppress noisy logs
logging.getLogger("multiaddr").setLevel(logging.WARNING)
logger = logging.getLogger("random-walk-example")
# Default bootstrap nodes
DEFAULT_BOOTSTRAP_NODES = [
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ"
]
def filter_compatible_peer_info(peer_info) -> bool:
"""Filter peer info to check if it has compatible addresses (TCP + IPv4)."""
if not hasattr(peer_info, "addrs") or not peer_info.addrs:
return False
for addr in peer_info.addrs:
addr_str = str(addr)
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
return True
return False
async def maintain_connections(host: IHost) -> None:
"""Maintain connections to ensure the host remains connected to healthy peers."""
while True:
try:
connected_peers = host.get_connected_peers()
list_peers = host.get_peerstore().peers_with_addrs()
if len(connected_peers) < 20:
logger.debug("Reconnecting to maintain peer connections...")
# Find compatible peers
compatible_peers = []
for peer_id in list_peers:
try:
peer_info = host.get_peerstore().peer_info(peer_id)
if filter_compatible_peer_info(peer_info):
compatible_peers.append(peer_id)
except Exception:
continue
# Connect to random subset of compatible peers
if compatible_peers:
random_peers = random.sample(
compatible_peers, min(50, len(compatible_peers))
)
for peer_id in random_peers:
if peer_id not in connected_peers:
try:
with trio.move_on_after(5):
peer_info = host.get_peerstore().peer_info(peer_id)
await host.connect(peer_info)
logger.debug(f"Connected to peer: {peer_id}")
except Exception as e:
logger.debug(f"Failed to connect to {peer_id}: {e}")
await trio.sleep(15)
except Exception as e:
logger.error(f"Error maintaining connections: {e}")
async def demonstrate_random_walk_discovery(dht: KadDHT, interval: int = 30) -> None:
"""Demonstrate Random Walk peer discovery with periodic statistics."""
iteration = 0
while True:
iteration += 1
logger.info(f"--- Iteration {iteration} ---")
logger.info(f"Routing table size: {dht.get_routing_table_size()}")
logger.info(f"Connected peers: {len(dht.host.get_connected_peers())}")
logger.info(f"Peerstore size: {len(dht.host.get_peerstore().peer_ids())}")
await trio.sleep(interval)
async def run_node(port: int, mode: str, demo_interval: int = 30) -> None:
"""Run a node that demonstrates Random Walk peer discovery."""
try:
if port <= 0:
port = random.randint(10000, 60000)
logger.info(f"Starting {mode} node on port {port}")
# Determine DHT mode
dht_mode = DHTMode.SERVER if mode == "server" else DHTMode.CLIENT
# Create host and DHT
key_pair = create_new_key_pair(secrets.token_bytes(32))
host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES)
listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Start maintenance tasks
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
nursery.start_soon(maintain_connections, host)
peer_id = host.get_id().pretty()
logger.info(f"Node peer ID: {peer_id}")
logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}")
# Create and start DHT with Random Walk enabled
dht = KadDHT(host, dht_mode, enable_random_walk=True)
logger.info(f"Initial routing table size: {dht.get_routing_table_size()}")
async with background_trio_service(dht):
logger.info(f"DHT service started in {dht_mode.value} mode")
logger.info(f"Random Walk enabled: {dht.is_random_walk_enabled()}")
async with trio.open_nursery() as task_nursery:
# Start demonstration and status reporting
task_nursery.start_soon(
demonstrate_random_walk_discovery, dht, demo_interval
)
# Periodic status updates
async def status_reporter():
while True:
await trio.sleep(30)
logger.debug(
f"Connected: {len(dht.host.get_connected_peers())}, "
f"Routing table: {dht.get_routing_table_size()}, "
f"Peerstore: {len(dht.host.get_peerstore().peer_ids())}"
)
task_nursery.start_soon(status_reporter)
await trio.sleep_forever()
except Exception as e:
logger.error(f"Node error: {e}", exc_info=True)
sys.exit(1)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Random Walk Example for py-libp2p Kademlia DHT",
)
parser.add_argument(
"--mode",
choices=["server", "client"],
default="server",
help="Node mode: server (DHT server), or client (DHT client)",
)
parser.add_argument(
"--port", type=int, default=0, help="Port to listen on (0 for random)"
)
parser.add_argument(
"--demo-interval",
type=int,
default=30,
help="Interval between random walk demonstrations in seconds",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
return parser.parse_args()
def main():
"""Main entry point for the random walk example."""
try:
args = parse_args()
setup_logging(args.verbose)
logger.info("=== Random Walk Example for py-libp2p ===")
logger.info(
f"Mode: {args.mode}, Port: {args.port} Demo interval: {args.demo_interval}s"
)
trio.run(run_node, args.port, args.mode, args.demo_interval)
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down...")
except Exception as e:
logger.critical(f"Example failed: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,17 @@
"""Random walk discovery modules for py-libp2p."""
from .rt_refresh_manager import RTRefreshManager
from .random_walk import RandomWalk
from .exceptions import (
RoutingTableRefreshError,
RandomWalkError,
PeerValidationError,
)
__all__ = [
"RTRefreshManager",
"RandomWalk",
"RoutingTableRefreshError",
"RandomWalkError",
"PeerValidationError",
]

View File

@ -0,0 +1,16 @@
from typing import Final
# Timing constants (matching go-libp2p)
PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds
REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds
REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes
SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute
# Routing table thresholds
MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh
MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try
# Random walk specific
RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks
RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks
RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback

View File

@ -0,0 +1,19 @@
from libp2p.exceptions import BaseLibp2pError
class RoutingTableRefreshError(BaseLibp2pError):
"""Base exception for routing table refresh operations."""
pass
class RandomWalkError(RoutingTableRefreshError):
"""Exception raised during random walk operations."""
pass
class PeerValidationError(RoutingTableRefreshError):
"""Exception raised when peer validation fails."""
pass

View File

@ -0,0 +1,218 @@
from collections.abc import Awaitable, Callable
import logging
import secrets
import trio
from libp2p.abc import IHost
from libp2p.discovery.random_walk.config import (
RANDOM_WALK_CONCURRENCY,
RANDOM_WALK_RT_THRESHOLD,
REFRESH_QUERY_TIMEOUT,
)
from libp2p.discovery.random_walk.exceptions import RandomWalkError
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
logger = logging.getLogger("libp2p.discovery.random_walk")
class RandomWalk:
"""
Random Walk implementation for peer discovery in Kademlia DHT.
Generates random peer IDs and performs FIND_NODE queries to discover
new peers and populate the routing table.
"""
def __init__(
self,
host: IHost,
local_peer_id: ID,
query_function: Callable[[bytes], Awaitable[list[ID]]],
):
"""
Initialize Random Walk module.
Args:
host: The libp2p host instance
local_peer_id: Local peer ID
query_function: Function to query for closest peers given target key bytes
"""
self.host = host
self.local_peer_id = local_peer_id
self.query_function = query_function
def generate_random_peer_id(self) -> str:
"""
Generate a completely random peer ID
for random walk queries.
Returns:
Random peer ID as string
"""
# Generate 32 random bytes (256 bits) - same as go-libp2p
random_bytes = secrets.token_bytes(32)
# Convert to hex string for query
return random_bytes.hex()
async def perform_random_walk(self) -> list[PeerInfo]:
"""
Perform a single random walk operation.
Returns:
List of validated peers discovered during the walk
"""
try:
# Generate random peer ID
random_peer_id = self.generate_random_peer_id()
logger.info(f"Starting random walk for peer ID: {random_peer_id}")
# Perform FIND_NODE query
discovered_peer_ids: list[ID] = []
with trio.move_on_after(REFRESH_QUERY_TIMEOUT):
# Call the query function with target key bytes
target_key = bytes.fromhex(random_peer_id)
discovered_peer_ids = await self.query_function(target_key) or []
if not discovered_peer_ids:
logger.debug(f"No peers discovered in random walk for {random_peer_id}")
return []
logger.info(
f"Discovered {len(discovered_peer_ids)} peers in random walk "
f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity
)
# Convert peer IDs to PeerInfo objects and validate
validated_peers: list[PeerInfo] = []
for peer_id in discovered_peer_ids:
try:
# Get addresses from peerstore
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
peer_info = PeerInfo(peer_id, addrs)
validated_peers.append(peer_info)
except Exception as e:
logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}")
continue
return validated_peers
except Exception as e:
logger.error(f"Random walk failed: {e}")
raise RandomWalkError(f"Random walk operation failed: {e}") from e
async def run_concurrent_random_walks(
self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0
) -> list[PeerInfo]:
"""
Run multiple random walks concurrently.
Args:
count: Number of concurrent random walks to perform
current_routing_table_size: Current size of routing table (for optimization)
Returns:
Combined list of all validated peers discovered
"""
all_validated_peers: list[PeerInfo] = []
logger.info(f"Starting {count} concurrent random walks")
# First, try to add peers from peerstore if routing table is small
if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD:
try:
peerstore_peers = self._get_peerstore_peers()
if peerstore_peers:
logger.debug(
f"RT size ({current_routing_table_size}) below threshold, "
f"adding {len(peerstore_peers)} peerstore peers"
)
all_validated_peers.extend(peerstore_peers)
except Exception as e:
logger.warning(f"Error processing peerstore peers: {e}")
async def single_walk() -> None:
try:
peers = await self.perform_random_walk()
all_validated_peers.extend(peers)
except Exception as e:
logger.warning(f"Concurrent random walk failed: {e}")
return
# Run concurrent random walks
async with trio.open_nursery() as nursery:
for _ in range(count):
nursery.start_soon(single_walk)
# Remove duplicates based on peer ID
unique_peers = {}
for peer in all_validated_peers:
unique_peers[peer.peer_id] = peer
result = list(unique_peers.values())
logger.info(
f"Concurrent random walks completed: {len(result)} unique peers discovered"
)
return result
def _get_peerstore_peers(self) -> list[PeerInfo]:
"""
Get peer info objects from the host's peerstore.
Returns:
List of PeerInfo objects from peerstore
"""
try:
peerstore = self.host.get_peerstore()
peer_ids = peerstore.peers_with_addrs()
peer_infos = []
for peer_id in peer_ids:
try:
# Skip local peer
if peer_id == self.local_peer_id:
continue
peer_info = peerstore.peer_info(peer_id)
if peer_info and peer_info.addrs:
# Filter for compatible addresses (TCP + IPv4)
if self._has_compatible_addresses(peer_info):
peer_infos.append(peer_info)
except Exception as e:
logger.debug(f"Error getting peer info for {peer_id}: {e}")
return peer_infos
except Exception as e:
logger.warning(f"Error accessing peerstore: {e}")
return []
def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool:
"""
Check if a peer has TCP+IPv4 compatible addresses.
Args:
peer_info: PeerInfo to check
Returns:
True if peer has compatible addresses
"""
if not peer_info.addrs:
return False
for addr in peer_info.addrs:
addr_str = str(addr)
# Check for TCP and IPv4 compatibility, avoid QUIC
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
return True
return False

View File

@ -0,0 +1,208 @@
from collections.abc import Awaitable, Callable
import logging
import time
from typing import Protocol
import trio
from libp2p.abc import IHost
from libp2p.discovery.random_walk.config import (
MIN_RT_REFRESH_THRESHOLD,
RANDOM_WALK_CONCURRENCY,
RANDOM_WALK_ENABLED,
REFRESH_INTERVAL,
)
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
class RoutingTableProtocol(Protocol):
"""Protocol for routing table operations needed by RT refresh manager."""
def size(self) -> int:
"""Return the current size of the routing table."""
...
async def add_peer(self, peer_obj: PeerInfo) -> bool:
"""Add a peer to the routing table."""
...
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
class RTRefreshManager:
"""
Routing Table Refresh Manager for py-libp2p.
Manages periodic routing table refreshes and random walk operations
to maintain routing table health and discover new peers.
"""
def __init__(
self,
host: IHost,
routing_table: RoutingTableProtocol,
local_peer_id: ID,
query_function: Callable[[bytes], Awaitable[list[ID]]],
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
refresh_interval: float = REFRESH_INTERVAL,
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
):
"""
Initialize RT Refresh Manager.
Args:
host: The libp2p host instance
routing_table: Routing table of host
local_peer_id: Local peer ID
query_function: Function to query for closest peers given target key bytes
enable_auto_refresh: Whether to enable automatic refresh
refresh_interval: Interval between refreshes in seconds
min_refresh_threshold: Minimum RT size before triggering refresh
"""
self.host = host
self.routing_table = routing_table
self.local_peer_id = local_peer_id
self.query_function = query_function
self.enable_auto_refresh = enable_auto_refresh
self.refresh_interval = refresh_interval
self.min_refresh_threshold = min_refresh_threshold
# Initialize random walk module
self.random_walk = RandomWalk(
host=host,
local_peer_id=self.local_peer_id,
query_function=query_function,
)
# Control variables
self._running = False
self._nursery: trio.Nursery | None = None
# Tracking
self._last_refresh_time = 0.0
self._refresh_done_callbacks: list[Callable[[], None]] = []
async def start(self) -> None:
"""Start the RT Refresh Manager."""
if self._running:
logger.warning("RT Refresh Manager is already running")
return
self._running = True
logger.info("Starting RT Refresh Manager")
# Start the main loop
async with trio.open_nursery() as nursery:
self._nursery = nursery
nursery.start_soon(self._main_loop)
async def stop(self) -> None:
"""Stop the RT Refresh Manager."""
if not self._running:
return
logger.info("Stopping RT Refresh Manager")
self._running = False
async def _main_loop(self) -> None:
"""Main loop for the RT Refresh Manager."""
logger.info("RT Refresh Manager main loop started")
# Initial refresh if auto-refresh is enabled
if self.enable_auto_refresh:
await self._do_refresh(force=True)
try:
while self._running:
async with trio.open_nursery() as nursery:
# Schedule periodic refresh if enabled
if self.enable_auto_refresh:
nursery.start_soon(self._periodic_refresh_task)
except Exception as e:
logger.error(f"RT Refresh Manager main loop error: {e}")
finally:
logger.info("RT Refresh Manager main loop stopped")
async def _periodic_refresh_task(self) -> None:
"""Task for periodic refreshes."""
while self._running:
await trio.sleep(self.refresh_interval)
if self._running:
await self._do_refresh()
async def _do_refresh(self, force: bool = False) -> None:
"""
Perform routing table refresh operation.
Args:
force: Whether to force refresh regardless of timing
"""
try:
current_time = time.time()
# Check if refresh is needed
if not force:
if current_time - self._last_refresh_time < self.refresh_interval:
logger.debug("Skipping refresh: interval not elapsed")
return
if self.routing_table.size() >= self.min_refresh_threshold:
logger.debug("Skipping refresh: routing table size above threshold")
return
logger.info(f"Starting routing table refresh (force={force})")
start_time = current_time
# Perform random walks to discover new peers
logger.info("Running concurrent random walks to discover new peers")
current_rt_size = self.routing_table.size()
discovered_peers = await self.random_walk.run_concurrent_random_walks(
count=RANDOM_WALK_CONCURRENCY,
current_routing_table_size=current_rt_size,
)
# Add discovered peers to routing table
added_count = 0
for peer_info in discovered_peers:
result = await self.routing_table.add_peer(peer_info)
if result:
added_count += 1
self._last_refresh_time = current_time
duration = time.time() - start_time
logger.info(
f"Routing table refresh completed: "
f"{added_count}/{len(discovered_peers)} peers added, "
f"RT size: {self.routing_table.size()}, "
f"duration: {duration:.2f}s"
)
# Notify refresh completion
for callback in self._refresh_done_callbacks:
try:
callback()
except Exception as e:
logger.warning(f"Refresh callback error: {e}")
except Exception as e:
logger.error(f"Routing table refresh failed: {e}")
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback to be called when refresh completes."""
self._refresh_done_callbacks.append(callback)
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
"""Remove a refresh completion callback."""
if callback in self._refresh_done_callbacks:
self._refresh_done_callbacks.remove(callback)

View File

@ -295,6 +295,13 @@ class BasicHost(IHost):
)
await net_stream.reset()
return
if protocol is None:
logger.debug(
"no protocol negotiated, closing stream from peer %s",
net_stream.muxed_conn.peer_id,
)
await net_stream.reset()
return
net_stream.set_protocol(protocol)
if handler is None:
logger.debug(

View File

@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT)
implementation based on the Kademlia algorithm and protocol.
"""
from collections.abc import Awaitable, Callable
from enum import (
Enum,
)
@ -20,6 +21,7 @@ import varint
from libp2p.abc import (
IHost,
)
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.network.stream.net_stream import (
INetStream,
)
@ -73,14 +75,27 @@ class KadDHT(Service):
This class provides a DHT implementation that combines routing table management,
peer discovery, content routing, and value storage.
Optional Random Walk feature enhances peer discovery by automatically
performing periodic random queries to discover new peers and maintain
routing table health.
Example:
# Basic DHT without random walk (default)
dht = KadDHT(host, DHTMode.SERVER)
# DHT with random walk enabled for enhanced peer discovery
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
"""
def __init__(self, host: IHost, mode: DHTMode):
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
"""
Initialize a new Kademlia DHT node.
:param host: The libp2p host.
:param mode: The mode of host (Client or Server) - must be DHTMode enum
:param enable_random_walk: Whether to enable automatic random walk
"""
super().__init__()
@ -92,6 +107,7 @@ class KadDHT(Service):
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
self.mode = mode
self.enable_random_walk = enable_random_walk
# Initialize the routing table
self.routing_table = RoutingTable(self.local_peer_id, self.host)
@ -108,13 +124,56 @@ class KadDHT(Service):
# Last time we republished provider records
self._last_provider_republish = time.time()
# Initialize RT Refresh Manager (only if random walk is enabled)
self.rt_refresh_manager: RTRefreshManager | None = None
if self.enable_random_walk:
self.rt_refresh_manager = RTRefreshManager(
host=self.host,
routing_table=self.routing_table,
local_peer_id=self.local_peer_id,
query_function=self._create_query_function(),
enable_auto_refresh=True,
)
# Set protocol handlers
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
"""
Create a query function that wraps peer_routing.find_closest_peers_network.
This function is used by the RandomWalk module to query for peers without
directly importing PeerRouting, avoiding circular import issues.
Returns:
Callable that takes target_key bytes and returns list of peer IDs
"""
async def query_function(target_key: bytes) -> list[ID]:
"""Query for closest peers to target key."""
return await self.peer_routing.find_closest_peers_network(target_key)
return query_function
async def run(self) -> None:
"""Run the DHT service."""
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
# Start the RT Refresh Manager in parallel with the main DHT service
async with trio.open_nursery() as nursery:
# Start the RT Refresh Manager only if random walk is enabled
if self.rt_refresh_manager is not None:
nursery.start_soon(self.rt_refresh_manager.start)
logger.info("RT Refresh Manager started - Random Walk is now active")
else:
logger.info("Random Walk is disabled - RT Refresh Manager not started")
# Start the main DHT service loop
nursery.start_soon(self._run_main_loop)
async def _run_main_loop(self) -> None:
"""Run the main DHT service loop."""
# Main service loop
while self.manager.is_running:
# Periodically refresh the routing table
@ -135,6 +194,17 @@ class KadDHT(Service):
# Wait before next maintenance cycle
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
async def stop(self) -> None:
"""Stop the DHT service and cleanup resources."""
logger.info("Stopping Kademlia DHT")
# Stop the RT Refresh Manager only if it was started
if self.rt_refresh_manager is not None:
await self.rt_refresh_manager.stop()
logger.info("RT Refresh Manager stopped")
else:
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
"""
Switch the DHT mode.
@ -614,3 +684,15 @@ class KadDHT(Service):
"""
return self.value_store.size()
def is_random_walk_enabled(self) -> bool:
"""
Check if random walk peer discovery is enabled.
Returns
-------
bool
True if random walk is enabled, False otherwise.
"""
return self.enable_random_walk

View File

@ -170,7 +170,7 @@ class PeerRouting(IPeerRouting):
# Return early if we have no peers to start with
if not closest_peers:
logger.warning("No local peers available for network lookup")
logger.debug("No local peers available for network lookup")
return []
# Iterative lookup until convergence

View File

@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
"""
self.handlers[protocol] = handler
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
async def negotiate(
self,
communicator: IMultiselectCommunicator,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> tuple[TProtocol, StreamHandlerFn | None]:
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
"""
Negotiate performs protocol selection.
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
raise MultiselectError() from error
else:
protocol = TProtocol(command)
if protocol in self.handlers:
protocol_to_check = None if not command else TProtocol(command)
if protocol_to_check in self.handlers:
try:
await communicator.write(protocol)
await communicator.write(command)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
return protocol, self.handlers[protocol]
return protocol_to_check, self.handlers[protocol_to_check]
try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error:

View File

@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
:raise MultiselectClientError: raised when protocol negotiation failed
:return: selected protocol
"""
# Represent `None` protocol as an empty string.
protocol_str = protocol if protocol is not None else ""
try:
await communicator.write(protocol)
await communicator.write(protocol_str)
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
if response == protocol:
if response == protocol_str:
return protocol
if response == PROTOCOL_NOT_FOUND_MSG:
raise MultiselectClientError("protocol not supported")

View File

@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
"""
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
""" # noqa: E501
msg_bytes = encode_delim(msg_str.encode())
if msg_str is None:
msg_bytes = encode_delim(b"")
else:
msg_bytes = encode_delim(msg_str.encode())
try:
await self.read_writer.write(msg_bytes)
except IOException as error:

View File

@ -17,6 +17,9 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
:param is_initiator: true if we are the initiator, false otherwise
:return: selected secure transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if is_initiator:
# Select protocol if initiator
@ -114,5 +117,7 @@ class SecurityMultistream(ABC):
else:
# Select protocol if non-initiator
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a security protocol")
# Return transport from protocol
return self.transports[protocol]

View File

@ -17,6 +17,9 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
@ -73,7 +76,7 @@ class MuxerMultistream:
:param conn: conn to choose a transport over
:return: selected muxer transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
@ -81,6 +84,8 @@ class MuxerMultistream:
)
else:
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a stream muxer protocol")
return self.transports[protocol]
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:

View File

@ -0,0 +1 @@
Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py

View File

@ -0,0 +1 @@
Added `Random Walk` peer discovery module that enables random peer exploration for improved peer discovery.

View File

@ -11,9 +11,9 @@ requires-python = ">=3.10, <4.0"
license = { text = "MIT AND Apache-2.0" }
keywords = ["libp2p", "p2p"]
maintainers = [
{ name = "pacrob", email = "pacrob@protonmail.com" },
{ name = "pacrob", email = "pacrob-py-libp2p@proton.me" },
{ name = "Manu Sheel Gupta", email = "manu@seeta.in" },
{ name = "Dave Grantham", email = "dave@aviation.community" },
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
]
dependencies = [
"base58>=1.0.3",

View File

@ -1,9 +1,9 @@
from collections import deque
import pytest
import trio
from libp2p.abc import (
IMultiselectCommunicator,
)
from libp2p.abc import IMultiselectCommunicator, INetStream
from libp2p.custom_types import TProtocol
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
async def dummy_handler(stream: INetStream) -> None:
pass
class DummyMultiselectCommunicator(IMultiselectCommunicator):
"""
Dummy MultiSelectCommunicator to test out negotiate timmeout.
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
@pytest.mark.trio
async def test_select_one_of_timeout():
async def test_select_one_of_timeout() -> None:
ECHO = TProtocol("/echo/1.0.0")
communicator = DummyMultiselectCommunicator()
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
@pytest.mark.trio
async def test_query_multistream_command_timeout():
async def test_query_multistream_command_timeout() -> None:
communicator = DummyMultiselectCommunicator()
client = MultiselectClient()
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
@pytest.mark.trio
async def test_negotiate_timeout():
async def test_negotiate_timeout() -> None:
communicator = DummyMultiselectCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 2)
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
handshaked: bool
def __init__(self) -> None:
self.handshaked = False
async def write(self, msg_str: str) -> None:
if msg_str == "/multistream/1.0.0":
self.handshaked = True
return
async def read(self) -> str:
if not self.handshaked:
return "/multistream/1.0.0"
# After handshake, hang on read.
await trio.sleep_forever()
# Should not be reached.
return ""
@pytest.mark.trio
async def test_negotiate_timeout_post_handshake() -> None:
communicator = HandshakeThenHangCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 1)
class MockCommunicator(IMultiselectCommunicator):
def __init__(self, commands_to_read: list[str]):
self.read_queue = deque(commands_to_read)
self.written_data: list[str] = []
async def write(self, msg_str: str) -> None:
self.written_data.append(msg_str)
async def read(self) -> str:
if not self.read_queue:
raise EOFError
return self.read_queue.popleft()
@pytest.mark.trio
async def test_negotiate_empty_string_command() -> None:
# server receives an empty string, which means client wants `None` protocol.
server = Multiselect({None: dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check that server sent back handshake and the protocol confirmation (empty string)
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler() -> None:
# server has None handler, client sends "" to select it.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, protocol confirmation
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler_ls() -> None:
# server has None handler, client sends "ls" then empty string.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, ls, empty command
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, ls response, protocol confirmation
assert communicator.written_data[0] == "/multistream/1.0.0"
assert "/proto1" in communicator.written_data[1]
# Note: `ls` should not list the `None` protocol.
assert "None" not in communicator.written_data[1]
assert "\n\n" not in communicator.written_data[1]
assert communicator.written_data[2] == ""

View File

@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
protocols = ms.get_protocols()
assert set(protocols) == {p1, p2, p3}
@pytest.mark.trio
async def test_negotiate_optional_tprotocol(security_protocol):
with pytest.raises(Exception):
await perform_simple_test(
None,
[None],
[None],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test(
expected_selected_protocol,
[None, PROTOCOL_ECHO],
[PROTOCOL_ECHO],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_server_none_client_other(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)

View File

@ -0,0 +1,99 @@
"""
Unit tests for the RandomWalk module in libp2p.discovery.random_walk.
"""
from unittest.mock import AsyncMock, Mock
import pytest
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
@pytest.fixture
def mock_host():
host = Mock()
peerstore = Mock()
peerstore.peers_with_addrs.return_value = []
peerstore.addrs.return_value = [Mock()]
host.get_peerstore.return_value = peerstore
host.new_stream = AsyncMock()
return host
@pytest.fixture
def dummy_query_function():
async def query(key_bytes):
return []
return query
@pytest.fixture
def dummy_peer_id():
return b"\x01" * 32
@pytest.mark.trio
async def test_random_walk_initialization(
mock_host, dummy_peer_id, dummy_query_function
):
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
assert rw.host == mock_host
assert rw.local_peer_id == dummy_peer_id
assert rw.query_function == dummy_query_function
def test_generate_random_peer_id(mock_host, dummy_peer_id, dummy_query_function):
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
peer_id = rw.generate_random_peer_id()
assert isinstance(peer_id, str)
assert len(peer_id) == 64 # 32 bytes hex
@pytest.mark.trio
async def test_run_concurrent_random_walks(mock_host, dummy_peer_id):
# Dummy query function returns different peer IDs for each walk
call_count = {"count": 0}
async def query(key_bytes):
call_count["count"] += 1
# Return a unique peer ID for each call
return [ID(bytes([call_count["count"]] * 32))]
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.run_concurrent_random_walks(count=3)
# Should get 3 unique peers
assert len(peers) == 3
peer_ids = [peer.peer_id for peer in peers]
assert len(set(peer_ids)) == 3
@pytest.mark.trio
async def test_perform_random_walk_running(mock_host, dummy_peer_id):
# Query function returns a single peer ID
async def query(key_bytes):
return [ID(b"\x02" * 32)]
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.perform_random_walk()
assert isinstance(peers, list)
if peers:
assert isinstance(peers[0], PeerInfo)
@pytest.mark.trio
async def test_perform_random_walk_no_peers_found(mock_host, dummy_peer_id):
"""Test perform_random_walk when no peers are discovered."""
# Query function returns empty list (no peers found)
async def query(key_bytes):
return []
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.perform_random_walk()
# Should return empty list when no peers are found
assert isinstance(peers, list)
assert len(peers) == 0

View File

@ -0,0 +1,451 @@
"""
Unit tests for the RTRefreshManager and related random walk logic.
"""
import time
from unittest.mock import AsyncMock, Mock, patch
import pytest
import trio
from libp2p.discovery.random_walk.config import (
MIN_RT_REFRESH_THRESHOLD,
RANDOM_WALK_CONCURRENCY,
REFRESH_INTERVAL,
)
from libp2p.discovery.random_walk.exceptions import (
RandomWalkError,
)
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
class DummyRoutingTable:
def __init__(self, size=0):
self._size = size
self.added_peers = []
def size(self):
return self._size
async def add_peer(self, peer_obj):
self.added_peers.append(peer_obj)
self._size += 1
return True
@pytest.fixture
def mock_host():
host = Mock()
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()
return host
@pytest.fixture
def local_peer_id():
return ID(b"\x01" * 32)
@pytest.fixture
def dummy_query_function():
async def query(key_bytes):
return [ID(b"\x02" * 32)]
return query
@pytest.mark.trio
async def test_rt_refresh_manager_initialization(
mock_host, local_peer_id, dummy_query_function
):
rt = DummyRoutingTable(size=5)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
assert manager.host == mock_host
assert manager.routing_table == rt
assert manager.local_peer_id == local_peer_id
assert manager.query_function == dummy_query_function
@pytest.mark.trio
async def test_rt_refresh_manager_refresh_logic(
mock_host, local_peer_id, dummy_query_function
):
rt = DummyRoutingTable(size=2)
# Simulate refresh logic
if rt.size() < MIN_RT_REFRESH_THRESHOLD:
await rt.add_peer(Mock())
assert rt.size() >= 3
@pytest.mark.trio
async def test_rt_refresh_manager_random_walk_integration(
mock_host, local_peer_id, dummy_query_function
):
# Simulate random walk usage
rw = RandomWalk(mock_host, local_peer_id, dummy_query_function)
random_peer_id = rw.generate_random_peer_id()
assert isinstance(random_peer_id, str)
assert len(random_peer_id) == 64
@pytest.mark.trio
async def test_rt_refresh_manager_error_handling(mock_host, local_peer_id):
rt = DummyRoutingTable(size=0)
async def failing_query(_):
raise RandomWalkError("Query failed")
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=failing_query,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
with pytest.raises(RandomWalkError):
await manager.query_function(b"key")
@pytest.mark.trio
async def test_rt_refresh_manager_start_method(
mock_host, local_peer_id, dummy_query_function
):
"""Test the start method functionality."""
rt = DummyRoutingTable(size=2)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=False, # Disable auto-refresh to control the test
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
):
# Test starting the manager
assert not manager._running
# Start the manager in a nursery that we can control
async with trio.open_nursery() as nursery:
nursery.start_soon(manager.start)
await trio.sleep(0.01) # Let it start
# Verify it's running
assert manager._running
# Stop the manager
await manager.stop()
assert not manager._running
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_with_auto_refresh(
mock_host, local_peer_id, dummy_query_function
):
"""Test the _main_loop method with auto-refresh enabled."""
rt = DummyRoutingTable(size=1) # Small size to trigger refresh
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
) as mock_random_walk:
manager._running = True
# Run the main loop for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._main_loop)
await trio.sleep(0.05) # Let it run briefly
manager._running = False # Stop the loop
# Verify that random walk was called (initial refresh)
mock_random_walk.assert_called()
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_without_auto_refresh(
mock_host, local_peer_id, dummy_query_function
):
"""Test the _main_loop method with auto-refresh disabled."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=False,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
manager._running = True
# Run the main loop for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._main_loop)
await trio.sleep(0.05)
manager._running = False
# Verify that random walk was not called since auto-refresh is disabled
mock_random_walk.assert_not_called()
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_initial_refresh_exception(
mock_host, local_peer_id, dummy_query_function
):
"""Test that _main_loop propagates exceptions from initial refresh."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock _do_refresh to raise an exception on the initial call
with patch.object(
manager, "_do_refresh", side_effect=Exception("Initial refresh failed")
):
manager._running = True
# The initial refresh exception should propagate
with pytest.raises(Exception, match="Initial refresh failed"):
await manager._main_loop()
@pytest.mark.trio
async def test_do_refresh_force_refresh(mock_host, local_peer_id, dummy_query_function):
"""Test _do_refresh method with force=True."""
rt = DummyRoutingTable(size=10) # Large size, but force should override
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info1 = Mock(spec=PeerInfo)
mock_peer_info2 = Mock(spec=PeerInfo)
discovered_peers = [mock_peer_info1, mock_peer_info2]
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=discovered_peers,
) as mock_random_walk:
# Force refresh should work regardless of RT size
await manager._do_refresh(force=True)
# Verify random walk was called
mock_random_walk.assert_called_once_with(
count=RANDOM_WALK_CONCURRENCY, current_routing_table_size=10
)
# Verify peers were added to routing table
assert len(rt.added_peers) == 2
assert manager._last_refresh_time > 0
@pytest.mark.trio
async def test_do_refresh_skip_due_to_interval(
mock_host, local_peer_id, dummy_query_function
):
"""Test _do_refresh skips refresh when interval hasn't elapsed."""
rt = DummyRoutingTable(size=1) # Small size to trigger refresh normally
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=100.0, # Long interval
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Set last refresh time to recent
manager._last_refresh_time = time.time()
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=False)
# Verify refresh was skipped
mock_random_walk.assert_not_called()
mock_logger.debug.assert_called_with(
"Skipping refresh: interval not elapsed"
)
@pytest.mark.trio
async def test_do_refresh_skip_due_to_rt_size(
mock_host, local_peer_id, dummy_query_function
):
"""Test _do_refresh skips refresh when RT size is above threshold."""
rt = DummyRoutingTable(size=20) # Large size above threshold
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1, # Short interval
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Set last refresh time to old
manager._last_refresh_time = 0.0
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=False)
# Verify refresh was skipped
mock_random_walk.assert_not_called()
mock_logger.debug.assert_called_with(
"Skipping refresh: routing table size above threshold"
)
@pytest.mark.trio
async def test_refresh_done_callbacks(mock_host, local_peer_id, dummy_query_function):
"""Test refresh completion callbacks functionality."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Create mock callbacks
callback1 = Mock()
callback2 = Mock()
failing_callback = Mock(side_effect=Exception("Callback failed"))
# Add callbacks
manager.add_refresh_done_callback(callback1)
manager.add_refresh_done_callback(callback2)
manager.add_refresh_done_callback(failing_callback)
# Mock the random walk
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
):
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=True)
# Verify all callbacks were called
callback1.assert_called_once()
callback2.assert_called_once()
failing_callback.assert_called_once()
# Verify warning was logged for failing callback
mock_logger.warning.assert_called()
@pytest.mark.trio
async def test_stop_when_not_running(mock_host, local_peer_id, dummy_query_function):
"""Test stop method when manager is not running."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Manager is not running
assert not manager._running
# Stop should return without doing anything
await manager.stop()
assert not manager._running
@pytest.mark.trio
async def test_periodic_refresh_task(mock_host, local_peer_id, dummy_query_function):
"""Test the _periodic_refresh_task method."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.05, # Very short interval for testing
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock _do_refresh to track calls
with patch.object(manager, "_do_refresh") as mock_do_refresh:
manager._running = True
# Run periodic refresh task for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._periodic_refresh_task)
await trio.sleep(0.12) # Let it run for ~2 intervals
manager._running = False # Stop the task
# Verify _do_refresh was called at least once
assert mock_do_refresh.call_count >= 1