diff --git a/docs/examples.random_walk.rst b/docs/examples.random_walk.rst new file mode 100644 index 00000000..baa3f81f --- /dev/null +++ b/docs/examples.random_walk.rst @@ -0,0 +1,131 @@ +Random Walk Example +=================== + +This example demonstrates the Random Walk module's peer discovery capabilities using real libp2p hosts and Kademlia DHT. +It shows how the Random Walk module automatically discovers new peers and maintains routing table health. + +The Random Walk implementation performs the following key operations: + +* **Automatic Peer Discovery**: Generates random peer IDs and queries the DHT network to discover new peers +* **Routing Table Maintenance**: Periodically refreshes the routing table to maintain network connectivity +* **Connection Management**: Maintains optimal connections to healthy peers in the network +* **Real-time Statistics**: Displays routing table size, connected peers, and peerstore statistics + +.. code-block:: console + + $ python -m pip install libp2p + Collecting libp2p + ... + Successfully installed libp2p-x.x.x + $ cd examples/random_walk + $ python random_walk.py --mode server + 2025-08-12 19:51:25,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p === + 2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s + 2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123 + 2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef + 2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef + 2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0 + 2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode + 2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started + 2025-08-12 19:51:55,432 - random-walk-example - INFO - --- Iteration 1 --- + 2025-08-12 19:51:55,432 - random-walk-example - INFO - Routing table size: 15 + 2025-08-12 19:51:55,432 - random-walk-example - INFO - Connected peers: 8 + 2025-08-12 19:51:55,432 - random-walk-example - INFO - Peerstore size: 42 + +You can also run the example in client mode: + +.. code-block:: console + + $ python random_walk.py --mode client + 2025-08-12 19:52:15,424 - random-walk-example - INFO - === Random Walk Example for py-libp2p === + 2025-08-12 19:52:15,424 - random-walk-example - INFO - Mode: client, Port: 0 Demo interval: 30s + 2025-08-12 19:52:15,426 - random-walk-example - INFO - Starting client node on port 51234 + 2025-08-12 19:52:15,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAmAbc123xyz... + 2025-08-12 19:52:15,427 - random-walk-example - INFO - DHT service started in CLIENT mode + 2025-08-12 19:52:45,432 - random-walk-example - INFO - --- Iteration 1 --- + 2025-08-12 19:52:45,432 - random-walk-example - INFO - Routing table size: 8 + 2025-08-12 19:52:45,432 - random-walk-example - INFO - Connected peers: 5 + 2025-08-12 19:52:45,432 - random-walk-example - INFO - Peerstore size: 25 + +Command Line Options +-------------------- + +The example supports several command-line options: + +.. code-block:: console + + $ python random_walk.py --help + usage: random_walk.py [-h] [--mode {server,client}] [--port PORT] + [--demo-interval DEMO_INTERVAL] [--verbose] + + Random Walk Example for py-libp2p Kademlia DHT + + optional arguments: + -h, --help show this help message and exit + --mode {server,client} + Node mode: server (DHT server), or client (DHT client) + --port PORT Port to listen on (0 for random) + --demo-interval DEMO_INTERVAL + Interval between random walk demonstrations in seconds + --verbose Enable verbose logging + +Key Features Demonstrated +------------------------- + +**Automatic Random Walk Discovery** + The example shows how the Random Walk module automatically: + + * Generates random 256-bit peer IDs for discovery queries + * Performs concurrent random walks to maximize peer discovery + * Validates discovered peers and adds them to the routing table + * Maintains routing table health through periodic refreshes + +**Real-time Network Statistics** + The example displays live statistics every 30 seconds (configurable): + + * **Routing Table Size**: Number of peers in the Kademlia routing table + * **Connected Peers**: Number of actively connected peers + * **Peerstore Size**: Total number of known peers with addresses + +**Connection Management** + The example includes sophisticated connection management: + + * Automatically maintains connections to healthy peers + * Filters for compatible peers (TCP + IPv4 addresses) + * Reconnects to maintain optimal network connectivity + * Handles connection failures gracefully + +**DHT Integration** + Shows seamless integration between Random Walk and Kademlia DHT: + + * RT Refresh Manager coordinates with the DHT routing table + * Peer discovery feeds directly into DHT operations + * Both SERVER and CLIENT modes supported + * Bootstrap connectivity to public IPFS nodes + +Understanding the Output +------------------------ + +When you run the example, you'll see periodic statistics that show how the Random Walk module is working: + +* **Initial Phase**: Routing table starts empty and quickly discovers peers +* **Growth Phase**: Routing table size increases as more peers are discovered +* **Maintenance Phase**: Routing table size stabilizes as the system maintains optimal peer connections + +The Random Walk module runs automatically in the background, performing peer discovery queries every few minutes to ensure the routing table remains populated with fresh, reachable peers. + +Configuration +------------- + +The Random Walk module can be configured through the following parameters in ``libp2p.discovery.random_walk.config``: + +* ``RANDOM_WALK_ENABLED``: Enable/disable automatic random walks (default: True) +* ``REFRESH_INTERVAL``: Time between automatic refreshes in seconds (default: 300) +* ``RANDOM_WALK_CONCURRENCY``: Number of concurrent random walks (default: 3) +* ``MIN_RT_REFRESH_THRESHOLD``: Minimum routing table size before triggering refresh (default: 4) + +See Also +-------- + +* :doc:`examples.kademlia` - Kademlia DHT value storage and content routing +* :doc:`libp2p.discovery.random_walk` - Random Walk module API documentation diff --git a/docs/examples.rst b/docs/examples.rst index b20d0e63..b8ba44d7 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -14,3 +14,4 @@ Examples examples.circuit_relay examples.kademlia examples.mDNS + examples.random_walk diff --git a/docs/libp2p.discovery.random_walk.rst b/docs/libp2p.discovery.random_walk.rst new file mode 100644 index 00000000..1cd7702c --- /dev/null +++ b/docs/libp2p.discovery.random_walk.rst @@ -0,0 +1,48 @@ +libp2p.discovery.random_walk package +==================================== + +The Random Walk module implements a peer discovery mechanism. +It performs random walks through the DHT network to discover new peers and maintain routing table health through periodic refreshes. + +Submodules +---------- + +libp2p.discovery.random_walk.config module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: libp2p.discovery.random_walk.config + :members: + :undoc-members: + :show-inheritance: + +libp2p.discovery.random_walk.exceptions module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: libp2p.discovery.random_walk.exceptions + :members: + :undoc-members: + :show-inheritance: + +libp2p.discovery.random_walk.random_walk module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: libp2p.discovery.random_walk.random_walk + :members: + :undoc-members: + :show-inheritance: + +libp2p.discovery.random_walk.rt_refresh_manager module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: libp2p.discovery.random_walk.rt_refresh_manager + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.discovery.random_walk + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.discovery.rst b/docs/libp2p.discovery.rst index 508ca059..4b812088 100644 --- a/docs/libp2p.discovery.rst +++ b/docs/libp2p.discovery.rst @@ -10,6 +10,7 @@ Subpackages libp2p.discovery.bootstrap libp2p.discovery.events libp2p.discovery.mdns + libp2p.discovery.random_walk Submodules ---------- diff --git a/examples/advanced/network_discover.py b/examples/advanced/network_discover.py new file mode 100644 index 00000000..87b44ddf --- /dev/null +++ b/examples/advanced/network_discover.py @@ -0,0 +1,63 @@ +""" +Advanced demonstration of Thin Waist address handling. + +Run: + python -m examples.advanced.network_discovery +""" + +from __future__ import annotations + +from multiaddr import Multiaddr + +try: + from libp2p.utils.address_validation import ( + expand_wildcard_address, + get_available_interfaces, + get_optimal_binding_address, + ) +except ImportError: + # Fallbacks if utilities are missing + def get_available_interfaces(port: int, protocol: str = "tcp"): + return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")] + + def expand_wildcard_address(addr: Multiaddr, port: int | None = None): + if port is None: + return [addr] + addr_str = str(addr).rsplit("/", 1)[0] + return [Multiaddr(addr_str + f"/{port}")] + + def get_optimal_binding_address(port: int, protocol: str = "tcp"): + return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + + +def main() -> None: + port = 8080 + interfaces = get_available_interfaces(port) + print(f"Discovered interfaces for port {port}:") + for a in interfaces: + print(f" - {a}") + + wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + expanded_v4 = expand_wildcard_address(wildcard_v4) + print("\nExpanded IPv4 wildcard:") + for a in expanded_v4: + print(f" - {a}") + + wildcard_v6 = Multiaddr(f"/ip6/::/tcp/{port}") + expanded_v6 = expand_wildcard_address(wildcard_v6) + print("\nExpanded IPv6 wildcard:") + for a in expanded_v6: + print(f" - {a}") + + print("\nOptimal binding address heuristic result:") + print(f" -> {get_optimal_binding_address(port)}") + + override_port = 9000 + overridden = expand_wildcard_address(wildcard_v4, port=override_port) + print(f"\nPort override expansion to {override_port}:") + for a in overridden: + print(f" - {a}") + + +if __name__ == "__main__": + main() diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 126a7da2..19e98377 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -1,4 +1,6 @@ import argparse +import random +import secrets import multiaddr import trio @@ -12,40 +14,54 @@ from libp2p.crypto.secp256k1 import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.network.stream.exceptions import ( + StreamEOF, +) from libp2p.network.stream.net_stream import ( INetStream, ) from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, +) PROTOCOL_ID = TProtocol("/echo/1.0.0") MAX_READ_LEN = 2**32 - 1 async def _echo_stream_handler(stream: INetStream) -> None: - # Wait until EOF - msg = await stream.read(MAX_READ_LEN) - await stream.write(msg) - await stream.close() + try: + peer_id = stream.muxed_conn.peer_id + print(f"Received connection from {peer_id}") + # Wait until EOF + msg = await stream.read(MAX_READ_LEN) + print(f"Echoing message: {msg.decode('utf-8')}") + await stream.write(msg) + except StreamEOF: + print("Stream closed by remote peer.") + except Exception as e: + print(f"Error in echo handler: {e}") + finally: + await stream.close() async def run(port: int, destination: str, seed: int | None = None) -> None: - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + if port <= 0: + port = find_free_port() + listen_addr = get_available_interfaces(port) if seed: - import random - random.seed(seed) secret_number = random.getrandbits(32 * 8) secret = secret_number.to_bytes(length=32, byteorder="big") else: - import secrets - secret = secrets.token_bytes(32) host = new_host(key_pair=create_new_key_pair(secret)) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with host.run(listen_addrs=listen_addr), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) @@ -54,10 +70,15 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: if not destination: # its the server host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + # Print all listen addresses with peer ID (JS parity) + print("Listener ready, listening on:\n") + peer_id = host.get_id().to_string() + for addr in listen_addr: + print(f"{addr}/p2p/{peer_id}") + print( - "Run this from the same folder in another console:\n\n" - f"echo-demo " - f"-d {host.get_addrs()[0]}\n" + "\nRun this from the same folder in another console:\n\n" + f"echo-demo -d {host.get_addrs()[0]}\n" ) print("Waiting for incoming connections...") await trio.sleep_forever() diff --git a/examples/kademlia/kademlia.py b/examples/kademlia/kademlia.py index 00c7915a..5daa70d7 100644 --- a/examples/kademlia/kademlia.py +++ b/examples/kademlia/kademlia.py @@ -227,7 +227,7 @@ async def run_node( # Keep the node running while True: - logger.debug( + logger.info( "Status - Connected peers: %d," "Peers in store: %d, Values in store: %d", len(dht.host.get_connected_peers()), diff --git a/examples/pubsub/pubsub.py b/examples/pubsub/pubsub.py index 1ab6d650..41545658 100644 --- a/examples/pubsub/pubsub.py +++ b/examples/pubsub/pubsub.py @@ -1,6 +1,5 @@ import argparse import logging -import socket import base58 import multiaddr @@ -31,6 +30,9 @@ from libp2p.stream_muxer.mplex.mplex import ( from libp2p.tools.async_service.trio_service import ( background_trio_service, ) +from libp2p.utils.address_validation import ( + find_free_port, +) # Configure logging logging.basicConfig( @@ -77,13 +79,6 @@ async def publish_loop(pubsub, topic, termination_event): await trio.sleep(1) # Avoid tight loop on error -def find_free_port(): - """Find a free port on localhost.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # Bind to a free port provided by the OS - return s.getsockname()[1] - - async def monitor_peer_topics(pubsub, nursery, termination_event): """ Monitor for new topics that peers are subscribed to and diff --git a/examples/random_walk/random_walk.py b/examples/random_walk/random_walk.py new file mode 100644 index 00000000..845ccd57 --- /dev/null +++ b/examples/random_walk/random_walk.py @@ -0,0 +1,221 @@ +""" +Random Walk Example for py-libp2p Kademlia DHT + +This example demonstrates the Random Walk module's peer discovery capabilities +using real libp2p hosts and Kademlia DHT. It shows how the Random Walk module +automatically discovers new peers and maintains routing table health. + +Usage: + # Start server nodes (they will discover peers via random walk) + python3 random_walk.py --mode server +""" + +import argparse +import logging +import random +import secrets +import sys + +from multiaddr import Multiaddr +import trio + +from libp2p import new_host +from libp2p.abc import IHost +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.kad_dht.kad_dht import DHTMode, KadDHT +from libp2p.tools.async_service import background_trio_service + + +# Simple logging configuration +def setup_logging(verbose: bool = False): + """Setup unified logging configuration.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], + ) + + # Configure key module loggers + for module in ["libp2p.discovery.random_walk", "libp2p.kad_dht"]: + logging.getLogger(module).setLevel(level) + + # Suppress noisy logs + logging.getLogger("multiaddr").setLevel(logging.WARNING) + + +logger = logging.getLogger("random-walk-example") + +# Default bootstrap nodes +DEFAULT_BOOTSTRAP_NODES = [ + "/ip4/104.131.131.82/tcp/4001/p2p/QmaCpDMGvV2BGHeYERUEnRQAwe3N8SzbUtfsmvsqQLuvuJ" +] + + +def filter_compatible_peer_info(peer_info) -> bool: + """Filter peer info to check if it has compatible addresses (TCP + IPv4).""" + if not hasattr(peer_info, "addrs") or not peer_info.addrs: + return False + + for addr in peer_info.addrs: + addr_str = str(addr) + if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str: + return True + return False + + +async def maintain_connections(host: IHost) -> None: + """Maintain connections to ensure the host remains connected to healthy peers.""" + while True: + try: + connected_peers = host.get_connected_peers() + list_peers = host.get_peerstore().peers_with_addrs() + + if len(connected_peers) < 20: + logger.debug("Reconnecting to maintain peer connections...") + + # Find compatible peers + compatible_peers = [] + for peer_id in list_peers: + try: + peer_info = host.get_peerstore().peer_info(peer_id) + if filter_compatible_peer_info(peer_info): + compatible_peers.append(peer_id) + except Exception: + continue + + # Connect to random subset of compatible peers + if compatible_peers: + random_peers = random.sample( + compatible_peers, min(50, len(compatible_peers)) + ) + for peer_id in random_peers: + if peer_id not in connected_peers: + try: + with trio.move_on_after(5): + peer_info = host.get_peerstore().peer_info(peer_id) + await host.connect(peer_info) + logger.debug(f"Connected to peer: {peer_id}") + except Exception as e: + logger.debug(f"Failed to connect to {peer_id}: {e}") + + await trio.sleep(15) + except Exception as e: + logger.error(f"Error maintaining connections: {e}") + + +async def demonstrate_random_walk_discovery(dht: KadDHT, interval: int = 30) -> None: + """Demonstrate Random Walk peer discovery with periodic statistics.""" + iteration = 0 + while True: + iteration += 1 + logger.info(f"--- Iteration {iteration} ---") + logger.info(f"Routing table size: {dht.get_routing_table_size()}") + logger.info(f"Connected peers: {len(dht.host.get_connected_peers())}") + logger.info(f"Peerstore size: {len(dht.host.get_peerstore().peer_ids())}") + await trio.sleep(interval) + + +async def run_node(port: int, mode: str, demo_interval: int = 30) -> None: + """Run a node that demonstrates Random Walk peer discovery.""" + try: + if port <= 0: + port = random.randint(10000, 60000) + + logger.info(f"Starting {mode} node on port {port}") + + # Determine DHT mode + dht_mode = DHTMode.SERVER if mode == "server" else DHTMode.CLIENT + + # Create host and DHT + key_pair = create_new_key_pair(secrets.token_bytes(32)) + host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES) + listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + + async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + # Start maintenance tasks + nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + nursery.start_soon(maintain_connections, host) + + peer_id = host.get_id().pretty() + logger.info(f"Node peer ID: {peer_id}") + logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}") + + # Create and start DHT with Random Walk enabled + dht = KadDHT(host, dht_mode, enable_random_walk=True) + logger.info(f"Initial routing table size: {dht.get_routing_table_size()}") + + async with background_trio_service(dht): + logger.info(f"DHT service started in {dht_mode.value} mode") + logger.info(f"Random Walk enabled: {dht.is_random_walk_enabled()}") + + async with trio.open_nursery() as task_nursery: + # Start demonstration and status reporting + task_nursery.start_soon( + demonstrate_random_walk_discovery, dht, demo_interval + ) + + # Periodic status updates + async def status_reporter(): + while True: + await trio.sleep(30) + logger.debug( + f"Connected: {len(dht.host.get_connected_peers())}, " + f"Routing table: {dht.get_routing_table_size()}, " + f"Peerstore: {len(dht.host.get_peerstore().peer_ids())}" + ) + + task_nursery.start_soon(status_reporter) + await trio.sleep_forever() + + except Exception as e: + logger.error(f"Node error: {e}", exc_info=True) + sys.exit(1) + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Random Walk Example for py-libp2p Kademlia DHT", + ) + parser.add_argument( + "--mode", + choices=["server", "client"], + default="server", + help="Node mode: server (DHT server), or client (DHT client)", + ) + parser.add_argument( + "--port", type=int, default=0, help="Port to listen on (0 for random)" + ) + parser.add_argument( + "--demo-interval", + type=int, + default=30, + help="Interval between random walk demonstrations in seconds", + ) + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + return parser.parse_args() + + +def main(): + """Main entry point for the random walk example.""" + try: + args = parse_args() + setup_logging(args.verbose) + + logger.info("=== Random Walk Example for py-libp2p ===") + logger.info( + f"Mode: {args.mode}, Port: {args.port} Demo interval: {args.demo_interval}s" + ) + + trio.run(run_node, args.port, args.mode, args.demo_interval) + + except KeyboardInterrupt: + logger.info("Received interrupt signal, shutting down...") + except Exception as e: + logger.critical(f"Example failed: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/libp2p/discovery/random_walk/__init__.py b/libp2p/discovery/random_walk/__init__.py new file mode 100644 index 00000000..0b455afd --- /dev/null +++ b/libp2p/discovery/random_walk/__init__.py @@ -0,0 +1,17 @@ +"""Random walk discovery modules for py-libp2p.""" + +from .rt_refresh_manager import RTRefreshManager +from .random_walk import RandomWalk +from .exceptions import ( + RoutingTableRefreshError, + RandomWalkError, + PeerValidationError, +) + +__all__ = [ + "RTRefreshManager", + "RandomWalk", + "RoutingTableRefreshError", + "RandomWalkError", + "PeerValidationError", +] diff --git a/libp2p/discovery/random_walk/config.py b/libp2p/discovery/random_walk/config.py new file mode 100644 index 00000000..4a2e9d56 --- /dev/null +++ b/libp2p/discovery/random_walk/config.py @@ -0,0 +1,16 @@ +from typing import Final + +# Timing constants (matching go-libp2p) +PEER_PING_TIMEOUT: Final[float] = 10.0 # seconds +REFRESH_QUERY_TIMEOUT: Final[float] = 60.0 # seconds +REFRESH_INTERVAL: Final[float] = 300.0 # 5 minutes +SUCCESSFUL_OUTBOUND_QUERY_GRACE_PERIOD: Final[float] = 60.0 # 1 minute + +# Routing table thresholds +MIN_RT_REFRESH_THRESHOLD: Final[int] = 4 # Minimum peers before triggering refresh +MAX_N_BOOTSTRAPPERS: Final[int] = 2 # Maximum bootstrap peers to try + +# Random walk specific +RANDOM_WALK_CONCURRENCY: Final[int] = 3 # Number of concurrent random walks +RANDOM_WALK_ENABLED: Final[bool] = True # Enable automatic random walks +RANDOM_WALK_RT_THRESHOLD: Final[int] = 20 # RT size threshold for peerstore fallback diff --git a/libp2p/discovery/random_walk/exceptions.py b/libp2p/discovery/random_walk/exceptions.py new file mode 100644 index 00000000..28325619 --- /dev/null +++ b/libp2p/discovery/random_walk/exceptions.py @@ -0,0 +1,19 @@ +from libp2p.exceptions import BaseLibp2pError + + +class RoutingTableRefreshError(BaseLibp2pError): + """Base exception for routing table refresh operations.""" + + pass + + +class RandomWalkError(RoutingTableRefreshError): + """Exception raised during random walk operations.""" + + pass + + +class PeerValidationError(RoutingTableRefreshError): + """Exception raised when peer validation fails.""" + + pass diff --git a/libp2p/discovery/random_walk/random_walk.py b/libp2p/discovery/random_walk/random_walk.py new file mode 100644 index 00000000..e1b2ae17 --- /dev/null +++ b/libp2p/discovery/random_walk/random_walk.py @@ -0,0 +1,218 @@ +from collections.abc import Awaitable, Callable +import logging +import secrets + +import trio + +from libp2p.abc import IHost +from libp2p.discovery.random_walk.config import ( + RANDOM_WALK_CONCURRENCY, + RANDOM_WALK_RT_THRESHOLD, + REFRESH_QUERY_TIMEOUT, +) +from libp2p.discovery.random_walk.exceptions import RandomWalkError +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo + +logger = logging.getLogger("libp2p.discovery.random_walk") + + +class RandomWalk: + """ + Random Walk implementation for peer discovery in Kademlia DHT. + + Generates random peer IDs and performs FIND_NODE queries to discover + new peers and populate the routing table. + """ + + def __init__( + self, + host: IHost, + local_peer_id: ID, + query_function: Callable[[bytes], Awaitable[list[ID]]], + ): + """ + Initialize Random Walk module. + + Args: + host: The libp2p host instance + local_peer_id: Local peer ID + query_function: Function to query for closest peers given target key bytes + + """ + self.host = host + self.local_peer_id = local_peer_id + self.query_function = query_function + + def generate_random_peer_id(self) -> str: + """ + Generate a completely random peer ID + for random walk queries. + + Returns: + Random peer ID as string + + """ + # Generate 32 random bytes (256 bits) - same as go-libp2p + random_bytes = secrets.token_bytes(32) + # Convert to hex string for query + return random_bytes.hex() + + async def perform_random_walk(self) -> list[PeerInfo]: + """ + Perform a single random walk operation. + + Returns: + List of validated peers discovered during the walk + + """ + try: + # Generate random peer ID + random_peer_id = self.generate_random_peer_id() + logger.info(f"Starting random walk for peer ID: {random_peer_id}") + + # Perform FIND_NODE query + discovered_peer_ids: list[ID] = [] + + with trio.move_on_after(REFRESH_QUERY_TIMEOUT): + # Call the query function with target key bytes + target_key = bytes.fromhex(random_peer_id) + discovered_peer_ids = await self.query_function(target_key) or [] + + if not discovered_peer_ids: + logger.debug(f"No peers discovered in random walk for {random_peer_id}") + return [] + + logger.info( + f"Discovered {len(discovered_peer_ids)} peers in random walk " + f"for {random_peer_id[:8]}..." # Show only first 8 chars for brevity + ) + + # Convert peer IDs to PeerInfo objects and validate + validated_peers: list[PeerInfo] = [] + + for peer_id in discovered_peer_ids: + try: + # Get addresses from peerstore + addrs = self.host.get_peerstore().addrs(peer_id) + if addrs: + peer_info = PeerInfo(peer_id, addrs) + validated_peers.append(peer_info) + except Exception as e: + logger.debug(f"Failed to create PeerInfo for {peer_id}: {e}") + continue + + return validated_peers + + except Exception as e: + logger.error(f"Random walk failed: {e}") + raise RandomWalkError(f"Random walk operation failed: {e}") from e + + async def run_concurrent_random_walks( + self, count: int = RANDOM_WALK_CONCURRENCY, current_routing_table_size: int = 0 + ) -> list[PeerInfo]: + """ + Run multiple random walks concurrently. + + Args: + count: Number of concurrent random walks to perform + current_routing_table_size: Current size of routing table (for optimization) + + Returns: + Combined list of all validated peers discovered + + """ + all_validated_peers: list[PeerInfo] = [] + logger.info(f"Starting {count} concurrent random walks") + + # First, try to add peers from peerstore if routing table is small + if current_routing_table_size < RANDOM_WALK_RT_THRESHOLD: + try: + peerstore_peers = self._get_peerstore_peers() + if peerstore_peers: + logger.debug( + f"RT size ({current_routing_table_size}) below threshold, " + f"adding {len(peerstore_peers)} peerstore peers" + ) + all_validated_peers.extend(peerstore_peers) + except Exception as e: + logger.warning(f"Error processing peerstore peers: {e}") + + async def single_walk() -> None: + try: + peers = await self.perform_random_walk() + all_validated_peers.extend(peers) + except Exception as e: + logger.warning(f"Concurrent random walk failed: {e}") + return + + # Run concurrent random walks + async with trio.open_nursery() as nursery: + for _ in range(count): + nursery.start_soon(single_walk) + + # Remove duplicates based on peer ID + unique_peers = {} + for peer in all_validated_peers: + unique_peers[peer.peer_id] = peer + + result = list(unique_peers.values()) + logger.info( + f"Concurrent random walks completed: {len(result)} unique peers discovered" + ) + return result + + def _get_peerstore_peers(self) -> list[PeerInfo]: + """ + Get peer info objects from the host's peerstore. + + Returns: + List of PeerInfo objects from peerstore + + """ + try: + peerstore = self.host.get_peerstore() + peer_ids = peerstore.peers_with_addrs() + + peer_infos = [] + for peer_id in peer_ids: + try: + # Skip local peer + if peer_id == self.local_peer_id: + continue + + peer_info = peerstore.peer_info(peer_id) + if peer_info and peer_info.addrs: + # Filter for compatible addresses (TCP + IPv4) + if self._has_compatible_addresses(peer_info): + peer_infos.append(peer_info) + except Exception as e: + logger.debug(f"Error getting peer info for {peer_id}: {e}") + + return peer_infos + + except Exception as e: + logger.warning(f"Error accessing peerstore: {e}") + return [] + + def _has_compatible_addresses(self, peer_info: PeerInfo) -> bool: + """ + Check if a peer has TCP+IPv4 compatible addresses. + + Args: + peer_info: PeerInfo to check + + Returns: + True if peer has compatible addresses + + """ + if not peer_info.addrs: + return False + + for addr in peer_info.addrs: + addr_str = str(addr) + # Check for TCP and IPv4 compatibility, avoid QUIC + if "/tcp/" in addr_str and "/ip4/" in addr_str and "/quic" not in addr_str: + return True + + return False diff --git a/libp2p/discovery/random_walk/rt_refresh_manager.py b/libp2p/discovery/random_walk/rt_refresh_manager.py new file mode 100644 index 00000000..7ed63cbd --- /dev/null +++ b/libp2p/discovery/random_walk/rt_refresh_manager.py @@ -0,0 +1,208 @@ +from collections.abc import Awaitable, Callable +import logging +import time +from typing import Protocol + +import trio + +from libp2p.abc import IHost +from libp2p.discovery.random_walk.config import ( + MIN_RT_REFRESH_THRESHOLD, + RANDOM_WALK_CONCURRENCY, + RANDOM_WALK_ENABLED, + REFRESH_INTERVAL, +) +from libp2p.discovery.random_walk.exceptions import RoutingTableRefreshError +from libp2p.discovery.random_walk.random_walk import RandomWalk +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo + + +class RoutingTableProtocol(Protocol): + """Protocol for routing table operations needed by RT refresh manager.""" + + def size(self) -> int: + """Return the current size of the routing table.""" + ... + + async def add_peer(self, peer_obj: PeerInfo) -> bool: + """Add a peer to the routing table.""" + ... + + +logger = logging.getLogger("libp2p.discovery.random_walk.rt_refresh_manager") + + +class RTRefreshManager: + """ + Routing Table Refresh Manager for py-libp2p. + + Manages periodic routing table refreshes and random walk operations + to maintain routing table health and discover new peers. + """ + + def __init__( + self, + host: IHost, + routing_table: RoutingTableProtocol, + local_peer_id: ID, + query_function: Callable[[bytes], Awaitable[list[ID]]], + enable_auto_refresh: bool = RANDOM_WALK_ENABLED, + refresh_interval: float = REFRESH_INTERVAL, + min_refresh_threshold: int = MIN_RT_REFRESH_THRESHOLD, + ): + """ + Initialize RT Refresh Manager. + + Args: + host: The libp2p host instance + routing_table: Routing table of host + local_peer_id: Local peer ID + query_function: Function to query for closest peers given target key bytes + enable_auto_refresh: Whether to enable automatic refresh + refresh_interval: Interval between refreshes in seconds + min_refresh_threshold: Minimum RT size before triggering refresh + + """ + self.host = host + self.routing_table = routing_table + self.local_peer_id = local_peer_id + self.query_function = query_function + + self.enable_auto_refresh = enable_auto_refresh + self.refresh_interval = refresh_interval + self.min_refresh_threshold = min_refresh_threshold + + # Initialize random walk module + self.random_walk = RandomWalk( + host=host, + local_peer_id=self.local_peer_id, + query_function=query_function, + ) + + # Control variables + self._running = False + self._nursery: trio.Nursery | None = None + + # Tracking + self._last_refresh_time = 0.0 + self._refresh_done_callbacks: list[Callable[[], None]] = [] + + async def start(self) -> None: + """Start the RT Refresh Manager.""" + if self._running: + logger.warning("RT Refresh Manager is already running") + return + + self._running = True + + logger.info("Starting RT Refresh Manager") + + # Start the main loop + async with trio.open_nursery() as nursery: + self._nursery = nursery + nursery.start_soon(self._main_loop) + + async def stop(self) -> None: + """Stop the RT Refresh Manager.""" + if not self._running: + return + + logger.info("Stopping RT Refresh Manager") + self._running = False + + async def _main_loop(self) -> None: + """Main loop for the RT Refresh Manager.""" + logger.info("RT Refresh Manager main loop started") + + # Initial refresh if auto-refresh is enabled + if self.enable_auto_refresh: + await self._do_refresh(force=True) + + try: + while self._running: + async with trio.open_nursery() as nursery: + # Schedule periodic refresh if enabled + if self.enable_auto_refresh: + nursery.start_soon(self._periodic_refresh_task) + + except Exception as e: + logger.error(f"RT Refresh Manager main loop error: {e}") + finally: + logger.info("RT Refresh Manager main loop stopped") + + async def _periodic_refresh_task(self) -> None: + """Task for periodic refreshes.""" + while self._running: + await trio.sleep(self.refresh_interval) + if self._running: + await self._do_refresh() + + async def _do_refresh(self, force: bool = False) -> None: + """ + Perform routing table refresh operation. + + Args: + force: Whether to force refresh regardless of timing + + """ + try: + current_time = time.time() + + # Check if refresh is needed + if not force: + if current_time - self._last_refresh_time < self.refresh_interval: + logger.debug("Skipping refresh: interval not elapsed") + return + + if self.routing_table.size() >= self.min_refresh_threshold: + logger.debug("Skipping refresh: routing table size above threshold") + return + + logger.info(f"Starting routing table refresh (force={force})") + start_time = current_time + + # Perform random walks to discover new peers + logger.info("Running concurrent random walks to discover new peers") + current_rt_size = self.routing_table.size() + discovered_peers = await self.random_walk.run_concurrent_random_walks( + count=RANDOM_WALK_CONCURRENCY, + current_routing_table_size=current_rt_size, + ) + + # Add discovered peers to routing table + added_count = 0 + for peer_info in discovered_peers: + result = await self.routing_table.add_peer(peer_info) + if result: + added_count += 1 + + self._last_refresh_time = current_time + + duration = time.time() - start_time + logger.info( + f"Routing table refresh completed: " + f"{added_count}/{len(discovered_peers)} peers added, " + f"RT size: {self.routing_table.size()}, " + f"duration: {duration:.2f}s" + ) + + # Notify refresh completion + for callback in self._refresh_done_callbacks: + try: + callback() + except Exception as e: + logger.warning(f"Refresh callback error: {e}") + + except Exception as e: + logger.error(f"Routing table refresh failed: {e}") + raise RoutingTableRefreshError(f"Refresh operation failed: {e}") from e + + def add_refresh_done_callback(self, callback: Callable[[], None]) -> None: + """Add a callback to be called when refresh completes.""" + self._refresh_done_callbacks.append(callback) + + def remove_refresh_done_callback(self, callback: Callable[[], None]) -> None: + """Remove a refresh completion callback.""" + if callback in self._refresh_done_callbacks: + self._refresh_done_callbacks.remove(callback) diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 70e41953..b40b0128 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -295,6 +295,13 @@ class BasicHost(IHost): ) await net_stream.reset() return + if protocol is None: + logger.debug( + "no protocol negotiated, closing stream from peer %s", + net_stream.muxed_conn.peer_id, + ) + await net_stream.reset() + return net_stream.set_protocol(protocol) if handler is None: logger.debug( diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index dcf323ba..097b6c48 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -5,6 +5,7 @@ This module provides a complete Distributed Hash Table (DHT) implementation based on the Kademlia algorithm and protocol. """ +from collections.abc import Awaitable, Callable from enum import ( Enum, ) @@ -20,6 +21,7 @@ import varint from libp2p.abc import ( IHost, ) +from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager from libp2p.network.stream.net_stream import ( INetStream, ) @@ -73,14 +75,27 @@ class KadDHT(Service): This class provides a DHT implementation that combines routing table management, peer discovery, content routing, and value storage. + + Optional Random Walk feature enhances peer discovery by automatically + performing periodic random queries to discover new peers and maintain + routing table health. + + Example: + # Basic DHT without random walk (default) + dht = KadDHT(host, DHTMode.SERVER) + + # DHT with random walk enabled for enhanced peer discovery + dht = KadDHT(host, DHTMode.SERVER, enable_random_walk=True) + """ - def __init__(self, host: IHost, mode: DHTMode): + def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False): """ Initialize a new Kademlia DHT node. :param host: The libp2p host. :param mode: The mode of host (Client or Server) - must be DHTMode enum + :param enable_random_walk: Whether to enable automatic random walk """ super().__init__() @@ -92,6 +107,7 @@ class KadDHT(Service): raise TypeError(f"mode must be DHTMode enum, got {type(mode)}") self.mode = mode + self.enable_random_walk = enable_random_walk # Initialize the routing table self.routing_table = RoutingTable(self.local_peer_id, self.host) @@ -108,13 +124,56 @@ class KadDHT(Service): # Last time we republished provider records self._last_provider_republish = time.time() + # Initialize RT Refresh Manager (only if random walk is enabled) + self.rt_refresh_manager: RTRefreshManager | None = None + if self.enable_random_walk: + self.rt_refresh_manager = RTRefreshManager( + host=self.host, + routing_table=self.routing_table, + local_peer_id=self.local_peer_id, + query_function=self._create_query_function(), + enable_auto_refresh=True, + ) + # Set protocol handlers host.set_stream_handler(PROTOCOL_ID, self.handle_stream) + def _create_query_function(self) -> Callable[[bytes], Awaitable[list[ID]]]: + """ + Create a query function that wraps peer_routing.find_closest_peers_network. + + This function is used by the RandomWalk module to query for peers without + directly importing PeerRouting, avoiding circular import issues. + + Returns: + Callable that takes target_key bytes and returns list of peer IDs + + """ + + async def query_function(target_key: bytes) -> list[ID]: + """Query for closest peers to target key.""" + return await self.peer_routing.find_closest_peers_network(target_key) + + return query_function + async def run(self) -> None: """Run the DHT service.""" logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}") + # Start the RT Refresh Manager in parallel with the main DHT service + async with trio.open_nursery() as nursery: + # Start the RT Refresh Manager only if random walk is enabled + if self.rt_refresh_manager is not None: + nursery.start_soon(self.rt_refresh_manager.start) + logger.info("RT Refresh Manager started - Random Walk is now active") + else: + logger.info("Random Walk is disabled - RT Refresh Manager not started") + + # Start the main DHT service loop + nursery.start_soon(self._run_main_loop) + + async def _run_main_loop(self) -> None: + """Run the main DHT service loop.""" # Main service loop while self.manager.is_running: # Periodically refresh the routing table @@ -135,6 +194,17 @@ class KadDHT(Service): # Wait before next maintenance cycle await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL) + async def stop(self) -> None: + """Stop the DHT service and cleanup resources.""" + logger.info("Stopping Kademlia DHT") + + # Stop the RT Refresh Manager only if it was started + if self.rt_refresh_manager is not None: + await self.rt_refresh_manager.stop() + logger.info("RT Refresh Manager stopped") + else: + logger.info("RT Refresh Manager was not running (Random Walk disabled)") + async def switch_mode(self, new_mode: DHTMode) -> DHTMode: """ Switch the DHT mode. @@ -614,3 +684,15 @@ class KadDHT(Service): """ return self.value_store.size() + + def is_random_walk_enabled(self) -> bool: + """ + Check if random walk peer discovery is enabled. + + Returns + ------- + bool + True if random walk is enabled, False otherwise. + + """ + return self.enable_random_walk diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index 4bcdb647..c4a066f7 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -170,7 +170,7 @@ class PeerRouting(IPeerRouting): # Return early if we have no peers to start with if not closest_peers: - logger.warning("No local peers available for network lookup") + logger.debug("No local peers available for network lookup") return [] # Iterative lookup until convergence diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 0aa60514..67d46279 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -249,9 +249,11 @@ class Swarm(Service, INetworkService): # We need to wait until `self.listener_nursery` is created. await self.event_listener_nursery_created.wait() + success_count = 0 for maddr in multiaddrs: if str(maddr) in self.listeners: - return True + success_count += 1 + continue async def conn_handler( read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr @@ -302,13 +304,14 @@ class Swarm(Service, INetworkService): # Call notifiers since event occurred await self.notify_listen(maddr) - return True + success_count += 1 + logger.debug("successfully started listening on: %s", maddr) except OSError: # Failed. Continue looping. logger.debug("fail to listen on: %s", maddr) - # No maddr succeeded - return False + # Return true if at least one address succeeded + return success_count > 0 async def close(self) -> None: """ diff --git a/libp2p/protocol_muxer/multiselect.py b/libp2p/protocol_muxer/multiselect.py index 8d311391..287a01f3 100644 --- a/libp2p/protocol_muxer/multiselect.py +++ b/libp2p/protocol_muxer/multiselect.py @@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer): """ self.handlers[protocol] = handler - # FIXME: Make TProtocol Optional[TProtocol] to keep types consistent async def negotiate( self, communicator: IMultiselectCommunicator, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, - ) -> tuple[TProtocol, StreamHandlerFn | None]: + ) -> tuple[TProtocol | None, StreamHandlerFn | None]: """ Negotiate performs protocol selection. @@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer): raise MultiselectError() from error else: - protocol = TProtocol(command) - if protocol in self.handlers: + protocol_to_check = None if not command else TProtocol(command) + if protocol_to_check in self.handlers: try: - await communicator.write(protocol) + await communicator.write(command) except MultiselectCommunicatorError as error: raise MultiselectError() from error - return protocol, self.handlers[protocol] + return protocol_to_check, self.handlers[protocol_to_check] try: await communicator.write(PROTOCOL_NOT_FOUND_MSG) except MultiselectCommunicatorError as error: diff --git a/libp2p/protocol_muxer/multiselect_client.py b/libp2p/protocol_muxer/multiselect_client.py index a5b35006..90adb251 100644 --- a/libp2p/protocol_muxer/multiselect_client.py +++ b/libp2p/protocol_muxer/multiselect_client.py @@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient): :raise MultiselectClientError: raised when protocol negotiation failed :return: selected protocol """ + # Represent `None` protocol as an empty string. + protocol_str = protocol if protocol is not None else "" try: - await communicator.write(protocol) + await communicator.write(protocol_str) except MultiselectCommunicatorError as error: raise MultiselectClientError() from error @@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient): except MultiselectCommunicatorError as error: raise MultiselectClientError() from error - if response == protocol: + if response == protocol_str: return protocol if response == PROTOCOL_NOT_FOUND_MSG: raise MultiselectClientError("protocol not supported") diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index c52266fd..98a8129c 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator): """ :raise MultiselectCommunicatorError: raised when failed to write to underlying reader """ # noqa: E501 - msg_bytes = encode_delim(msg_str.encode()) + if msg_str is None: + msg_bytes = encode_delim(b"") + else: + msg_bytes = encode_delim(msg_str.encode()) try: await self.read_writer.write(msg_bytes) except IOException as error: diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index d396c776..209e1989 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -777,14 +777,18 @@ class GossipSub(IPubsubRouter, Service): # Get list of all seen (seqnos, from) from the (seqno, from) tuples in # seen_messages cache seen_seqnos_and_peers = [ - seqno_and_from for seqno_and_from in self.pubsub.seen_messages.cache.keys() + str(seqno_and_from) + for seqno_and_from in self.pubsub.seen_messages.cache.keys() ] # Add all unknown message ids (ids that appear in ihave_msg but not in # seen_seqnos) to list of messages we want to request + msg_ids_wanted: list[str] = [ + msg_id msg_ids_wanted: list[MessageID] = [ parse_message_id_safe(msg_id) for msg_id in ihave_msg.messageIDs + if msg_id not in seen_seqnos_and_peers if msg_id not in str(seen_seqnos_and_peers) ] diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 193cc092..a9c4b19c 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -17,6 +17,9 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -104,7 +107,7 @@ class SecurityMultistream(ABC): :param is_initiator: true if we are the initiator, false otherwise :return: selected secure transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if is_initiator: # Select protocol if initiator @@ -114,5 +117,7 @@ class SecurityMultistream(ABC): else: # Select protocol if non-initiator protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise MultiselectError("fail to negotiate a security protocol") # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 76699c67..322db912 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -17,6 +17,9 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) from libp2p.protocol_muxer.multiselect import ( Multiselect, ) @@ -73,7 +76,7 @@ class MuxerMultistream: :param conn: conn to choose a transport over :return: selected muxer transport """ - protocol: TProtocol + protocol: TProtocol | None communicator = MultiselectCommunicator(conn) if conn.is_initiator: protocol = await self.multiselect_client.select_one_of( @@ -81,6 +84,8 @@ class MuxerMultistream: ) else: protocol, _ = await self.multiselect.negotiate(communicator) + if protocol is None: + raise MultiselectError("fail to negotiate a stream muxer protocol") return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: diff --git a/libp2p/utils/__init__.py b/libp2p/utils/__init__.py index 0f78bfcb..b881eb92 100644 --- a/libp2p/utils/__init__.py +++ b/libp2p/utils/__init__.py @@ -15,6 +15,13 @@ from libp2p.utils.version import ( get_agent_version, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + expand_wildcard_address, + find_free_port, +) + __all__ = [ "decode_uvarint_from_stream", "encode_delim", @@ -26,4 +33,8 @@ __all__ = [ "decode_varint_from_bytes", "decode_varint_with_size", "read_length_prefixed_protobuf", + "get_available_interfaces", + "get_optimal_binding_address", + "expand_wildcard_address", + "find_free_port", ] diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py new file mode 100644 index 00000000..77b797a1 --- /dev/null +++ b/libp2p/utils/address_validation.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import socket + +from multiaddr import Multiaddr + +try: + from multiaddr.utils import ( # type: ignore + get_network_addrs, + get_thin_waist_addresses, + ) + + _HAS_THIN_WAIST = True +except ImportError: # pragma: no cover - only executed in older environments + _HAS_THIN_WAIST = False + get_thin_waist_addresses = None # type: ignore + get_network_addrs = None # type: ignore + + +def _safe_get_network_addrs(ip_version: int) -> list[str]: + """ + Internal safe wrapper. Returns a list of IP addresses for the requested IP version. + Falls back to minimal defaults when Thin Waist helpers are missing. + + :param ip_version: 4 or 6 + """ + if _HAS_THIN_WAIST and get_network_addrs: + try: + return get_network_addrs(ip_version) or [] + except Exception: # pragma: no cover - defensive + return [] + # Fallback behavior (very conservative) + if ip_version == 4: + return ["127.0.0.1"] + if ip_version == 6: + return ["::1"] + return [] + + +def find_free_port() -> int: + """Find a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to a free port provided by the OS + return s.getsockname()[1] + + +def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]: + """ + Internal safe expansion wrapper. Returns a list of Multiaddr objects. + If Thin Waist isn't available, returns [addr] (identity). + """ + if _HAS_THIN_WAIST and get_thin_waist_addresses: + try: + if port is not None: + return get_thin_waist_addresses(addr, port=port) or [] + return get_thin_waist_addresses(addr) or [] + except Exception: # pragma: no cover - defensive + return [addr] + return [addr] + + +def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]: + """ + Discover available network interfaces (IPv4 + IPv6 if supported) for binding. + + :param port: Port number to bind to. + :param protocol: Transport protocol (e.g., "tcp" or "udp"). + :return: List of Multiaddr objects representing candidate interface addresses. + """ + addrs: list[Multiaddr] = [] + + # IPv4 enumeration + seen_v4: set[str] = set() + + for ip in _safe_get_network_addrs(4): + seen_v4.add(ip) + addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) + + # Ensure IPv4 loopback is always included when IPv4 interfaces are discovered + if seen_v4 and "127.0.0.1" not in seen_v4: + addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")) + + # TODO: IPv6 support temporarily disabled due to libp2p handshake issues + # IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure) + # Re-enable IPv6 support once the following issues are resolved: + # - libp2p security handshake over IPv6 + # - multiselect protocol over IPv6 + # - connection establishment over IPv6 + # + # seen_v6: set[str] = set() + # for ip in _safe_get_network_addrs(6): + # seen_v6.add(ip) + # addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) + # + # # Always include IPv6 loopback for testing purposes when IPv6 is available + # # This ensures IPv6 functionality can be tested even without global IPv6 addresses + # if "::1" not in seen_v6: + # addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}")) + + # Fallback if nothing discovered + if not addrs: + addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")) + + return addrs + + +def expand_wildcard_address( + addr: Multiaddr, port: int | None = None +) -> list[Multiaddr]: + """ + Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces. + + :param addr: Multiaddr to expand. + :param port: Optional override for port selection. + :return: List of concrete Multiaddr instances. + """ + expanded = _safe_expand(addr, port=port) + if not expanded: # Safety fallback + return [addr] + return expanded + + +def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr: + """ + Choose an optimal address for an example to bind to: + - Prefer non-loopback IPv4 + - Then non-loopback IPv6 + - Fallback to loopback + - Fallback to wildcard + + :param port: Port number. + :param protocol: Transport protocol. + :return: A single Multiaddr chosen heuristically. + """ + candidates = get_available_interfaces(port, protocol) + + def is_non_loopback(ma: Multiaddr) -> bool: + s = str(ma) + return not ("/ip4/127." in s or "/ip6/::1" in s) + + for c in candidates: + if "/ip4/" in str(c) and is_non_loopback(c): + return c + for c in candidates: + if "/ip6/" in str(c) and is_non_loopback(c): + return c + for c in candidates: + if "/ip4/127." in str(c) or "/ip6/::1" in str(c): + return c + + # As a final fallback, produce a wildcard + return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + + +__all__ = [ + "get_available_interfaces", + "get_optimal_binding_address", + "expand_wildcard_address", + "find_free_port", +] diff --git a/newsfragments/770.internal.rst b/newsfragments/770.internal.rst new file mode 100644 index 00000000..f33cb3c0 --- /dev/null +++ b/newsfragments/770.internal.rst @@ -0,0 +1 @@ +Make TProtocol as Optional[TProtocol] to keep types consistent in py-libp2p/libp2p/protocol_muxer/multiselect.py diff --git a/newsfragments/811.feature.rst b/newsfragments/811.feature.rst new file mode 100644 index 00000000..47a0aa68 --- /dev/null +++ b/newsfragments/811.feature.rst @@ -0,0 +1 @@ + Added Thin Waist address validation utilities (with support for interface enumeration, optimal binding, and wildcard expansion). diff --git a/newsfragments/811.internal.rst b/newsfragments/811.internal.rst new file mode 100644 index 00000000..59804430 --- /dev/null +++ b/newsfragments/811.internal.rst @@ -0,0 +1,7 @@ +Add Thin Waist address validation utilities and integrate into echo example + +- Add ``libp2p/utils/address_validation.py`` with dynamic interface discovery +- Implement ``get_available_interfaces()``, ``get_optimal_binding_address()``, and ``expand_wildcard_address()`` +- Update echo example to use dynamic address discovery instead of hardcoded wildcard +- Add safe fallbacks for environments lacking Thin Waist support +- Temporarily disable IPv6 support due to libp2p handshake issues (TODO: re-enable when resolved) diff --git a/newsfragments/822.feature.rst b/newsfragments/822.feature.rst new file mode 100644 index 00000000..f9aa3c0e --- /dev/null +++ b/newsfragments/822.feature.rst @@ -0,0 +1 @@ +Added `Random Walk` peer discovery module that enables random peer exploration for improved peer discovery. diff --git a/newsfragments/855.internal.rst b/newsfragments/855.internal.rst new file mode 100644 index 00000000..2c425dde --- /dev/null +++ b/newsfragments/855.internal.rst @@ -0,0 +1 @@ +Improved PubsubNotifee integration tests and added failure scenario coverage. diff --git a/newsfragments/859.feature.rst b/newsfragments/859.feature.rst new file mode 100644 index 00000000..7307c82f --- /dev/null +++ b/newsfragments/859.feature.rst @@ -0,0 +1 @@ +Fix type for gossipsub_message_id for consistency and security diff --git a/newsfragments/863.bugfix.rst b/newsfragments/863.bugfix.rst new file mode 100644 index 00000000..64de57b4 --- /dev/null +++ b/newsfragments/863.bugfix.rst @@ -0,0 +1,5 @@ +Fix multi-address listening bug in swarm.listen() + +- Fix early return in swarm.listen() that prevented listening on all addresses +- Add comprehensive tests for multi-address listening functionality +- Ensure all available interfaces are properly bound and connectable diff --git a/pyproject.toml b/pyproject.toml index c328f038..7f08697e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,9 @@ requires-python = ">=3.10, <4.0" license = { text = "MIT AND Apache-2.0" } keywords = ["libp2p", "p2p"] maintainers = [ - { name = "pacrob", email = "pacrob@protonmail.com" }, + { name = "pacrob", email = "pacrob-py-libp2p@proton.me" }, { name = "Manu Sheel Gupta", email = "manu@seeta.in" }, - { name = "Dave Grantham", email = "dave@aviation.community" }, + { name = "Dave Grantham", email = "dwg@linuxprogrammer.org" }, ] dependencies = [ "base58>=1.0.3", diff --git a/tests/core/network/test_notifee_performance.py b/tests/core/network/test_notifee_performance.py new file mode 100644 index 00000000..cba6d0ad --- /dev/null +++ b/tests/core/network/test_notifee_performance.py @@ -0,0 +1,82 @@ +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p.abc import ( + INetConn, + INetStream, + INetwork, + INotifee, +) +from libp2p.tools.utils import connect_swarm +from tests.utils.factories import SwarmFactory + + +class CountingNotifee(INotifee): + def __init__(self, event: trio.Event) -> None: + self._event = event + + async def opened_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def closed_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def connected(self, network: INetwork, conn: INetConn) -> None: + self._event.set() + + async def disconnected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + +class SlowNotifee(INotifee): + async def opened_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def closed_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def connected(self, network: INetwork, conn: INetConn) -> None: + await trio.sleep(0.5) + + async def disconnected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + +@pytest.mark.trio +async def test_many_notifees_receive_connected_quickly() -> None: + async with SwarmFactory.create_batch_and_listen(2) as swarms: + count = 200 + events = [trio.Event() for _ in range(count)] + for ev in events: + swarms[0].register_notifee(CountingNotifee(ev)) + await connect_swarm(swarms[0], swarms[1]) + with trio.fail_after(1.5): + for ev in events: + await ev.wait() + + +@pytest.mark.trio +async def test_slow_notifee_does_not_block_others() -> None: + async with SwarmFactory.create_batch_and_listen(2) as swarms: + fast_events = [trio.Event() for _ in range(20)] + for ev in fast_events: + swarms[0].register_notifee(CountingNotifee(ev)) + swarms[0].register_notifee(SlowNotifee()) + await connect_swarm(swarms[0], swarms[1]) + # Fast notifees should complete quickly despite one slow notifee + with trio.fail_after(0.3): + for ev in fast_events: + await ev.wait() diff --git a/tests/core/network/test_notify_listen_lifecycle.py b/tests/core/network/test_notify_listen_lifecycle.py new file mode 100644 index 00000000..7bac5938 --- /dev/null +++ b/tests/core/network/test_notify_listen_lifecycle.py @@ -0,0 +1,76 @@ +import enum + +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p.abc import ( + INetConn, + INetStream, + INetwork, + INotifee, +) +from libp2p.tools.async_service import background_trio_service +from libp2p.tools.constants import LISTEN_MADDR +from tests.utils.factories import SwarmFactory + + +class Event(enum.Enum): + Listen = 0 + ListenClose = 1 + + +class MyNotifee(INotifee): + def __init__(self, events: list[Event]): + self.events = events + + async def opened_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def closed_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def connected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def disconnected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: + self.events.append(Event.Listen) + + async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: + self.events.append(Event.ListenClose) + + +async def wait_for_event( + events_list: list[Event], event: Event, timeout: float = 1.0 +) -> bool: + with trio.move_on_after(timeout): + while event not in events_list: + await trio.sleep(0.01) + return True + return False + + +@pytest.mark.trio +async def test_listen_emitted_when_registered_before_listen(): + events: list[Event] = [] + swarm = SwarmFactory.build() + swarm.register_notifee(MyNotifee(events)) + async with background_trio_service(swarm): + # Start listening now; notifee was registered beforehand + assert await swarm.listen(LISTEN_MADDR) + assert await wait_for_event(events, Event.Listen) + + +@pytest.mark.trio +async def test_single_listener_close_emits_listen_close(): + events: list[Event] = [] + swarm = SwarmFactory.build() + swarm.register_notifee(MyNotifee(events)) + async with background_trio_service(swarm): + assert await swarm.listen(LISTEN_MADDR) + # Explicitly notify listen_close (close path via manager doesn't emit it) + await swarm.notify_listen_close(LISTEN_MADDR) + assert await wait_for_event(events, Event.ListenClose) diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index 6389bcb3..605913ec 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -16,6 +16,9 @@ from libp2p.network.exceptions import ( from libp2p.network.swarm import ( Swarm, ) +from libp2p.tools.async_service import ( + background_trio_service, +) from libp2p.tools.utils import ( connect_swarm, ) @@ -184,3 +187,116 @@ def test_new_swarm_quic_multiaddr_raises(): addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic") with pytest.raises(ValueError, match="QUIC not yet supported"): new_swarm(listen_addrs=[addr]) + + +@pytest.mark.trio +async def test_swarm_listen_multiple_addresses(security_protocol): + """Test that swarm can listen on multiple addresses simultaneously.""" + from libp2p.utils.address_validation import get_available_interfaces + + # Get multiple addresses to listen on + listen_addrs = get_available_interfaces(0) # Let OS choose ports + + # Create a swarm and listen on multiple addresses + swarm = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm): + # Listen on all addresses + success = await swarm.listen(*listen_addrs) + assert success, "Should successfully listen on at least one address" + + # Check that we have listeners for the addresses + actual_listeners = list(swarm.listeners.keys()) + assert len(actual_listeners) > 0, "Should have at least one listener" + + # Verify that all successful listeners are in the listeners dict + successful_count = 0 + for addr in listen_addrs: + addr_str = str(addr) + if addr_str in actual_listeners: + successful_count += 1 + # This address successfully started listening + listener = swarm.listeners[addr_str] + listener_addrs = listener.get_addrs() + assert len(listener_addrs) > 0, ( + f"Listener for {addr} should have addresses" + ) + + # Check that the listener address matches the expected address + # (port might be different if we used port 0) + expected_ip = addr.value_for_protocol("ip4") + expected_protocol = addr.value_for_protocol("tcp") + if expected_ip and expected_protocol: + found_matching = False + for listener_addr in listener_addrs: + if ( + listener_addr.value_for_protocol("ip4") == expected_ip + and listener_addr.value_for_protocol("tcp") is not None + ): + found_matching = True + break + assert found_matching, ( + f"Listener for {addr} should have matching IP" + ) + + assert successful_count == len(listen_addrs), ( + f"All {len(listen_addrs)} addresses should be listening, " + f"but only {successful_count} succeeded" + ) + + +@pytest.mark.trio +async def test_swarm_listen_multiple_addresses_connectivity(security_protocol): + """Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501 + from libp2p.peer.peerinfo import info_from_p2p_addr + from libp2p.utils.address_validation import get_available_interfaces + + # Get multiple addresses to listen on + listen_addrs = get_available_interfaces(0) # Let OS choose ports + + # Create a swarm and listen on multiple addresses + swarm1 = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm1): + # Listen on all addresses + success = await swarm1.listen(*listen_addrs) + assert success, "Should successfully listen on at least one address" + + # Verify all available interfaces are listening + assert len(swarm1.listeners) == len(listen_addrs), ( + f"All {len(listen_addrs)} interfaces should be listening, " + f"but only {len(swarm1.listeners)} are" + ) + + # Create a second swarm to test connections + swarm2 = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm2): + # Test connectivity to each listening address using real libp2p connections + for addr_str, listener in swarm1.listeners.items(): + listener_addrs = listener.get_addrs() + for listener_addr in listener_addrs: + # Create a full multiaddr with peer ID for libp2p connection + peer_id = swarm1.get_peer_id() + full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}") + + # Test real libp2p connection + try: + peer_info = info_from_p2p_addr(full_addr) + + # Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501 + swarm2.peerstore.add_addrs( + peer_info.peer_id, [listener_addr], 10000 + ) + + await swarm2.dial_peer(peer_info.peer_id) + + # Verify connection was established + assert peer_info.peer_id in swarm2.connections, ( + f"Connection to {full_addr} should be established" + ) + assert swarm2.get_peer_id() in swarm1.connections, ( + f"Connection from {full_addr} should be established" + ) + + except Exception as e: + pytest.fail( + f"Failed to establish libp2p connection to {full_addr}: {e}" + ) diff --git a/tests/core/protocol_muxer/test_negotiate_timeout.py b/tests/core/protocol_muxer/test_negotiate_timeout.py index a50d65f6..1d089949 100644 --- a/tests/core/protocol_muxer/test_negotiate_timeout.py +++ b/tests/core/protocol_muxer/test_negotiate_timeout.py @@ -1,9 +1,9 @@ +from collections import deque + import pytest import trio -from libp2p.abc import ( - IMultiselectCommunicator, -) +from libp2p.abc import IMultiselectCommunicator, INetStream from libp2p.custom_types import TProtocol from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, @@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect from libp2p.protocol_muxer.multiselect_client import MultiselectClient +async def dummy_handler(stream: INetStream) -> None: + pass + + class DummyMultiselectCommunicator(IMultiselectCommunicator): """ Dummy MultiSelectCommunicator to test out negotiate timmeout. @@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator): @pytest.mark.trio -async def test_select_one_of_timeout(): +async def test_select_one_of_timeout() -> None: ECHO = TProtocol("/echo/1.0.0") communicator = DummyMultiselectCommunicator() @@ -42,7 +46,7 @@ async def test_select_one_of_timeout(): @pytest.mark.trio -async def test_query_multistream_command_timeout(): +async def test_query_multistream_command_timeout() -> None: communicator = DummyMultiselectCommunicator() client = MultiselectClient() @@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout(): @pytest.mark.trio -async def test_negotiate_timeout(): +async def test_negotiate_timeout() -> None: communicator = DummyMultiselectCommunicator() server = Multiselect() with pytest.raises(MultiselectError, match="handshake read timeout"): await server.negotiate(communicator, 2) + + +class HandshakeThenHangCommunicator(IMultiselectCommunicator): + handshaked: bool + + def __init__(self) -> None: + self.handshaked = False + + async def write(self, msg_str: str) -> None: + if msg_str == "/multistream/1.0.0": + self.handshaked = True + return + + async def read(self) -> str: + if not self.handshaked: + return "/multistream/1.0.0" + # After handshake, hang on read. + await trio.sleep_forever() + # Should not be reached. + return "" + + +@pytest.mark.trio +async def test_negotiate_timeout_post_handshake() -> None: + communicator = HandshakeThenHangCommunicator() + server = Multiselect() + with pytest.raises(MultiselectError, match="handshake read timeout"): + await server.negotiate(communicator, 1) + + +class MockCommunicator(IMultiselectCommunicator): + def __init__(self, commands_to_read: list[str]): + self.read_queue = deque(commands_to_read) + self.written_data: list[str] = [] + + async def write(self, msg_str: str) -> None: + self.written_data.append(msg_str) + + async def read(self) -> str: + if not self.read_queue: + raise EOFError + return self.read_queue.popleft() + + +@pytest.mark.trio +async def test_negotiate_empty_string_command() -> None: + # server receives an empty string, which means client wants `None` protocol. + server = Multiselect({None: dummy_handler}) + # Handshake, then empty command + communicator = MockCommunicator(["/multistream/1.0.0", ""]) + protocol, handler = await server.negotiate(communicator) + assert protocol is None + assert handler == dummy_handler + # Check that server sent back handshake and the protocol confirmation (empty string) + assert communicator.written_data == ["/multistream/1.0.0", ""] + + +@pytest.mark.trio +async def test_negotiate_with_none_handler() -> None: + # server has None handler, client sends "" to select it. + server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler}) + # Handshake, then empty command + communicator = MockCommunicator(["/multistream/1.0.0", ""]) + protocol, handler = await server.negotiate(communicator) + assert protocol is None + assert handler == dummy_handler + # Check written data: handshake, protocol confirmation + assert communicator.written_data == ["/multistream/1.0.0", ""] + + +@pytest.mark.trio +async def test_negotiate_with_none_handler_ls() -> None: + # server has None handler, client sends "ls" then empty string. + server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler}) + # Handshake, ls, empty command + communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""]) + protocol, handler = await server.negotiate(communicator) + assert protocol is None + assert handler == dummy_handler + # Check written data: handshake, ls response, protocol confirmation + assert communicator.written_data[0] == "/multistream/1.0.0" + assert "/proto1" in communicator.written_data[1] + # Note: `ls` should not list the `None` protocol. + assert "None" not in communicator.written_data[1] + assert "\n\n" not in communicator.written_data[1] + assert communicator.written_data[2] == "" diff --git a/tests/core/protocol_muxer/test_protocol_muxer.py b/tests/core/protocol_muxer/test_protocol_muxer.py index 1d6a0f86..57939bb6 100644 --- a/tests/core/protocol_muxer/test_protocol_muxer.py +++ b/tests/core/protocol_muxer/test_protocol_muxer.py @@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols(): protocols = ms.get_protocols() assert set(protocols) == {p1, p2, p3} + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol(security_protocol): + with pytest.raises(Exception): + await perform_simple_test( + None, + [None], + [None], + security_protocol, + ) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_client_none_server_no_none( + security_protocol, +): + with pytest.raises(Exception): + await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol): + expected_selected_protocol = PROTOCOL_ECHO + await perform_simple_test( + expected_selected_protocol, + [None, PROTOCOL_ECHO], + [PROTOCOL_ECHO], + security_protocol, + ) + + +@pytest.mark.trio +async def test_negotiate_optional_tprotocol_server_none_client_other( + security_protocol, +): + with pytest.raises(Exception): + await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol) diff --git a/tests/core/pubsub/test_pubsub_notifee_integration.py b/tests/core/pubsub/test_pubsub_notifee_integration.py new file mode 100644 index 00000000..e35dfeb1 --- /dev/null +++ b/tests/core/pubsub/test_pubsub_notifee_integration.py @@ -0,0 +1,90 @@ +from typing import cast + +import pytest +import trio + +from libp2p.tools.utils import connect +from tests.utils.factories import PubsubFactory + + +@pytest.mark.trio +async def test_connected_enqueues_and_adds_peer(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + await connect(p0.host, p1.host) + await p0.wait_until_ready() + # Wait until peer is added via queue processing + with trio.fail_after(1.0): + while p1.my_id not in p0.peers: + await trio.sleep(0.01) + assert p1.my_id in p0.peers + + +@pytest.mark.trio +async def test_disconnected_enqueues_and_removes_peer(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + await connect(p0.host, p1.host) + await p0.wait_until_ready() + # Ensure present first + with trio.fail_after(1.0): + while p1.my_id not in p0.peers: + await trio.sleep(0.01) + # Now disconnect and expect removal via dead peer queue + await p0.host.get_network().close_peer(p1.host.get_id()) + with trio.fail_after(1.0): + while p1.my_id in p0.peers: + await trio.sleep(0.01) + assert p1.my_id not in p0.peers + + +@pytest.mark.trio +async def test_channel_closed_is_swallowed_in_notifee(monkeypatch) -> None: + # Ensure PubsubNotifee catches BrokenResourceError from its send channel + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + # Find the PubsubNotifee registered on the network + from libp2p.pubsub.pubsub_notifee import PubsubNotifee + + network = p0.host.get_network() + notifees = getattr(network, "notifees", []) + target = None + for nf in notifees: + if isinstance(nf, cast(type, PubsubNotifee)): + target = nf + break + assert target is not None, "PubsubNotifee not found on network" + + async def failing_send(_peer_id): # type: ignore[no-redef] + raise trio.BrokenResourceError + + # Make initiator queue send fail; PubsubNotifee should swallow + monkeypatch.setattr(target.initiator_peers_queue, "send", failing_send) + + # Connect peers; if exceptions are swallowed, service stays running + await connect(p0.host, p1.host) + await p0.wait_until_ready() + assert True + + +@pytest.mark.trio +async def test_duplicate_connection_does_not_duplicate_peer_state(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + await connect(p0.host, p1.host) + await p0.wait_until_ready() + with trio.fail_after(1.0): + while p1.my_id not in p0.peers: + await trio.sleep(0.01) + # Connect again should not add duplicates + await connect(p0.host, p1.host) + await trio.sleep(0.1) + assert list(p0.peers.keys()).count(p1.my_id) == 1 + + +@pytest.mark.trio +async def test_blacklist_blocks_peer_added_by_notifee(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + # Blacklist before connecting + p0.add_to_blacklist(p1.my_id) + await connect(p0.host, p1.host) + await p0.wait_until_ready() + # Give handler a chance to run + await trio.sleep(0.1) + assert p1.my_id not in p0.peers diff --git a/tests/discovery/random_walk/test_random_walk.py b/tests/discovery/random_walk/test_random_walk.py new file mode 100644 index 00000000..f5691782 --- /dev/null +++ b/tests/discovery/random_walk/test_random_walk.py @@ -0,0 +1,99 @@ +""" +Unit tests for the RandomWalk module in libp2p.discovery.random_walk. +""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from libp2p.discovery.random_walk.random_walk import RandomWalk +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo + + +@pytest.fixture +def mock_host(): + host = Mock() + peerstore = Mock() + peerstore.peers_with_addrs.return_value = [] + peerstore.addrs.return_value = [Mock()] + host.get_peerstore.return_value = peerstore + host.new_stream = AsyncMock() + return host + + +@pytest.fixture +def dummy_query_function(): + async def query(key_bytes): + return [] + + return query + + +@pytest.fixture +def dummy_peer_id(): + return b"\x01" * 32 + + +@pytest.mark.trio +async def test_random_walk_initialization( + mock_host, dummy_peer_id, dummy_query_function +): + rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function) + assert rw.host == mock_host + assert rw.local_peer_id == dummy_peer_id + assert rw.query_function == dummy_query_function + + +def test_generate_random_peer_id(mock_host, dummy_peer_id, dummy_query_function): + rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function) + peer_id = rw.generate_random_peer_id() + assert isinstance(peer_id, str) + assert len(peer_id) == 64 # 32 bytes hex + + +@pytest.mark.trio +async def test_run_concurrent_random_walks(mock_host, dummy_peer_id): + # Dummy query function returns different peer IDs for each walk + call_count = {"count": 0} + + async def query(key_bytes): + call_count["count"] += 1 + # Return a unique peer ID for each call + return [ID(bytes([call_count["count"]] * 32))] + + rw = RandomWalk(mock_host, dummy_peer_id, query) + peers = await rw.run_concurrent_random_walks(count=3) + # Should get 3 unique peers + assert len(peers) == 3 + peer_ids = [peer.peer_id for peer in peers] + assert len(set(peer_ids)) == 3 + + +@pytest.mark.trio +async def test_perform_random_walk_running(mock_host, dummy_peer_id): + # Query function returns a single peer ID + async def query(key_bytes): + return [ID(b"\x02" * 32)] + + rw = RandomWalk(mock_host, dummy_peer_id, query) + peers = await rw.perform_random_walk() + assert isinstance(peers, list) + if peers: + assert isinstance(peers[0], PeerInfo) + + +@pytest.mark.trio +async def test_perform_random_walk_no_peers_found(mock_host, dummy_peer_id): + """Test perform_random_walk when no peers are discovered.""" + + # Query function returns empty list (no peers found) + async def query(key_bytes): + return [] + + rw = RandomWalk(mock_host, dummy_peer_id, query) + peers = await rw.perform_random_walk() + + # Should return empty list when no peers are found + assert isinstance(peers, list) + assert len(peers) == 0 diff --git a/tests/discovery/random_walk/test_rt_refresh_manager.py b/tests/discovery/random_walk/test_rt_refresh_manager.py new file mode 100644 index 00000000..d0a65916 --- /dev/null +++ b/tests/discovery/random_walk/test_rt_refresh_manager.py @@ -0,0 +1,451 @@ +""" +Unit tests for the RTRefreshManager and related random walk logic. +""" + +import time +from unittest.mock import AsyncMock, Mock, patch + +import pytest +import trio + +from libp2p.discovery.random_walk.config import ( + MIN_RT_REFRESH_THRESHOLD, + RANDOM_WALK_CONCURRENCY, + REFRESH_INTERVAL, +) +from libp2p.discovery.random_walk.exceptions import ( + RandomWalkError, +) +from libp2p.discovery.random_walk.random_walk import RandomWalk +from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo + + +class DummyRoutingTable: + def __init__(self, size=0): + self._size = size + self.added_peers = [] + + def size(self): + return self._size + + async def add_peer(self, peer_obj): + self.added_peers.append(peer_obj) + self._size += 1 + return True + + +@pytest.fixture +def mock_host(): + host = Mock() + host.get_peerstore.return_value = Mock() + host.new_stream = AsyncMock() + return host + + +@pytest.fixture +def local_peer_id(): + return ID(b"\x01" * 32) + + +@pytest.fixture +def dummy_query_function(): + async def query(key_bytes): + return [ID(b"\x02" * 32)] + + return query + + +@pytest.mark.trio +async def test_rt_refresh_manager_initialization( + mock_host, local_peer_id, dummy_query_function +): + rt = DummyRoutingTable(size=5) + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=REFRESH_INTERVAL, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + assert manager.host == mock_host + assert manager.routing_table == rt + assert manager.local_peer_id == local_peer_id + assert manager.query_function == dummy_query_function + + +@pytest.mark.trio +async def test_rt_refresh_manager_refresh_logic( + mock_host, local_peer_id, dummy_query_function +): + rt = DummyRoutingTable(size=2) + # Simulate refresh logic + if rt.size() < MIN_RT_REFRESH_THRESHOLD: + await rt.add_peer(Mock()) + assert rt.size() >= 3 + + +@pytest.mark.trio +async def test_rt_refresh_manager_random_walk_integration( + mock_host, local_peer_id, dummy_query_function +): + # Simulate random walk usage + rw = RandomWalk(mock_host, local_peer_id, dummy_query_function) + random_peer_id = rw.generate_random_peer_id() + assert isinstance(random_peer_id, str) + assert len(random_peer_id) == 64 + + +@pytest.mark.trio +async def test_rt_refresh_manager_error_handling(mock_host, local_peer_id): + rt = DummyRoutingTable(size=0) + + async def failing_query(_): + raise RandomWalkError("Query failed") + + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=failing_query, + enable_auto_refresh=True, + refresh_interval=REFRESH_INTERVAL, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + with pytest.raises(RandomWalkError): + await manager.query_function(b"key") + + +@pytest.mark.trio +async def test_rt_refresh_manager_start_method( + mock_host, local_peer_id, dummy_query_function +): + """Test the start method functionality.""" + rt = DummyRoutingTable(size=2) + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=False, # Disable auto-refresh to control the test + refresh_interval=0.1, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Mock the random walk to return some peers + mock_peer_info = Mock(spec=PeerInfo) + with patch.object( + manager.random_walk, + "run_concurrent_random_walks", + return_value=[mock_peer_info], + ): + # Test starting the manager + assert not manager._running + + # Start the manager in a nursery that we can control + async with trio.open_nursery() as nursery: + nursery.start_soon(manager.start) + await trio.sleep(0.01) # Let it start + + # Verify it's running + assert manager._running + + # Stop the manager + await manager.stop() + assert not manager._running + + +@pytest.mark.trio +async def test_rt_refresh_manager_main_loop_with_auto_refresh( + mock_host, local_peer_id, dummy_query_function +): + """Test the _main_loop method with auto-refresh enabled.""" + rt = DummyRoutingTable(size=1) # Small size to trigger refresh + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=0.1, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Mock the random walk to return some peers + mock_peer_info = Mock(spec=PeerInfo) + with patch.object( + manager.random_walk, + "run_concurrent_random_walks", + return_value=[mock_peer_info], + ) as mock_random_walk: + manager._running = True + + # Run the main loop for a short time + async with trio.open_nursery() as nursery: + nursery.start_soon(manager._main_loop) + await trio.sleep(0.05) # Let it run briefly + manager._running = False # Stop the loop + + # Verify that random walk was called (initial refresh) + mock_random_walk.assert_called() + + +@pytest.mark.trio +async def test_rt_refresh_manager_main_loop_without_auto_refresh( + mock_host, local_peer_id, dummy_query_function +): + """Test the _main_loop method with auto-refresh disabled.""" + rt = DummyRoutingTable(size=1) + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=False, + refresh_interval=0.1, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + with patch.object( + manager.random_walk, "run_concurrent_random_walks" + ) as mock_random_walk: + manager._running = True + + # Run the main loop for a short time + async with trio.open_nursery() as nursery: + nursery.start_soon(manager._main_loop) + await trio.sleep(0.05) + manager._running = False + + # Verify that random walk was not called since auto-refresh is disabled + mock_random_walk.assert_not_called() + + +@pytest.mark.trio +async def test_rt_refresh_manager_main_loop_initial_refresh_exception( + mock_host, local_peer_id, dummy_query_function +): + """Test that _main_loop propagates exceptions from initial refresh.""" + rt = DummyRoutingTable(size=1) + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=0.1, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Mock _do_refresh to raise an exception on the initial call + with patch.object( + manager, "_do_refresh", side_effect=Exception("Initial refresh failed") + ): + manager._running = True + + # The initial refresh exception should propagate + with pytest.raises(Exception, match="Initial refresh failed"): + await manager._main_loop() + + +@pytest.mark.trio +async def test_do_refresh_force_refresh(mock_host, local_peer_id, dummy_query_function): + """Test _do_refresh method with force=True.""" + rt = DummyRoutingTable(size=10) # Large size, but force should override + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=REFRESH_INTERVAL, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Mock the random walk to return some peers + mock_peer_info1 = Mock(spec=PeerInfo) + mock_peer_info2 = Mock(spec=PeerInfo) + discovered_peers = [mock_peer_info1, mock_peer_info2] + + with patch.object( + manager.random_walk, + "run_concurrent_random_walks", + return_value=discovered_peers, + ) as mock_random_walk: + # Force refresh should work regardless of RT size + await manager._do_refresh(force=True) + + # Verify random walk was called + mock_random_walk.assert_called_once_with( + count=RANDOM_WALK_CONCURRENCY, current_routing_table_size=10 + ) + + # Verify peers were added to routing table + assert len(rt.added_peers) == 2 + assert manager._last_refresh_time > 0 + + +@pytest.mark.trio +async def test_do_refresh_skip_due_to_interval( + mock_host, local_peer_id, dummy_query_function +): + """Test _do_refresh skips refresh when interval hasn't elapsed.""" + rt = DummyRoutingTable(size=1) # Small size to trigger refresh normally + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=100.0, # Long interval + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Set last refresh time to recent + manager._last_refresh_time = time.time() + + with patch.object( + manager.random_walk, "run_concurrent_random_walks" + ) as mock_random_walk: + with patch( + "libp2p.discovery.random_walk.rt_refresh_manager.logger" + ) as mock_logger: + await manager._do_refresh(force=False) + + # Verify refresh was skipped + mock_random_walk.assert_not_called() + mock_logger.debug.assert_called_with( + "Skipping refresh: interval not elapsed" + ) + + +@pytest.mark.trio +async def test_do_refresh_skip_due_to_rt_size( + mock_host, local_peer_id, dummy_query_function +): + """Test _do_refresh skips refresh when RT size is above threshold.""" + rt = DummyRoutingTable(size=20) # Large size above threshold + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=0.1, # Short interval + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Set last refresh time to old + manager._last_refresh_time = 0.0 + + with patch.object( + manager.random_walk, "run_concurrent_random_walks" + ) as mock_random_walk: + with patch( + "libp2p.discovery.random_walk.rt_refresh_manager.logger" + ) as mock_logger: + await manager._do_refresh(force=False) + + # Verify refresh was skipped + mock_random_walk.assert_not_called() + mock_logger.debug.assert_called_with( + "Skipping refresh: routing table size above threshold" + ) + + +@pytest.mark.trio +async def test_refresh_done_callbacks(mock_host, local_peer_id, dummy_query_function): + """Test refresh completion callbacks functionality.""" + rt = DummyRoutingTable(size=1) + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=0.1, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Create mock callbacks + callback1 = Mock() + callback2 = Mock() + failing_callback = Mock(side_effect=Exception("Callback failed")) + + # Add callbacks + manager.add_refresh_done_callback(callback1) + manager.add_refresh_done_callback(callback2) + manager.add_refresh_done_callback(failing_callback) + + # Mock the random walk + mock_peer_info = Mock(spec=PeerInfo) + with patch.object( + manager.random_walk, + "run_concurrent_random_walks", + return_value=[mock_peer_info], + ): + with patch( + "libp2p.discovery.random_walk.rt_refresh_manager.logger" + ) as mock_logger: + await manager._do_refresh(force=True) + + # Verify all callbacks were called + callback1.assert_called_once() + callback2.assert_called_once() + failing_callback.assert_called_once() + + # Verify warning was logged for failing callback + mock_logger.warning.assert_called() + + +@pytest.mark.trio +async def test_stop_when_not_running(mock_host, local_peer_id, dummy_query_function): + """Test stop method when manager is not running.""" + rt = DummyRoutingTable(size=1) + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=0.1, + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Manager is not running + assert not manager._running + + # Stop should return without doing anything + await manager.stop() + assert not manager._running + + +@pytest.mark.trio +async def test_periodic_refresh_task(mock_host, local_peer_id, dummy_query_function): + """Test the _periodic_refresh_task method.""" + rt = DummyRoutingTable(size=1) + manager = RTRefreshManager( + host=mock_host, + routing_table=rt, + local_peer_id=local_peer_id, + query_function=dummy_query_function, + enable_auto_refresh=True, + refresh_interval=0.05, # Very short interval for testing + min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD, + ) + + # Mock _do_refresh to track calls + with patch.object(manager, "_do_refresh") as mock_do_refresh: + manager._running = True + + # Run periodic refresh task for a short time + async with trio.open_nursery() as nursery: + nursery.start_soon(manager._periodic_refresh_task) + await trio.sleep(0.12) # Let it run for ~2 intervals + manager._running = False # Stop the task + + # Verify _do_refresh was called at least once + assert mock_do_refresh.call_count >= 1 diff --git a/tests/examples/test_echo_thin_waist.py b/tests/examples/test_echo_thin_waist.py new file mode 100644 index 00000000..2bcb52b1 --- /dev/null +++ b/tests/examples/test_echo_thin_waist.py @@ -0,0 +1,109 @@ +import contextlib +import os +from pathlib import Path +import subprocess +import sys +import time + +from multiaddr import Multiaddr +from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP + +# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging + +# This test is intentionally lightweight and can be marked as 'integration'. +# It ensures the echo example runs and prints the new Thin Waist lines using +# Trio primitives. + +current_file = Path(__file__) +project_root = current_file.parent.parent.parent +EXAMPLES_DIR: Path = project_root / "examples" / "echo" + + +def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path): + """Run echo server and validate printed multiaddr and peer id.""" + # Run echo example as server + cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"] + env = {**os.environ, "PYTHONUNBUFFERED": "1"} + proc: subprocess.Popen[str] = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + env=env, + ) + + if proc.stdout is None: + proc.terminate() + raise RuntimeError("Process stdout is None") + out_stream = proc.stdout + + peer_id: str | None = None + printed_multiaddr: str | None = None + saw_waiting = False + + start = time.time() + timeout_s = 8.0 + try: + while time.time() - start < timeout_s: + line = out_stream.readline() + if not line: + time.sleep(0.05) + continue + s = line.strip() + if s.startswith("I am "): + peer_id = s.partition("I am ")[2] + if s.startswith("echo-demo -d "): + printed_multiaddr = s.partition("echo-demo -d ")[2] + if "Waiting for incoming connections..." in s: + saw_waiting = True + break + finally: + with contextlib.suppress(ProcessLookupError): + proc.terminate() + with contextlib.suppress(ProcessLookupError): + proc.kill() + + assert peer_id, "Did not capture peer ID line" + assert printed_multiaddr, "Did not capture multiaddr line" + assert saw_waiting, "Did not capture waiting-for-connections line" + + # Validate multiaddr structure using py-multiaddr protocol methods + ma = Multiaddr(printed_multiaddr) # should parse without error + + # Check that the multiaddr contains the p2p protocol + try: + peer_id_from_multiaddr = ma.value_for_protocol("p2p") + assert peer_id_from_multiaddr is not None, ( + "Multiaddr missing p2p protocol value" + ) + assert peer_id_from_multiaddr == peer_id, ( + f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}" + ) + except Exception as e: + raise AssertionError(f"Failed to extract p2p protocol value: {e}") + + # Validate the multiaddr structure by checking protocols + protocols = ma.protocols() + + # Should have at least IP, TCP, and P2P protocols + assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), ( + "Missing IP protocol" + ) + assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol" + assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol" + + # Extract the p2p part and validate it matches the captured peer ID + p2p_part = Multiaddr(f"/p2p/{peer_id}") + try: + # Decapsulate the p2p part to get the transport address + transport_addr = ma.decapsulate(p2p_part) + # Verify the decapsulated address doesn't contain p2p + transport_protocols = transport_addr.protocols() + assert not any(p.code == P_P2P for p in transport_protocols), ( + "Decapsulation failed - still contains p2p" + ) + # Verify the original multiaddr can be reconstructed + reconstructed = transport_addr.encapsulate(p2p_part) + assert str(reconstructed) == str(ma), "Reconstruction failed" + except Exception as e: + raise AssertionError(f"Multiaddr decapsulation failed: {e}") diff --git a/tests/utils/test_address_validation.py b/tests/utils/test_address_validation.py new file mode 100644 index 00000000..5b108d09 --- /dev/null +++ b/tests/utils/test_address_validation.py @@ -0,0 +1,56 @@ +import os + +import pytest +from multiaddr import Multiaddr + +from libp2p.utils.address_validation import ( + expand_wildcard_address, + get_available_interfaces, + get_optimal_binding_address, +) + + +@pytest.mark.parametrize("proto", ["tcp"]) +def test_get_available_interfaces(proto: str) -> None: + interfaces = get_available_interfaces(0, protocol=proto) + assert len(interfaces) > 0 + for addr in interfaces: + assert isinstance(addr, Multiaddr) + assert f"/{proto}/" in str(addr) + + +def test_get_optimal_binding_address() -> None: + addr = get_optimal_binding_address(0) + assert isinstance(addr, Multiaddr) + # At least IPv4 or IPv6 prefix present + s = str(addr) + assert ("/ip4/" in s) or ("/ip6/" in s) + + +def test_expand_wildcard_address_ipv4() -> None: + wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0") + expanded = expand_wildcard_address(wildcard) + assert len(expanded) > 0 + for e in expanded: + assert isinstance(e, Multiaddr) + assert "/tcp/" in str(e) + + +def test_expand_wildcard_address_port_override() -> None: + wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000") + overridden = expand_wildcard_address(wildcard, port=9001) + assert len(overridden) > 0 + for e in overridden: + assert str(e).endswith("/tcp/9001") + + +@pytest.mark.skipif( + os.environ.get("NO_IPV6") == "1", + reason="Environment disallows IPv6", +) +def test_expand_wildcard_address_ipv6() -> None: + wildcard = Multiaddr("/ip6/::/tcp/0") + expanded = expand_wildcard_address(wildcard) + assert len(expanded) > 0 + for e in expanded: + assert "/ip6/" in str(e)