mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'libp2p:main' into tests/notifee-coverage
This commit is contained in:
131
docs/examples.random_walk.rst
Normal file
131
docs/examples.random_walk.rst
Normal file
@ -0,0 +1,131 @@
|
||||
Random Walk Example
|
||||
===================
|
||||
|
||||
This example demonstrates the Random Walk module's peer discovery capabilities using real libp2p hosts and Kademlia DHT.
|
||||
It shows how the Random Walk module automatically discovers new peers and maintains routing table health.
|
||||
|
||||
The Random Walk implementation performs the following key operations:
|
||||
|
||||
* **Automatic Peer Discovery**: Generates random peer IDs and queries the DHT network to discover new peers
|
||||
* **Routing Table Maintenance**: Periodically refreshes the routing table to maintain network connectivity
|
||||
* **Connection Management**: Maintains optimal connections to healthy peers in the network
|
||||
* **Real-time Statistics**: Displays routing table size, connected peers, and peerstore statistics
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
Collecting libp2p
|
||||
...
|
||||
Successfully installed libp2p-x.x.x
|
||||
$ cd examples/random_walk
|
||||
$ python random_walk.py --mode server
|
||||
2025-08-12 19:51:25,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
|
||||
2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0
|
||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode
|
||||
2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - --- Iteration 1 ---
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Routing table size: 15
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Connected peers: 8
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Peerstore size: 42
|
||||
|
||||
You can also run the example in client mode:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python random_walk.py --mode client
|
||||
2025-08-12 19:52:15,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
|
||||
2025-08-12 19:52:15,424 - random-walk-example - INFO - Mode: client, Port: 0 Demo interval: 30s
|
||||
2025-08-12 19:52:15,426 - random-walk-example - INFO - Starting client node on port 51234
|
||||
2025-08-12 19:52:15,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAmAbc123xyz...
|
||||
2025-08-12 19:52:15,427 - random-walk-example - INFO - DHT service started in CLIENT mode
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - --- Iteration 1 ---
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Routing table size: 8
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Connected peers: 5
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Peerstore size: 25
|
||||
|
||||
Command Line Options
|
||||
--------------------
|
||||
|
||||
The example supports several command-line options:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python random_walk.py --help
|
||||
usage: random_walk.py [-h] [--mode {server,client}] [--port PORT]
|
||||
[--demo-interval DEMO_INTERVAL] [--verbose]
|
||||
|
||||
Random Walk Example for py-libp2p Kademlia DHT
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--mode {server,client}
|
||||
Node mode: server (DHT server), or client (DHT client)
|
||||
--port PORT Port to listen on (0 for random)
|
||||
--demo-interval DEMO_INTERVAL
|
||||
Interval between random walk demonstrations in seconds
|
||||
--verbose Enable verbose logging
|
||||
|
||||
Key Features Demonstrated
|
||||
-------------------------
|
||||
|
||||
**Automatic Random Walk Discovery**
|
||||
The example shows how the Random Walk module automatically:
|
||||
|
||||
* Generates random 256-bit peer IDs for discovery queries
|
||||
* Performs concurrent random walks to maximize peer discovery
|
||||
* Validates discovered peers and adds them to the routing table
|
||||
* Maintains routing table health through periodic refreshes
|
||||
|
||||
**Real-time Network Statistics**
|
||||
The example displays live statistics every 30 seconds (configurable):
|
||||
|
||||
* **Routing Table Size**: Number of peers in the Kademlia routing table
|
||||
* **Connected Peers**: Number of actively connected peers
|
||||
* **Peerstore Size**: Total number of known peers with addresses
|
||||
|
||||
**Connection Management**
|
||||
The example includes sophisticated connection management:
|
||||
|
||||
* Automatically maintains connections to healthy peers
|
||||
* Filters for compatible peers (TCP + IPv4 addresses)
|
||||
* Reconnects to maintain optimal network connectivity
|
||||
* Handles connection failures gracefully
|
||||
|
||||
**DHT Integration**
|
||||
Shows seamless integration between Random Walk and Kademlia DHT:
|
||||
|
||||
* RT Refresh Manager coordinates with the DHT routing table
|
||||
* Peer discovery feeds directly into DHT operations
|
||||
* Both SERVER and CLIENT modes supported
|
||||
* Bootstrap connectivity to public IPFS nodes
|
||||
|
||||
Understanding the Output
|
||||
------------------------
|
||||
|
||||
When you run the example, you'll see periodic statistics that show how the Random Walk module is working:
|
||||
|
||||
* **Initial Phase**: Routing table starts empty and quickly discovers peers
|
||||
* **Growth Phase**: Routing table size increases as more peers are discovered
|
||||
* **Maintenance Phase**: Routing table size stabilizes as the system maintains optimal peer connections
|
||||
|
||||
The Random Walk module runs automatically in the background, performing peer discovery queries every few minutes to ensure the routing table remains populated with fresh, reachable peers.
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
|
||||
The Random Walk module can be configured through the following parameters in ``libp2p.discovery.random_walk.config``:
|
||||
|
||||
* ``RANDOM_WALK_ENABLED``: Enable/disable automatic random walks (default: True)
|
||||
* ``REFRESH_INTERVAL``: Time between automatic refreshes in seconds (default: 300)
|
||||
* ``RANDOM_WALK_CONCURRENCY``: Number of concurrent random walks (default: 3)
|
||||
* ``MIN_RT_REFRESH_THRESHOLD``: Minimum routing table size before triggering refresh (default: 4)
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
* :doc:`examples.kademlia` - Kademlia DHT value storage and content routing
|
||||
* :doc:`libp2p.discovery.random_walk` - Random Walk module API documentation
|
||||
@ -14,3 +14,4 @@ Examples
|
||||
examples.circuit_relay
|
||||
examples.kademlia
|
||||
examples.mDNS
|
||||
examples.random_walk
|
||||
|
||||
48
docs/libp2p.discovery.random_walk.rst
Normal file
48
docs/libp2p.discovery.random_walk.rst
Normal file
@ -0,0 +1,48 @@
|
||||
libp2p.discovery.random_walk package
|
||||
====================================
|
||||
|
||||
The Random Walk module implements a peer discovery mechanism.
|
||||
It performs random walks through the DHT network to discover new peers and maintain routing table health through periodic refreshes.
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.discovery.random_walk.config module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.config
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.exceptions module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.exceptions
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.random_walk module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.random_walk
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.rt_refresh_manager module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.rt_refresh_manager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -10,6 +10,7 @@ Subpackages
|
||||
libp2p.discovery.bootstrap
|
||||
libp2p.discovery.events
|
||||
libp2p.discovery.mdns
|
||||
libp2p.discovery.random_walk
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
@ -227,7 +227,7 @@ async def run_node(
|
||||
|
||||
# Keep the node running
|
||||
while True:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Status - Connected peers: %d,"
|
||||
"Peers in store: %d, Values in store: %d",
|
||||
len(dht.host.get_connected_peers()),
|
||||
|
||||
221
examples/random_walk/random_walk.py
Normal file
221
examples/random_walk/random_walk.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""
|
||||
Random Walk Example for py-libp2p Kademlia DHT
|
||||
|
||||
This example demonstrates the Random Walk module's peer discovery capabilities
|
||||
using real libp2p hosts and Kademlia DHT. It shows how the Random Walk module
|
||||
automatically discovers new peers and maintains routing table health.
|
||||
|
||||
Usage:
|
||||
# Start server nodes (they will discover peers via random walk)
|
||||
python3 random_walk.py --mode server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.kad_dht.kad_dht import DHTMode, KadDHT
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
|
||||
|
||||
# Simple logging configuration
|
||||
def setup_logging(verbose: bool = False):
|
||||
"""Setup unified logging configuration."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
|
||||
# Configure key module loggers
|
||||
for module in ["libp2p.discovery.random_walk", "libp2p.kad_dht"]:
|
||||
logging.getLogger(module).setLevel(level)
|
||||
|
||||
# Suppress noisy logs
|
||||
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
logger = logging.getLogger("random-walk-example")
|
||||
|
||||
# Default bootstrap nodes
|
||||
DEFAULT_BOOTSTRAP_NODES = [
|
||||
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ"
|
||||
]
|
||||
|
||||
|
||||
def filter_compatible_peer_info(peer_info) -> bool:
|
||||
"""Filter peer info to check if it has compatible addresses (TCP + IPv4)."""
|
||||
if not hasattr(peer_info, "addrs") or not peer_info.addrs:
|
||||
return False
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
addr_str = str(addr)
|
||||
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def maintain_connections(host: IHost) -> None:
|
||||
"""Maintain connections to ensure the host remains connected to healthy peers."""
|
||||
while True:
|
||||
try:
|
||||
connected_peers = host.get_connected_peers()
|
||||
list_peers = host.get_peerstore().peers_with_addrs()
|
||||
|
||||
if len(connected_peers) < 20:
|
||||
logger.debug("Reconnecting to maintain peer connections...")
|
||||
|
||||
# Find compatible peers
|
||||
compatible_peers = []
|
||||
for peer_id in list_peers:
|
||||
try:
|
||||
peer_info = host.get_peerstore().peer_info(peer_id)
|
||||
if filter_compatible_peer_info(peer_info):
|
||||
compatible_peers.append(peer_id)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Connect to random subset of compatible peers
|
||||
if compatible_peers:
|
||||
random_peers = random.sample(
|
||||
compatible_peers, min(50, len(compatible_peers))
|
||||
)
|
||||
for peer_id in random_peers:
|
||||
if peer_id not in connected_peers:
|
||||
try:
|
||||
with trio.move_on_after(5):
|
||||
peer_info = host.get_peerstore().peer_info(peer_id)
|
||||
await host.connect(peer_info)
|
||||
logger.debug(f"Connected to peer: {peer_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to connect to {peer_id}: {e}")
|
||||
|
||||
await trio.sleep(15)
|
||||
except Exception as e:
|
||||
logger.error(f"Error maintaining connections: {e}")
|
||||
|
||||
|
||||
async def demonstrate_random_walk_discovery(dht: KadDHT, interval: int = 30) -> None:
|
||||
"""Demonstrate Random Walk peer discovery with periodic statistics."""
|
||||
iteration = 0
|
||||
while True:
|
||||
iteration += 1
|
||||
logger.info(f"--- Iteration {iteration} ---")
|
||||
logger.info(f"Routing table size: {dht.get_routing_table_size()}")
|
||||
logger.info(f"Connected peers: {len(dht.host.get_connected_peers())}")
|
||||
logger.info(f"Peerstore size: {len(dht.host.get_peerstore().peer_ids())}")
|
||||
await trio.sleep(interval)
|
||||
|
||||
|
||||
async def run_node(port: int, mode: str, demo_interval: int = 30) -> None:
|
||||
"""Run a node that demonstrates Random Walk peer discovery."""
|
||||
try:
|
||||
if port <= 0:
|
||||
port = random.randint(10000, 60000)
|
||||
|
||||
logger.info(f"Starting {mode} node on port {port}")
|
||||
|
||||
# Determine DHT mode
|
||||
dht_mode = DHTMode.SERVER if mode == "server" else DHTMode.CLIENT
|
||||
|
||||
# Create host and DHT
|
||||
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
||||
host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES)
|
||||
listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start maintenance tasks
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
nursery.start_soon(maintain_connections, host)
|
||||
|
||||
peer_id = host.get_id().pretty()
|
||||
logger.info(f"Node peer ID: {peer_id}")
|
||||
logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}")
|
||||
|
||||
# Create and start DHT with Random Walk enabled
|
||||
dht = KadDHT(host, dht_mode, enable_random_walk=True)
|
||||
logger.info(f"Initial routing table size: {dht.get_routing_table_size()}")
|
||||
|
||||
async with background_trio_service(dht):
|
||||
logger.info(f"DHT service started in {dht_mode.value} mode")
|
||||
logger.info(f"Random Walk enabled: {dht.is_random_walk_enabled()}")
|
||||
|
||||
async with trio.open_nursery() as task_nursery:
|
||||
# Start demonstration and status reporting
|
||||
task_nursery.start_soon(
|
||||
demonstrate_random_walk_discovery, dht, demo_interval
|
||||
)
|
||||
|
||||
# Periodic status updates
|
||||
async def status_reporter():
|
||||
while True:
|
||||
await trio.sleep(30)
|
||||
logger.debug(
|
||||
f"Connected: {len(dht.host.get_connected_peers())}, "
|
||||
f"Routing table: {dht.get_routing_table_size()}, "
|
||||
f"Peerstore: {len(dht.host.get_peerstore().peer_ids())}"
|
||||
)
|
||||
|
||||
task_nursery.start_soon(status_reporter)
|
||||
await trio.sleep_forever()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Node error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Random Walk Example for py-libp2p Kademlia DHT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["server", "client"],
|
||||
default="server",
|
||||
help="Node mode: server (DHT server), or client (DHT client)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=0, help="Port to listen on (0 for random)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--demo-interval",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Interval between random walk demonstrations in seconds",
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the random walk example."""
|
||||
try:
|
||||
args = parse_args()
|
||||
setup_logging(args.verbose)
|
||||
|
||||
logger.info("=== Random Walk Example for py-libp2p ===")
|
||||
logger.info(
|
||||
f"Mode: {args.mode}, Port: {args.port} Demo interval: {args.demo_interval}s"
|
||||
)
|
||||
|
||||
trio.run(run_node, args.port, args.mode, args.demo_interval)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal, shutting down...")
|
||||
except Exception as e:
|
||||
logger.critical(f"Example failed: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
17
libp2p/discovery/random_walk/__init__.py
Normal file
17
libp2p/discovery/random_walk/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Random walk discovery modules for py-libp2p."""
|
||||
|
||||
from .rt_refresh_manager import RTRefreshManager
|
||||
from .random_walk import RandomWalk
|
||||
from .exceptions import (
|
||||
RoutingTableRefreshError,
|
||||
RandomWalkError,
|
||||
PeerValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RTRefreshManager",
|
||||
"RandomWalk",
|
||||
"RoutingTableRefreshError",
|
||||
"RandomWalkError",
|
||||
"PeerValidationError",
|
||||
]
|
||||
16
libp2p/discovery/random_walk/config.py
Normal file
16
libp2p/discovery/random_walk/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Final
|
||||
|
||||
# Timing constants (matching go-libp2p)
|
||||
PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds
|
||||
REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds
|
||||
REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes
|
||||
SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute
|
||||
|
||||
# Routing table thresholds
|
||||
MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh
|
||||
MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try
|
||||
|
||||
# Random walk specific
|
||||
RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks
|
||||
RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks
|
||||
RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback
|
||||
19
libp2p/discovery/random_walk/exceptions.py
Normal file
19
libp2p/discovery/random_walk/exceptions.py
Normal file
@ -0,0 +1,19 @@
|
||||
from libp2p.exceptions import BaseLibp2pError
|
||||
|
||||
|
||||
class RoutingTableRefreshError(BaseLibp2pError):
|
||||
"""Base exception for routing table refresh operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RandomWalkError(RoutingTableRefreshError):
|
||||
"""Exception raised during random walk operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PeerValidationError(RoutingTableRefreshError):
|
||||
"""Exception raised when peer validation fails."""
|
||||
|
||||
pass
|
||||
218
libp2p/discovery/random_walk/random_walk.py
Normal file
218
libp2p/discovery/random_walk/random_walk.py
Normal file
@ -0,0 +1,218 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_RT_THRESHOLD,
|
||||
REFRESH_QUERY_TIMEOUT,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RandomWalkError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk")
|
||||
|
||||
|
||||
class RandomWalk:
|
||||
"""
|
||||
Random Walk implementation for peer discovery in Kademlia DHT.
|
||||
|
||||
Generates random peer IDs and performs FIND_NODE queries to discover
|
||||
new peers and populate the routing table.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
):
|
||||
"""
|
||||
Initialize Random Walk module.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
def generate_random_peer_id(self) -> str:
|
||||
"""
|
||||
Generate a completely random peer ID
|
||||
for random walk queries.
|
||||
|
||||
Returns:
|
||||
Random peer ID as string
|
||||
|
||||
"""
|
||||
# Generate 32 random bytes (256 bits) - same as go-libp2p
|
||||
random_bytes = secrets.token_bytes(32)
|
||||
# Convert to hex string for query
|
||||
return random_bytes.hex()
|
||||
|
||||
async def perform_random_walk(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Perform a single random walk operation.
|
||||
|
||||
Returns:
|
||||
List of validated peers discovered during the walk
|
||||
|
||||
"""
|
||||
try:
|
||||
# Generate random peer ID
|
||||
random_peer_id = self.generate_random_peer_id()
|
||||
logger.info(f"Starting random walk for peer ID: {random_peer_id}")
|
||||
|
||||
# Perform FIND_NODE query
|
||||
discovered_peer_ids: list[ID] = []
|
||||
|
||||
with trio.move_on_after(REFRESH_QUERY_TIMEOUT):
|
||||
# Call the query function with target key bytes
|
||||
target_key = bytes.fromhex(random_peer_id)
|
||||
discovered_peer_ids = await self.query_function(target_key) or []
|
||||
|
||||
if not discovered_peer_ids:
|
||||
logger.debug(f"No peers discovered in random walk for {random_peer_id}")
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(discovered_peer_ids)} peers in random walk "
|
||||
f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity
|
||||
)
|
||||
|
||||
# Convert peer IDs to PeerInfo objects and validate
|
||||
validated_peers: list[PeerInfo] = []
|
||||
|
||||
for peer_id in discovered_peer_ids:
|
||||
try:
|
||||
# Get addresses from peerstore
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
validated_peers.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}")
|
||||
continue
|
||||
|
||||
return validated_peers
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Random walk failed: {e}")
|
||||
raise RandomWalkError(f"Random walk operation failed: {e}") from e
|
||||
|
||||
async def run_concurrent_random_walks(
|
||||
self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0
|
||||
) -> list[PeerInfo]:
|
||||
"""
|
||||
Run multiple random walks concurrently.
|
||||
|
||||
Args:
|
||||
count: Number of concurrent random walks to perform
|
||||
current_routing_table_size: Current size of routing table (for optimization)
|
||||
|
||||
Returns:
|
||||
Combined list of all validated peers discovered
|
||||
|
||||
"""
|
||||
all_validated_peers: list[PeerInfo] = []
|
||||
logger.info(f"Starting {count} concurrent random walks")
|
||||
|
||||
# First, try to add peers from peerstore if routing table is small
|
||||
if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD:
|
||||
try:
|
||||
peerstore_peers = self._get_peerstore_peers()
|
||||
if peerstore_peers:
|
||||
logger.debug(
|
||||
f"RT size ({current_routing_table_size}) below threshold, "
|
||||
f"adding {len(peerstore_peers)} peerstore peers"
|
||||
)
|
||||
all_validated_peers.extend(peerstore_peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing peerstore peers: {e}")
|
||||
|
||||
async def single_walk() -> None:
|
||||
try:
|
||||
peers = await self.perform_random_walk()
|
||||
all_validated_peers.extend(peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Concurrent random walk failed: {e}")
|
||||
return
|
||||
|
||||
# Run concurrent random walks
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(count):
|
||||
nursery.start_soon(single_walk)
|
||||
|
||||
# Remove duplicates based on peer ID
|
||||
unique_peers = {}
|
||||
for peer in all_validated_peers:
|
||||
unique_peers[peer.peer_id] = peer
|
||||
|
||||
result = list(unique_peers.values())
|
||||
logger.info(
|
||||
f"Concurrent random walks completed: {len(result)} unique peers discovered"
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_peerstore_peers(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Get peer info objects from the host's peerstore.
|
||||
|
||||
Returns:
|
||||
List of PeerInfo objects from peerstore
|
||||
|
||||
"""
|
||||
try:
|
||||
peerstore = self.host.get_peerstore()
|
||||
peer_ids = peerstore.peers_with_addrs()
|
||||
|
||||
peer_infos = []
|
||||
for peer_id in peer_ids:
|
||||
try:
|
||||
# Skip local peer
|
||||
if peer_id == self.local_peer_id:
|
||||
continue
|
||||
|
||||
peer_info = peerstore.peer_info(peer_id)
|
||||
if peer_info and peer_info.addrs:
|
||||
# Filter for compatible addresses (TCP + IPv4)
|
||||
if self._has_compatible_addresses(peer_info):
|
||||
peer_infos.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting peer info for {peer_id}: {e}")
|
||||
|
||||
return peer_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error accessing peerstore: {e}")
|
||||
return []
|
||||
|
||||
def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool:
|
||||
"""
|
||||
Check if a peer has TCP+IPv4 compatible addresses.
|
||||
|
||||
Args:
|
||||
peer_info: PeerInfo to check
|
||||
|
||||
Returns:
|
||||
True if peer has compatible addresses
|
||||
|
||||
"""
|
||||
if not peer_info.addrs:
|
||||
return False
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
addr_str = str(addr)
|
||||
# Check for TCP and IPv4 compatibility, avoid QUIC
|
||||
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
@ -0,0 +1,208 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import time
|
||||
from typing import Protocol
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_ENABLED,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class RoutingTableProtocol(Protocol):
|
||||
"""Protocol for routing table operations needed by RT refresh manager."""
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the current size of the routing table."""
|
||||
...
|
||||
|
||||
async def add_peer(self, peer_obj: PeerInfo) -> bool:
|
||||
"""Add a peer to the routing table."""
|
||||
...
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
|
||||
|
||||
|
||||
class RTRefreshManager:
|
||||
"""
|
||||
Routing Table Refresh Manager for py-libp2p.
|
||||
|
||||
Manages periodic routing table refreshes and random walk operations
|
||||
to maintain routing table health and discover new peers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
routing_table: RoutingTableProtocol,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
|
||||
refresh_interval: float = REFRESH_INTERVAL,
|
||||
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
|
||||
):
|
||||
"""
|
||||
Initialize RT Refresh Manager.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
routing_table: Routing table of host
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
enable_auto_refresh: Whether to enable automatic refresh
|
||||
refresh_interval: Interval between refreshes in seconds
|
||||
min_refresh_threshold: Minimum RT size before triggering refresh
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.routing_table = routing_table
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
self.enable_auto_refresh = enable_auto_refresh
|
||||
self.refresh_interval = refresh_interval
|
||||
self.min_refresh_threshold = min_refresh_threshold
|
||||
|
||||
# Initialize random walk module
|
||||
self.random_walk = RandomWalk(
|
||||
host=host,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=query_function,
|
||||
)
|
||||
|
||||
# Control variables
|
||||
self._running = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
# Tracking
|
||||
self._last_refresh_time = 0.0
|
||||
self._refresh_done_callbacks: list[Callable[[], None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the RT Refresh Manager."""
|
||||
if self._running:
|
||||
logger.warning("RT Refresh Manager is already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
logger.info("Starting RT Refresh Manager")
|
||||
|
||||
# Start the main loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
self._nursery = nursery
|
||||
nursery.start_soon(self._main_loop)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the RT Refresh Manager."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info("Stopping RT Refresh Manager")
|
||||
self._running = False
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
"""Main loop for the RT Refresh Manager."""
|
||||
logger.info("RT Refresh Manager main loop started")
|
||||
|
||||
# Initial refresh if auto-refresh is enabled
|
||||
if self.enable_auto_refresh:
|
||||
await self._do_refresh(force=True)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Schedule periodic refresh if enabled
|
||||
if self.enable_auto_refresh:
|
||||
nursery.start_soon(self._periodic_refresh_task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RT Refresh Manager main loop error: {e}")
|
||||
finally:
|
||||
logger.info("RT Refresh Manager main loop stopped")
|
||||
|
||||
async def _periodic_refresh_task(self) -> None:
|
||||
"""Task for periodic refreshes."""
|
||||
while self._running:
|
||||
await trio.sleep(self.refresh_interval)
|
||||
if self._running:
|
||||
await self._do_refresh()
|
||||
|
||||
async def _do_refresh(self, force: bool = False) -> None:
|
||||
"""
|
||||
Perform routing table refresh operation.
|
||||
|
||||
Args:
|
||||
force: Whether to force refresh regardless of timing
|
||||
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Check if refresh is needed
|
||||
if not force:
|
||||
if current_time - self._last_refresh_time < self.refresh_interval:
|
||||
logger.debug("Skipping refresh: interval not elapsed")
|
||||
return
|
||||
|
||||
if self.routing_table.size() >= self.min_refresh_threshold:
|
||||
logger.debug("Skipping refresh: routing table size above threshold")
|
||||
return
|
||||
|
||||
logger.info(f"Starting routing table refresh (force={force})")
|
||||
start_time = current_time
|
||||
|
||||
# Perform random walks to discover new peers
|
||||
logger.info("Running concurrent random walks to discover new peers")
|
||||
current_rt_size = self.routing_table.size()
|
||||
discovered_peers = await self.random_walk.run_concurrent_random_walks(
|
||||
count=RANDOM_WALK_CONCURRENCY,
|
||||
current_routing_table_size=current_rt_size,
|
||||
)
|
||||
|
||||
# Add discovered peers to routing table
|
||||
added_count = 0
|
||||
for peer_info in discovered_peers:
|
||||
result = await self.routing_table.add_peer(peer_info)
|
||||
if result:
|
||||
added_count += 1
|
||||
|
||||
self._last_refresh_time = current_time
|
||||
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
f"Routing table refresh completed: "
|
||||
f"{added_count}/{len(discovered_peers)} peers added, "
|
||||
f"RT size: {self.routing_table.size()}, "
|
||||
f"duration: {duration:.2f}s"
|
||||
)
|
||||
|
||||
# Notify refresh completion
|
||||
for callback in self._refresh_done_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.warning(f"Refresh callback error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Routing table refresh failed: {e}")
|
||||
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
|
||||
|
||||
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Add a callback to be called when refresh completes."""
|
||||
self._refresh_done_callbacks.append(callback)
|
||||
|
||||
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Remove a refresh completion callback."""
|
||||
if callback in self._refresh_done_callbacks:
|
||||
self._refresh_done_callbacks.remove(callback)
|
||||
@ -295,6 +295,13 @@ class BasicHost(IHost):
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
if protocol is None:
|
||||
logger.debug(
|
||||
"no protocol negotiated, closing stream from peer %s",
|
||||
net_stream.muxed_conn.peer_id,
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
net_stream.set_protocol(protocol)
|
||||
if handler is None:
|
||||
logger.debug(
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT)
|
||||
implementation based on the Kademlia algorithm and protocol.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from enum import (
|
||||
Enum,
|
||||
)
|
||||
@ -20,6 +21,7 @@ import varint
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
@ -73,14 +75,27 @@ class KadDHT(Service):
|
||||
|
||||
This class provides a DHT implementation that combines routing table management,
|
||||
peer discovery, content routing, and value storage.
|
||||
|
||||
Optional Random Walk feature enhances peer discovery by automatically
|
||||
performing periodic random queries to discover new peers and maintain
|
||||
routing table health.
|
||||
|
||||
Example:
|
||||
# Basic DHT without random walk (default)
|
||||
dht = KadDHT(host, DHTMode.SERVER)
|
||||
|
||||
# DHT with random walk enabled for enhanced peer discovery
|
||||
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, mode: DHTMode):
|
||||
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
|
||||
"""
|
||||
Initialize a new Kademlia DHT node.
|
||||
|
||||
:param host: The libp2p host.
|
||||
:param mode: The mode of host (Client or Server) - must be DHTMode enum
|
||||
:param enable_random_walk: Whether to enable automatic random walk
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -92,6 +107,7 @@ class KadDHT(Service):
|
||||
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
|
||||
|
||||
self.mode = mode
|
||||
self.enable_random_walk = enable_random_walk
|
||||
|
||||
# Initialize the routing table
|
||||
self.routing_table = RoutingTable(self.local_peer_id, self.host)
|
||||
@ -108,13 +124,56 @@ class KadDHT(Service):
|
||||
# Last time we republished provider records
|
||||
self._last_provider_republish = time.time()
|
||||
|
||||
# Initialize RT Refresh Manager (only if random walk is enabled)
|
||||
self.rt_refresh_manager: RTRefreshManager | None = None
|
||||
if self.enable_random_walk:
|
||||
self.rt_refresh_manager = RTRefreshManager(
|
||||
host=self.host,
|
||||
routing_table=self.routing_table,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=self._create_query_function(),
|
||||
enable_auto_refresh=True,
|
||||
)
|
||||
|
||||
# Set protocol handlers
|
||||
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
|
||||
|
||||
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
|
||||
"""
|
||||
Create a query function that wraps peer_routing.find_closest_peers_network.
|
||||
|
||||
This function is used by the RandomWalk module to query for peers without
|
||||
directly importing PeerRouting, avoiding circular import issues.
|
||||
|
||||
Returns:
|
||||
Callable that takes target_key bytes and returns list of peer IDs
|
||||
|
||||
"""
|
||||
|
||||
async def query_function(target_key: bytes) -> list[ID]:
|
||||
"""Query for closest peers to target key."""
|
||||
return await self.peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
return query_function
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the DHT service."""
|
||||
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
|
||||
|
||||
# Start the RT Refresh Manager in parallel with the main DHT service
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start the RT Refresh Manager only if random walk is enabled
|
||||
if self.rt_refresh_manager is not None:
|
||||
nursery.start_soon(self.rt_refresh_manager.start)
|
||||
logger.info("RT Refresh Manager started - Random Walk is now active")
|
||||
else:
|
||||
logger.info("Random Walk is disabled - RT Refresh Manager not started")
|
||||
|
||||
# Start the main DHT service loop
|
||||
nursery.start_soon(self._run_main_loop)
|
||||
|
||||
async def _run_main_loop(self) -> None:
|
||||
"""Run the main DHT service loop."""
|
||||
# Main service loop
|
||||
while self.manager.is_running:
|
||||
# Periodically refresh the routing table
|
||||
@ -135,6 +194,17 @@ class KadDHT(Service):
|
||||
# Wait before next maintenance cycle
|
||||
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DHT service and cleanup resources."""
|
||||
logger.info("Stopping Kademlia DHT")
|
||||
|
||||
# Stop the RT Refresh Manager only if it was started
|
||||
if self.rt_refresh_manager is not None:
|
||||
await self.rt_refresh_manager.stop()
|
||||
logger.info("RT Refresh Manager stopped")
|
||||
else:
|
||||
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
|
||||
|
||||
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
|
||||
"""
|
||||
Switch the DHT mode.
|
||||
@ -614,3 +684,15 @@ class KadDHT(Service):
|
||||
|
||||
"""
|
||||
return self.value_store.size()
|
||||
|
||||
def is_random_walk_enabled(self) -> bool:
|
||||
"""
|
||||
Check if random walk peer discovery is enabled.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if random walk is enabled, False otherwise.
|
||||
|
||||
"""
|
||||
return self.enable_random_walk
|
||||
|
||||
@ -170,7 +170,7 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Return early if we have no peers to start with
|
||||
if not closest_peers:
|
||||
logger.warning("No local peers available for network lookup")
|
||||
logger.debug("No local peers available for network lookup")
|
||||
return []
|
||||
|
||||
# Iterative lookup until convergence
|
||||
|
||||
@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
|
||||
"""
|
||||
self.handlers[protocol] = handler
|
||||
|
||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||
async def negotiate(
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
|
||||
"""
|
||||
Negotiate performs protocol selection.
|
||||
|
||||
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
protocol_to_check = None if not command else TProtocol(command)
|
||||
if protocol_to_check in self.handlers:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(command)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
return protocol_to_check, self.handlers[protocol_to_check]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
|
||||
@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
:return: selected protocol
|
||||
"""
|
||||
# Represent `None` protocol as an empty string.
|
||||
protocol_str = protocol if protocol is not None else ""
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(protocol_str)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
if response == protocol:
|
||||
if response == protocol_str:
|
||||
return protocol
|
||||
if response == PROTOCOL_NOT_FOUND_MSG:
|
||||
raise MultiselectClientError("protocol not supported")
|
||||
|
||||
@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
||||
""" # noqa: E501
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
if msg_str is None:
|
||||
msg_bytes = encode_delim(b"")
|
||||
else:
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
|
||||
:param is_initiator: true if we are the initiator, false otherwise
|
||||
:return: selected secure transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if is_initiator:
|
||||
# Select protocol if initiator
|
||||
@ -114,5 +117,7 @@ class SecurityMultistream(ABC):
|
||||
else:
|
||||
# Select protocol if non-initiator
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a security protocol")
|
||||
# Return transport from protocol
|
||||
return self.transports[protocol]
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -73,7 +76,7 @@ class MuxerMultistream:
|
||||
:param conn: conn to choose a transport over
|
||||
:return: selected muxer transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
@ -81,6 +84,8 @@ class MuxerMultistream:
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a stream muxer protocol")
|
||||
return self.transports[protocol]
|
||||
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
|
||||
1
newsfragments/770.internal.rst
Normal file
1
newsfragments/770.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py
|
||||
1
newsfragments/822.feature.rst
Normal file
1
newsfragments/822.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Added `Random Walk` peer discovery module that enables random peer exploration for improved peer discovery.
|
||||
@ -11,9 +11,9 @@ requires-python = ">=3.10, <4.0"
|
||||
license = { text = "MIT AND Apache-2.0" }
|
||||
keywords = ["libp2p", "p2p"]
|
||||
maintainers = [
|
||||
{ name = "pacrob", email = "pacrob@protonmail.com" },
|
||||
{ name = "pacrob", email = "pacrob-py-libp2p@proton.me" },
|
||||
{ name = "Manu Sheel Gupta", email = "manu@seeta.in" },
|
||||
{ name = "Dave Grantham", email = "dave@aviation.community" },
|
||||
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
|
||||
]
|
||||
dependencies = [
|
||||
"base58>=1.0.3",
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from collections import deque
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
from libp2p.abc import IMultiselectCommunicator, INetStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||
|
||||
|
||||
async def dummy_handler(stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
||||
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_one_of_timeout():
|
||||
async def test_select_one_of_timeout() -> None:
|
||||
ECHO = TProtocol("/echo/1.0.0")
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
|
||||
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_multistream_command_timeout():
|
||||
async def test_query_multistream_command_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
client = MultiselectClient()
|
||||
|
||||
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout():
|
||||
async def test_negotiate_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
server = Multiselect()
|
||||
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 2)
|
||||
|
||||
|
||||
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
|
||||
handshaked: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handshaked = False
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
if msg_str == "/multistream/1.0.0":
|
||||
self.handshaked = True
|
||||
return
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.handshaked:
|
||||
return "/multistream/1.0.0"
|
||||
# After handshake, hang on read.
|
||||
await trio.sleep_forever()
|
||||
# Should not be reached.
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout_post_handshake() -> None:
|
||||
communicator = HandshakeThenHangCommunicator()
|
||||
server = Multiselect()
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 1)
|
||||
|
||||
|
||||
class MockCommunicator(IMultiselectCommunicator):
|
||||
def __init__(self, commands_to_read: list[str]):
|
||||
self.read_queue = deque(commands_to_read)
|
||||
self.written_data: list[str] = []
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
self.written_data.append(msg_str)
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.read_queue:
|
||||
raise EOFError
|
||||
return self.read_queue.popleft()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_empty_string_command() -> None:
|
||||
# server receives an empty string, which means client wants `None` protocol.
|
||||
server = Multiselect({None: dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check that server sent back handshake and the protocol confirmation (empty string)
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler() -> None:
|
||||
# server has None handler, client sends "" to select it.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, protocol confirmation
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler_ls() -> None:
|
||||
# server has None handler, client sends "ls" then empty string.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, ls, empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, ls response, protocol confirmation
|
||||
assert communicator.written_data[0] == "/multistream/1.0.0"
|
||||
assert "/proto1" in communicator.written_data[1]
|
||||
# Note: `ls` should not list the `None` protocol.
|
||||
assert "None" not in communicator.written_data[1]
|
||||
assert "\n\n" not in communicator.written_data[1]
|
||||
assert communicator.written_data[2] == ""
|
||||
|
||||
@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
|
||||
protocols = ms.get_protocols()
|
||||
|
||||
assert set(protocols) == {p1, p2, p3}
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol(security_protocol):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(
|
||||
None,
|
||||
[None],
|
||||
[None],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
[None, PROTOCOL_ECHO],
|
||||
[PROTOCOL_ECHO],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_server_none_client_other(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)
|
||||
|
||||
99
tests/discovery/random_walk/test_random_walk.py
Normal file
99
tests/discovery/random_walk/test_random_walk.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""
|
||||
Unit tests for the RandomWalk module in libp2p.discovery.random_walk.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host():
|
||||
host = Mock()
|
||||
peerstore = Mock()
|
||||
peerstore.peers_with_addrs.return_value = []
|
||||
peerstore.addrs.return_value = [Mock()]
|
||||
host.get_peerstore.return_value = peerstore
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_query_function():
|
||||
async def query(key_bytes):
|
||||
return []
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_peer_id():
|
||||
return b"\x01" * 32
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_random_walk_initialization(
|
||||
mock_host, dummy_peer_id, dummy_query_function
|
||||
):
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
|
||||
assert rw.host == mock_host
|
||||
assert rw.local_peer_id == dummy_peer_id
|
||||
assert rw.query_function == dummy_query_function
|
||||
|
||||
|
||||
def test_generate_random_peer_id(mock_host, dummy_peer_id, dummy_query_function):
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
|
||||
peer_id = rw.generate_random_peer_id()
|
||||
assert isinstance(peer_id, str)
|
||||
assert len(peer_id) == 64 # 32 bytes hex
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_run_concurrent_random_walks(mock_host, dummy_peer_id):
|
||||
# Dummy query function returns different peer IDs for each walk
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def query(key_bytes):
|
||||
call_count["count"] += 1
|
||||
# Return a unique peer ID for each call
|
||||
return [ID(bytes([call_count["count"]] * 32))]
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.run_concurrent_random_walks(count=3)
|
||||
# Should get 3 unique peers
|
||||
assert len(peers) == 3
|
||||
peer_ids = [peer.peer_id for peer in peers]
|
||||
assert len(set(peer_ids)) == 3
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_perform_random_walk_running(mock_host, dummy_peer_id):
|
||||
# Query function returns a single peer ID
|
||||
async def query(key_bytes):
|
||||
return [ID(b"\x02" * 32)]
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.perform_random_walk()
|
||||
assert isinstance(peers, list)
|
||||
if peers:
|
||||
assert isinstance(peers[0], PeerInfo)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_perform_random_walk_no_peers_found(mock_host, dummy_peer_id):
|
||||
"""Test perform_random_walk when no peers are discovered."""
|
||||
|
||||
# Query function returns empty list (no peers found)
|
||||
async def query(key_bytes):
|
||||
return []
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.perform_random_walk()
|
||||
|
||||
# Should return empty list when no peers are found
|
||||
assert isinstance(peers, list)
|
||||
assert len(peers) == 0
|
||||
451
tests/discovery/random_walk/test_rt_refresh_manager.py
Normal file
451
tests/discovery/random_walk/test_rt_refresh_manager.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""
|
||||
Unit tests for the RTRefreshManager and related random walk logic.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import (
|
||||
RandomWalkError,
|
||||
)
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class DummyRoutingTable:
|
||||
def __init__(self, size=0):
|
||||
self._size = size
|
||||
self.added_peers = []
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
async def add_peer(self, peer_obj):
|
||||
self.added_peers.append(peer_obj)
|
||||
self._size += 1
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host():
|
||||
host = Mock()
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_peer_id():
|
||||
return ID(b"\x01" * 32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_query_function():
|
||||
async def query(key_bytes):
|
||||
return [ID(b"\x02" * 32)]
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_initialization(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
rt = DummyRoutingTable(size=5)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
assert manager.host == mock_host
|
||||
assert manager.routing_table == rt
|
||||
assert manager.local_peer_id == local_peer_id
|
||||
assert manager.query_function == dummy_query_function
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_refresh_logic(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
rt = DummyRoutingTable(size=2)
|
||||
# Simulate refresh logic
|
||||
if rt.size() < MIN_RT_REFRESH_THRESHOLD:
|
||||
await rt.add_peer(Mock())
|
||||
assert rt.size() >= 3
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_random_walk_integration(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
# Simulate random walk usage
|
||||
rw = RandomWalk(mock_host, local_peer_id, dummy_query_function)
|
||||
random_peer_id = rw.generate_random_peer_id()
|
||||
assert isinstance(random_peer_id, str)
|
||||
assert len(random_peer_id) == 64
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_error_handling(mock_host, local_peer_id):
|
||||
rt = DummyRoutingTable(size=0)
|
||||
|
||||
async def failing_query(_):
|
||||
raise RandomWalkError("Query failed")
|
||||
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=failing_query,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
with pytest.raises(RandomWalkError):
|
||||
await manager.query_function(b"key")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_start_method(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the start method functionality."""
|
||||
rt = DummyRoutingTable(size=2)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=False, # Disable auto-refresh to control the test
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
):
|
||||
# Test starting the manager
|
||||
assert not manager._running
|
||||
|
||||
# Start the manager in a nursery that we can control
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager.start)
|
||||
await trio.sleep(0.01) # Let it start
|
||||
|
||||
# Verify it's running
|
||||
assert manager._running
|
||||
|
||||
# Stop the manager
|
||||
await manager.stop()
|
||||
assert not manager._running
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_with_auto_refresh(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the _main_loop method with auto-refresh enabled."""
|
||||
rt = DummyRoutingTable(size=1) # Small size to trigger refresh
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
) as mock_random_walk:
|
||||
manager._running = True
|
||||
|
||||
# Run the main loop for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._main_loop)
|
||||
await trio.sleep(0.05) # Let it run briefly
|
||||
manager._running = False # Stop the loop
|
||||
|
||||
# Verify that random walk was called (initial refresh)
|
||||
mock_random_walk.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_without_auto_refresh(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the _main_loop method with auto-refresh disabled."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=False,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
manager._running = True
|
||||
|
||||
# Run the main loop for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._main_loop)
|
||||
await trio.sleep(0.05)
|
||||
manager._running = False
|
||||
|
||||
# Verify that random walk was not called since auto-refresh is disabled
|
||||
mock_random_walk.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_initial_refresh_exception(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test that _main_loop propagates exceptions from initial refresh."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock _do_refresh to raise an exception on the initial call
|
||||
with patch.object(
|
||||
manager, "_do_refresh", side_effect=Exception("Initial refresh failed")
|
||||
):
|
||||
manager._running = True
|
||||
|
||||
# The initial refresh exception should propagate
|
||||
with pytest.raises(Exception, match="Initial refresh failed"):
|
||||
await manager._main_loop()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_force_refresh(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test _do_refresh method with force=True."""
|
||||
rt = DummyRoutingTable(size=10) # Large size, but force should override
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info1 = Mock(spec=PeerInfo)
|
||||
mock_peer_info2 = Mock(spec=PeerInfo)
|
||||
discovered_peers = [mock_peer_info1, mock_peer_info2]
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=discovered_peers,
|
||||
) as mock_random_walk:
|
||||
# Force refresh should work regardless of RT size
|
||||
await manager._do_refresh(force=True)
|
||||
|
||||
# Verify random walk was called
|
||||
mock_random_walk.assert_called_once_with(
|
||||
count=RANDOM_WALK_CONCURRENCY, current_routing_table_size=10
|
||||
)
|
||||
|
||||
# Verify peers were added to routing table
|
||||
assert len(rt.added_peers) == 2
|
||||
assert manager._last_refresh_time > 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_skip_due_to_interval(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test _do_refresh skips refresh when interval hasn't elapsed."""
|
||||
rt = DummyRoutingTable(size=1) # Small size to trigger refresh normally
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=100.0, # Long interval
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Set last refresh time to recent
|
||||
manager._last_refresh_time = time.time()
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=False)
|
||||
|
||||
# Verify refresh was skipped
|
||||
mock_random_walk.assert_not_called()
|
||||
mock_logger.debug.assert_called_with(
|
||||
"Skipping refresh: interval not elapsed"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_skip_due_to_rt_size(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test _do_refresh skips refresh when RT size is above threshold."""
|
||||
rt = DummyRoutingTable(size=20) # Large size above threshold
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1, # Short interval
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Set last refresh time to old
|
||||
manager._last_refresh_time = 0.0
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=False)
|
||||
|
||||
# Verify refresh was skipped
|
||||
mock_random_walk.assert_not_called()
|
||||
mock_logger.debug.assert_called_with(
|
||||
"Skipping refresh: routing table size above threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_refresh_done_callbacks(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test refresh completion callbacks functionality."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Create mock callbacks
|
||||
callback1 = Mock()
|
||||
callback2 = Mock()
|
||||
failing_callback = Mock(side_effect=Exception("Callback failed"))
|
||||
|
||||
# Add callbacks
|
||||
manager.add_refresh_done_callback(callback1)
|
||||
manager.add_refresh_done_callback(callback2)
|
||||
manager.add_refresh_done_callback(failing_callback)
|
||||
|
||||
# Mock the random walk
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
):
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=True)
|
||||
|
||||
# Verify all callbacks were called
|
||||
callback1.assert_called_once()
|
||||
callback2.assert_called_once()
|
||||
failing_callback.assert_called_once()
|
||||
|
||||
# Verify warning was logged for failing callback
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stop_when_not_running(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test stop method when manager is not running."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Manager is not running
|
||||
assert not manager._running
|
||||
|
||||
# Stop should return without doing anything
|
||||
await manager.stop()
|
||||
assert not manager._running
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_periodic_refresh_task(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test the _periodic_refresh_task method."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.05, # Very short interval for testing
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock _do_refresh to track calls
|
||||
with patch.object(manager, "_do_refresh") as mock_do_refresh:
|
||||
manager._running = True
|
||||
|
||||
# Run periodic refresh task for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._periodic_refresh_task)
|
||||
await trio.sleep(0.12) # Let it run for ~2 intervals
|
||||
manager._running = False # Stop the task
|
||||
|
||||
# Verify _do_refresh was called at least once
|
||||
assert mock_do_refresh.call_count >= 1
|
||||
Reference in New Issue
Block a user