mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge remote-tracking branch 'origin/main' into fix_pubsub_msg_id_type_inconsistency
This commit is contained in:
131
docs/examples.random_walk.rst
Normal file
131
docs/examples.random_walk.rst
Normal file
@ -0,0 +1,131 @@
|
||||
Random Walk Example
|
||||
===================
|
||||
|
||||
This example demonstrates the Random Walk module's peer discovery capabilities using real libp2p hosts and Kademlia DHT.
|
||||
It shows how the Random Walk module automatically discovers new peers and maintains routing table health.
|
||||
|
||||
The Random Walk implementation performs the following key operations:
|
||||
|
||||
* **Automatic Peer Discovery**: Generates random peer IDs and queries the DHT network to discover new peers
|
||||
* **Routing Table Maintenance**: Periodically refreshes the routing table to maintain network connectivity
|
||||
* **Connection Management**: Maintains optimal connections to healthy peers in the network
|
||||
* **Real-time Statistics**: Displays routing table size, connected peers, and peerstore statistics
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
Collecting libp2p
|
||||
...
|
||||
Successfully installed libp2p-x.x.x
|
||||
$ cd examples/random_walk
|
||||
$ python random_walk.py --mode server
|
||||
2025-08-12 19:51:25,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
|
||||
2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0
|
||||
2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode
|
||||
2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - --- Iteration 1 ---
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Routing table size: 15
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Connected peers: 8
|
||||
2025-08-12 19:51:55,432 - random-walk-example - INFO - Peerstore size: 42
|
||||
|
||||
You can also run the example in client mode:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python random_walk.py --mode client
|
||||
2025-08-12 19:52:15,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p ===
|
||||
2025-08-12 19:52:15,424 - random-walk-example - INFO - Mode: client, Port: 0 Demo interval: 30s
|
||||
2025-08-12 19:52:15,426 - random-walk-example - INFO - Starting client node on port 51234
|
||||
2025-08-12 19:52:15,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAmAbc123xyz...
|
||||
2025-08-12 19:52:15,427 - random-walk-example - INFO - DHT service started in CLIENT mode
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - --- Iteration 1 ---
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Routing table size: 8
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Connected peers: 5
|
||||
2025-08-12 19:52:45,432 - random-walk-example - INFO - Peerstore size: 25
|
||||
|
||||
Command Line Options
|
||||
--------------------
|
||||
|
||||
The example supports several command-line options:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python random_walk.py --help
|
||||
usage: random_walk.py [-h] [--mode {server,client}] [--port PORT]
|
||||
[--demo-interval DEMO_INTERVAL] [--verbose]
|
||||
|
||||
Random Walk Example for py-libp2p Kademlia DHT
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--mode {server,client}
|
||||
Node mode: server (DHT server), or client (DHT client)
|
||||
--port PORT Port to listen on (0 for random)
|
||||
--demo-interval DEMO_INTERVAL
|
||||
Interval between random walk demonstrations in seconds
|
||||
--verbose Enable verbose logging
|
||||
|
||||
Key Features Demonstrated
|
||||
-------------------------
|
||||
|
||||
**Automatic Random Walk Discovery**
|
||||
The example shows how the Random Walk module automatically:
|
||||
|
||||
* Generates random 256-bit peer IDs for discovery queries
|
||||
* Performs concurrent random walks to maximize peer discovery
|
||||
* Validates discovered peers and adds them to the routing table
|
||||
* Maintains routing table health through periodic refreshes
|
||||
|
||||
**Real-time Network Statistics**
|
||||
The example displays live statistics every 30 seconds (configurable):
|
||||
|
||||
* **Routing Table Size**: Number of peers in the Kademlia routing table
|
||||
* **Connected Peers**: Number of actively connected peers
|
||||
* **Peerstore Size**: Total number of known peers with addresses
|
||||
|
||||
**Connection Management**
|
||||
The example includes sophisticated connection management:
|
||||
|
||||
* Automatically maintains connections to healthy peers
|
||||
* Filters for compatible peers (TCP + IPv4 addresses)
|
||||
* Reconnects to maintain optimal network connectivity
|
||||
* Handles connection failures gracefully
|
||||
|
||||
**DHT Integration**
|
||||
Shows seamless integration between Random Walk and Kademlia DHT:
|
||||
|
||||
* RT Refresh Manager coordinates with the DHT routing table
|
||||
* Peer discovery feeds directly into DHT operations
|
||||
* Both SERVER and CLIENT modes supported
|
||||
* Bootstrap connectivity to public IPFS nodes
|
||||
|
||||
Understanding the Output
|
||||
------------------------
|
||||
|
||||
When you run the example, you'll see periodic statistics that show how the Random Walk module is working:
|
||||
|
||||
* **Initial Phase**: Routing table starts empty and quickly discovers peers
|
||||
* **Growth Phase**: Routing table size increases as more peers are discovered
|
||||
* **Maintenance Phase**: Routing table size stabilizes as the system maintains optimal peer connections
|
||||
|
||||
The Random Walk module runs automatically in the background, performing peer discovery queries every few minutes to ensure the routing table remains populated with fresh, reachable peers.
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
|
||||
The Random Walk module can be configured through the following parameters in ``libp2p.discovery.random_walk.config``:
|
||||
|
||||
* ``RANDOM_WALK_ENABLED``: Enable/disable automatic random walks (default: True)
|
||||
* ``REFRESH_INTERVAL``: Time between automatic refreshes in seconds (default: 300)
|
||||
* ``RANDOM_WALK_CONCURRENCY``: Number of concurrent random walks (default: 3)
|
||||
* ``MIN_RT_REFRESH_THRESHOLD``: Minimum routing table size before triggering refresh (default: 4)
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
* :doc:`examples.kademlia` - Kademlia DHT value storage and content routing
|
||||
* :doc:`libp2p.discovery.random_walk` - Random Walk module API documentation
|
||||
@ -14,3 +14,4 @@ Examples
|
||||
examples.circuit_relay
|
||||
examples.kademlia
|
||||
examples.mDNS
|
||||
examples.random_walk
|
||||
|
||||
48
docs/libp2p.discovery.random_walk.rst
Normal file
48
docs/libp2p.discovery.random_walk.rst
Normal file
@ -0,0 +1,48 @@
|
||||
libp2p.discovery.random_walk package
|
||||
====================================
|
||||
|
||||
The Random Walk module implements a peer discovery mechanism.
|
||||
It performs random walks through the DHT network to discover new peers and maintain routing table health through periodic refreshes.
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.discovery.random_walk.config module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.config
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.exceptions module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.exceptions
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.random_walk module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.random_walk
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.discovery.random_walk.rt_refresh_manager module
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk.rt_refresh_manager
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.discovery.random_walk
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -10,6 +10,7 @@ Subpackages
|
||||
libp2p.discovery.bootstrap
|
||||
libp2p.discovery.events
|
||||
libp2p.discovery.mdns
|
||||
libp2p.discovery.random_walk
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
63
examples/advanced/network_discover.py
Normal file
63
examples/advanced/network_discover.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Advanced demonstration of Thin Waist address handling.
|
||||
|
||||
Run:
|
||||
python -m examples.advanced.network_discovery
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
try:
|
||||
from libp2p.utils.address_validation import (
|
||||
expand_wildcard_address,
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
)
|
||||
except ImportError:
|
||||
# Fallbacks if utilities are missing
|
||||
def get_available_interfaces(port: int, protocol: str = "tcp"):
|
||||
return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")]
|
||||
|
||||
def expand_wildcard_address(addr: Multiaddr, port: int | None = None):
|
||||
if port is None:
|
||||
return [addr]
|
||||
addr_str = str(addr).rsplit("/", 1)[0]
|
||||
return [Multiaddr(addr_str + f"/{port}")]
|
||||
|
||||
def get_optimal_binding_address(port: int, protocol: str = "tcp"):
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
port = 8080
|
||||
interfaces = get_available_interfaces(port)
|
||||
print(f"Discovered interfaces for port {port}:")
|
||||
for a in interfaces:
|
||||
print(f" - {a}")
|
||||
|
||||
wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
expanded_v4 = expand_wildcard_address(wildcard_v4)
|
||||
print("\nExpanded IPv4 wildcard:")
|
||||
for a in expanded_v4:
|
||||
print(f" - {a}")
|
||||
|
||||
wildcard_v6 = Multiaddr(f"/ip6/::/tcp/{port}")
|
||||
expanded_v6 = expand_wildcard_address(wildcard_v6)
|
||||
print("\nExpanded IPv6 wildcard:")
|
||||
for a in expanded_v6:
|
||||
print(f" - {a}")
|
||||
|
||||
print("\nOptimal binding address heuristic result:")
|
||||
print(f" -> {get_optimal_binding_address(port)}")
|
||||
|
||||
override_port = 9000
|
||||
overridden = expand_wildcard_address(wildcard_v4, port=override_port)
|
||||
print(f"\nPort override expansion to {override_port}:")
|
||||
for a in overridden:
|
||||
print(f" - {a}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,4 +1,6 @@
|
||||
import argparse
|
||||
import random
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
@ -12,40 +14,54 @@ from libp2p.crypto.secp256k1 import (
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamEOF,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
from libp2p.utils.address_validation import (
|
||||
find_free_port,
|
||||
get_available_interfaces,
|
||||
)
|
||||
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
MAX_READ_LEN = 2**32 - 1
|
||||
|
||||
|
||||
async def _echo_stream_handler(stream: INetStream) -> None:
|
||||
# Wait until EOF
|
||||
msg = await stream.read(MAX_READ_LEN)
|
||||
await stream.write(msg)
|
||||
await stream.close()
|
||||
try:
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print(f"Received connection from {peer_id}")
|
||||
# Wait until EOF
|
||||
msg = await stream.read(MAX_READ_LEN)
|
||||
print(f"Echoing message: {msg.decode('utf-8')}")
|
||||
await stream.write(msg)
|
||||
except StreamEOF:
|
||||
print("Stream closed by remote peer.")
|
||||
except Exception as e:
|
||||
print(f"Error in echo handler: {e}")
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
|
||||
async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
if port <= 0:
|
||||
port = find_free_port()
|
||||
listen_addr = get_available_interfaces(port)
|
||||
|
||||
if seed:
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
secret_number = random.getrandbits(32 * 8)
|
||||
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||
else:
|
||||
import secrets
|
||||
|
||||
secret = secrets.token_bytes(32)
|
||||
|
||||
host = new_host(key_pair=create_new_key_pair(secret))
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
async with host.run(listen_addrs=listen_addr), trio.open_nursery() as nursery:
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
@ -54,10 +70,15 @@ async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||
if not destination: # its the server
|
||||
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
|
||||
|
||||
# Print all listen addresses with peer ID (JS parity)
|
||||
print("Listener ready, listening on:\n")
|
||||
peer_id = host.get_id().to_string()
|
||||
for addr in listen_addr:
|
||||
print(f"{addr}/p2p/{peer_id}")
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console:\n\n"
|
||||
f"echo-demo "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
"\nRun this from the same folder in another console:\n\n"
|
||||
f"echo-demo -d {host.get_addrs()[0]}\n"
|
||||
)
|
||||
print("Waiting for incoming connections...")
|
||||
await trio.sleep_forever()
|
||||
|
||||
@ -227,7 +227,7 @@ async def run_node(
|
||||
|
||||
# Keep the node running
|
||||
while True:
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"Status - Connected peers: %d,"
|
||||
"Peers in store: %d, Values in store: %d",
|
||||
len(dht.host.get_connected_peers()),
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import base58
|
||||
import multiaddr
|
||||
@ -31,6 +30,9 @@ from libp2p.stream_muxer.mplex.mplex import (
|
||||
from libp2p.tools.async_service.trio_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.utils.address_validation import (
|
||||
find_free_port,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -77,13 +79,6 @@ async def publish_loop(pubsub, topic, termination_event):
|
||||
await trio.sleep(1) # Avoid tight loop on error
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to a free port provided by the OS
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
async def monitor_peer_topics(pubsub, nursery, termination_event):
|
||||
"""
|
||||
Monitor for new topics that peers are subscribed to and
|
||||
|
||||
221
examples/random_walk/random_walk.py
Normal file
221
examples/random_walk/random_walk.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""
|
||||
Random Walk Example for py-libp2p Kademlia DHT
|
||||
|
||||
This example demonstrates the Random Walk module's peer discovery capabilities
|
||||
using real libp2p hosts and Kademlia DHT. It shows how the Random Walk module
|
||||
automatically discovers new peers and maintains routing table health.
|
||||
|
||||
Usage:
|
||||
# Start server nodes (they will discover peers via random walk)
|
||||
python3 random_walk.py --mode server
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.kad_dht.kad_dht import DHTMode, KadDHT
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
|
||||
|
||||
# Simple logging configuration
|
||||
def setup_logging(verbose: bool = False):
|
||||
"""Setup unified logging configuration."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
|
||||
# Configure key module loggers
|
||||
for module in ["libp2p.discovery.random_walk", "libp2p.kad_dht"]:
|
||||
logging.getLogger(module).setLevel(level)
|
||||
|
||||
# Suppress noisy logs
|
||||
logging.getLogger("multiaddr").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
logger = logging.getLogger("random-walk-example")
|
||||
|
||||
# Default bootstrap nodes
|
||||
DEFAULT_BOOTSTRAP_NODES = [
|
||||
"/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ"
|
||||
]
|
||||
|
||||
|
||||
def filter_compatible_peer_info(peer_info) -> bool:
|
||||
"""Filter peer info to check if it has compatible addresses (TCP + IPv4)."""
|
||||
if not hasattr(peer_info, "addrs") or not peer_info.addrs:
|
||||
return False
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
addr_str = str(addr)
|
||||
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def maintain_connections(host: IHost) -> None:
|
||||
"""Maintain connections to ensure the host remains connected to healthy peers."""
|
||||
while True:
|
||||
try:
|
||||
connected_peers = host.get_connected_peers()
|
||||
list_peers = host.get_peerstore().peers_with_addrs()
|
||||
|
||||
if len(connected_peers) < 20:
|
||||
logger.debug("Reconnecting to maintain peer connections...")
|
||||
|
||||
# Find compatible peers
|
||||
compatible_peers = []
|
||||
for peer_id in list_peers:
|
||||
try:
|
||||
peer_info = host.get_peerstore().peer_info(peer_id)
|
||||
if filter_compatible_peer_info(peer_info):
|
||||
compatible_peers.append(peer_id)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Connect to random subset of compatible peers
|
||||
if compatible_peers:
|
||||
random_peers = random.sample(
|
||||
compatible_peers, min(50, len(compatible_peers))
|
||||
)
|
||||
for peer_id in random_peers:
|
||||
if peer_id not in connected_peers:
|
||||
try:
|
||||
with trio.move_on_after(5):
|
||||
peer_info = host.get_peerstore().peer_info(peer_id)
|
||||
await host.connect(peer_info)
|
||||
logger.debug(f"Connected to peer: {peer_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to connect to {peer_id}: {e}")
|
||||
|
||||
await trio.sleep(15)
|
||||
except Exception as e:
|
||||
logger.error(f"Error maintaining connections: {e}")
|
||||
|
||||
|
||||
async def demonstrate_random_walk_discovery(dht: KadDHT, interval: int = 30) -> None:
|
||||
"""Demonstrate Random Walk peer discovery with periodic statistics."""
|
||||
iteration = 0
|
||||
while True:
|
||||
iteration += 1
|
||||
logger.info(f"--- Iteration {iteration} ---")
|
||||
logger.info(f"Routing table size: {dht.get_routing_table_size()}")
|
||||
logger.info(f"Connected peers: {len(dht.host.get_connected_peers())}")
|
||||
logger.info(f"Peerstore size: {len(dht.host.get_peerstore().peer_ids())}")
|
||||
await trio.sleep(interval)
|
||||
|
||||
|
||||
async def run_node(port: int, mode: str, demo_interval: int = 30) -> None:
|
||||
"""Run a node that demonstrates Random Walk peer discovery."""
|
||||
try:
|
||||
if port <= 0:
|
||||
port = random.randint(10000, 60000)
|
||||
|
||||
logger.info(f"Starting {mode} node on port {port}")
|
||||
|
||||
# Determine DHT mode
|
||||
dht_mode = DHTMode.SERVER if mode == "server" else DHTMode.CLIENT
|
||||
|
||||
# Create host and DHT
|
||||
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
||||
host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES)
|
||||
listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
# Start maintenance tasks
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
nursery.start_soon(maintain_connections, host)
|
||||
|
||||
peer_id = host.get_id().pretty()
|
||||
logger.info(f"Node peer ID: {peer_id}")
|
||||
logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}")
|
||||
|
||||
# Create and start DHT with Random Walk enabled
|
||||
dht = KadDHT(host, dht_mode, enable_random_walk=True)
|
||||
logger.info(f"Initial routing table size: {dht.get_routing_table_size()}")
|
||||
|
||||
async with background_trio_service(dht):
|
||||
logger.info(f"DHT service started in {dht_mode.value} mode")
|
||||
logger.info(f"Random Walk enabled: {dht.is_random_walk_enabled()}")
|
||||
|
||||
async with trio.open_nursery() as task_nursery:
|
||||
# Start demonstration and status reporting
|
||||
task_nursery.start_soon(
|
||||
demonstrate_random_walk_discovery, dht, demo_interval
|
||||
)
|
||||
|
||||
# Periodic status updates
|
||||
async def status_reporter():
|
||||
while True:
|
||||
await trio.sleep(30)
|
||||
logger.debug(
|
||||
f"Connected: {len(dht.host.get_connected_peers())}, "
|
||||
f"Routing table: {dht.get_routing_table_size()}, "
|
||||
f"Peerstore: {len(dht.host.get_peerstore().peer_ids())}"
|
||||
)
|
||||
|
||||
task_nursery.start_soon(status_reporter)
|
||||
await trio.sleep_forever()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Node error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Random Walk Example for py-libp2p Kademlia DHT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["server", "client"],
|
||||
default="server",
|
||||
help="Node mode: server (DHT server), or client (DHT client)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=0, help="Port to listen on (0 for random)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--demo-interval",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Interval between random walk demonstrations in seconds",
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the random walk example."""
|
||||
try:
|
||||
args = parse_args()
|
||||
setup_logging(args.verbose)
|
||||
|
||||
logger.info("=== Random Walk Example for py-libp2p ===")
|
||||
logger.info(
|
||||
f"Mode: {args.mode}, Port: {args.port} Demo interval: {args.demo_interval}s"
|
||||
)
|
||||
|
||||
trio.run(run_node, args.port, args.mode, args.demo_interval)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal, shutting down...")
|
||||
except Exception as e:
|
||||
logger.critical(f"Example failed: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
17
libp2p/discovery/random_walk/__init__.py
Normal file
17
libp2p/discovery/random_walk/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Random walk discovery modules for py-libp2p."""
|
||||
|
||||
from .rt_refresh_manager import RTRefreshManager
|
||||
from .random_walk import RandomWalk
|
||||
from .exceptions import (
|
||||
RoutingTableRefreshError,
|
||||
RandomWalkError,
|
||||
PeerValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RTRefreshManager",
|
||||
"RandomWalk",
|
||||
"RoutingTableRefreshError",
|
||||
"RandomWalkError",
|
||||
"PeerValidationError",
|
||||
]
|
||||
16
libp2p/discovery/random_walk/config.py
Normal file
16
libp2p/discovery/random_walk/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Final
|
||||
|
||||
# Timing constants (matching go-libp2p)
|
||||
PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds
|
||||
REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds
|
||||
REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes
|
||||
SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute
|
||||
|
||||
# Routing table thresholds
|
||||
MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh
|
||||
MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try
|
||||
|
||||
# Random walk specific
|
||||
RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks
|
||||
RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks
|
||||
RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback
|
||||
19
libp2p/discovery/random_walk/exceptions.py
Normal file
19
libp2p/discovery/random_walk/exceptions.py
Normal file
@ -0,0 +1,19 @@
|
||||
from libp2p.exceptions import BaseLibp2pError
|
||||
|
||||
|
||||
class RoutingTableRefreshError(BaseLibp2pError):
|
||||
"""Base exception for routing table refresh operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RandomWalkError(RoutingTableRefreshError):
|
||||
"""Exception raised during random walk operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PeerValidationError(RoutingTableRefreshError):
|
||||
"""Exception raised when peer validation fails."""
|
||||
|
||||
pass
|
||||
218
libp2p/discovery/random_walk/random_walk.py
Normal file
218
libp2p/discovery/random_walk/random_walk.py
Normal file
@ -0,0 +1,218 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_RT_THRESHOLD,
|
||||
REFRESH_QUERY_TIMEOUT,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RandomWalkError
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk")
|
||||
|
||||
|
||||
class RandomWalk:
|
||||
"""
|
||||
Random Walk implementation for peer discovery in Kademlia DHT.
|
||||
|
||||
Generates random peer IDs and performs FIND_NODE queries to discover
|
||||
new peers and populate the routing table.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
):
|
||||
"""
|
||||
Initialize Random Walk module.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
def generate_random_peer_id(self) -> str:
|
||||
"""
|
||||
Generate a completely random peer ID
|
||||
for random walk queries.
|
||||
|
||||
Returns:
|
||||
Random peer ID as string
|
||||
|
||||
"""
|
||||
# Generate 32 random bytes (256 bits) - same as go-libp2p
|
||||
random_bytes = secrets.token_bytes(32)
|
||||
# Convert to hex string for query
|
||||
return random_bytes.hex()
|
||||
|
||||
async def perform_random_walk(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Perform a single random walk operation.
|
||||
|
||||
Returns:
|
||||
List of validated peers discovered during the walk
|
||||
|
||||
"""
|
||||
try:
|
||||
# Generate random peer ID
|
||||
random_peer_id = self.generate_random_peer_id()
|
||||
logger.info(f"Starting random walk for peer ID: {random_peer_id}")
|
||||
|
||||
# Perform FIND_NODE query
|
||||
discovered_peer_ids: list[ID] = []
|
||||
|
||||
with trio.move_on_after(REFRESH_QUERY_TIMEOUT):
|
||||
# Call the query function with target key bytes
|
||||
target_key = bytes.fromhex(random_peer_id)
|
||||
discovered_peer_ids = await self.query_function(target_key) or []
|
||||
|
||||
if not discovered_peer_ids:
|
||||
logger.debug(f"No peers discovered in random walk for {random_peer_id}")
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(discovered_peer_ids)} peers in random walk "
|
||||
f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity
|
||||
)
|
||||
|
||||
# Convert peer IDs to PeerInfo objects and validate
|
||||
validated_peers: list[PeerInfo] = []
|
||||
|
||||
for peer_id in discovered_peer_ids:
|
||||
try:
|
||||
# Get addresses from peerstore
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
validated_peers.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}")
|
||||
continue
|
||||
|
||||
return validated_peers
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Random walk failed: {e}")
|
||||
raise RandomWalkError(f"Random walk operation failed: {e}") from e
|
||||
|
||||
async def run_concurrent_random_walks(
|
||||
self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0
|
||||
) -> list[PeerInfo]:
|
||||
"""
|
||||
Run multiple random walks concurrently.
|
||||
|
||||
Args:
|
||||
count: Number of concurrent random walks to perform
|
||||
current_routing_table_size: Current size of routing table (for optimization)
|
||||
|
||||
Returns:
|
||||
Combined list of all validated peers discovered
|
||||
|
||||
"""
|
||||
all_validated_peers: list[PeerInfo] = []
|
||||
logger.info(f"Starting {count} concurrent random walks")
|
||||
|
||||
# First, try to add peers from peerstore if routing table is small
|
||||
if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD:
|
||||
try:
|
||||
peerstore_peers = self._get_peerstore_peers()
|
||||
if peerstore_peers:
|
||||
logger.debug(
|
||||
f"RT size ({current_routing_table_size}) below threshold, "
|
||||
f"adding {len(peerstore_peers)} peerstore peers"
|
||||
)
|
||||
all_validated_peers.extend(peerstore_peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing peerstore peers: {e}")
|
||||
|
||||
async def single_walk() -> None:
|
||||
try:
|
||||
peers = await self.perform_random_walk()
|
||||
all_validated_peers.extend(peers)
|
||||
except Exception as e:
|
||||
logger.warning(f"Concurrent random walk failed: {e}")
|
||||
return
|
||||
|
||||
# Run concurrent random walks
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(count):
|
||||
nursery.start_soon(single_walk)
|
||||
|
||||
# Remove duplicates based on peer ID
|
||||
unique_peers = {}
|
||||
for peer in all_validated_peers:
|
||||
unique_peers[peer.peer_id] = peer
|
||||
|
||||
result = list(unique_peers.values())
|
||||
logger.info(
|
||||
f"Concurrent random walks completed: {len(result)} unique peers discovered"
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_peerstore_peers(self) -> list[PeerInfo]:
|
||||
"""
|
||||
Get peer info objects from the host's peerstore.
|
||||
|
||||
Returns:
|
||||
List of PeerInfo objects from peerstore
|
||||
|
||||
"""
|
||||
try:
|
||||
peerstore = self.host.get_peerstore()
|
||||
peer_ids = peerstore.peers_with_addrs()
|
||||
|
||||
peer_infos = []
|
||||
for peer_id in peer_ids:
|
||||
try:
|
||||
# Skip local peer
|
||||
if peer_id == self.local_peer_id:
|
||||
continue
|
||||
|
||||
peer_info = peerstore.peer_info(peer_id)
|
||||
if peer_info and peer_info.addrs:
|
||||
# Filter for compatible addresses (TCP + IPv4)
|
||||
if self._has_compatible_addresses(peer_info):
|
||||
peer_infos.append(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting peer info for {peer_id}: {e}")
|
||||
|
||||
return peer_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error accessing peerstore: {e}")
|
||||
return []
|
||||
|
||||
def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool:
|
||||
"""
|
||||
Check if a peer has TCP+IPv4 compatible addresses.
|
||||
|
||||
Args:
|
||||
peer_info: PeerInfo to check
|
||||
|
||||
Returns:
|
||||
True if peer has compatible addresses
|
||||
|
||||
"""
|
||||
if not peer_info.addrs:
|
||||
return False
|
||||
|
||||
for addr in peer_info.addrs:
|
||||
addr_str = str(addr)
|
||||
# Check for TCP and IPv4 compatibility, avoid QUIC
|
||||
if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str:
|
||||
return True
|
||||
|
||||
return False
|
||||
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
208
libp2p/discovery/random_walk/rt_refresh_manager.py
Normal file
@ -0,0 +1,208 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import time
|
||||
from typing import Protocol
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IHost
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
RANDOM_WALK_ENABLED,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class RoutingTableProtocol(Protocol):
|
||||
"""Protocol for routing table operations needed by RT refresh manager."""
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the current size of the routing table."""
|
||||
...
|
||||
|
||||
async def add_peer(self, peer_obj: PeerInfo) -> bool:
|
||||
"""Add a peer to the routing table."""
|
||||
...
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager")
|
||||
|
||||
|
||||
class RTRefreshManager:
|
||||
"""
|
||||
Routing Table Refresh Manager for py-libp2p.
|
||||
|
||||
Manages periodic routing table refreshes and random walk operations
|
||||
to maintain routing table health and discover new peers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
routing_table: RoutingTableProtocol,
|
||||
local_peer_id: ID,
|
||||
query_function: Callable[[bytes], Awaitable[list[ID]]],
|
||||
enable_auto_refresh: bool = RANDOM_WALK_ENABLED,
|
||||
refresh_interval: float = REFRESH_INTERVAL,
|
||||
min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD,
|
||||
):
|
||||
"""
|
||||
Initialize RT Refresh Manager.
|
||||
|
||||
Args:
|
||||
host: The libp2p host instance
|
||||
routing_table: Routing table of host
|
||||
local_peer_id: Local peer ID
|
||||
query_function: Function to query for closest peers given target key bytes
|
||||
enable_auto_refresh: Whether to enable automatic refresh
|
||||
refresh_interval: Interval between refreshes in seconds
|
||||
min_refresh_threshold: Minimum RT size before triggering refresh
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.routing_table = routing_table
|
||||
self.local_peer_id = local_peer_id
|
||||
self.query_function = query_function
|
||||
|
||||
self.enable_auto_refresh = enable_auto_refresh
|
||||
self.refresh_interval = refresh_interval
|
||||
self.min_refresh_threshold = min_refresh_threshold
|
||||
|
||||
# Initialize random walk module
|
||||
self.random_walk = RandomWalk(
|
||||
host=host,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=query_function,
|
||||
)
|
||||
|
||||
# Control variables
|
||||
self._running = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
# Tracking
|
||||
self._last_refresh_time = 0.0
|
||||
self._refresh_done_callbacks: list[Callable[[], None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the RT Refresh Manager."""
|
||||
if self._running:
|
||||
logger.warning("RT Refresh Manager is already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
logger.info("Starting RT Refresh Manager")
|
||||
|
||||
# Start the main loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
self._nursery = nursery
|
||||
nursery.start_soon(self._main_loop)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the RT Refresh Manager."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
logger.info("Stopping RT Refresh Manager")
|
||||
self._running = False
|
||||
|
||||
async def _main_loop(self) -> None:
|
||||
"""Main loop for the RT Refresh Manager."""
|
||||
logger.info("RT Refresh Manager main loop started")
|
||||
|
||||
# Initial refresh if auto-refresh is enabled
|
||||
if self.enable_auto_refresh:
|
||||
await self._do_refresh(force=True)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Schedule periodic refresh if enabled
|
||||
if self.enable_auto_refresh:
|
||||
nursery.start_soon(self._periodic_refresh_task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RT Refresh Manager main loop error: {e}")
|
||||
finally:
|
||||
logger.info("RT Refresh Manager main loop stopped")
|
||||
|
||||
async def _periodic_refresh_task(self) -> None:
|
||||
"""Task for periodic refreshes."""
|
||||
while self._running:
|
||||
await trio.sleep(self.refresh_interval)
|
||||
if self._running:
|
||||
await self._do_refresh()
|
||||
|
||||
async def _do_refresh(self, force: bool = False) -> None:
|
||||
"""
|
||||
Perform routing table refresh operation.
|
||||
|
||||
Args:
|
||||
force: Whether to force refresh regardless of timing
|
||||
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
# Check if refresh is needed
|
||||
if not force:
|
||||
if current_time - self._last_refresh_time < self.refresh_interval:
|
||||
logger.debug("Skipping refresh: interval not elapsed")
|
||||
return
|
||||
|
||||
if self.routing_table.size() >= self.min_refresh_threshold:
|
||||
logger.debug("Skipping refresh: routing table size above threshold")
|
||||
return
|
||||
|
||||
logger.info(f"Starting routing table refresh (force={force})")
|
||||
start_time = current_time
|
||||
|
||||
# Perform random walks to discover new peers
|
||||
logger.info("Running concurrent random walks to discover new peers")
|
||||
current_rt_size = self.routing_table.size()
|
||||
discovered_peers = await self.random_walk.run_concurrent_random_walks(
|
||||
count=RANDOM_WALK_CONCURRENCY,
|
||||
current_routing_table_size=current_rt_size,
|
||||
)
|
||||
|
||||
# Add discovered peers to routing table
|
||||
added_count = 0
|
||||
for peer_info in discovered_peers:
|
||||
result = await self.routing_table.add_peer(peer_info)
|
||||
if result:
|
||||
added_count += 1
|
||||
|
||||
self._last_refresh_time = current_time
|
||||
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
f"Routing table refresh completed: "
|
||||
f"{added_count}/{len(discovered_peers)} peers added, "
|
||||
f"RT size: {self.routing_table.size()}, "
|
||||
f"duration: {duration:.2f}s"
|
||||
)
|
||||
|
||||
# Notify refresh completion
|
||||
for callback in self._refresh_done_callbacks:
|
||||
try:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.warning(f"Refresh callback error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Routing table refresh failed: {e}")
|
||||
raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e
|
||||
|
||||
def add_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Add a callback to be called when refresh completes."""
|
||||
self._refresh_done_callbacks.append(callback)
|
||||
|
||||
def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None:
|
||||
"""Remove a refresh completion callback."""
|
||||
if callback in self._refresh_done_callbacks:
|
||||
self._refresh_done_callbacks.remove(callback)
|
||||
@ -295,6 +295,13 @@ class BasicHost(IHost):
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
if protocol is None:
|
||||
logger.debug(
|
||||
"no protocol negotiated, closing stream from peer %s",
|
||||
net_stream.muxed_conn.peer_id,
|
||||
)
|
||||
await net_stream.reset()
|
||||
return
|
||||
net_stream.set_protocol(protocol)
|
||||
if handler is None:
|
||||
logger.debug(
|
||||
|
||||
@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT)
|
||||
implementation based on the Kademlia algorithm and protocol.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from enum import (
|
||||
Enum,
|
||||
)
|
||||
@ -20,6 +21,7 @@ import varint
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
@ -73,14 +75,27 @@ class KadDHT(Service):
|
||||
|
||||
This class provides a DHT implementation that combines routing table management,
|
||||
peer discovery, content routing, and value storage.
|
||||
|
||||
Optional Random Walk feature enhances peer discovery by automatically
|
||||
performing periodic random queries to discover new peers and maintain
|
||||
routing table health.
|
||||
|
||||
Example:
|
||||
# Basic DHT without random walk (default)
|
||||
dht = KadDHT(host, DHTMode.SERVER)
|
||||
|
||||
# DHT with random walk enabled for enhanced peer discovery
|
||||
dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, mode: DHTMode):
|
||||
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
|
||||
"""
|
||||
Initialize a new Kademlia DHT node.
|
||||
|
||||
:param host: The libp2p host.
|
||||
:param mode: The mode of host (Client or Server) - must be DHTMode enum
|
||||
:param enable_random_walk: Whether to enable automatic random walk
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -92,6 +107,7 @@ class KadDHT(Service):
|
||||
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
|
||||
|
||||
self.mode = mode
|
||||
self.enable_random_walk = enable_random_walk
|
||||
|
||||
# Initialize the routing table
|
||||
self.routing_table = RoutingTable(self.local_peer_id, self.host)
|
||||
@ -108,13 +124,56 @@ class KadDHT(Service):
|
||||
# Last time we republished provider records
|
||||
self._last_provider_republish = time.time()
|
||||
|
||||
# Initialize RT Refresh Manager (only if random walk is enabled)
|
||||
self.rt_refresh_manager: RTRefreshManager | None = None
|
||||
if self.enable_random_walk:
|
||||
self.rt_refresh_manager = RTRefreshManager(
|
||||
host=self.host,
|
||||
routing_table=self.routing_table,
|
||||
local_peer_id=self.local_peer_id,
|
||||
query_function=self._create_query_function(),
|
||||
enable_auto_refresh=True,
|
||||
)
|
||||
|
||||
# Set protocol handlers
|
||||
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
|
||||
|
||||
def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]:
|
||||
"""
|
||||
Create a query function that wraps peer_routing.find_closest_peers_network.
|
||||
|
||||
This function is used by the RandomWalk module to query for peers without
|
||||
directly importing PeerRouting, avoiding circular import issues.
|
||||
|
||||
Returns:
|
||||
Callable that takes target_key bytes and returns list of peer IDs
|
||||
|
||||
"""
|
||||
|
||||
async def query_function(target_key: bytes) -> list[ID]:
|
||||
"""Query for closest peers to target key."""
|
||||
return await self.peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
return query_function
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the DHT service."""
|
||||
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
|
||||
|
||||
# Start the RT Refresh Manager in parallel with the main DHT service
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start the RT Refresh Manager only if random walk is enabled
|
||||
if self.rt_refresh_manager is not None:
|
||||
nursery.start_soon(self.rt_refresh_manager.start)
|
||||
logger.info("RT Refresh Manager started - Random Walk is now active")
|
||||
else:
|
||||
logger.info("Random Walk is disabled - RT Refresh Manager not started")
|
||||
|
||||
# Start the main DHT service loop
|
||||
nursery.start_soon(self._run_main_loop)
|
||||
|
||||
async def _run_main_loop(self) -> None:
|
||||
"""Run the main DHT service loop."""
|
||||
# Main service loop
|
||||
while self.manager.is_running:
|
||||
# Periodically refresh the routing table
|
||||
@ -135,6 +194,17 @@ class KadDHT(Service):
|
||||
# Wait before next maintenance cycle
|
||||
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DHT service and cleanup resources."""
|
||||
logger.info("Stopping Kademlia DHT")
|
||||
|
||||
# Stop the RT Refresh Manager only if it was started
|
||||
if self.rt_refresh_manager is not None:
|
||||
await self.rt_refresh_manager.stop()
|
||||
logger.info("RT Refresh Manager stopped")
|
||||
else:
|
||||
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
|
||||
|
||||
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
|
||||
"""
|
||||
Switch the DHT mode.
|
||||
@ -614,3 +684,15 @@ class KadDHT(Service):
|
||||
|
||||
"""
|
||||
return self.value_store.size()
|
||||
|
||||
def is_random_walk_enabled(self) -> bool:
|
||||
"""
|
||||
Check if random walk peer discovery is enabled.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if random walk is enabled, False otherwise.
|
||||
|
||||
"""
|
||||
return self.enable_random_walk
|
||||
|
||||
@ -170,7 +170,7 @@ class PeerRouting(IPeerRouting):
|
||||
|
||||
# Return early if we have no peers to start with
|
||||
if not closest_peers:
|
||||
logger.warning("No local peers available for network lookup")
|
||||
logger.debug("No local peers available for network lookup")
|
||||
return []
|
||||
|
||||
# Iterative lookup until convergence
|
||||
|
||||
@ -249,9 +249,11 @@ class Swarm(Service, INetworkService):
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
for maddr in multiaddrs:
|
||||
if str(maddr) in self.listeners:
|
||||
return True
|
||||
success_count += 1
|
||||
continue
|
||||
|
||||
async def conn_handler(
|
||||
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
||||
@ -302,13 +304,14 @@ class Swarm(Service, INetworkService):
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_listen(maddr)
|
||||
|
||||
return True
|
||||
success_count += 1
|
||||
logger.debug("successfully started listening on: %s", maddr)
|
||||
except OSError:
|
||||
# Failed. Continue looping.
|
||||
logger.debug("fail to listen on: %s", maddr)
|
||||
|
||||
# No maddr succeeded
|
||||
return False
|
||||
# Return true if at least one address succeeded
|
||||
return success_count > 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
|
||||
@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
|
||||
"""
|
||||
self.handlers[protocol] = handler
|
||||
|
||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||
async def negotiate(
|
||||
self,
|
||||
communicator: IMultiselectCommunicator,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
|
||||
"""
|
||||
Negotiate performs protocol selection.
|
||||
|
||||
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
|
||||
raise MultiselectError() from error
|
||||
|
||||
else:
|
||||
protocol = TProtocol(command)
|
||||
if protocol in self.handlers:
|
||||
protocol_to_check = None if not command else TProtocol(command)
|
||||
if protocol_to_check in self.handlers:
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(command)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectError() from error
|
||||
|
||||
return protocol, self.handlers[protocol]
|
||||
return protocol_to_check, self.handlers[protocol_to_check]
|
||||
try:
|
||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||
except MultiselectCommunicatorError as error:
|
||||
|
||||
@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
:return: selected protocol
|
||||
"""
|
||||
# Represent `None` protocol as an empty string.
|
||||
protocol_str = protocol if protocol is not None else ""
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
await communicator.write(protocol_str)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
if response == protocol:
|
||||
if response == protocol_str:
|
||||
return protocol
|
||||
if response == PROTOCOL_NOT_FOUND_MSG:
|
||||
raise MultiselectClientError("protocol not supported")
|
||||
|
||||
@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
|
||||
""" # noqa: E501
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
if msg_str is None:
|
||||
msg_bytes = encode_delim(b"")
|
||||
else:
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
|
||||
@ -777,14 +777,18 @@ class GossipSub(IPubsubRouter, Service):
|
||||
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in
|
||||
# seen_messages cache
|
||||
seen_seqnos_and_peers = [
|
||||
seqno_and_from for seqno_and_from in self.pubsub.seen_messages.cache.keys()
|
||||
str(seqno_and_from)
|
||||
for seqno_and_from in self.pubsub.seen_messages.cache.keys()
|
||||
]
|
||||
|
||||
# Add all unknown message ids (ids that appear in ihave_msg but not in
|
||||
# seen_seqnos) to list of messages we want to request
|
||||
msg_ids_wanted: list[str] = [
|
||||
msg_id
|
||||
msg_ids_wanted: list[MessageID] = [
|
||||
parse_message_id_safe(msg_id)
|
||||
for msg_id in ihave_msg.messageIDs
|
||||
if msg_id not in seen_seqnos_and_peers
|
||||
if msg_id not in str(seen_seqnos_and_peers)
|
||||
]
|
||||
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
|
||||
:param is_initiator: true if we are the initiator, false otherwise
|
||||
:return: selected secure transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if is_initiator:
|
||||
# Select protocol if initiator
|
||||
@ -114,5 +117,7 @@ class SecurityMultistream(ABC):
|
||||
else:
|
||||
# Select protocol if non-initiator
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a security protocol")
|
||||
# Return transport from protocol
|
||||
return self.transports[protocol]
|
||||
|
||||
@ -17,6 +17,9 @@ from libp2p.custom_types import (
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
Multiselect,
|
||||
)
|
||||
@ -73,7 +76,7 @@ class MuxerMultistream:
|
||||
:param conn: conn to choose a transport over
|
||||
:return: selected muxer transport
|
||||
"""
|
||||
protocol: TProtocol
|
||||
protocol: TProtocol | None
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
@ -81,6 +84,8 @@ class MuxerMultistream:
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
if protocol is None:
|
||||
raise MultiselectError("fail to negotiate a stream muxer protocol")
|
||||
return self.transports[protocol]
|
||||
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
|
||||
@ -15,6 +15,13 @@ from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
)
|
||||
|
||||
from libp2p.utils.address_validation import (
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
expand_wildcard_address,
|
||||
find_free_port,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"decode_uvarint_from_stream",
|
||||
"encode_delim",
|
||||
@ -26,4 +33,8 @@ __all__ = [
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
"read_length_prefixed_protobuf",
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
|
||||
160
libp2p/utils/address_validation.py
Normal file
160
libp2p/utils/address_validation.py
Normal file
@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
try:
|
||||
from multiaddr.utils import ( # type: ignore
|
||||
get_network_addrs,
|
||||
get_thin_waist_addresses,
|
||||
)
|
||||
|
||||
_HAS_THIN_WAIST = True
|
||||
except ImportError: # pragma: no cover - only executed in older environments
|
||||
_HAS_THIN_WAIST = False
|
||||
get_thin_waist_addresses = None # type: ignore
|
||||
get_network_addrs = None # type: ignore
|
||||
|
||||
|
||||
def _safe_get_network_addrs(ip_version: int) -> list[str]:
|
||||
"""
|
||||
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
|
||||
Falls back to minimal defaults when Thin Waist helpers are missing.
|
||||
|
||||
:param ip_version: 4 or 6
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_network_addrs:
|
||||
try:
|
||||
return get_network_addrs(ip_version) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return []
|
||||
# Fallback behavior (very conservative)
|
||||
if ip_version == 4:
|
||||
return ["127.0.0.1"]
|
||||
if ip_version == 6:
|
||||
return ["::1"]
|
||||
return []
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to a free port provided by the OS
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
|
||||
"""
|
||||
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
|
||||
If Thin Waist isn't available, returns [addr] (identity).
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_thin_waist_addresses:
|
||||
try:
|
||||
if port is not None:
|
||||
return get_thin_waist_addresses(addr, port=port) or []
|
||||
return get_thin_waist_addresses(addr) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return [addr]
|
||||
return [addr]
|
||||
|
||||
|
||||
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
|
||||
"""
|
||||
Discover available network interfaces (IPv4 + IPv6 if supported) for binding.
|
||||
|
||||
:param port: Port number to bind to.
|
||||
:param protocol: Transport protocol (e.g., "tcp" or "udp").
|
||||
:return: List of Multiaddr objects representing candidate interface addresses.
|
||||
"""
|
||||
addrs: list[Multiaddr] = []
|
||||
|
||||
# IPv4 enumeration
|
||||
seen_v4: set[str] = set()
|
||||
|
||||
for ip in _safe_get_network_addrs(4):
|
||||
seen_v4.add(ip)
|
||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||
|
||||
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
|
||||
if seen_v4 and "127.0.0.1" not in seen_v4:
|
||||
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
|
||||
|
||||
# TODO: IPv6 support temporarily disabled due to libp2p handshake issues
|
||||
# IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure)
|
||||
# Re-enable IPv6 support once the following issues are resolved:
|
||||
# - libp2p security handshake over IPv6
|
||||
# - multiselect protocol over IPv6
|
||||
# - connection establishment over IPv6
|
||||
#
|
||||
# seen_v6: set[str] = set()
|
||||
# for ip in _safe_get_network_addrs(6):
|
||||
# seen_v6.add(ip)
|
||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||
#
|
||||
# # Always include IPv6 loopback for testing purposes when IPv6 is available
|
||||
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
|
||||
# if "::1" not in seen_v6:
|
||||
# addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}"))
|
||||
|
||||
# Fallback if nothing discovered
|
||||
if not addrs:
|
||||
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
|
||||
|
||||
return addrs
|
||||
|
||||
|
||||
def expand_wildcard_address(
|
||||
addr: Multiaddr, port: int | None = None
|
||||
) -> list[Multiaddr]:
|
||||
"""
|
||||
Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces.
|
||||
|
||||
:param addr: Multiaddr to expand.
|
||||
:param port: Optional override for port selection.
|
||||
:return: List of concrete Multiaddr instances.
|
||||
"""
|
||||
expanded = _safe_expand(addr, port=port)
|
||||
if not expanded: # Safety fallback
|
||||
return [addr]
|
||||
return expanded
|
||||
|
||||
|
||||
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
"""
|
||||
Choose an optimal address for an example to bind to:
|
||||
- Prefer non-loopback IPv4
|
||||
- Then non-loopback IPv6
|
||||
- Fallback to loopback
|
||||
- Fallback to wildcard
|
||||
|
||||
:param port: Port number.
|
||||
:param protocol: Transport protocol.
|
||||
:return: A single Multiaddr chosen heuristically.
|
||||
"""
|
||||
candidates = get_available_interfaces(port, protocol)
|
||||
|
||||
def is_non_loopback(ma: Multiaddr) -> bool:
|
||||
s = str(ma)
|
||||
return not ("/ip4/127." in s or "/ip6/::1" in s)
|
||||
|
||||
for c in candidates:
|
||||
if "/ip4/" in str(c) and is_non_loopback(c):
|
||||
return c
|
||||
for c in candidates:
|
||||
if "/ip6/" in str(c) and is_non_loopback(c):
|
||||
return c
|
||||
for c in candidates:
|
||||
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
|
||||
return c
|
||||
|
||||
# As a final fallback, produce a wildcard
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
1
newsfragments/770.internal.rst
Normal file
1
newsfragments/770.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py
|
||||
1
newsfragments/811.feature.rst
Normal file
1
newsfragments/811.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Added Thin Waist address validation utilities (with support for interface enumeration, optimal binding, and wildcard expansion).
|
||||
7
newsfragments/811.internal.rst
Normal file
7
newsfragments/811.internal.rst
Normal file
@ -0,0 +1,7 @@
|
||||
Add Thin Waist address validation utilities and integrate into echo example
|
||||
|
||||
- Add ``libp2p/utils/address_validation.py`` with dynamic interface discovery
|
||||
- Implement ``get_available_interfaces()``, ``get_optimal_binding_address()``, and ``expand_wildcard_address()``
|
||||
- Update echo example to use dynamic address discovery instead of hardcoded wildcard
|
||||
- Add safe fallbacks for environments lacking Thin Waist support
|
||||
- Temporarily disable IPv6 support due to libp2p handshake issues (TODO: re-enable when resolved)
|
||||
1
newsfragments/822.feature.rst
Normal file
1
newsfragments/822.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Added `Random Walk` peer discovery module that enables random peer exploration for improved peer discovery.
|
||||
1
newsfragments/855.internal.rst
Normal file
1
newsfragments/855.internal.rst
Normal file
@ -0,0 +1 @@
|
||||
Improved PubsubNotifee integration tests and added failure scenario coverage.
|
||||
1
newsfragments/859.feature.rst
Normal file
1
newsfragments/859.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Fix type for gossipsub_message_id for consistency and security
|
||||
5
newsfragments/863.bugfix.rst
Normal file
5
newsfragments/863.bugfix.rst
Normal file
@ -0,0 +1,5 @@
|
||||
Fix multi-address listening bug in swarm.listen()
|
||||
|
||||
- Fix early return in swarm.listen() that prevented listening on all addresses
|
||||
- Add comprehensive tests for multi-address listening functionality
|
||||
- Ensure all available interfaces are properly bound and connectable
|
||||
@ -11,9 +11,9 @@ requires-python = ">=3.10, <4.0"
|
||||
license = { text = "MIT AND Apache-2.0" }
|
||||
keywords = ["libp2p", "p2p"]
|
||||
maintainers = [
|
||||
{ name = "pacrob", email = "pacrob@protonmail.com" },
|
||||
{ name = "pacrob", email = "pacrob-py-libp2p@proton.me" },
|
||||
{ name = "Manu Sheel Gupta", email = "manu@seeta.in" },
|
||||
{ name = "Dave Grantham", email = "dave@aviation.community" },
|
||||
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
|
||||
]
|
||||
dependencies = [
|
||||
"base58>=1.0.3",
|
||||
|
||||
82
tests/core/network/test_notifee_performance.py
Normal file
82
tests/core/network/test_notifee_performance.py
Normal file
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
INetConn,
|
||||
INetStream,
|
||||
INetwork,
|
||||
INotifee,
|
||||
)
|
||||
from libp2p.tools.utils import connect_swarm
|
||||
from tests.utils.factories import SwarmFactory
|
||||
|
||||
|
||||
class CountingNotifee(INotifee):
|
||||
def __init__(self, event: trio.Event) -> None:
|
||||
self._event = event
|
||||
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
self._event.set()
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SlowNotifee(INotifee):
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
await trio.sleep(0.5)
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_many_notifees_receive_connected_quickly() -> None:
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
count = 200
|
||||
events = [trio.Event() for _ in range(count)]
|
||||
for ev in events:
|
||||
swarms[0].register_notifee(CountingNotifee(ev))
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
with trio.fail_after(1.5):
|
||||
for ev in events:
|
||||
await ev.wait()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_slow_notifee_does_not_block_others() -> None:
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
fast_events = [trio.Event() for _ in range(20)]
|
||||
for ev in fast_events:
|
||||
swarms[0].register_notifee(CountingNotifee(ev))
|
||||
swarms[0].register_notifee(SlowNotifee())
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# Fast notifees should complete quickly despite one slow notifee
|
||||
with trio.fail_after(0.3):
|
||||
for ev in fast_events:
|
||||
await ev.wait()
|
||||
76
tests/core/network/test_notify_listen_lifecycle.py
Normal file
76
tests/core/network/test_notify_listen_lifecycle.py
Normal file
@ -0,0 +1,76 @@
|
||||
import enum
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
INetConn,
|
||||
INetStream,
|
||||
INetwork,
|
||||
INotifee,
|
||||
)
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
from tests.utils.factories import SwarmFactory
|
||||
|
||||
|
||||
class Event(enum.Enum):
|
||||
Listen = 0
|
||||
ListenClose = 1
|
||||
|
||||
|
||||
class MyNotifee(INotifee):
|
||||
def __init__(self, events: list[Event]):
|
||||
self.events = events
|
||||
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
self.events.append(Event.Listen)
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
self.events.append(Event.ListenClose)
|
||||
|
||||
|
||||
async def wait_for_event(
|
||||
events_list: list[Event], event: Event, timeout: float = 1.0
|
||||
) -> bool:
|
||||
with trio.move_on_after(timeout):
|
||||
while event not in events_list:
|
||||
await trio.sleep(0.01)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listen_emitted_when_registered_before_listen():
|
||||
events: list[Event] = []
|
||||
swarm = SwarmFactory.build()
|
||||
swarm.register_notifee(MyNotifee(events))
|
||||
async with background_trio_service(swarm):
|
||||
# Start listening now; notifee was registered beforehand
|
||||
assert await swarm.listen(LISTEN_MADDR)
|
||||
assert await wait_for_event(events, Event.Listen)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_single_listener_close_emits_listen_close():
|
||||
events: list[Event] = []
|
||||
swarm = SwarmFactory.build()
|
||||
swarm.register_notifee(MyNotifee(events))
|
||||
async with background_trio_service(swarm):
|
||||
assert await swarm.listen(LISTEN_MADDR)
|
||||
# Explicitly notify listen_close (close path via manager doesn't emit it)
|
||||
await swarm.notify_listen_close(LISTEN_MADDR)
|
||||
assert await wait_for_event(events, Event.ListenClose)
|
||||
@ -16,6 +16,9 @@ from libp2p.network.exceptions import (
|
||||
from libp2p.network.swarm import (
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect_swarm,
|
||||
)
|
||||
@ -184,3 +187,116 @@ def test_new_swarm_quic_multiaddr_raises():
|
||||
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
|
||||
with pytest.raises(ValueError, match="QUIC not yet supported"):
|
||||
new_swarm(listen_addrs=[addr])
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_listen_multiple_addresses(security_protocol):
|
||||
"""Test that swarm can listen on multiple addresses simultaneously."""
|
||||
from libp2p.utils.address_validation import get_available_interfaces
|
||||
|
||||
# Get multiple addresses to listen on
|
||||
listen_addrs = get_available_interfaces(0) # Let OS choose ports
|
||||
|
||||
# Create a swarm and listen on multiple addresses
|
||||
swarm = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm):
|
||||
# Listen on all addresses
|
||||
success = await swarm.listen(*listen_addrs)
|
||||
assert success, "Should successfully listen on at least one address"
|
||||
|
||||
# Check that we have listeners for the addresses
|
||||
actual_listeners = list(swarm.listeners.keys())
|
||||
assert len(actual_listeners) > 0, "Should have at least one listener"
|
||||
|
||||
# Verify that all successful listeners are in the listeners dict
|
||||
successful_count = 0
|
||||
for addr in listen_addrs:
|
||||
addr_str = str(addr)
|
||||
if addr_str in actual_listeners:
|
||||
successful_count += 1
|
||||
# This address successfully started listening
|
||||
listener = swarm.listeners[addr_str]
|
||||
listener_addrs = listener.get_addrs()
|
||||
assert len(listener_addrs) > 0, (
|
||||
f"Listener for {addr} should have addresses"
|
||||
)
|
||||
|
||||
# Check that the listener address matches the expected address
|
||||
# (port might be different if we used port 0)
|
||||
expected_ip = addr.value_for_protocol("ip4")
|
||||
expected_protocol = addr.value_for_protocol("tcp")
|
||||
if expected_ip and expected_protocol:
|
||||
found_matching = False
|
||||
for listener_addr in listener_addrs:
|
||||
if (
|
||||
listener_addr.value_for_protocol("ip4") == expected_ip
|
||||
and listener_addr.value_for_protocol("tcp") is not None
|
||||
):
|
||||
found_matching = True
|
||||
break
|
||||
assert found_matching, (
|
||||
f"Listener for {addr} should have matching IP"
|
||||
)
|
||||
|
||||
assert successful_count == len(listen_addrs), (
|
||||
f"All {len(listen_addrs)} addresses should be listening, "
|
||||
f"but only {successful_count} succeeded"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_listen_multiple_addresses_connectivity(security_protocol):
|
||||
"""Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.utils.address_validation import get_available_interfaces
|
||||
|
||||
# Get multiple addresses to listen on
|
||||
listen_addrs = get_available_interfaces(0) # Let OS choose ports
|
||||
|
||||
# Create a swarm and listen on multiple addresses
|
||||
swarm1 = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm1):
|
||||
# Listen on all addresses
|
||||
success = await swarm1.listen(*listen_addrs)
|
||||
assert success, "Should successfully listen on at least one address"
|
||||
|
||||
# Verify all available interfaces are listening
|
||||
assert len(swarm1.listeners) == len(listen_addrs), (
|
||||
f"All {len(listen_addrs)} interfaces should be listening, "
|
||||
f"but only {len(swarm1.listeners)} are"
|
||||
)
|
||||
|
||||
# Create a second swarm to test connections
|
||||
swarm2 = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm2):
|
||||
# Test connectivity to each listening address using real libp2p connections
|
||||
for addr_str, listener in swarm1.listeners.items():
|
||||
listener_addrs = listener.get_addrs()
|
||||
for listener_addr in listener_addrs:
|
||||
# Create a full multiaddr with peer ID for libp2p connection
|
||||
peer_id = swarm1.get_peer_id()
|
||||
full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}")
|
||||
|
||||
# Test real libp2p connection
|
||||
try:
|
||||
peer_info = info_from_p2p_addr(full_addr)
|
||||
|
||||
# Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501
|
||||
swarm2.peerstore.add_addrs(
|
||||
peer_info.peer_id, [listener_addr], 10000
|
||||
)
|
||||
|
||||
await swarm2.dial_peer(peer_info.peer_id)
|
||||
|
||||
# Verify connection was established
|
||||
assert peer_info.peer_id in swarm2.connections, (
|
||||
f"Connection to {full_addr} should be established"
|
||||
)
|
||||
assert swarm2.get_peer_id() in swarm1.connections, (
|
||||
f"Connection from {full_addr} should be established"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Failed to establish libp2p connection to {full_addr}: {e}"
|
||||
)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from collections import deque
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
from libp2p.abc import IMultiselectCommunicator, INetStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||
|
||||
|
||||
async def dummy_handler(stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
||||
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_one_of_timeout():
|
||||
async def test_select_one_of_timeout() -> None:
|
||||
ECHO = TProtocol("/echo/1.0.0")
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
|
||||
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_multistream_command_timeout():
|
||||
async def test_query_multistream_command_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
client = MultiselectClient()
|
||||
|
||||
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout():
|
||||
async def test_negotiate_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
server = Multiselect()
|
||||
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 2)
|
||||
|
||||
|
||||
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
|
||||
handshaked: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handshaked = False
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
if msg_str == "/multistream/1.0.0":
|
||||
self.handshaked = True
|
||||
return
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.handshaked:
|
||||
return "/multistream/1.0.0"
|
||||
# After handshake, hang on read.
|
||||
await trio.sleep_forever()
|
||||
# Should not be reached.
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout_post_handshake() -> None:
|
||||
communicator = HandshakeThenHangCommunicator()
|
||||
server = Multiselect()
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 1)
|
||||
|
||||
|
||||
class MockCommunicator(IMultiselectCommunicator):
|
||||
def __init__(self, commands_to_read: list[str]):
|
||||
self.read_queue = deque(commands_to_read)
|
||||
self.written_data: list[str] = []
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
self.written_data.append(msg_str)
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.read_queue:
|
||||
raise EOFError
|
||||
return self.read_queue.popleft()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_empty_string_command() -> None:
|
||||
# server receives an empty string, which means client wants `None` protocol.
|
||||
server = Multiselect({None: dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check that server sent back handshake and the protocol confirmation (empty string)
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler() -> None:
|
||||
# server has None handler, client sends "" to select it.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, protocol confirmation
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler_ls() -> None:
|
||||
# server has None handler, client sends "ls" then empty string.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, ls, empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, ls response, protocol confirmation
|
||||
assert communicator.written_data[0] == "/multistream/1.0.0"
|
||||
assert "/proto1" in communicator.written_data[1]
|
||||
# Note: `ls` should not list the `None` protocol.
|
||||
assert "None" not in communicator.written_data[1]
|
||||
assert "\n\n" not in communicator.written_data[1]
|
||||
assert communicator.written_data[2] == ""
|
||||
|
||||
@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
|
||||
protocols = ms.get_protocols()
|
||||
|
||||
assert set(protocols) == {p1, p2, p3}
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol(security_protocol):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(
|
||||
None,
|
||||
[None],
|
||||
[None],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
[None, PROTOCOL_ECHO],
|
||||
[PROTOCOL_ECHO],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_server_none_client_other(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)
|
||||
|
||||
90
tests/core/pubsub/test_pubsub_notifee_integration.py
Normal file
90
tests/core/pubsub/test_pubsub_notifee_integration.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.utils import connect
|
||||
from tests.utils.factories import PubsubFactory
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connected_enqueues_and_adds_peer():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Wait until peer is added via queue processing
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
assert p1.my_id in p0.peers
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_disconnected_enqueues_and_removes_peer():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Ensure present first
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
# Now disconnect and expect removal via dead peer queue
|
||||
await p0.host.get_network().close_peer(p1.host.get_id())
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
assert p1.my_id not in p0.peers
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_channel_closed_is_swallowed_in_notifee(monkeypatch) -> None:
|
||||
# Ensure PubsubNotifee catches BrokenResourceError from its send channel
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
# Find the PubsubNotifee registered on the network
|
||||
from libp2p.pubsub.pubsub_notifee import PubsubNotifee
|
||||
|
||||
network = p0.host.get_network()
|
||||
notifees = getattr(network, "notifees", [])
|
||||
target = None
|
||||
for nf in notifees:
|
||||
if isinstance(nf, cast(type, PubsubNotifee)):
|
||||
target = nf
|
||||
break
|
||||
assert target is not None, "PubsubNotifee not found on network"
|
||||
|
||||
async def failing_send(_peer_id): # type: ignore[no-redef]
|
||||
raise trio.BrokenResourceError
|
||||
|
||||
# Make initiator queue send fail; PubsubNotifee should swallow
|
||||
monkeypatch.setattr(target.initiator_peers_queue, "send", failing_send)
|
||||
|
||||
# Connect peers; if exceptions are swallowed, service stays running
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_duplicate_connection_does_not_duplicate_peer_state():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
# Connect again should not add duplicates
|
||||
await connect(p0.host, p1.host)
|
||||
await trio.sleep(0.1)
|
||||
assert list(p0.peers.keys()).count(p1.my_id) == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_blocks_peer_added_by_notifee():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
# Blacklist before connecting
|
||||
p0.add_to_blacklist(p1.my_id)
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Give handler a chance to run
|
||||
await trio.sleep(0.1)
|
||||
assert p1.my_id not in p0.peers
|
||||
99
tests/discovery/random_walk/test_random_walk.py
Normal file
99
tests/discovery/random_walk/test_random_walk.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""
|
||||
Unit tests for the RandomWalk module in libp2p.discovery.random_walk.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host():
|
||||
host = Mock()
|
||||
peerstore = Mock()
|
||||
peerstore.peers_with_addrs.return_value = []
|
||||
peerstore.addrs.return_value = [Mock()]
|
||||
host.get_peerstore.return_value = peerstore
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_query_function():
|
||||
async def query(key_bytes):
|
||||
return []
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_peer_id():
|
||||
return b"\x01" * 32
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_random_walk_initialization(
|
||||
mock_host, dummy_peer_id, dummy_query_function
|
||||
):
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
|
||||
assert rw.host == mock_host
|
||||
assert rw.local_peer_id == dummy_peer_id
|
||||
assert rw.query_function == dummy_query_function
|
||||
|
||||
|
||||
def test_generate_random_peer_id(mock_host, dummy_peer_id, dummy_query_function):
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
|
||||
peer_id = rw.generate_random_peer_id()
|
||||
assert isinstance(peer_id, str)
|
||||
assert len(peer_id) == 64 # 32 bytes hex
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_run_concurrent_random_walks(mock_host, dummy_peer_id):
|
||||
# Dummy query function returns different peer IDs for each walk
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def query(key_bytes):
|
||||
call_count["count"] += 1
|
||||
# Return a unique peer ID for each call
|
||||
return [ID(bytes([call_count["count"]] * 32))]
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.run_concurrent_random_walks(count=3)
|
||||
# Should get 3 unique peers
|
||||
assert len(peers) == 3
|
||||
peer_ids = [peer.peer_id for peer in peers]
|
||||
assert len(set(peer_ids)) == 3
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_perform_random_walk_running(mock_host, dummy_peer_id):
|
||||
# Query function returns a single peer ID
|
||||
async def query(key_bytes):
|
||||
return [ID(b"\x02" * 32)]
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.perform_random_walk()
|
||||
assert isinstance(peers, list)
|
||||
if peers:
|
||||
assert isinstance(peers[0], PeerInfo)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_perform_random_walk_no_peers_found(mock_host, dummy_peer_id):
|
||||
"""Test perform_random_walk when no peers are discovered."""
|
||||
|
||||
# Query function returns empty list (no peers found)
|
||||
async def query(key_bytes):
|
||||
return []
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.perform_random_walk()
|
||||
|
||||
# Should return empty list when no peers are found
|
||||
assert isinstance(peers, list)
|
||||
assert len(peers) == 0
|
||||
451
tests/discovery/random_walk/test_rt_refresh_manager.py
Normal file
451
tests/discovery/random_walk/test_rt_refresh_manager.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""
|
||||
Unit tests for the RTRefreshManager and related random walk logic.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import (
|
||||
RandomWalkError,
|
||||
)
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class DummyRoutingTable:
|
||||
def __init__(self, size=0):
|
||||
self._size = size
|
||||
self.added_peers = []
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
async def add_peer(self, peer_obj):
|
||||
self.added_peers.append(peer_obj)
|
||||
self._size += 1
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host():
|
||||
host = Mock()
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_peer_id():
|
||||
return ID(b"\x01" * 32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_query_function():
|
||||
async def query(key_bytes):
|
||||
return [ID(b"\x02" * 32)]
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_initialization(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
rt = DummyRoutingTable(size=5)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
assert manager.host == mock_host
|
||||
assert manager.routing_table == rt
|
||||
assert manager.local_peer_id == local_peer_id
|
||||
assert manager.query_function == dummy_query_function
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_refresh_logic(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
rt = DummyRoutingTable(size=2)
|
||||
# Simulate refresh logic
|
||||
if rt.size() < MIN_RT_REFRESH_THRESHOLD:
|
||||
await rt.add_peer(Mock())
|
||||
assert rt.size() >= 3
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_random_walk_integration(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
# Simulate random walk usage
|
||||
rw = RandomWalk(mock_host, local_peer_id, dummy_query_function)
|
||||
random_peer_id = rw.generate_random_peer_id()
|
||||
assert isinstance(random_peer_id, str)
|
||||
assert len(random_peer_id) == 64
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_error_handling(mock_host, local_peer_id):
|
||||
rt = DummyRoutingTable(size=0)
|
||||
|
||||
async def failing_query(_):
|
||||
raise RandomWalkError("Query failed")
|
||||
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=failing_query,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
with pytest.raises(RandomWalkError):
|
||||
await manager.query_function(b"key")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_start_method(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the start method functionality."""
|
||||
rt = DummyRoutingTable(size=2)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=False, # Disable auto-refresh to control the test
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
):
|
||||
# Test starting the manager
|
||||
assert not manager._running
|
||||
|
||||
# Start the manager in a nursery that we can control
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager.start)
|
||||
await trio.sleep(0.01) # Let it start
|
||||
|
||||
# Verify it's running
|
||||
assert manager._running
|
||||
|
||||
# Stop the manager
|
||||
await manager.stop()
|
||||
assert not manager._running
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_with_auto_refresh(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the _main_loop method with auto-refresh enabled."""
|
||||
rt = DummyRoutingTable(size=1) # Small size to trigger refresh
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
) as mock_random_walk:
|
||||
manager._running = True
|
||||
|
||||
# Run the main loop for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._main_loop)
|
||||
await trio.sleep(0.05) # Let it run briefly
|
||||
manager._running = False # Stop the loop
|
||||
|
||||
# Verify that random walk was called (initial refresh)
|
||||
mock_random_walk.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_without_auto_refresh(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the _main_loop method with auto-refresh disabled."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=False,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
manager._running = True
|
||||
|
||||
# Run the main loop for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._main_loop)
|
||||
await trio.sleep(0.05)
|
||||
manager._running = False
|
||||
|
||||
# Verify that random walk was not called since auto-refresh is disabled
|
||||
mock_random_walk.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_initial_refresh_exception(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test that _main_loop propagates exceptions from initial refresh."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock _do_refresh to raise an exception on the initial call
|
||||
with patch.object(
|
||||
manager, "_do_refresh", side_effect=Exception("Initial refresh failed")
|
||||
):
|
||||
manager._running = True
|
||||
|
||||
# The initial refresh exception should propagate
|
||||
with pytest.raises(Exception, match="Initial refresh failed"):
|
||||
await manager._main_loop()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_force_refresh(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test _do_refresh method with force=True."""
|
||||
rt = DummyRoutingTable(size=10) # Large size, but force should override
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info1 = Mock(spec=PeerInfo)
|
||||
mock_peer_info2 = Mock(spec=PeerInfo)
|
||||
discovered_peers = [mock_peer_info1, mock_peer_info2]
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=discovered_peers,
|
||||
) as mock_random_walk:
|
||||
# Force refresh should work regardless of RT size
|
||||
await manager._do_refresh(force=True)
|
||||
|
||||
# Verify random walk was called
|
||||
mock_random_walk.assert_called_once_with(
|
||||
count=RANDOM_WALK_CONCURRENCY, current_routing_table_size=10
|
||||
)
|
||||
|
||||
# Verify peers were added to routing table
|
||||
assert len(rt.added_peers) == 2
|
||||
assert manager._last_refresh_time > 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_skip_due_to_interval(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test _do_refresh skips refresh when interval hasn't elapsed."""
|
||||
rt = DummyRoutingTable(size=1) # Small size to trigger refresh normally
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=100.0, # Long interval
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Set last refresh time to recent
|
||||
manager._last_refresh_time = time.time()
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=False)
|
||||
|
||||
# Verify refresh was skipped
|
||||
mock_random_walk.assert_not_called()
|
||||
mock_logger.debug.assert_called_with(
|
||||
"Skipping refresh: interval not elapsed"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_skip_due_to_rt_size(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test _do_refresh skips refresh when RT size is above threshold."""
|
||||
rt = DummyRoutingTable(size=20) # Large size above threshold
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1, # Short interval
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Set last refresh time to old
|
||||
manager._last_refresh_time = 0.0
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=False)
|
||||
|
||||
# Verify refresh was skipped
|
||||
mock_random_walk.assert_not_called()
|
||||
mock_logger.debug.assert_called_with(
|
||||
"Skipping refresh: routing table size above threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_refresh_done_callbacks(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test refresh completion callbacks functionality."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Create mock callbacks
|
||||
callback1 = Mock()
|
||||
callback2 = Mock()
|
||||
failing_callback = Mock(side_effect=Exception("Callback failed"))
|
||||
|
||||
# Add callbacks
|
||||
manager.add_refresh_done_callback(callback1)
|
||||
manager.add_refresh_done_callback(callback2)
|
||||
manager.add_refresh_done_callback(failing_callback)
|
||||
|
||||
# Mock the random walk
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
):
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=True)
|
||||
|
||||
# Verify all callbacks were called
|
||||
callback1.assert_called_once()
|
||||
callback2.assert_called_once()
|
||||
failing_callback.assert_called_once()
|
||||
|
||||
# Verify warning was logged for failing callback
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stop_when_not_running(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test stop method when manager is not running."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Manager is not running
|
||||
assert not manager._running
|
||||
|
||||
# Stop should return without doing anything
|
||||
await manager.stop()
|
||||
assert not manager._running
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_periodic_refresh_task(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test the _periodic_refresh_task method."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.05, # Very short interval for testing
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock _do_refresh to track calls
|
||||
with patch.object(manager, "_do_refresh") as mock_do_refresh:
|
||||
manager._running = True
|
||||
|
||||
# Run periodic refresh task for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._periodic_refresh_task)
|
||||
await trio.sleep(0.12) # Let it run for ~2 intervals
|
||||
manager._running = False # Stop the task
|
||||
|
||||
# Verify _do_refresh was called at least once
|
||||
assert mock_do_refresh.call_count >= 1
|
||||
109
tests/examples/test_echo_thin_waist.py
Normal file
109
tests/examples/test_echo_thin_waist.py
Normal file
@ -0,0 +1,109 @@
|
||||
import contextlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP
|
||||
|
||||
# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging
|
||||
|
||||
# This test is intentionally lightweight and can be marked as 'integration'.
|
||||
# It ensures the echo example runs and prints the new Thin Waist lines using
|
||||
# Trio primitives.
|
||||
|
||||
current_file = Path(__file__)
|
||||
project_root = current_file.parent.parent.parent
|
||||
EXAMPLES_DIR: Path = project_root / "examples" / "echo"
|
||||
|
||||
|
||||
def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path):
|
||||
"""Run echo server and validate printed multiaddr and peer id."""
|
||||
# Run echo example as server
|
||||
cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"]
|
||||
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||
proc: subprocess.Popen[str] = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if proc.stdout is None:
|
||||
proc.terminate()
|
||||
raise RuntimeError("Process stdout is None")
|
||||
out_stream = proc.stdout
|
||||
|
||||
peer_id: str | None = None
|
||||
printed_multiaddr: str | None = None
|
||||
saw_waiting = False
|
||||
|
||||
start = time.time()
|
||||
timeout_s = 8.0
|
||||
try:
|
||||
while time.time() - start < timeout_s:
|
||||
line = out_stream.readline()
|
||||
if not line:
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
s = line.strip()
|
||||
if s.startswith("I am "):
|
||||
peer_id = s.partition("I am ")[2]
|
||||
if s.startswith("echo-demo -d "):
|
||||
printed_multiaddr = s.partition("echo-demo -d ")[2]
|
||||
if "Waiting for incoming connections..." in s:
|
||||
saw_waiting = True
|
||||
break
|
||||
finally:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.terminate()
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.kill()
|
||||
|
||||
assert peer_id, "Did not capture peer ID line"
|
||||
assert printed_multiaddr, "Did not capture multiaddr line"
|
||||
assert saw_waiting, "Did not capture waiting-for-connections line"
|
||||
|
||||
# Validate multiaddr structure using py-multiaddr protocol methods
|
||||
ma = Multiaddr(printed_multiaddr) # should parse without error
|
||||
|
||||
# Check that the multiaddr contains the p2p protocol
|
||||
try:
|
||||
peer_id_from_multiaddr = ma.value_for_protocol("p2p")
|
||||
assert peer_id_from_multiaddr is not None, (
|
||||
"Multiaddr missing p2p protocol value"
|
||||
)
|
||||
assert peer_id_from_multiaddr == peer_id, (
|
||||
f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Failed to extract p2p protocol value: {e}")
|
||||
|
||||
# Validate the multiaddr structure by checking protocols
|
||||
protocols = ma.protocols()
|
||||
|
||||
# Should have at least IP, TCP, and P2P protocols
|
||||
assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), (
|
||||
"Missing IP protocol"
|
||||
)
|
||||
assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol"
|
||||
assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol"
|
||||
|
||||
# Extract the p2p part and validate it matches the captured peer ID
|
||||
p2p_part = Multiaddr(f"/p2p/{peer_id}")
|
||||
try:
|
||||
# Decapsulate the p2p part to get the transport address
|
||||
transport_addr = ma.decapsulate(p2p_part)
|
||||
# Verify the decapsulated address doesn't contain p2p
|
||||
transport_protocols = transport_addr.protocols()
|
||||
assert not any(p.code == P_P2P for p in transport_protocols), (
|
||||
"Decapsulation failed - still contains p2p"
|
||||
)
|
||||
# Verify the original multiaddr can be reconstructed
|
||||
reconstructed = transport_addr.encapsulate(p2p_part)
|
||||
assert str(reconstructed) == str(ma), "Reconstruction failed"
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Multiaddr decapsulation failed: {e}")
|
||||
56
tests/utils/test_address_validation.py
Normal file
56
tests/utils/test_address_validation.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.utils.address_validation import (
|
||||
expand_wildcard_address,
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("proto", ["tcp"])
|
||||
def test_get_available_interfaces(proto: str) -> None:
|
||||
interfaces = get_available_interfaces(0, protocol=proto)
|
||||
assert len(interfaces) > 0
|
||||
for addr in interfaces:
|
||||
assert isinstance(addr, Multiaddr)
|
||||
assert f"/{proto}/" in str(addr)
|
||||
|
||||
|
||||
def test_get_optimal_binding_address() -> None:
|
||||
addr = get_optimal_binding_address(0)
|
||||
assert isinstance(addr, Multiaddr)
|
||||
# At least IPv4 or IPv6 prefix present
|
||||
s = str(addr)
|
||||
assert ("/ip4/" in s) or ("/ip6/" in s)
|
||||
|
||||
|
||||
def test_expand_wildcard_address_ipv4() -> None:
|
||||
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0")
|
||||
expanded = expand_wildcard_address(wildcard)
|
||||
assert len(expanded) > 0
|
||||
for e in expanded:
|
||||
assert isinstance(e, Multiaddr)
|
||||
assert "/tcp/" in str(e)
|
||||
|
||||
|
||||
def test_expand_wildcard_address_port_override() -> None:
|
||||
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000")
|
||||
overridden = expand_wildcard_address(wildcard, port=9001)
|
||||
assert len(overridden) > 0
|
||||
for e in overridden:
|
||||
assert str(e).endswith("/tcp/9001")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("NO_IPV6") == "1",
|
||||
reason="Environment disallows IPv6",
|
||||
)
|
||||
def test_expand_wildcard_address_ipv6() -> None:
|
||||
wildcard = Multiaddr("/ip6/::/tcp/0")
|
||||
expanded = expand_wildcard_address(wildcard)
|
||||
assert len(expanded) > 0
|
||||
for e in expanded:
|
||||
assert "/ip6/" in str(e)
|
||||
Reference in New Issue
Block a user