From d61bca78ab3ca8369445c3d22f7f749b730ddfdf Mon Sep 17 00:00:00 2001 From: Sumanjeet Date: Tue, 17 Jun 2025 02:16:40 +0530 Subject: [PATCH] Kademlia DHT implementation in py-libp2p (#579) * initialise the module * added content routing * added routing module * added peer routing * added value store * added utilities functions * added main kademlia file * fixed create_key_from_binary function * example to test kademlia dht * added protocol ID and enhanced logging for peer store size in provider and consumer nodes * refactor: specify stream type in handle_stream method and add peer in routing table * removed content routing * added default value of count for finding closest peers * added functions to find close peers * refactor: remove content routing and enhance peer discovery * added put value function * added get value function * fix: improve logging and handle key encoding in get_value method * refactor: remove ContentRouting import from __init__.py * refactor: improved basic kademlia example * added protobuf files * replaced json with protobuf * refactor: enhance peer discovery and routing logic in KadDHT * refactor: enhance Kademlia routing table to use PeerInfo objects and improve peer management * refactor: enhance peer addition logic to utilize PeerInfo objects in routing table * feat: implement content provider functionality in Kademlia DHT * refactor: update value store to use datetime for validity management * refactor: update RoutingTable initialization to include host reference * refactor: enhance KBucket and RoutingTable for improved peer management and functionality * refactor: streamline peer discovery and value storage methods in KadDHT * refactor: update KadDHT and related classes for async peer management and enhanced value storage * refactor: enhance ProviderStore initialization and improve peer routing integration * test: add tests for Kademlia DHT functionality * fix linting issues * pydocstyle issues fixed * CICD pipeline issues solved * fix: update docstring format for find_peer method * refactor: improve logging and remove unused code in DHT implementation * refactor: clean up logging and remove unused imports in DHT and test files * Refactor logging setup and improve DHT stream handling with varint length prefixes * Update bootstrap peer handling in basic_dht example and refactor peer routing to accept string addresses * Enhance peer querying in Kademlia DHT by implementing parallel queries using Trio. * Enhance peer querying by adding deduplication checks * Refactor DHT implementation to use varint for length prefixes and enhance logging for better traceability * Add base58 encoding for value storage and enhance logging in basic_dht example * Refactor Kademlia DHT to support server/client modes * Added unit tests * Refactor documentation to fixsome warning * Add unit tests and remove outdated tests * Fixed precommit errora * Refactor error handling test to raise StringParseError for invalid bootstrap addresses * Add libp2p.kad_dht to the list of subpackages in documentation * Fix expiration and republish checks to use inclusive comparison * Add __init__.py file to libp2p.kad_dht.pb package * Refactor get value and put value to run in parallel with query timeout * Refactor provider message handling to use parallel processing with timeout * Add methods for provider store in KadDHT class * Refactor KadDHT and ProviderStore methods to improve type hints and enhance parallel processing * Add documentation for libp2p.kad_dht.pb module. * Update documentation for libp2p.kad_dht package to include subpackages and correct formatting * Fix formatting in documentation for libp2p.kad_dht package by correcting the subpackage reference * Fix header formatting in libp2p.kad_dht.pb documentation * Change log level from info to debug for various logging statements. * fix CICD issues (post revamp) * fixed value store unit test * Refactored kademlia example * Refactor Kademlia example: enhance logging, improve bootstrap node connection, and streamline server address handling * removed bootstrap module * Refactor Kademlia DHT example and core modules: enhance logging, remove unused code, and improve peer handling * Added docs of kad dht example * Update server address log file path to use the script's directory * Refactor: Introduce DHTMode enum for clearer mode management * moved xor_distance function to utils.py * Enhance logging in ValueStore and KadDHT: include decoded value in debug logs and update parameter description for validity * Add handling for closest peers in GET_VALUE response when value is not found * Handled failure scenario for PUT_VALUE * Remove kademlia demo from project scripts and contributing documentation * spelling and logging --------- Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com> --- Makefile | 4 +- docs/examples.kademlia.rst | 124 +++ docs/examples.rst | 1 + docs/libp2p.kad_dht.pb.rst | 22 + docs/libp2p.kad_dht.rst | 77 ++ docs/libp2p.rst | 1 + examples/kademlia/kademlia.py | 300 +++++++ libp2p/kad_dht/__init__.py | 30 + libp2p/kad_dht/kad_dht.py | 616 ++++++++++++++ libp2p/kad_dht/pb/__init__.py | 0 libp2p/kad_dht/pb/kademlia.proto | 38 + libp2p/kad_dht/pb/kademlia_pb2.py | 33 + libp2p/kad_dht/pb/kademlia_pb2.pyi | 133 +++ libp2p/kad_dht/peer_routing.py | 418 +++++++++ libp2p/kad_dht/provider_store.py | 575 +++++++++++++ libp2p/kad_dht/routing_table.py | 601 +++++++++++++ libp2p/kad_dht/utils.py | 117 +++ libp2p/kad_dht/value_store.py | 393 +++++++++ newsfragments/579.feature.rst | 1 + tests/core/kad_dht/test_kad_dht.py | 168 ++++ tests/core/kad_dht/test_unit_peer_routing.py | 459 ++++++++++ .../core/kad_dht/test_unit_provider_store.py | 805 ++++++++++++++++++ tests/core/kad_dht/test_unit_routing_table.py | 371 ++++++++ tests/core/kad_dht/test_unit_value_store.py | 504 +++++++++++ 24 files changed, 5790 insertions(+), 1 deletion(-) create mode 100644 docs/examples.kademlia.rst create mode 100644 docs/libp2p.kad_dht.pb.rst create mode 100644 docs/libp2p.kad_dht.rst create mode 100644 examples/kademlia/kademlia.py create mode 100644 libp2p/kad_dht/__init__.py create mode 100644 libp2p/kad_dht/kad_dht.py create mode 100644 libp2p/kad_dht/pb/__init__.py create mode 100644 libp2p/kad_dht/pb/kademlia.proto create mode 100644 libp2p/kad_dht/pb/kademlia_pb2.py create mode 100644 libp2p/kad_dht/pb/kademlia_pb2.pyi create mode 100644 libp2p/kad_dht/peer_routing.py create mode 100644 libp2p/kad_dht/provider_store.py create mode 100644 libp2p/kad_dht/routing_table.py create mode 100644 libp2p/kad_dht/utils.py create mode 100644 libp2p/kad_dht/value_store.py create mode 100644 newsfragments/579.feature.rst create mode 100644 tests/core/kad_dht/test_kad_dht.py create mode 100644 tests/core/kad_dht/test_unit_peer_routing.py create mode 100644 tests/core/kad_dht/test_unit_provider_store.py create mode 100644 tests/core/kad_dht/test_unit_routing_table.py create mode 100644 tests/core/kad_dht/test_unit_value_store.py diff --git a/Makefile b/Makefile index e99b3ac9..08adba67 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,9 @@ PB = libp2p/crypto/pb/crypto.proto \ libp2p/security/secio/pb/spipe.proto \ libp2p/security/noise/pb/noise.proto \ libp2p/identity/identify/pb/identify.proto \ - libp2p/host/autonat/pb/autonat.proto + libp2p/host/autonat/pb/autonat.proto \ + libp2p/kad_dht/pb/kademlia.proto + PY = $(PB:.proto=_pb2.py) PYI = $(PB:.proto=_pb2.pyi) diff --git a/docs/examples.kademlia.rst b/docs/examples.kademlia.rst new file mode 100644 index 00000000..fdc497ab --- /dev/null +++ b/docs/examples.kademlia.rst @@ -0,0 +1,124 @@ +Kademlia DHT Demo +================= + +This example demonstrates a Kademlia Distributed Hash Table (DHT) implementation with both value storage/retrieval and content provider advertisement/discovery functionality. + +.. code-block:: console + + $ python -m pip install libp2p + Collecting libp2p + ... + Successfully installed libp2p-x.x.x + $ cd examples/kademlia + $ python kademlia.py --mode server + 2025-06-13 19:51:25,424 - kademlia-example - INFO - Running in server mode on port 0 + 2025-06-13 19:51:25,426 - kademlia-example - INFO - Connected to bootstrap nodes: [] + 2025-06-13 19:51:25,426 - kademlia-example - INFO - To connect to this node, use: --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef + 2025-06-13 19:51:25,426 - kademlia-example - INFO - Saved server address to log: /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef + 2025-06-13 19:51:25,427 - kademlia-example - INFO - DHT service started in SERVER mode + 2025-06-13 19:51:25,427 - kademlia-example - INFO - Stored value 'Hello message from Sumanjeet' with key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP + 2025-06-13 19:51:25,427 - kademlia-example - INFO - Successfully advertised as server for content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8 + + +Copy the line that starts with ``--bootstrap``, open a new terminal in the same folder and run the client: + +.. code-block:: console + + $ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef + 2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0 + 2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [] + 2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode + 2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP + 2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet + 2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8 + 2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef'] + +Alternatively, if you run the server first, the client can automatically extract the bootstrap address from the server log file: + +.. code-block:: console + + $ python kademlia.py --mode client + 2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0 + 2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [] + 2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode + 2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP + 2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet + 2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8 + 2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef'] + +The demo showcases key DHT operations: + +- **Value Storage & Retrieval**: The server stores a value, and the client retrieves it +- **Content Provider Discovery**: The server advertises content, and the client finds providers +- **Peer Discovery**: Automatic bootstrap and peer routing using the Kademlia algorithm +- **Network Resilience**: Distributed storage across multiple nodes (when available) + +Command Line Options +-------------------- + +The Kademlia demo supports several command line options for customization: + +.. code-block:: console + + $ python kademlia.py --help + usage: kademlia.py [-h] [--mode MODE] [--port PORT] [--bootstrap [BOOTSTRAP ...]] [--verbose] + + Kademlia DHT example with content server functionality + + options: + -h, --help show this help message and exit + --mode MODE Run as a server or client node (default: server) + --port PORT Port to listen on (0 for random) (default: 0) + --bootstrap [BOOTSTRAP ...] + Multiaddrs of bootstrap nodes. Provide a space-separated list of addresses. + This is required for client mode. + --verbose Enable verbose logging + +**Examples:** + +Start server on a specific port: + +.. code-block:: console + + $ python kademlia.py --mode server --port 8000 + +Start client with verbose logging: + +.. code-block:: console + + $ python kademlia.py --mode client --verbose + +Connect to multiple bootstrap nodes: + +.. code-block:: console + + $ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/8000/p2p/... /ip4/127.0.0.1/tcp/8001/p2p/... + +How It Works +------------ + +The Kademlia DHT implementation demonstrates several key concepts: + +**Server Mode:** + - Stores key-value pairs in the distributed hash table + - Advertises itself as a content provider for specific content + - Handles incoming DHT requests from other nodes + - Maintains routing table with known peers + +**Client Mode:** + - Connects to bootstrap nodes to join the network + - Retrieves values by their keys from the DHT + - Discovers content providers for specific content + - Performs network lookups using the Kademlia algorithm + +**Key Components:** + - **Routing Table**: Organizes peers in k-buckets based on XOR distance + - **Value Store**: Manages key-value storage with TTL (time-to-live) + - **Provider Store**: Tracks which peers provide specific content + - **Peer Routing**: Implements iterative lookups to find closest peers + +The full source code for this example is below: + +.. literalinclude:: ../examples/kademlia/kademlia.py + :language: python + :linenos: diff --git a/docs/examples.rst b/docs/examples.rst index e2f0bdd4..c8d82820 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -11,3 +11,4 @@ Examples examples.echo examples.ping examples.pubsub + examples.kademlia diff --git a/docs/libp2p.kad_dht.pb.rst b/docs/libp2p.kad_dht.pb.rst new file mode 100644 index 00000000..475c838d --- /dev/null +++ b/docs/libp2p.kad_dht.pb.rst @@ -0,0 +1,22 @@ +libp2p.kad\_dht.pb package +========================== + +Submodules +---------- + +libp2p.kad_dht.pb.kademlia_pb2 module +------------------------------------- + +.. automodule:: libp2p.kad_dht.pb.kademlia_pb2 + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.kad_dht.pb + :no-index: + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.kad_dht.rst b/docs/libp2p.kad_dht.rst new file mode 100644 index 00000000..ff59ee5d --- /dev/null +++ b/docs/libp2p.kad_dht.rst @@ -0,0 +1,77 @@ +libp2p.kad\_dht package +======================= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + libp2p.kad_dht.pb + +Submodules +---------- + +libp2p.kad\_dht.kad\_dht module +------------------------------- + +.. automodule:: libp2p.kad_dht.kad_dht + :members: + :undoc-members: + :show-inheritance: + +libp2p.kad\_dht.peer\_routing module +------------------------------------ + +.. automodule:: libp2p.kad_dht.peer_routing + :members: + :undoc-members: + :show-inheritance: + +libp2p.kad\_dht.provider\_store module +-------------------------------------- + +.. automodule:: libp2p.kad_dht.provider_store + :members: + :undoc-members: + :show-inheritance: + +libp2p.kad\_dht.routing\_table module +------------------------------------- + +.. automodule:: libp2p.kad_dht.routing_table + :members: + :undoc-members: + :show-inheritance: + +libp2p.kad\_dht.utils module +---------------------------- + +.. automodule:: libp2p.kad_dht.utils + :members: + :undoc-members: + :show-inheritance: + +libp2p.kad\_dht.value\_store module +----------------------------------- + +.. automodule:: libp2p.kad_dht.value_store + :members: + :undoc-members: + :show-inheritance: + +libp2p.kad\_dht.pb +------------------ + +.. automodule:: libp2p.kad_dht.pb + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.kad_dht + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.rst b/docs/libp2p.rst index b82676d1..7f62e6d7 100644 --- a/docs/libp2p.rst +++ b/docs/libp2p.rst @@ -11,6 +11,7 @@ Subpackages libp2p.host libp2p.identity libp2p.io + libp2p.kad_dht libp2p.network libp2p.peer libp2p.protocol_muxer diff --git a/examples/kademlia/kademlia.py b/examples/kademlia/kademlia.py new file mode 100644 index 00000000..ada81d87 --- /dev/null +++ b/examples/kademlia/kademlia.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python + +""" +A basic example of using the Kademlia DHT implementation, with all setup logic inlined. +This example demonstrates both value storage/retrieval and content server +advertisement/discovery. +""" + +import argparse +import logging +import os +import random +import secrets +import sys + +import base58 +from multiaddr import ( + Multiaddr, +) +import trio + +from libp2p import ( + new_host, +) +from libp2p.abc import ( + IHost, +) +from libp2p.crypto.secp256k1 import ( + create_new_key_pair, +) +from libp2p.kad_dht.kad_dht import ( + DHTMode, + KadDHT, +) +from libp2p.kad_dht.utils import ( + create_key_from_binary, +) +from libp2p.tools.async_service import ( + background_trio_service, +) +from libp2p.tools.utils import ( + info_from_p2p_addr, +) + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger("kademlia-example") + +# Configure DHT module loggers to inherit from the parent logger +# This ensures all kademlia-example.* loggers use the same configuration +# Get the directory where this script is located +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt") + +# Set the level for all child loggers +for module in [ + "kad_dht", + "value_store", + "peer_routing", + "routing_table", + "provider_store", +]: + child_logger = logging.getLogger(f"kademlia-example.{module}") + child_logger.setLevel(logging.INFO) + child_logger.propagate = True # Allow propagation to parent + +# File to store node information +bootstrap_nodes = [] + + +# function to take bootstrap_nodes as input and connects to them +async def connect_to_bootstrap_nodes(host: IHost, bootstrap_addrs: list[str]) -> None: + """ + Connect to the bootstrap nodes provided in the list. + + params: host: The host instance to connect to + bootstrap_addrs: List of bootstrap node addresses + + Returns + ------- + None + + """ + for addr in bootstrap_addrs: + try: + peerInfo = info_from_p2p_addr(Multiaddr(addr)) + host.get_peerstore().add_addrs(peerInfo.peer_id, peerInfo.addrs, 3600) + await host.connect(peerInfo) + except Exception as e: + logger.error(f"Failed to connect to bootstrap node {addr}: {e}") + + +def save_server_addr(addr: str) -> None: + """Append the server's multiaddress to the log file.""" + try: + with open(SERVER_ADDR_LOG, "w") as f: + f.write(addr + "\n") + logger.info(f"Saved server address to log: {addr}") + except Exception as e: + logger.error(f"Failed to save server address: {e}") + + +def load_server_addrs() -> list[str]: + """Load all server multiaddresses from the log file.""" + if not os.path.exists(SERVER_ADDR_LOG): + return [] + try: + with open(SERVER_ADDR_LOG) as f: + return [line.strip() for line in f if line.strip()] + except Exception as e: + logger.error(f"Failed to load server addresses: {e}") + return [] + + +async def run_node( + port: int, mode: str, bootstrap_addrs: list[str] | None = None +) -> None: + """Run a node that serves content in the DHT with setup inlined.""" + try: + if port <= 0: + port = random.randint(10000, 60000) + logger.debug(f"Using port: {port}") + + # Convert string mode to DHTMode enum + if mode is None or mode.upper() == "CLIENT": + dht_mode = DHTMode.CLIENT + elif mode.upper() == "SERVER": + dht_mode = DHTMode.SERVER + else: + logger.error(f"Invalid mode: {mode}. Must be 'client' or 'server'") + sys.exit(1) + + # Load server addresses for client mode + if dht_mode == DHTMode.CLIENT: + server_addrs = load_server_addrs() + if server_addrs: + logger.info(f"Loaded {len(server_addrs)} server addresses from log") + bootstrap_nodes.append(server_addrs[0]) # Use the first server address + else: + logger.warning("No server addresses found in log file") + + if bootstrap_addrs: + for addr in bootstrap_addrs: + bootstrap_nodes.append(addr) + + key_pair = create_new_key_pair(secrets.token_bytes(32)) + host = new_host(key_pair=key_pair) + listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}") + + async with host.run(listen_addrs=[listen_addr]): + peer_id = host.get_id().pretty() + addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}" + await connect_to_bootstrap_nodes(host, bootstrap_nodes) + dht = KadDHT(host, dht_mode) + # take all peer ids from the host and add them to the dht + for peer_id in host.get_peerstore().peer_ids(): + await dht.routing_table.add_peer(peer_id) + logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}") + bootstrap_cmd = f"--bootstrap {addr_str}" + logger.info("To connect to this node, use: %s", bootstrap_cmd) + + # Save server address in server mode + if dht_mode == DHTMode.SERVER: + save_server_addr(addr_str) + + # Start the DHT service + async with background_trio_service(dht): + logger.info(f"DHT service started in {dht_mode.value} mode") + val_key = create_key_from_binary(b"py-libp2p kademlia example value") + content = b"Hello from python node " + content_key = create_key_from_binary(content) + + if dht_mode == DHTMode.SERVER: + # Store a value in the DHT + msg = "Hello message from Sumanjeet" + val_data = msg.encode() + await dht.put_value(val_key, val_data) + logger.info( + f"Stored value '{val_data.decode()}'" + f"with key: {base58.b58encode(val_key).decode()}" + ) + + # Advertise as content server + success = await dht.provider_store.provide(content_key) + if success: + logger.info( + "Successfully advertised as server" + f"for content: {content_key.hex()}" + ) + else: + logger.warning("Failed to advertise as content server") + + else: + # retrieve the value + logger.info( + "Looking up key: %s", base58.b58encode(val_key).decode() + ) + val_data = await dht.get_value(val_key) + if val_data: + try: + logger.info(f"Retrieved value: {val_data.decode()}") + except UnicodeDecodeError: + logger.info(f"Retrieved value (bytes): {val_data!r}") + else: + logger.warning("Failed to retrieve value") + + # Also check if we can find servers for our own content + logger.info("Looking for servers of content: %s", content_key.hex()) + providers = await dht.provider_store.find_providers(content_key) + if providers: + logger.info( + "Found %d servers for content: %s", + len(providers), + [p.peer_id.pretty() for p in providers], + ) + else: + logger.warning( + "No servers found for content %s", content_key.hex() + ) + + # Keep the node running + while True: + logger.debug( + "Status - Connected peers: %d," + "Peers in store: %d, Values in store: %d", + len(dht.host.get_connected_peers()), + len(dht.host.get_peerstore().peer_ids()), + len(dht.value_store.store), + ) + await trio.sleep(10) + + except Exception as e: + logger.error(f"Server node error: {e}", exc_info=True) + sys.exit(1) + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Kademlia DHT example with content server functionality" + ) + parser.add_argument( + "--mode", + default="server", + help="Run as a server or client node", + ) + parser.add_argument( + "--port", + type=int, + default=0, + help="Port to listen on (0 for random)", + ) + parser.add_argument( + "--bootstrap", + type=str, + nargs="*", + help=( + "Multiaddrs of bootstrap nodes. " + "Provide a space-separated list of addresses. " + "This is required for client mode." + ), + ) + # add option to use verbose logging + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + # Set logging level based on verbosity + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + else: + logging.getLogger().setLevel(logging.INFO) + + return args + + +def main(): + """Main entry point for the kademlia demo.""" + try: + args = parse_args() + logger.info( + "Running in %s mode on port %d", + args.mode, + args.port, + ) + trio.run(run_node, args.port, args.mode, args.bootstrap) + except Exception as e: + logger.critical(f"Script failed: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/libp2p/kad_dht/__init__.py b/libp2p/kad_dht/__init__.py new file mode 100644 index 00000000..690d37ba --- /dev/null +++ b/libp2p/kad_dht/__init__.py @@ -0,0 +1,30 @@ +""" +Kademlia DHT implementation for py-libp2p. + +This module provides a Distributed Hash Table (DHT) implementation +based on the Kademlia protocol. +""" + +from .kad_dht import ( + KadDHT, +) +from .peer_routing import ( + PeerRouting, +) +from .routing_table import ( + RoutingTable, +) +from .utils import ( + create_key_from_binary, +) +from .value_store import ( + ValueStore, +) + +__all__ = [ + "KadDHT", + "RoutingTable", + "PeerRouting", + "ValueStore", + "create_key_from_binary", +] diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py new file mode 100644 index 00000000..7daad4cb --- /dev/null +++ b/libp2p/kad_dht/kad_dht.py @@ -0,0 +1,616 @@ +""" +Kademlia DHT implementation for py-libp2p. + +This module provides a complete Distributed Hash Table (DHT) +implementation based on the Kademlia algorithm and protocol. +""" + +from enum import Enum +import logging +import time + +from multiaddr import ( + Multiaddr, +) +import trio +import varint + +from libp2p.abc import ( + IHost, +) +from libp2p.custom_types import ( + TProtocol, +) +from libp2p.network.stream.net_stream import ( + INetStream, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) +from libp2p.tools.async_service import ( + Service, +) + +from .pb.kademlia_pb2 import ( + Message, +) +from .peer_routing import ( + PeerRouting, +) +from .provider_store import ( + ProviderStore, +) +from .routing_table import ( + RoutingTable, +) +from .value_store import ( + ValueStore, +) + +logger = logging.getLogger("kademlia-example.kad_dht") +# logger = logging.getLogger("libp2p.kademlia") +# Default parameters +PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0") +ROUTING_TABLE_REFRESH_INTERVAL = 1 * 60 # 1 min in seconds for testing +TTL = 24 * 60 * 60 # 24 hours in seconds +ALPHA = 3 +QUERY_TIMEOUT = 10 # seconds + + +class DHTMode(Enum): + """DHT operation modes.""" + + CLIENT = "CLIENT" + SERVER = "SERVER" + + +class KadDHT(Service): + """ + Kademlia DHT implementation for libp2p. + + This class provides a DHT implementation that combines routing table management, + peer discovery, content routing, and value storage. + """ + + def __init__(self, host: IHost, mode: DHTMode): + """ + Initialize a new Kademlia DHT node. + + :param host: The libp2p host. + :param mode: The mode of host (Client or Server) - must be DHTMode enum + """ + super().__init__() + + self.host = host + self.local_peer_id = host.get_id() + + # Validate that mode is a DHTMode enum + if not isinstance(mode, DHTMode): + raise TypeError(f"mode must be DHTMode enum, got {type(mode)}") + + self.mode = mode + + # Initialize the routing table + self.routing_table = RoutingTable(self.local_peer_id, self.host) + + # Initialize peer routing + self.peer_routing = PeerRouting(host, self.routing_table) + + # Initialize value store + self.value_store = ValueStore(host=host, local_peer_id=self.local_peer_id) + + # Initialize provider store with host and peer_routing references + self.provider_store = ProviderStore(host=host, peer_routing=self.peer_routing) + + # Last time we republished provider records + self._last_provider_republish = time.time() + + # Set protocol handlers + host.set_stream_handler(PROTOCOL_ID, self.handle_stream) + + async def run(self) -> None: + """Run the DHT service.""" + logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}") + + # Main service loop + while self.manager.is_running: + # Periodically refresh the routing table + await self.refresh_routing_table() + + # Check if it's time to republish provider records + current_time = time.time() + # await self._republish_provider_records() + self._last_provider_republish = current_time + + # Clean up expired values and provider records + expired_values = self.value_store.cleanup_expired() + if expired_values > 0: + logger.debug(f"Cleaned up {expired_values} expired values") + + self.provider_store.cleanup_expired() + + # Wait before next maintenance cycle + await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL) + + async def switch_mode(self, new_mode: DHTMode) -> DHTMode: + """ + Switch the DHT mode. + + :param new_mode: The new mode - must be DHTMode enum + :return: The new mode as DHTMode enum + """ + # Validate that new_mode is a DHTMode enum + if not isinstance(new_mode, DHTMode): + raise TypeError(f"new_mode must be DHTMode enum, got {type(new_mode)}") + + if new_mode == DHTMode.CLIENT: + self.routing_table.cleanup_routing_table() + self.mode = new_mode + logger.info(f"Switched to {new_mode.value} mode") + return self.mode + + async def handle_stream(self, stream: INetStream) -> None: + """ + Handle an incoming DHT stream using varint length prefixes. + """ + if self.mode == DHTMode.CLIENT: + stream.close + return + peer_id = stream.muxed_conn.peer_id + logger.debug(f"Received DHT stream from peer {peer_id}") + await self.add_peer(peer_id) + logger.debug(f"Added peer {peer_id} to routing table") + + try: + # Read varint-prefixed length for the message + length_prefix = b"" + while True: + byte = await stream.read(1) + if not byte: + logger.warning("Stream closed while reading varint length") + await stream.close() + return + length_prefix += byte + if byte[0] & 0x80 == 0: + break + msg_length = varint.decode_bytes(length_prefix) + + # Read the message bytes + msg_bytes = await stream.read(msg_length) + if len(msg_bytes) < msg_length: + logger.warning("Failed to read full message from stream") + await stream.close() + return + + try: + # Parse as protobuf + message = Message() + message.ParseFromString(msg_bytes) + logger.debug( + f"Received DHT message from {peer_id}, type: {message.type}" + ) + + # Handle FIND_NODE message + if message.type == Message.MessageType.FIND_NODE: + # Get target key directly from protobuf + target_key = message.key + + # Find closest peers to the target key + closest_peers = self.routing_table.find_local_closest_peers( + target_key, 20 + ) + logger.debug(f"Found {len(closest_peers)} peers close to target") + + # Build response message with protobuf + response = Message() + response.type = Message.MessageType.FIND_NODE + + # Add closest peers to response + for peer in closest_peers: + # Skip if the peer is the requester + if peer == peer_id: + continue + + # Add peer to closerPeers field + peer_proto = response.closerPeers.add() + peer_proto.id = peer.to_bytes() + peer_proto.connection = Message.ConnectionType.CAN_CONNECT + + # Add addresses if available + try: + addrs = self.host.get_peerstore().addrs(peer) + if addrs: + for addr in addrs: + peer_proto.addrs.append(addr.to_bytes()) + except Exception: + pass + + # Serialize and send response + response_bytes = response.SerializeToString() + await stream.write(varint.encode(len(response_bytes))) + await stream.write(response_bytes) + logger.debug( + f"Sent FIND_NODE response with{len(response.closerPeers)} peers" + ) + + # Handle ADD_PROVIDER message + elif message.type == Message.MessageType.ADD_PROVIDER: + # Process ADD_PROVIDER + key = message.key + logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") + + # Extract provider information + for provider_proto in message.providerPeers: + try: + # Validate that the provider is the sender + provider_id = ID(provider_proto.id) + if provider_id != peer_id: + logger.warning( + f"Provider ID {provider_id} doesn't" + f"match sender {peer_id}, ignoring" + ) + continue + + # Convert addresses to Multiaddr + addrs = [] + for addr_bytes in provider_proto.addrs: + try: + addrs.append(Multiaddr(addr_bytes)) + except Exception as e: + logger.warning(f"Failed to parse address: {e}") + + # Add to provider store + provider_info = PeerInfo(provider_id, addrs) + self.provider_store.add_provider(key, provider_info) + logger.debug( + f"Added provider {provider_id} for key {key.hex()}" + ) + except Exception as e: + logger.warning(f"Failed to process provider info: {e}") + + # Send acknowledgement + response = Message() + response.type = Message.MessageType.ADD_PROVIDER + response.key = key + + response_bytes = response.SerializeToString() + await stream.write(varint.encode(len(response_bytes))) + await stream.write(response_bytes) + logger.debug("Sent ADD_PROVIDER acknowledgement") + + # Handle GET_PROVIDERS message + elif message.type == Message.MessageType.GET_PROVIDERS: + # Process GET_PROVIDERS + key = message.key + logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") + + # Find providers for the key + providers = self.provider_store.get_providers(key) + logger.debug( + f"Found {len(providers)} providers for key {key.hex()}" + ) + + # Create response + response = Message() + response.type = Message.MessageType.GET_PROVIDERS + response.key = key + + # Add provider information to response + for provider_info in providers: + provider_proto = response.providerPeers.add() + provider_proto.id = provider_info.peer_id.to_bytes() + provider_proto.connection = Message.ConnectionType.CAN_CONNECT + + # Add addresses if available + for addr in provider_info.addrs: + provider_proto.addrs.append(addr.to_bytes()) + + # Also include closest peers if we don't have providers + if not providers: + closest_peers = self.routing_table.find_local_closest_peers( + key, 20 + ) + logger.debug( + f"No providers found, including {len(closest_peers)}" + "closest peers" + ) + + for peer in closest_peers: + # Skip if peer is the requester + if peer == peer_id: + continue + + peer_proto = response.closerPeers.add() + peer_proto.id = peer.to_bytes() + peer_proto.connection = Message.ConnectionType.CAN_CONNECT + + # Add addresses if available + try: + addrs = self.host.get_peerstore().addrs(peer) + for addr in addrs: + peer_proto.addrs.append(addr.to_bytes()) + except Exception: + pass + + # Serialize and send response + response_bytes = response.SerializeToString() + await stream.write(varint.encode(len(response_bytes))) + await stream.write(response_bytes) + logger.debug("Sent GET_PROVIDERS response") + + # Handle GET_VALUE message + elif message.type == Message.MessageType.GET_VALUE: + # Process GET_VALUE + key = message.key + logger.debug(f"Received GET_VALUE request for key {key.hex()}") + + value = self.value_store.get(key) + if value: + logger.debug(f"Found value for key {key.hex()}") + + # Create response using protobuf + response = Message() + response.type = Message.MessageType.GET_VALUE + + # Create record + response.key = key + response.record.key = key + response.record.value = value + response.record.timeReceived = str(time.time()) + + # Serialize and send response + response_bytes = response.SerializeToString() + await stream.write(varint.encode(len(response_bytes))) + await stream.write(response_bytes) + logger.debug("Sent GET_VALUE response") + else: + logger.debug(f"No value found for key {key.hex()}") + + # Create response with closest peers when no value is found + response = Message() + response.type = Message.MessageType.GET_VALUE + response.key = key + + # Add closest peers to key + closest_peers = self.routing_table.find_local_closest_peers( + key, 20 + ) + logger.debug( + "No value found," + f"including {len(closest_peers)} closest peers" + ) + + for peer in closest_peers: + # Skip if peer is the requester + if peer == peer_id: + continue + + peer_proto = response.closerPeers.add() + peer_proto.id = peer.to_bytes() + peer_proto.connection = Message.ConnectionType.CAN_CONNECT + + # Add addresses if available + try: + addrs = self.host.get_peerstore().addrs(peer) + for addr in addrs: + peer_proto.addrs.append(addr.to_bytes()) + except Exception: + pass + + # Serialize and send response + response_bytes = response.SerializeToString() + await stream.write(varint.encode(len(response_bytes))) + await stream.write(response_bytes) + logger.debug("Sent GET_VALUE response with closest peers") + + # Handle PUT_VALUE message + elif message.type == Message.MessageType.PUT_VALUE and message.HasField( + "record" + ): + # Process PUT_VALUE + key = message.record.key + value = message.record.value + success = False + try: + if not (key and value): + raise ValueError( + "Missing key or value in PUT_VALUE message" + ) + + self.value_store.put(key, value) + logger.debug(f"Stored value {value.hex()} for key {key.hex()}") + success = True + except Exception as e: + logger.warning( + f"Failed to store value {value.hex()} for key " + f"{key.hex()}: {e}" + ) + finally: + # Send acknowledgement + response = Message() + response.type = Message.MessageType.PUT_VALUE + if success: + response.key = key + response_bytes = response.SerializeToString() + await stream.write(varint.encode(len(response_bytes))) + await stream.write(response_bytes) + logger.debug("Sent PUT_VALUE acknowledgement") + + except Exception as proto_err: + logger.warning(f"Failed to parse protobuf message: {proto_err}") + + await stream.close() + except Exception as e: + logger.error(f"Error handling DHT stream: {e}") + await stream.close() + + async def refresh_routing_table(self) -> None: + """Refresh the routing table.""" + logger.debug("Refreshing routing table") + await self.peer_routing.refresh_routing_table() + + # Peer routing methods + + async def find_peer(self, peer_id: ID) -> PeerInfo | None: + """ + Find a peer with the given ID. + """ + logger.debug(f"Finding peer: {peer_id}") + return await self.peer_routing.find_peer(peer_id) + + # Value storage and retrieval methods + + async def put_value(self, key: bytes, value: bytes) -> None: + """ + Store a value in the DHT. + """ + logger.debug(f"Storing value for key {key.hex()}") + + # 1. Store locally first + self.value_store.put(key, value) + try: + decoded_value = value.decode("utf-8") + except UnicodeDecodeError: + decoded_value = value.hex() + logger.debug( + f"Stored value locally for key {key.hex()} with value {decoded_value}" + ) + + # 2. Get closest peers, excluding self + closest_peers = [ + peer + for peer in self.routing_table.find_local_closest_peers(key) + if peer != self.local_peer_id + ] + logger.debug(f"Found {len(closest_peers)} peers to store value at") + + # 3. Store at remote peers in batches of ALPHA, in parallel + stored_count = 0 + for i in range(0, len(closest_peers), ALPHA): + batch = closest_peers[i : i + ALPHA] + batch_results = [False] * len(batch) + + async def store_one(idx: int, peer: ID) -> None: + try: + with trio.move_on_after(QUERY_TIMEOUT): + success = await self.value_store._store_at_peer( + peer, key, value + ) + batch_results[idx] = success + if success: + logger.debug(f"Stored value at peer {peer}") + else: + logger.debug(f"Failed to store value at peer {peer}") + except Exception as e: + logger.debug(f"Error storing value at peer {peer}: {e}") + + async with trio.open_nursery() as nursery: + for idx, peer in enumerate(batch): + nursery.start_soon(store_one, idx, peer) + + stored_count += sum(batch_results) + + logger.info(f"Successfully stored value at {stored_count} peers") + + async def get_value(self, key: bytes) -> bytes | None: + logger.debug(f"Getting value for key: {key.hex()}") + + # 1. Check local store first + value = self.value_store.get(key) + if value: + logger.debug("Found value locally") + return value + + # 2. Get closest peers, excluding self + closest_peers = [ + peer + for peer in self.routing_table.find_local_closest_peers(key) + if peer != self.local_peer_id + ] + logger.debug(f"Searching {len(closest_peers)} peers for value") + + # 3. Query ALPHA peers at a time in parallel + for i in range(0, len(closest_peers), ALPHA): + batch = closest_peers[i : i + ALPHA] + found_value = None + + async def query_one(peer: ID) -> None: + nonlocal found_value + try: + with trio.move_on_after(QUERY_TIMEOUT): + value = await self.value_store._get_from_peer(peer, key) + if value is not None and found_value is None: + found_value = value + logger.debug(f"Found value at peer {peer}") + except Exception as e: + logger.debug(f"Error querying peer {peer}: {e}") + + async with trio.open_nursery() as nursery: + for peer in batch: + nursery.start_soon(query_one, peer) + + if found_value is not None: + self.value_store.put(key, found_value) + logger.info("Successfully retrieved value from network") + return found_value + + # 4. Not found + logger.warning(f"Value not found for key {key.hex()}") + return None + + # Add these methods in the Utility methods section + + # Utility methods + + async def add_peer(self, peer_id: ID) -> bool: + """ + Add a peer to the routing table. + + params: peer_id: The peer ID to add. + + Returns + ------- + bool + True if peer was added or updated, False otherwise. + + """ + return await self.routing_table.add_peer(peer_id) + + async def provide(self, key: bytes) -> bool: + """ + Reference to provider_store.provide for convenience. + """ + return await self.provider_store.provide(key) + + async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]: + """ + Reference to provider_store.find_providers for convenience. + """ + return await self.provider_store.find_providers(key, count) + + def get_routing_table_size(self) -> int: + """ + Get the number of peers in the routing table. + + Returns + ------- + int + Number of peers. + + """ + return self.routing_table.size() + + def get_value_store_size(self) -> int: + """ + Get the number of items in the value store. + + Returns + ------- + int + Number of items. + + """ + return self.value_store.size() diff --git a/libp2p/kad_dht/pb/__init__.py b/libp2p/kad_dht/pb/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libp2p/kad_dht/pb/kademlia.proto b/libp2p/kad_dht/pb/kademlia.proto new file mode 100644 index 00000000..fd198d28 --- /dev/null +++ b/libp2p/kad_dht/pb/kademlia.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; + +message Record { + bytes key = 1; + bytes value = 2; + string timeReceived = 5; +}; + +message Message { + enum MessageType { + PUT_VALUE = 0; + GET_VALUE = 1; + ADD_PROVIDER = 2; + GET_PROVIDERS = 3; + FIND_NODE = 4; + PING = 5; + } + + enum ConnectionType { + NOT_CONNECTED = 0; + CONNECTED = 1; + CAN_CONNECT = 2; + CANNOT_CONNECT = 3; + } + + message Peer { + bytes id = 1; + repeated bytes addrs = 2; + ConnectionType connection = 3; + } + + MessageType type = 1; + int32 clusterLevelRaw = 10; + bytes key = 2; + Record record = 3; + repeated Peer closerPeers = 8; + repeated Peer providerPeers = 9; +} diff --git a/libp2p/kad_dht/pb/kademlia_pb2.py b/libp2p/kad_dht/pb/kademlia_pb2.py new file mode 100644 index 00000000..1fe2c032 --- /dev/null +++ b/libp2p/kad_dht/pb/kademlia_pb2.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: libp2p/kad_dht/pb/kademlia.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_RECORD']._serialized_start=36 + _globals['_RECORD']._serialized_end=94 + _globals['_MESSAGE']._serialized_start=97 + _globals['_MESSAGE']._serialized_end=555 + _globals['_MESSAGE_PEER']._serialized_start=281 + _globals['_MESSAGE_PEER']._serialized_end=359 + _globals['_MESSAGE_MESSAGETYPE']._serialized_start=361 + _globals['_MESSAGE_MESSAGETYPE']._serialized_end=466 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555 +# @@protoc_insertion_point(module_scope) diff --git a/libp2p/kad_dht/pb/kademlia_pb2.pyi b/libp2p/kad_dht/pb/kademlia_pb2.pyi new file mode 100644 index 00000000..c8f16db2 --- /dev/null +++ b/libp2p/kad_dht/pb/kademlia_pb2.pyi @@ -0,0 +1,133 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing.final +class Record(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + TIMERECEIVED_FIELD_NUMBER: builtins.int + key: builtins.bytes + value: builtins.bytes + timeReceived: builtins.str + def __init__( + self, + *, + key: builtins.bytes = ..., + value: builtins.bytes = ..., + timeReceived: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ... + +global___Record = Record + +@typing.final +class Message(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class _MessageType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + PUT_VALUE: Message._MessageType.ValueType # 0 + GET_VALUE: Message._MessageType.ValueType # 1 + ADD_PROVIDER: Message._MessageType.ValueType # 2 + GET_PROVIDERS: Message._MessageType.ValueType # 3 + FIND_NODE: Message._MessageType.ValueType # 4 + PING: Message._MessageType.ValueType # 5 + + class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ... + PUT_VALUE: Message.MessageType.ValueType # 0 + GET_VALUE: Message.MessageType.ValueType # 1 + ADD_PROVIDER: Message.MessageType.ValueType # 2 + GET_PROVIDERS: Message.MessageType.ValueType # 3 + FIND_NODE: Message.MessageType.ValueType # 4 + PING: Message.MessageType.ValueType # 5 + + class _ConnectionType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + NOT_CONNECTED: Message._ConnectionType.ValueType # 0 + CONNECTED: Message._ConnectionType.ValueType # 1 + CAN_CONNECT: Message._ConnectionType.ValueType # 2 + CANNOT_CONNECT: Message._ConnectionType.ValueType # 3 + + class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ... + NOT_CONNECTED: Message.ConnectionType.ValueType # 0 + CONNECTED: Message.ConnectionType.ValueType # 1 + CAN_CONNECT: Message.ConnectionType.ValueType # 2 + CANNOT_CONNECT: Message.ConnectionType.ValueType # 3 + + @typing.final + class Peer(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ID_FIELD_NUMBER: builtins.int + ADDRS_FIELD_NUMBER: builtins.int + CONNECTION_FIELD_NUMBER: builtins.int + id: builtins.bytes + connection: global___Message.ConnectionType.ValueType + @property + def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + def __init__( + self, + *, + id: builtins.bytes = ..., + addrs: collections.abc.Iterable[builtins.bytes] | None = ..., + connection: global___Message.ConnectionType.ValueType = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ... + + TYPE_FIELD_NUMBER: builtins.int + CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int + KEY_FIELD_NUMBER: builtins.int + RECORD_FIELD_NUMBER: builtins.int + CLOSERPEERS_FIELD_NUMBER: builtins.int + PROVIDERPEERS_FIELD_NUMBER: builtins.int + type: global___Message.MessageType.ValueType + clusterLevelRaw: builtins.int + key: builtins.bytes + @property + def record(self) -> global___Record: ... + @property + def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... + @property + def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... + def __init__( + self, + *, + type: global___Message.MessageType.ValueType = ..., + clusterLevelRaw: builtins.int = ..., + key: builtins.bytes = ..., + record: global___Record | None = ..., + closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., + providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ... + +global___Message = Message diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py new file mode 100644 index 00000000..f3689e11 --- /dev/null +++ b/libp2p/kad_dht/peer_routing.py @@ -0,0 +1,418 @@ +""" +Peer routing implementation for Kademlia DHT. + +This module implements the peer routing interface using Kademlia's algorithm +to efficiently locate peers in a distributed network. +""" + +import logging + +import trio +import varint + +from libp2p.abc import ( + IHost, + INetStream, + IPeerRouting, +) +from libp2p.custom_types import ( + TProtocol, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) + +from .pb.kademlia_pb2 import ( + Message, +) +from .routing_table import ( + RoutingTable, +) +from .utils import ( + sort_peer_ids_by_distance, +) + +# logger = logging.getLogger("libp2p.kademlia.peer_routing") +logger = logging.getLogger("kademlia-example.peer_routing") + +# Constants for the Kademlia algorithm +ALPHA = 3 # Concurrency parameter +MAX_PEER_LOOKUP_ROUNDS = 20 # Maximum number of rounds in peer lookup +PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0") + + +class PeerRouting(IPeerRouting): + """ + Implementation of peer routing using the Kademlia algorithm. + + This class provides methods to find peers in the DHT network + and helps maintain the routing table. + """ + + def __init__(self, host: IHost, routing_table: RoutingTable): + """ + Initialize the peer routing service. + + :param host: The libp2p host + :param routing_table: The Kademlia routing table + + """ + self.host = host + self.routing_table = routing_table + self.protocol_id = PROTOCOL_ID + + async def find_peer(self, peer_id: ID) -> PeerInfo | None: + """ + Find a peer with the given ID. + + :param peer_id: The ID of the peer to find + + Returns + ------- + Optional[PeerInfo] + The peer information if found, None otherwise + + """ + # Check if this is actually our peer ID + if peer_id == self.host.get_id(): + try: + # Return our own peer info + return PeerInfo(peer_id, self.host.get_addrs()) + except Exception: + logger.exception("Error getting our own peer info") + return None + + # First check if the peer is in our routing table + peer_info = self.routing_table.get_peer_info(peer_id) + if peer_info: + logger.debug(f"Found peer {peer_id} in routing table") + return peer_info + + # Then check if the peer is in our peerstore + try: + addrs = self.host.get_peerstore().addrs(peer_id) + if addrs: + logger.debug(f"Found peer {peer_id} in peerstore") + return PeerInfo(peer_id, addrs) + except Exception: + pass + + # If not found locally, search the network + try: + closest_peers = await self.find_closest_peers_network(peer_id.to_bytes()) + logger.info(f"Closest peers found: {closest_peers}") + + # Check if we found the peer we're looking for + for found_peer in closest_peers: + if found_peer == peer_id: + try: + addrs = self.host.get_peerstore().addrs(found_peer) + if addrs: + return PeerInfo(found_peer, addrs) + except Exception: + pass + + except Exception as e: + logger.error(f"Error searching for peer {peer_id}: {e}") + + # Not found + logger.info(f"Peer {peer_id} not found") + return None + + async def _query_single_peer_for_closest( + self, peer: ID, target_key: bytes, new_peers: list[ID] + ) -> None: + """ + Query a single peer for closest peers and append results to the shared list. + + params: peer : ID + The peer to query + params: target_key : bytes + The target key to find closest peers for + params: new_peers : list[ID] + Shared list to append results to + + """ + try: + result = await self._query_peer_for_closest(peer, target_key) + # Add deduplication to prevent duplicate peers + for peer_id in result: + if peer_id not in new_peers: + new_peers.append(peer_id) + logger.debug( + "Queried peer %s for closest peers, got %d results (%d unique)", + peer, + len(result), + len([p for p in result if p not in new_peers[: -len(result)]]), + ) + except Exception as e: + logger.debug(f"Query to peer {peer} failed: {e}") + + async def find_closest_peers_network( + self, target_key: bytes, count: int = 20 + ) -> list[ID]: + """ + Find the closest peers to a target key in the entire network. + + Performs an iterative lookup by querying peers for their closest peers. + + Returns + ------- + list[ID] + Closest peer IDs + + """ + # Start with closest peers from our routing table + closest_peers = self.routing_table.find_local_closest_peers(target_key, count) + logger.debug("Local closest peers: %d found", len(closest_peers)) + queried_peers: set[ID] = set() + rounds = 0 + + # Return early if we have no peers to start with + if not closest_peers: + logger.warning("No local peers available for network lookup") + return [] + + # Iterative lookup until convergence + while rounds < MAX_PEER_LOOKUP_ROUNDS: + rounds += 1 + logger.debug(f"Lookup round {rounds}/{MAX_PEER_LOOKUP_ROUNDS}") + + # Find peers we haven't queried yet + peers_to_query = [p for p in closest_peers if p not in queried_peers] + if not peers_to_query: + logger.debug("No more unqueried peers available, ending lookup") + break # No more peers to query + + # Query these peers for their closest peers to target + peers_batch = peers_to_query[:ALPHA] # Limit to ALPHA peers at a time + + # Mark these peers as queried before we actually query them + for peer in peers_batch: + queried_peers.add(peer) + + # Run queries in parallel for this batch using trio nursery + new_peers: list[ID] = [] # Shared array to collect all results + + async with trio.open_nursery() as nursery: + for peer in peers_batch: + nursery.start_soon( + self._query_single_peer_for_closest, peer, target_key, new_peers + ) + + # If we got no new peers, we're done + if not new_peers: + logger.debug("No new peers discovered in this round, ending lookup") + break + + # Update our list of closest peers + all_candidates = closest_peers + new_peers + old_closest_peers = closest_peers[:] + closest_peers = sort_peer_ids_by_distance(target_key, all_candidates)[ + :count + ] + logger.debug(f"Updated closest peers count: {len(closest_peers)}") + + # Check if we made any progress (found closer peers) + if closest_peers == old_closest_peers: + logger.debug("No improvement in closest peers, ending lookup") + break + + logger.info( + f"Network lookup completed after {rounds} rounds, " + f"found {len(closest_peers)} peers" + ) + return closest_peers + + async def _query_peer_for_closest(self, peer: ID, target_key: bytes) -> list[ID]: + """ + Query a peer for their closest peers + to the target key using varint length prefix + """ + stream = None + results = [] + try: + # Add the peer to our routing table regardless of query outcome + try: + addrs = self.host.get_peerstore().addrs(peer) + if addrs: + peer_info = PeerInfo(peer, addrs) + await self.routing_table.add_peer(peer_info) + except Exception as e: + logger.debug(f"Failed to add peer {peer} to routing table: {e}") + + # Open a stream to the peer using the Kademlia protocol + logger.debug(f"Opening stream to {peer} for closest peers query") + try: + stream = await self.host.new_stream(peer, [self.protocol_id]) + logger.debug(f"Stream opened to {peer}") + except Exception as e: + logger.warning(f"Failed to open stream to {peer}: {e}") + return [] + + # Create and send FIND_NODE request using protobuf + find_node_msg = Message() + find_node_msg.type = Message.MessageType.FIND_NODE + find_node_msg.key = target_key # Set target key directly as bytes + + # Serialize and send the protobuf message with varint length prefix + proto_bytes = find_node_msg.SerializeToString() + logger.debug( + f"Sending FIND_NODE: {proto_bytes.hex()} (len={len(proto_bytes)})" + ) + await stream.write(varint.encode(len(proto_bytes))) + await stream.write(proto_bytes) + + # Read varint-prefixed response length + length_bytes = b"" + while True: + b = await stream.read(1) + if not b: + logger.warning( + "Error reading varint length from stream: connection closed" + ) + return [] + length_bytes += b + if b[0] & 0x80 == 0: + break + response_length = varint.decode_bytes(length_bytes) + + # Read response data + response_bytes = b"" + remaining = response_length + while remaining > 0: + chunk = await stream.read(remaining) + if not chunk: + logger.debug(f"Connection closed by peer {peer} while reading data") + return [] + response_bytes += chunk + remaining -= len(chunk) + + # Parse the protobuf response + response_msg = Message() + response_msg.ParseFromString(response_bytes) + logger.debug( + "Received response from %s with %d peers", + peer, + len(response_msg.closerPeers), + ) + + # Process closest peers from response + if response_msg.type == Message.MessageType.FIND_NODE: + for peer_data in response_msg.closerPeers: + new_peer_id = ID(peer_data.id) + if new_peer_id not in results: + results.append(new_peer_id) + if peer_data.addrs: + from multiaddr import ( + Multiaddr, + ) + + addrs = [Multiaddr(addr) for addr in peer_data.addrs] + self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600) + + except Exception as e: + logger.debug(f"Error querying peer {peer} for closest: {e}") + + finally: + if stream: + await stream.close() + return results + + async def _handle_kad_stream(self, stream: INetStream) -> None: + """ + Handle incoming Kademlia protocol streams. + + params: stream: The incoming stream + + Returns + ------- + None + + """ + try: + # Read message length + length_bytes = await stream.read(4) + if not length_bytes: + return + + message_length = int.from_bytes(length_bytes, byteorder="big") + + # Read message + message_bytes = await stream.read(message_length) + if not message_bytes: + return + + # Parse protobuf message + kad_message = Message() + try: + kad_message.ParseFromString(message_bytes) + + if kad_message.type == Message.MessageType.FIND_NODE: + # Get target key directly from protobuf message + target_key = kad_message.key + + # Find closest peers to target + closest_peers = self.routing_table.find_local_closest_peers( + target_key, 20 + ) + + # Create protobuf response + response = Message() + response.type = Message.MessageType.FIND_NODE + + # Add peer information to response + for peer_id in closest_peers: + peer_proto = response.closerPeers.add() + peer_proto.id = peer_id.to_bytes() + peer_proto.connection = Message.ConnectionType.CAN_CONNECT + + # Add addresses if available + try: + addrs = self.host.get_peerstore().addrs(peer_id) + if addrs: + for addr in addrs: + peer_proto.addrs.append(addr.to_bytes()) + except Exception: + pass + + # Send response + response_bytes = response.SerializeToString() + await stream.write(len(response_bytes).to_bytes(4, byteorder="big")) + await stream.write(response_bytes) + + except Exception as parse_err: + logger.error(f"Failed to parse protocol buffer message: {parse_err}") + + except Exception as e: + logger.debug(f"Error handling Kademlia stream: {e}") + finally: + await stream.close() + + async def refresh_routing_table(self) -> None: + """ + Refresh the routing table by performing lookups for random keys. + + Returns + ------- + None + + """ + logger.info("Refreshing routing table") + + # Perform a lookup for ourselves to populate the routing table + local_id = self.host.get_id() + closest_peers = await self.find_closest_peers_network(local_id.to_bytes()) + + # Add discovered peers to routing table + for peer_id in closest_peers: + try: + addrs = self.host.get_peerstore().addrs(peer_id) + if addrs: + peer_info = PeerInfo(peer_id, addrs) + await self.routing_table.add_peer(peer_info) + except Exception as e: + logger.debug(f"Failed to add discovered peer {peer_id}: {e}") diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py new file mode 100644 index 00000000..00ac6010 --- /dev/null +++ b/libp2p/kad_dht/provider_store.py @@ -0,0 +1,575 @@ +""" +Provider record storage for Kademlia DHT. + +This module implements the storage for content provider records in the Kademlia DHT. +""" + +import logging +import time +from typing import ( + Any, +) + +from multiaddr import ( + Multiaddr, +) +import trio +import varint + +from libp2p.abc import ( + IHost, +) +from libp2p.custom_types import ( + TProtocol, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) + +from .pb.kademlia_pb2 import ( + Message, +) + +# logger = logging.getLogger("libp2p.kademlia.provider_store") +logger = logging.getLogger("kademlia-example.provider_store") + +# Constants for provider records (based on IPFS standards) +PROVIDER_RECORD_REPUBLISH_INTERVAL = 22 * 60 * 60 # 22 hours in seconds +PROVIDER_RECORD_EXPIRATION_INTERVAL = 48 * 60 * 60 # 48 hours in seconds +PROVIDER_ADDRESS_TTL = 30 * 60 # 30 minutes in seconds +PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0") +ALPHA = 3 # Number of parallel queries/advertisements +QUERY_TIMEOUT = 10 # Timeout for each query in seconds + + +class ProviderRecord: + """ + A record for a content provider in the DHT. + + Contains the peer information and timestamp. + """ + + def __init__( + self, + provider_info: PeerInfo, + timestamp: float | None = None, + ) -> None: + """ + Initialize a new provider record. + + :param provider_info: The provider's peer information + :param timestamp: Time this record was created/updated + (defaults to current time) + + """ + self.provider_info = provider_info + self.timestamp = timestamp or time.time() + + def is_expired(self) -> bool: + """ + Check if this provider record has expired. + + Returns + ------- + bool + True if the record has expired + + """ + current_time = time.time() + return (current_time - self.timestamp) >= PROVIDER_RECORD_EXPIRATION_INTERVAL + + def should_republish(self) -> bool: + """ + Check if this provider record should be republished. + + Returns + ------- + bool + True if the record should be republished + + """ + current_time = time.time() + return (current_time - self.timestamp) >= PROVIDER_RECORD_REPUBLISH_INTERVAL + + @property + def peer_id(self) -> ID: + """Get the provider's peer ID.""" + return self.provider_info.peer_id + + @property + def addresses(self) -> list[Multiaddr]: + """Get the provider's addresses.""" + return self.provider_info.addrs + + +class ProviderStore: + """ + Store for content provider records in the Kademlia DHT. + + Maps content keys to provider records, with support for expiration. + """ + + def __init__(self, host: IHost, peer_routing: Any = None) -> None: + """ + Initialize a new provider store. + + :param host: The libp2p host instance (optional) + :param peer_routing: The peer routing instance (optional) + """ + # Maps content keys to a dict of provider records (peer_id -> record) + self.providers: dict[bytes, dict[str, ProviderRecord]] = {} + self.host = host + self.peer_routing = peer_routing + self.providing_keys: set[bytes] = set() + self.local_peer_id = host.get_id() + + async def _republish_provider_records(self) -> None: + """Republish all provider records for content this node is providing.""" + # First, republish keys we're actively providing + for key in self.providing_keys: + logger.debug(f"Republishing provider record for key {key.hex()}") + await self.provide(key) + + # Also check for any records that should be republished + time.time() + for key, providers in self.providers.items(): + for peer_id_str, record in providers.items(): + # Only republish records for our own peer + if self.local_peer_id and str(self.local_peer_id) == peer_id_str: + if record.should_republish(): + logger.debug( + f"Republishing old provider record for key {key.hex()}" + ) + await self.provide(key) + + async def provide(self, key: bytes) -> bool: + """ + Advertise that this node can provide a piece of content. + + Finds the k closest peers to the key and sends them ADD_PROVIDER messages. + + :param key: The content key (multihash) to advertise + + Returns + ------- + bool + True if the advertisement was successful + + """ + if not self.host or not self.peer_routing: + logger.error("Host or peer_routing not initialized, cannot provide content") + return False + + # Add to local provider store + local_addrs = [] + for addr in self.host.get_addrs(): + local_addrs.append(addr) + + local_peer_info = PeerInfo(self.host.get_id(), local_addrs) + self.add_provider(key, local_peer_info) + + # Track that we're providing this key + self.providing_keys.add(key) + + # Find the k closest peers to the key + closest_peers = await self.peer_routing.find_closest_peers_network(key) + logger.debug( + "Found %d peers close to key %s for provider advertisement", + len(closest_peers), + key.hex(), + ) + + # Send ADD_PROVIDER messages to these ALPHA peers in parallel. + success_count = 0 + for i in range(0, len(closest_peers), ALPHA): + batch = closest_peers[i : i + ALPHA] + results: list[bool] = [False] * len(batch) + + async def send_one( + idx: int, peer_id: ID, results: list[bool] = results + ) -> None: + if peer_id == self.local_peer_id: + return + try: + with trio.move_on_after(QUERY_TIMEOUT): + success = await self._send_add_provider(peer_id, key) + results[idx] = success + if not success: + logger.warning(f"Failed to send ADD_PROVIDER to {peer_id}") + except Exception as e: + logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") + + async with trio.open_nursery() as nursery: + for idx, peer_id in enumerate(batch): + nursery.start_soon(send_one, idx, peer_id, results) + success_count += sum(results) + + logger.info(f"Successfully advertised to {success_count} peers") + return success_count > 0 + + async def _send_add_provider(self, peer_id: ID, key: bytes) -> bool: + """ + Send ADD_PROVIDER message to a specific peer. + + :param peer_id: The peer to send the message to + :param key: The content key being provided + + Returns + ------- + bool + True if the message was successfully sent and acknowledged + + """ + try: + result = False + # Open a stream to the peer + stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)]) + + # Get our addresses to include in the message + addrs = [] + for addr in self.host.get_addrs(): + addrs.append(addr.to_bytes()) + + # Create the ADD_PROVIDER message + message = Message() + message.type = Message.MessageType.ADD_PROVIDER + message.key = key + + # Add our provider info + provider = message.providerPeers.add() + provider.id = self.local_peer_id.to_bytes() + provider.addrs.extend(addrs) + + # Serialize and send the message + proto_bytes = message.SerializeToString() + await stream.write(varint.encode(len(proto_bytes))) + await stream.write(proto_bytes) + logger.debug(f"Sent ADD_PROVIDER to {peer_id} for key {key.hex()}") + # Read response length prefix + length_bytes = b"" + while True: + logger.debug("Reading response length prefix in add provider") + b = await stream.read(1) + if not b: + return False + length_bytes += b + if b[0] & 0x80 == 0: + break + + response_length = varint.decode_bytes(length_bytes) + # Read response data + response_bytes = b"" + remaining = response_length + while remaining > 0: + chunk = await stream.read(remaining) + if not chunk: + return False + response_bytes += chunk + remaining -= len(chunk) + + # Parse response + response = Message() + response.ParseFromString(response_bytes) + + # Check response type + response.type == Message.MessageType.ADD_PROVIDER + if response.type: + result = True + + except Exception as e: + logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") + + finally: + await stream.close() + return result + + async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]: + """ + Find content providers for a given key. + + :param key: The content key to look for + :param count: Maximum number of providers to return + + Returns + ------- + List[PeerInfo] + List of content providers + + """ + if not self.host or not self.peer_routing: + logger.error("Host or peer_routing not initialized, cannot find providers") + return [] + + # Check local provider store first + local_providers = self.get_providers(key) + if local_providers: + logger.debug( + f"Found {len(local_providers)} providers locally for {key.hex()}" + ) + return local_providers[:count] + logger.debug("local providers are %s", local_providers) + + # Find the closest peers to the key + closest_peers = await self.peer_routing.find_closest_peers_network(key) + logger.debug( + f"Searching {len(closest_peers)} peers for providers of {key.hex()}" + ) + + # Query these peers for providers in batches of ALPHA, in parallel, with timeout + all_providers = [] + for i in range(0, len(closest_peers), ALPHA): + batch = closest_peers[i : i + ALPHA] + batch_results: list[list[PeerInfo]] = [[] for _ in batch] + + async def get_one( + idx: int, + peer_id: ID, + batch_results: list[list[PeerInfo]] = batch_results, + ) -> None: + if peer_id == self.local_peer_id: + return + try: + with trio.move_on_after(QUERY_TIMEOUT): + providers = await self._get_providers_from_peer(peer_id, key) + if providers: + for provider in providers: + self.add_provider(key, provider) + batch_results[idx] = providers + else: + logger.debug(f"No providers found at peer {peer_id}") + except Exception as e: + logger.warning(f"Failed to get providers from {peer_id}: {e}") + + async with trio.open_nursery() as nursery: + for idx, peer_id in enumerate(batch): + nursery.start_soon(get_one, idx, peer_id, batch_results) + + for providers in batch_results: + all_providers.extend(providers) + if len(all_providers) >= count: + return all_providers[:count] + + return all_providers[:count] + + async def _get_providers_from_peer(self, peer_id: ID, key: bytes) -> list[PeerInfo]: + """ + Get content providers from a specific peer. + + :param peer_id: The peer to query + :param key: The content key to look for + + Returns + ------- + List[PeerInfo] + List of provider information + + """ + providers: list[PeerInfo] = [] + try: + # Open a stream to the peer + stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)]) + + try: + # Create the GET_PROVIDERS message + message = Message() + message.type = Message.MessageType.GET_PROVIDERS + message.key = key + + # Serialize and send the message + proto_bytes = message.SerializeToString() + await stream.write(varint.encode(len(proto_bytes))) + await stream.write(proto_bytes) + + # Read response length prefix + length_bytes = b"" + while True: + b = await stream.read(1) + if not b: + return [] + length_bytes += b + if b[0] & 0x80 == 0: + break + + response_length = varint.decode_bytes(length_bytes) + # Read response data + response_bytes = b"" + remaining = response_length + while remaining > 0: + chunk = await stream.read(remaining) + if not chunk: + return [] + response_bytes += chunk + remaining -= len(chunk) + + # Parse response + response = Message() + response.ParseFromString(response_bytes) + + # Check response type + if response.type != Message.MessageType.GET_PROVIDERS: + return [] + + # Extract provider information + providers = [] + for provider_proto in response.providerPeers: + try: + # Create peer ID from bytes + provider_id = ID(provider_proto.id) + + # Convert addresses to Multiaddr + addrs = [] + for addr_bytes in provider_proto.addrs: + try: + addrs.append(Multiaddr(addr_bytes)) + except Exception: + pass # Skip invalid addresses + + # Create PeerInfo and add to result + providers.append(PeerInfo(provider_id, addrs)) + except Exception as e: + logger.warning(f"Failed to parse provider info: {e}") + + finally: + await stream.close() + return providers + + except Exception as e: + logger.warning(f"Error getting providers from {peer_id}: {e}") + return [] + + def add_provider(self, key: bytes, provider: PeerInfo) -> None: + """ + Add a provider for a given content key. + + :param key: The content key + :param provider: The provider's peer information + + Returns + ------- + None + + """ + # Initialize providers for this key if needed + if key not in self.providers: + self.providers[key] = {} + + # Add or update the provider record + peer_id_str = str(provider.peer_id) # Use string representation as dict key + self.providers[key][peer_id_str] = ProviderRecord( + provider_info=provider, timestamp=time.time() + ) + logger.debug(f"Added provider {provider.peer_id} for key {key.hex()}") + + def get_providers(self, key: bytes) -> list[PeerInfo]: + """ + Get all providers for a given content key. + + :param key: The content key + + Returns + ------- + List[PeerInfo] + List of providers for the key + + """ + if key not in self.providers: + return [] + + # Collect valid provider records (not expired) + result = [] + current_time = time.time() + expired_peers = [] + + for peer_id_str, record in self.providers[key].items(): + # Check if the record has expired + if current_time - record.timestamp > PROVIDER_RECORD_EXPIRATION_INTERVAL: + expired_peers.append(peer_id_str) + continue + + # Use addresses only if they haven't expired + addresses = [] + if current_time - record.timestamp <= PROVIDER_ADDRESS_TTL: + addresses = record.addresses + + # Create PeerInfo and add to results + result.append(PeerInfo(record.peer_id, addresses)) + + # Clean up expired records + for peer_id in expired_peers: + del self.providers[key][peer_id] + + # Remove the key if no providers left + if not self.providers[key]: + del self.providers[key] + + return result + + def cleanup_expired(self) -> None: + """Remove expired provider records.""" + current_time = time.time() + expired_keys = [] + + for key, providers in self.providers.items(): + expired_providers = [] + + for peer_id_str, record in providers.items(): + if ( + current_time - record.timestamp + > PROVIDER_RECORD_EXPIRATION_INTERVAL + ): + expired_providers.append(peer_id_str) + logger.debug( + f"Removing expired provider {peer_id_str} for key {key.hex()}" + ) + + # Remove expired providers + for peer_id in expired_providers: + del providers[peer_id] + + # Track empty keys for removal + if not providers: + expired_keys.append(key) + + # Remove empty keys + for key in expired_keys: + del self.providers[key] + logger.debug(f"Removed key with no providers: {key.hex()}") + + def get_provided_keys(self, peer_id: ID) -> list[bytes]: + """ + Get all content keys provided by a specific peer. + + :param peer_id: The peer ID to look for + + Returns + ------- + List[bytes] + List of content keys provided by the peer + + """ + peer_id_str = str(peer_id) + result = [] + + for key, providers in self.providers.items(): + if peer_id_str in providers: + result.append(key) + + return result + + def size(self) -> int: + """ + Get the total number of provider records in the store. + + Returns + ------- + int + Total number of provider records across all keys + + """ + total = 0 + for providers in self.providers.values(): + total += len(providers) + return total diff --git a/libp2p/kad_dht/routing_table.py b/libp2p/kad_dht/routing_table.py new file mode 100644 index 00000000..4377c591 --- /dev/null +++ b/libp2p/kad_dht/routing_table.py @@ -0,0 +1,601 @@ +""" +Kademlia DHT routing table implementation. +""" + +from collections import ( + OrderedDict, +) +import logging +import time + +import trio + +from libp2p.abc import ( + IHost, +) +from libp2p.custom_types import ( + TProtocol, +) +from libp2p.kad_dht.utils import xor_distance +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) + +from .pb.kademlia_pb2 import ( + Message, +) + +# logger = logging.getLogger("libp2p.kademlia.routing_table") +logger = logging.getLogger("kademlia-example.routing_table") + +# Default parameters +BUCKET_SIZE = 20 # k in the Kademlia paper +MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys) +PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds +STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale + + +class KBucket: + """ + A k-bucket implementation for the Kademlia DHT. + + Each k-bucket stores up to k (BUCKET_SIZE) peers, sorted by least-recently seen. + """ + + def __init__( + self, + host: IHost, + bucket_size: int = BUCKET_SIZE, + min_range: int = 0, + max_range: int = 2**256, + ): + """ + Initialize a new k-bucket. + + :param host: The host this bucket belongs to + :param bucket_size: Maximum number of peers to store in the bucket + :param min_range: Lower boundary of the bucket's key range (inclusive) + :param max_range: Upper boundary of the bucket's key range (exclusive) + + """ + self.bucket_size = bucket_size + self.host = host + self.min_range = min_range + self.max_range = max_range + # Store PeerInfo objects along with last-seen timestamp + self.peers: OrderedDict[ID, tuple[PeerInfo, float]] = OrderedDict() + + def peer_ids(self) -> list[ID]: + """Get all peer IDs in the bucket.""" + return list(self.peers.keys()) + + def peer_infos(self) -> list[PeerInfo]: + """Get all PeerInfo objects in the bucket.""" + return [info for info, _ in self.peers.values()] + + def get_oldest_peer(self) -> ID | None: + """Get the least-recently seen peer.""" + if not self.peers: + return None + return next(iter(self.peers.keys())) + + async def add_peer(self, peer_info: PeerInfo) -> bool: + """ + Add a peer to the bucket. Returns True if the peer was added or updated, + False if the bucket is full. + """ + current_time = time.time() + peer_id = peer_info.peer_id + + # If peer is already in the bucket, move it to the end (most recently seen) + if peer_id in self.peers: + self.refresh_peer_last_seen(peer_id) + return True + + # If bucket has space, add the peer + if len(self.peers) < self.bucket_size: + self.peers[peer_id] = (peer_info, current_time) + return True + + # If bucket is full, we need to replace the least-recently seen peer + # Get the least-recently seen peer + oldest_peer_id = self.get_oldest_peer() + if oldest_peer_id is None: + logger.warning("No oldest peer found when bucket is full") + return False + + # Check if the old peer is responsive to ping request + try: + # Try to ping the oldest peer, not the new peer + response = await self._ping_peer(oldest_peer_id) + if response: + # If the old peer is still alive, we will not add the new peer + logger.debug( + "Old peer %s is still alive, cannot add new peer %s", + oldest_peer_id, + peer_id, + ) + return False + except Exception as e: + # If the old peer is unresponsive, we can replace it with the new peer + logger.debug( + "Old peer %s is unresponsive, replacing with new peer %s: %s", + oldest_peer_id, + peer_id, + str(e), + ) + self.peers.popitem(last=False) # Remove oldest peer + self.peers[peer_id] = (peer_info, current_time) + return True + + # If we got here, the oldest peer responded but we couldn't add the new peer + return False + + def remove_peer(self, peer_id: ID) -> bool: + """ + Remove a peer from the bucket. + Returns True if the peer was in the bucket, False otherwise. + """ + if peer_id in self.peers: + del self.peers[peer_id] + return True + return False + + def has_peer(self, peer_id: ID) -> bool: + """Check if the peer is in the bucket.""" + return peer_id in self.peers + + def get_peer_info(self, peer_id: ID) -> PeerInfo | None: + """Get the PeerInfo for a given peer ID if it exists in the bucket.""" + if peer_id in self.peers: + return self.peers[peer_id][0] + return None + + def size(self) -> int: + """Get the number of peers in the bucket.""" + return len(self.peers) + + def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]: + """ + Get peers that haven't been pinged recently. + + params: stale_threshold_seconds: Time in seconds + params: after which a peer is considered stale + + Returns + ------- + list[ID] + List of peer IDs that need to be refreshed + + """ + current_time = time.time() + stale_peers = [] + + for peer_id, (_, last_seen) in self.peers.items(): + if current_time - last_seen > stale_threshold_seconds: + stale_peers.append(peer_id) + + return stale_peers + + async def _periodic_peer_refresh(self) -> None: + """Background task to periodically refresh peers""" + try: + while True: + await trio.sleep(PEER_REFRESH_INTERVAL) # Check every minute + + # Find stale peers (not pinged in last hour) + stale_peers = self.get_stale_peers( + stale_threshold_seconds=STALE_PEER_THRESHOLD + ) + if stale_peers: + logger.debug(f"Found {len(stale_peers)} stale peers to refresh") + + for peer_id in stale_peers: + try: + # Try to ping the peer + logger.debug("Pinging stale peer %s", peer_id) + responce = await self._ping_peer(peer_id) + if responce: + # Update the last seen time + self.refresh_peer_last_seen(peer_id) + logger.debug(f"Refreshed peer {peer_id}") + else: + # If ping fails, remove the peer + logger.debug(f"Failed to ping peer {peer_id}") + self.remove_peer(peer_id) + logger.info(f"Removed unresponsive peer {peer_id}") + + logger.debug(f"Successfully refreshed peer {peer_id}") + except Exception as e: + # If ping fails, remove the peer + logger.debug( + "Failed to ping peer %s: %s", + peer_id, + e, + ) + self.remove_peer(peer_id) + logger.info(f"Removed unresponsive peer {peer_id}") + except trio.Cancelled: + logger.debug("Peer refresh task cancelled") + except Exception as e: + logger.error(f"Error in peer refresh task: {e}", exc_info=True) + + async def _ping_peer(self, peer_id: ID) -> bool: + """ + Ping a peer using protobuf message to check + if it's still alive and update last seen time. + + params: peer_id: The ID of the peer to ping + + Returns + ------- + bool + True if ping successful, False otherwise + + """ + result = False + # Get peer info directly from the bucket + peer_info = self.get_peer_info(peer_id) + if not peer_info: + raise ValueError(f"Peer {peer_id} not in bucket") + + # Default protocol ID for Kademlia DHT + protocol_id = TProtocol("/ipfs/kad/1.0.0") + + try: + # Open a stream to the peer with the DHT protocol + stream = await self.host.new_stream(peer_id, [protocol_id]) + + try: + # Create ping protobuf message + ping_msg = Message() + ping_msg.type = Message.PING # Use correct enum + + # Serialize and send with length prefix (4 bytes big-endian) + msg_bytes = ping_msg.SerializeToString() + logger.debug( + f"Sending PING message to {peer_id}, size: {len(msg_bytes)} bytes" + ) + await stream.write(len(msg_bytes).to_bytes(4, byteorder="big")) + await stream.write(msg_bytes) + + # Wait for response with timeout + with trio.move_on_after(2): # 2 second timeout + # Read response length (4 bytes) + length_bytes = await stream.read(4) + if not length_bytes or len(length_bytes) < 4: + logger.warning(f"Peer {peer_id} disconnected during ping") + return False + + msg_len = int.from_bytes(length_bytes, byteorder="big") + if ( + msg_len <= 0 or msg_len > 1024 * 1024 + ): # Sanity check on message size + logger.warning( + f"Invalid message length from {peer_id}: {msg_len}" + ) + return False + + logger.debug( + f"Receiving response from {peer_id}, size: {msg_len} bytes" + ) + + # Read full message + response_bytes = await stream.read(msg_len) + if not response_bytes: + logger.warning(f"Failed to read response from {peer_id}") + return False + + # Parse protobuf response + response = Message() + try: + response.ParseFromString(response_bytes) + except Exception as e: + logger.warning( + f"Failed to parse protobuf response from {peer_id}: {e}" + ) + return False + + if response.type == Message.PING: + # Update the last seen timestamp for this peer + logger.debug(f"Successfully pinged peer {peer_id}") + result = True + return result + + else: + logger.warning( + f"Unexpected response type from {peer_id}: {response.type}" + ) + return False + + # If we get here, the ping timed out + logger.warning(f"Ping to peer {peer_id} timed out") + return False + + finally: + await stream.close() + return result + + except Exception as e: + logger.error(f"Error pinging peer {peer_id}: {str(e)}") + return False + + def refresh_peer_last_seen(self, peer_id: ID) -> bool: + """ + Update the last-seen timestamp for a peer in the bucket. + + params: peer_id: The ID of the peer to refresh + + Returns + ------- + bool + True if the peer was found and refreshed, False otherwise + + """ + if peer_id in self.peers: + # Get current peer info and update the timestamp + peer_info, _ = self.peers[peer_id] + current_time = time.time() + self.peers[peer_id] = (peer_info, current_time) + # Move to end of ordered dict to mark as most recently seen + self.peers.move_to_end(peer_id) + return True + + return False + + def key_in_range(self, key: bytes) -> bool: + """ + Check if a key is in the range of this bucket. + + params: key: The key to check (bytes) + + Returns + ------- + bool + True if the key is in range, False otherwise + + """ + key_int = int.from_bytes(key, byteorder="big") + return self.min_range <= key_int < self.max_range + + def split(self) -> tuple["KBucket", "KBucket"]: + """ + Split the bucket into two buckets. + + Returns + ------- + tuple + (lower_bucket, upper_bucket) + + """ + midpoint = (self.min_range + self.max_range) // 2 + lower_bucket = KBucket(self.host, self.bucket_size, self.min_range, midpoint) + upper_bucket = KBucket(self.host, self.bucket_size, midpoint, self.max_range) + + # Redistribute peers + for peer_id, (peer_info, timestamp) in self.peers.items(): + peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big") + if peer_key < midpoint: + lower_bucket.peers[peer_id] = (peer_info, timestamp) + else: + upper_bucket.peers[peer_id] = (peer_info, timestamp) + + return lower_bucket, upper_bucket + + +class RoutingTable: + """ + The Kademlia routing table maintains information on which peers to contact for any + given peer ID in the network. + """ + + def __init__(self, local_id: ID, host: IHost) -> None: + """ + Initialize the routing table. + + :param local_id: The ID of the local node. + :param host: The host this routing table belongs to. + + """ + self.local_id = local_id + self.host = host + self.buckets = [KBucket(host, BUCKET_SIZE)] + + async def add_peer(self, peer_obj: PeerInfo | ID) -> bool: + """ + Add a peer to the routing table. + + :param peer_obj: Either PeerInfo object or peer ID to add + + Returns + ------- + bool: True if the peer was added or updated, False otherwise + + """ + peer_id = None + peer_info = None + + try: + # Handle different types of input + if isinstance(peer_obj, PeerInfo): + # Already have PeerInfo object + peer_info = peer_obj + peer_id = peer_obj.peer_id + else: + # Assume it's a peer ID + peer_id = peer_obj + # Try to get addresses from the peerstore if available + try: + addrs = self.host.get_peerstore().addrs(peer_id) + if addrs: + # Create PeerInfo object + peer_info = PeerInfo(peer_id, addrs) + else: + logger.debug( + "No addresses found for peer %s in peerstore, skipping", + peer_id, + ) + return False + except Exception as peerstore_error: + # Handle case where peer is not in peerstore yet + logger.debug( + "Peer %s not found in peerstore: %s, skipping", + peer_id, + str(peerstore_error), + ) + return False + + # Don't add ourselves + if peer_id == self.local_id: + return False + + # Find the right bucket for this peer + bucket = self.find_bucket(peer_id) + + # Try to add to the bucket + success = await bucket.add_peer(peer_info) + if success: + logger.debug(f"Successfully added peer {peer_id} to routing table") + return success + + except Exception as e: + logger.debug(f"Error adding peer {peer_obj} to routing table: {e}") + return False + + def remove_peer(self, peer_id: ID) -> bool: + """ + Remove a peer from the routing table. + + :param peer_id: The ID of the peer to remove + + Returns + ------- + bool: True if the peer was removed, False otherwise + + """ + bucket = self.find_bucket(peer_id) + return bucket.remove_peer(peer_id) + + def find_bucket(self, peer_id: ID) -> KBucket: + """ + Find the bucket that would contain the given peer ID or PeerInfo. + + :param peer_obj: Either a peer ID or a PeerInfo object + + Returns + ------- + KBucket: The bucket for this peer + + """ + for bucket in self.buckets: + if bucket.key_in_range(peer_id.to_bytes()): + return bucket + + return self.buckets[0] + + def find_local_closest_peers(self, key: bytes, count: int = 20) -> list[ID]: + """ + Find the closest peers to a given key. + + :param key: The key to find closest peers to (bytes) + :param count: Maximum number of peers to return + + Returns + ------- + List[ID]: List of peer IDs closest to the key + + """ + # Get all peers from all buckets + all_peers = [] + for bucket in self.buckets: + all_peers.extend(bucket.peer_ids()) + + # Sort by XOR distance to the key + all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key)) + + return all_peers[:count] + + def get_peer_ids(self) -> list[ID]: + """ + Get all peer IDs in the routing table. + + Returns + ------- + :param List[ID]: List of all peer IDs + + """ + peers = [] + for bucket in self.buckets: + peers.extend(bucket.peer_ids()) + return peers + + def get_peer_info(self, peer_id: ID) -> PeerInfo | None: + """ + Get the peer info for a specific peer. + + :param peer_id: The ID of the peer to get info for + + Returns + ------- + PeerInfo: The peer info, or None if not found + + """ + bucket = self.find_bucket(peer_id) + return bucket.get_peer_info(peer_id) + + def peer_in_table(self, peer_id: ID) -> bool: + """ + Check if a peer is in the routing table. + + :param peer_id: The ID of the peer to check + + Returns + ------- + bool: True if the peer is in the routing table, False otherwise + + """ + bucket = self.find_bucket(peer_id) + return bucket.has_peer(peer_id) + + def size(self) -> int: + """ + Get the number of peers in the routing table. + + Returns + ------- + int: Number of peers + + """ + count = 0 + for bucket in self.buckets: + count += bucket.size() + return count + + def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]: + """ + Get all stale peers from all buckets + + params: stale_threshold_seconds: + Time in seconds after which a peer is considered stale + + Returns + ------- + list[ID] + List of stale peer IDs + + """ + stale_peers = [] + for bucket in self.buckets: + stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds)) + return stale_peers + + def cleanup_routing_table(self) -> None: + """ + Cleanup the routing table by removing all data. + This is useful for resetting the routing table during tests or reinitialization. + """ + self.buckets = [KBucket(self.host, BUCKET_SIZE)] + logger.info("Routing table cleaned up, all data removed.") diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py new file mode 100644 index 00000000..61158320 --- /dev/null +++ b/libp2p/kad_dht/utils.py @@ -0,0 +1,117 @@ +""" +Utility functions for Kademlia DHT implementation. +""" + +import base58 +import multihash + +from libp2p.peer.id import ( + ID, +) + + +def create_key_from_binary(binary_data: bytes) -> bytes: + """ + Creates a key for the DHT by hashing binary data with SHA-256. + + params: binary_data: The binary data to hash. + + Returns + ------- + bytes: The resulting key. + + """ + return multihash.digest(binary_data, "sha2-256").digest + + +def xor_distance(key1: bytes, key2: bytes) -> int: + """ + Calculate the XOR distance between two keys. + + params: key1: First key (bytes) + params: key2: Second key (bytes) + + Returns + ------- + int: The XOR distance between the keys + + """ + # Ensure the inputs are bytes + if not isinstance(key1, bytes) or not isinstance(key2, bytes): + raise TypeError("Both key1 and key2 must be bytes objects") + + # Convert to integers + k1 = int.from_bytes(key1, byteorder="big") + k2 = int.from_bytes(key2, byteorder="big") + + # Calculate XOR distance + return k1 ^ k2 + + +def bytes_to_base58(data: bytes) -> str: + """ + Convert bytes to base58 encoded string. + + params: data: Input bytes + + Returns + ------- + str: Base58 encoded string + + """ + return base58.b58encode(data).decode("utf-8") + + +def sort_peer_ids_by_distance(target_key: bytes, peer_ids: list[ID]) -> list[ID]: + """ + Sort a list of peer IDs by their distance to the target key. + + params: target_key: The target key to measure distance from + params: peer_ids: List of peer IDs to sort + + Returns + ------- + List[ID]: Sorted list of peer IDs from closest to furthest + + """ + + def get_distance(peer_id: ID) -> int: + # Hash the peer ID bytes to get a key for distance calculation + peer_hash = multihash.digest(peer_id.to_bytes(), "sha2-256").digest + return xor_distance(target_key, peer_hash) + + return sorted(peer_ids, key=get_distance) + + +def shared_prefix_len(first: bytes, second: bytes) -> int: + """ + Calculate the number of prefix bits shared by two byte sequences. + + params: first: First byte sequence + params: second: Second byte sequence + + Returns + ------- + int: Number of shared prefix bits + + """ + # Compare each byte to find the first bit difference + common_length = 0 + for i in range(min(len(first), len(second))): + byte_first = first[i] + byte_second = second[i] + + if byte_first == byte_second: + common_length += 8 + else: + # Find specific bit where they differ + xor = byte_first ^ byte_second + # Count leading zeros in the xor result + for j in range(7, -1, -1): + if (xor >> j) & 1 == 1: + return common_length + (7 - j) + + # This shouldn't be reached if xor != 0 + return common_length + 8 + + return common_length diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py new file mode 100644 index 00000000..a2e54776 --- /dev/null +++ b/libp2p/kad_dht/value_store.py @@ -0,0 +1,393 @@ +""" +Value store implementation for Kademlia DHT. + +Provides a way to store and retrieve key-value pairs with optional expiration. +""" + +import logging +import time + +import varint + +from libp2p.abc import ( + IHost, +) +from libp2p.custom_types import ( + TProtocol, +) +from libp2p.peer.id import ( + ID, +) + +from .pb.kademlia_pb2 import ( + Message, +) + +# logger = logging.getLogger("libp2p.kademlia.value_store") +logger = logging.getLogger("kademlia-example.value_store") + +# Default time to live for values in seconds (24 hours) +DEFAULT_TTL = 24 * 60 * 60 +PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0") + + +class ValueStore: + """ + Store for key-value pairs in a Kademlia DHT. + + Values are stored with a timestamp and optional expiration time. + """ + + def __init__(self, host: IHost, local_peer_id: ID): + """ + Initialize an empty value store. + + :param host: The libp2p host instance. + :param local_peer_id: The local peer ID to ignore in peer requests. + + """ + # Store format: {key: (value, validity)} + self.store: dict[bytes, tuple[bytes, float]] = {} + # Store references to the host and local peer ID for making requests + self.host = host + self.local_peer_id = local_peer_id + + def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None: + """ + Store a value in the DHT. + + :param key: The key to store the value under + :param value: The value to store + :param validity: validity in seconds before the value expires. + Defaults to `DEFAULT_TTL` if set to 0.0. + + Returns + ------- + None + + """ + if validity == 0.0: + validity = time.time() + DEFAULT_TTL + logger.debug( + "Storing value for key %s... with validity %s", key.hex(), validity + ) + self.store[key] = (value, validity) + logger.debug(f"Stored value for key {key.hex()}") + + async def _store_at_peer(self, peer_id: ID, key: bytes, value: bytes) -> bool: + """ + Store a value at a specific peer. + + params: peer_id: The ID of the peer to store the value at + params: key: The key to store + params: value: The value to store + + Returns + ------- + bool + True if the value was successfully stored, False otherwise + + """ + result = False + stream = None + try: + # Don't try to store at ourselves + if self.local_peer_id and peer_id == self.local_peer_id: + result = True + return result + + if not self.host: + logger.error("Host not initialized, cannot store value at peer") + return False + + logger.debug(f"Storing value for key {key.hex()} at peer {peer_id}") + + # Open a stream to the peer + stream = await self.host.new_stream(peer_id, [PROTOCOL_ID]) + logger.debug(f"Opened stream to peer {peer_id}") + + # Create the PUT_VALUE message with protobuf + message = Message() + message.type = Message.MessageType.PUT_VALUE + + # Set message fields + message.key = key + message.record.key = key + message.record.value = value + message.record.timeReceived = str(time.time()) + + # Serialize and send the protobuf message with length prefix + proto_bytes = message.SerializeToString() + await stream.write(varint.encode(len(proto_bytes))) + await stream.write(proto_bytes) + logger.debug("Sent PUT_VALUE protobuf message with varint length") + # Read varint-prefixed response length + + length_bytes = b"" + while True: + logger.debug("Reading varint length prefix for response...") + b = await stream.read(1) + if not b: + logger.warning("Connection closed while reading varint length") + return False + length_bytes += b + if b[0] & 0x80 == 0: + break + logger.debug(f"Received varint length bytes: {length_bytes.hex()}") + response_length = varint.decode_bytes(length_bytes) + logger.debug("Response length: %d bytes", response_length) + # Read response data + response_bytes = b"" + remaining = response_length + while remaining > 0: + chunk = await stream.read(remaining) + if not chunk: + logger.debug( + f"Connection closed by peer {peer_id} while reading data" + ) + return False + response_bytes += chunk + remaining -= len(chunk) + + # Parse protobuf response + response = Message() + response.ParseFromString(response_bytes) + + # Check if response is valid + if response.type == Message.MessageType.PUT_VALUE: + if response.key: + result = True + return result + + except Exception as e: + logger.warning(f"Failed to store value at peer {peer_id}: {e}") + return False + + finally: + if stream: + await stream.close() + return result + + def get(self, key: bytes) -> bytes | None: + """ + Retrieve a value from the DHT. + + params: key: The key to look up + + Returns + ------- + Optional[bytes] + The stored value, or None if not found or expired + + """ + logger.debug("Retrieving value for key %s...", key.hex()[:8]) + if key not in self.store: + return None + + value, validity = self.store[key] + logger.debug( + "Found value for key %s... with validity %s", + key.hex(), + validity, + ) + # Check if the value has expired + if validity is not None and validity < time.time(): + logger.debug( + "Value for key %s... has expired, removing it", + key.hex()[:8], + ) + self.remove(key) + return None + + return value + + async def _get_from_peer(self, peer_id: ID, key: bytes) -> bytes | None: + """ + Retrieve a value from a specific peer. + + params: peer_id: The ID of the peer to retrieve the value from + params: key: The key to retrieve + + Returns + ------- + Optional[bytes] + The value if found, None otherwise + + """ + stream = None + try: + # Don't try to get from ourselves + if peer_id == self.local_peer_id: + return None + + logger.debug(f"Getting value for key {key.hex()} from peer {peer_id}") + + # Open a stream to the peer + stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)]) + logger.debug(f"Opened stream to peer {peer_id} for GET_VALUE") + + # Create the GET_VALUE message using protobuf + message = Message() + message.type = Message.MessageType.GET_VALUE + message.key = key + + # Serialize and send the protobuf message + proto_bytes = message.SerializeToString() + await stream.write(varint.encode(len(proto_bytes))) + await stream.write(proto_bytes) + + # Read response length + length_bytes = b"" + while True: + b = await stream.read(1) + if not b: + logger.warning("Connection closed while reading length") + return None + length_bytes += b + if b[0] & 0x80 == 0: + break + response_length = varint.decode_bytes(length_bytes) + # Read response data + response_bytes = b"" + remaining = response_length + while remaining > 0: + chunk = await stream.read(remaining) + if not chunk: + logger.debug( + f"Connection closed by peer {peer_id} while reading data" + ) + return None + response_bytes += chunk + remaining -= len(chunk) + + # Parse protobuf response + try: + response = Message() + response.ParseFromString(response_bytes) + logger.debug( + f"Received protobuf response from peer" + f" {peer_id}, type: {response.type}" + ) + + # Process protobuf response + if ( + response.type == Message.MessageType.GET_VALUE + and response.HasField("record") + and response.record.value + ): + logger.debug( + f"Received value for key {key.hex()} from peer {peer_id}" + ) + return response.record.value + + # Handle case where value is not found but peer infos are returned + else: + logger.debug( + f"Value not found for key {key.hex()} from peer {peer_id}," + f" received {len(response.closerPeers)} closer peers" + ) + return None + + except Exception as proto_err: + logger.warning(f"Failed to parse as protobuf: {proto_err}") + + return None + + except Exception as e: + logger.warning(f"Failed to get value from peer {peer_id}: {e}") + return None + + finally: + if stream: + await stream.close() + + def remove(self, key: bytes) -> bool: + """ + Remove a value from the DHT. + + + params: key: The key to remove + + Returns + ------- + bool + True if the key was found and removed, False otherwise + + """ + if key in self.store: + del self.store[key] + logger.debug(f"Removed value for key {key.hex()[:8]}...") + return True + return False + + def has(self, key: bytes) -> bool: + """ + Check if a key exists in the store and hasn't expired. + + params: key: The key to check + + Returns + ------- + bool + True if the key exists and hasn't expired, False otherwise + + """ + if key not in self.store: + return False + + _, validity = self.store[key] + if validity is not None and time.time() > validity: + self.remove(key) + return False + + return True + + def cleanup_expired(self) -> int: + """ + Remove all expired values from the store. + + Returns + ------- + int + The number of expired values that were removed + + """ + current_time = time.time() + expired_keys = [ + key for key, (_, validity) in self.store.items() if current_time > validity + ] + + for key in expired_keys: + del self.store[key] + + if expired_keys: + logger.debug(f"Cleaned up {len(expired_keys)} expired values") + + return len(expired_keys) + + def get_keys(self) -> list[bytes]: + """ + Get all non-expired keys in the store. + + Returns + ------- + list[bytes] + List of keys + + """ + # Clean up expired values first + self.cleanup_expired() + return list(self.store.keys()) + + def size(self) -> int: + """ + Get the number of items in the store (after removing expired entries). + + Returns + ------- + int + Number of items + + """ + self.cleanup_expired() + return len(self.store) diff --git a/newsfragments/579.feature.rst b/newsfragments/579.feature.rst new file mode 100644 index 00000000..9da91328 --- /dev/null +++ b/newsfragments/579.feature.rst @@ -0,0 +1 @@ +Added support for ``Kademlia DHT`` in py-libp2p. diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py new file mode 100644 index 00000000..a6f73074 --- /dev/null +++ b/tests/core/kad_dht/test_kad_dht.py @@ -0,0 +1,168 @@ +""" +Tests for the Kademlia DHT implementation. + +This module tests core functionality of the Kademlia DHT including: +- Node discovery (find_node) +- Value storage and retrieval (put_value, get_value) +- Content provider advertisement and discovery (provide, find_providers) +""" + +import hashlib +import logging +import uuid + +import pytest +import trio + +from libp2p.kad_dht.kad_dht import ( + DHTMode, + KadDHT, +) +from libp2p.kad_dht.utils import ( + create_key_from_binary, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) +from libp2p.tools.async_service import ( + background_trio_service, +) +from tests.utils.factories import ( + host_pair_factory, +) + +# Configure logger +logger = logging.getLogger("test.kad_dht") + +# Constants for the tests +TEST_TIMEOUT = 5 # Timeout in seconds + + +@pytest.fixture +async def dht_pair(security_protocol): + """Create a pair of connected DHT nodes for testing.""" + async with host_pair_factory(security_protocol=security_protocol) as ( + host_a, + host_b, + ): + # Get peer info for bootstrapping + peer_b_info = PeerInfo(host_b.get_id(), host_b.get_addrs()) + peer_a_info = PeerInfo(host_a.get_id(), host_a.get_addrs()) + + # Create DHT nodes from the hosts with bootstrap peers as multiaddr strings + dht_a: KadDHT = KadDHT(host_a, mode=DHTMode.SERVER) + dht_b: KadDHT = KadDHT(host_b, mode=DHTMode.SERVER) + await dht_a.peer_routing.routing_table.add_peer(peer_b_info) + await dht_b.peer_routing.routing_table.add_peer(peer_a_info) + + # Start both DHT services + async with background_trio_service(dht_a), background_trio_service(dht_b): + # Allow time for bootstrap to complete and connections to establish + await trio.sleep(0.1) + + logger.debug( + "After bootstrap: Node A peers: %s", dht_a.routing_table.get_peer_ids() + ) + logger.debug( + "After bootstrap: Node B peers: %s", dht_b.routing_table.get_peer_ids() + ) + + # Return the DHT pair + yield (dht_a, dht_b) + + +@pytest.mark.trio +async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): + """Test that nodes can find each other in the DHT.""" + dht_a, dht_b = dht_pair + + # Node A should be able to find Node B + with trio.fail_after(TEST_TIMEOUT): + found_info = await dht_a.find_peer(dht_b.host.get_id()) + + # Verify that the found peer has the correct peer ID + assert found_info is not None, "Failed to find the target peer" + assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" + + +@pytest.mark.trio +async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): + """Test storing and retrieving values in the DHT.""" + dht_a, dht_b = dht_pair + # dht_a.peer_routing.routing_table.add_peer(dht_b.pe) + peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs()) + # Generate a random key and value + key = create_key_from_binary(b"test-key") + value = b"test-value" + + # First add the value directly to node A's store to verify storage works + dht_a.value_store.put(key, value) + logger.debug("Local value store: %s", dht_a.value_store.store) + local_value = dht_a.value_store.get(key) + assert local_value == value, "Local value storage failed" + print("number of nodes in peer store", dht_a.host.get_peerstore().peer_ids()) + await dht_a.routing_table.add_peer(peer_b_info) + print("Routing table of a has ", dht_a.routing_table.get_peer_ids()) + + # Store the value using the first node (this will also store locally) + with trio.fail_after(TEST_TIMEOUT): + await dht_a.put_value(key, value) + + # # Log debugging information + logger.debug("Put value with key %s...", key.hex()[:10]) + logger.debug("Node A value store: %s", dht_a.value_store.store) + print("hello test") + + # # Allow more time for the value to propagate + await trio.sleep(0.5) + + # # Try direct connection between nodes to ensure they're properly linked + logger.debug("Node A peers: %s", dht_a.routing_table.get_peer_ids()) + logger.debug("Node B peers: %s", dht_b.routing_table.get_peer_ids()) + + # Retrieve the value using the second node + with trio.fail_after(TEST_TIMEOUT): + retrieved_value = await dht_b.get_value(key) + print("the value stored in node b is", dht_b.get_value_store_size()) + logger.debug("Retrieved value: %s", retrieved_value) + + # Verify that the retrieved value matches the original + assert retrieved_value == value, "Retrieved value does not match the stored value" + + +@pytest.mark.trio +async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): + """Test advertising and finding content providers.""" + dht_a, dht_b = dht_pair + + # Generate a random content ID + content = f"test-content-{uuid.uuid4()}".encode() + content_id = hashlib.sha256(content).digest() + + # Store content on the first node + dht_a.value_store.put(content_id, content) + + # Advertise the first node as a provider + with trio.fail_after(TEST_TIMEOUT): + success = await dht_a.provide(content_id) + assert success, "Failed to advertise as provider" + + # Allow time for the provider record to propagate + await trio.sleep(0.1) + + # Find providers using the second node + with trio.fail_after(TEST_TIMEOUT): + providers = await dht_b.find_providers(content_id) + + # Verify that we found the first node as a provider + assert providers, "No providers found" + assert any(p.peer_id == dht_a.local_peer_id for p in providers), ( + "Expected provider not found" + ) + + # Retrieve the content using the provider information + with trio.fail_after(TEST_TIMEOUT): + retrieved_value = await dht_b.get_value(content_id) + assert retrieved_value == content, ( + "Retrieved content does not match the original" + ) diff --git a/tests/core/kad_dht/test_unit_peer_routing.py b/tests/core/kad_dht/test_unit_peer_routing.py new file mode 100644 index 00000000..72320b73 --- /dev/null +++ b/tests/core/kad_dht/test_unit_peer_routing.py @@ -0,0 +1,459 @@ +""" +Unit tests for the PeerRouting class in Kademlia DHT. + +This module tests the core functionality of peer routing including: +- Peer discovery and lookup +- Network queries for closest peers +- Protocol message handling +- Error handling and edge cases +""" + +import time +from unittest.mock import ( + AsyncMock, + Mock, + patch, +) + +import pytest +from multiaddr import ( + Multiaddr, +) +import varint + +from libp2p.crypto.secp256k1 import ( + create_new_key_pair, +) +from libp2p.kad_dht.pb.kademlia_pb2 import ( + Message, +) +from libp2p.kad_dht.peer_routing import ( + ALPHA, + MAX_PEER_LOOKUP_ROUNDS, + PROTOCOL_ID, + PeerRouting, +) +from libp2p.kad_dht.routing_table import ( + RoutingTable, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) + + +def create_valid_peer_id(name: str) -> ID: + """Create a valid peer ID for testing.""" + key_pair = create_new_key_pair() + return ID.from_pubkey(key_pair.public_key) + + +class TestPeerRouting: + """Test suite for PeerRouting class.""" + + @pytest.fixture + def mock_host(self): + """Create a mock host for testing.""" + host = Mock() + host.get_id.return_value = create_valid_peer_id("local") + host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + host.get_peerstore.return_value = Mock() + host.new_stream = AsyncMock() + host.connect = AsyncMock() + return host + + @pytest.fixture + def mock_routing_table(self, mock_host): + """Create a mock routing table for testing.""" + local_id = create_valid_peer_id("local") + routing_table = RoutingTable(local_id, mock_host) + return routing_table + + @pytest.fixture + def peer_routing(self, mock_host, mock_routing_table): + """Create a PeerRouting instance for testing.""" + return PeerRouting(mock_host, mock_routing_table) + + @pytest.fixture + def sample_peer_info(self): + """Create sample peer info for testing.""" + peer_id = create_valid_peer_id("sample") + addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8001")] + return PeerInfo(peer_id, addresses) + + def test_init_peer_routing(self, mock_host, mock_routing_table): + """Test PeerRouting initialization.""" + peer_routing = PeerRouting(mock_host, mock_routing_table) + + assert peer_routing.host == mock_host + assert peer_routing.routing_table == mock_routing_table + assert peer_routing.protocol_id == PROTOCOL_ID + + @pytest.mark.trio + async def test_find_peer_local_host(self, peer_routing, mock_host): + """Test finding our own peer.""" + local_id = mock_host.get_id() + + result = await peer_routing.find_peer(local_id) + + assert result is not None + assert result.peer_id == local_id + assert result.addrs == mock_host.get_addrs() + + @pytest.mark.trio + async def test_find_peer_in_routing_table(self, peer_routing, sample_peer_info): + """Test finding peer that exists in routing table.""" + # Add peer to routing table + await peer_routing.routing_table.add_peer(sample_peer_info) + + result = await peer_routing.find_peer(sample_peer_info.peer_id) + + assert result is not None + assert result.peer_id == sample_peer_info.peer_id + + @pytest.mark.trio + async def test_find_peer_in_peerstore(self, peer_routing, mock_host): + """Test finding peer that exists in peerstore.""" + peer_id = create_valid_peer_id("peerstore") + mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8002")] + + # Mock peerstore to return addresses + mock_host.get_peerstore().addrs.return_value = mock_addrs + + result = await peer_routing.find_peer(peer_id) + + assert result is not None + assert result.peer_id == peer_id + assert result.addrs == mock_addrs + + @pytest.mark.trio + async def test_find_peer_not_found(self, peer_routing, mock_host): + """Test finding peer that doesn't exist anywhere.""" + peer_id = create_valid_peer_id("nonexistent") + + # Mock peerstore to return no addresses + mock_host.get_peerstore().addrs.return_value = [] + + # Mock network search to return empty results + with patch.object(peer_routing, "find_closest_peers_network", return_value=[]): + result = await peer_routing.find_peer(peer_id) + + assert result is None + + @pytest.mark.trio + async def test_find_closest_peers_network_empty_start(self, peer_routing): + """Test network search with no local peers.""" + target_key = b"target_key" + + # Mock routing table to return empty list + with patch.object( + peer_routing.routing_table, "find_local_closest_peers", return_value=[] + ): + result = await peer_routing.find_closest_peers_network(target_key) + + assert result == [] + + @pytest.mark.trio + async def test_find_closest_peers_network_with_peers(self, peer_routing, mock_host): + """Test network search with some initial peers.""" + target_key = b"target_key" + + # Create some test peers + initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(3)] + + # Mock routing table to return initial peers + with patch.object( + peer_routing.routing_table, + "find_local_closest_peers", + return_value=initial_peers, + ): + # Mock _query_peer_for_closest to return empty results (no new peers found) + with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]): + result = await peer_routing.find_closest_peers_network( + target_key, count=5 + ) + + assert len(result) <= 5 + # Should return the initial peers since no new ones were discovered + assert all(peer in initial_peers for peer in result) + + @pytest.mark.trio + async def test_find_closest_peers_convergence(self, peer_routing): + """Test that network search converges properly.""" + target_key = b"target_key" + + # Create test peers + initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(2)] + + # Mock to simulate convergence (no improvement in closest peers) + with patch.object( + peer_routing.routing_table, + "find_local_closest_peers", + return_value=initial_peers, + ): + with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]): + with patch( + "libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance", + return_value=initial_peers, + ): + result = await peer_routing.find_closest_peers_network(target_key) + + assert result == initial_peers + + @pytest.mark.trio + async def test_query_peer_for_closest_success( + self, peer_routing, mock_host, sample_peer_info + ): + """Test successful peer query for closest peers.""" + target_key = b"target_key" + + # Create mock stream + mock_stream = AsyncMock() + mock_host.new_stream.return_value = mock_stream + + # Create mock response + response_msg = Message() + response_msg.type = Message.MessageType.FIND_NODE + + # Add a peer to the response + peer_proto = response_msg.closerPeers.add() + response_peer_id = create_valid_peer_id("response_peer") + peer_proto.id = response_peer_id.to_bytes() + peer_proto.addrs.append(Multiaddr("/ip4/127.0.0.1/tcp/8003").to_bytes()) + + response_bytes = response_msg.SerializeToString() + + # Mock stream reading + varint_length = varint.encode(len(response_bytes)) + mock_stream.read.side_effect = [varint_length, response_bytes] + + # Mock peerstore + mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]] + mock_host.get_peerstore().add_addrs = Mock() + + result = await peer_routing._query_peer_for_closest( + sample_peer_info.peer_id, target_key + ) + + assert len(result) == 1 + assert result[0] == response_peer_id + mock_stream.write.assert_called() + mock_stream.close.assert_called_once() + + @pytest.mark.trio + async def test_query_peer_for_closest_stream_failure(self, peer_routing, mock_host): + """Test peer query when stream creation fails.""" + target_key = b"target_key" + peer_id = create_valid_peer_id("test") + + # Mock stream creation failure + mock_host.new_stream.side_effect = Exception("Stream failed") + mock_host.get_peerstore().addrs.return_value = [] + + result = await peer_routing._query_peer_for_closest(peer_id, target_key) + + assert result == [] + + @pytest.mark.trio + async def test_query_peer_for_closest_read_failure( + self, peer_routing, mock_host, sample_peer_info + ): + """Test peer query when reading response fails.""" + target_key = b"target_key" + + # Create mock stream that fails to read + mock_stream = AsyncMock() + mock_stream.read.side_effect = [b""] # Empty read simulates connection close + mock_host.new_stream.return_value = mock_stream + mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]] + + result = await peer_routing._query_peer_for_closest( + sample_peer_info.peer_id, target_key + ) + + assert result == [] + mock_stream.close.assert_called_once() + + @pytest.mark.trio + async def test_refresh_routing_table(self, peer_routing, mock_host): + """Test routing table refresh.""" + local_id = mock_host.get_id() + discovered_peers = [create_valid_peer_id(f"discovered{i}") for i in range(3)] + + # Mock find_closest_peers_network to return discovered peers + with patch.object( + peer_routing, "find_closest_peers_network", return_value=discovered_peers + ): + # Mock peerstore to return addresses for discovered peers + mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8003")] + mock_host.get_peerstore().addrs.return_value = mock_addrs + + await peer_routing.refresh_routing_table() + + # Should perform lookup for local ID + peer_routing.find_closest_peers_network.assert_called_once_with( + local_id.to_bytes() + ) + + @pytest.mark.trio + async def test_handle_kad_stream_find_node(self, peer_routing, mock_host): + """Test handling incoming FIND_NODE requests.""" + # Create mock stream + mock_stream = AsyncMock() + + # Create FIND_NODE request + request_msg = Message() + request_msg.type = Message.MessageType.FIND_NODE + request_msg.key = b"target_key" + + request_bytes = request_msg.SerializeToString() + + # Mock stream reading + mock_stream.read.side_effect = [ + len(request_bytes).to_bytes(4, byteorder="big"), + request_bytes, + ] + + # Mock routing table to return some peers + closest_peers = [create_valid_peer_id(f"close{i}") for i in range(2)] + with patch.object( + peer_routing.routing_table, + "find_local_closest_peers", + return_value=closest_peers, + ): + mock_host.get_peerstore().addrs.return_value = [ + Multiaddr("/ip4/127.0.0.1/tcp/8004") + ] + + await peer_routing._handle_kad_stream(mock_stream) + + # Should write response + mock_stream.write.assert_called() + mock_stream.close.assert_called_once() + + @pytest.mark.trio + async def test_handle_kad_stream_invalid_message(self, peer_routing): + """Test handling stream with invalid message.""" + mock_stream = AsyncMock() + + # Mock stream to return invalid data + mock_stream.read.side_effect = [ + (10).to_bytes(4, byteorder="big"), + b"invalid_proto_data", + ] + + # Should handle gracefully without raising exception + await peer_routing._handle_kad_stream(mock_stream) + + mock_stream.close.assert_called_once() + + @pytest.mark.trio + async def test_handle_kad_stream_connection_closed(self, peer_routing): + """Test handling stream when connection is closed early.""" + mock_stream = AsyncMock() + + # Mock stream to return empty data (connection closed) + mock_stream.read.return_value = b"" + + await peer_routing._handle_kad_stream(mock_stream) + + mock_stream.close.assert_called_once() + + @pytest.mark.trio + async def test_query_single_peer_for_closest_success(self, peer_routing): + """Test _query_single_peer_for_closest method.""" + target_key = b"target_key" + peer_id = create_valid_peer_id("test") + new_peers = [] + + # Mock successful query + mock_result = [create_valid_peer_id("result1"), create_valid_peer_id("result2")] + with patch.object( + peer_routing, "_query_peer_for_closest", return_value=mock_result + ): + await peer_routing._query_single_peer_for_closest( + peer_id, target_key, new_peers + ) + + assert len(new_peers) == 2 + assert all(peer in new_peers for peer in mock_result) + + @pytest.mark.trio + async def test_query_single_peer_for_closest_failure(self, peer_routing): + """Test _query_single_peer_for_closest when query fails.""" + target_key = b"target_key" + peer_id = create_valid_peer_id("test") + new_peers = [] + + # Mock query failure + with patch.object( + peer_routing, + "_query_peer_for_closest", + side_effect=Exception("Query failed"), + ): + await peer_routing._query_single_peer_for_closest( + peer_id, target_key, new_peers + ) + + # Should handle exception gracefully + assert len(new_peers) == 0 + + @pytest.mark.trio + async def test_query_single_peer_deduplication(self, peer_routing): + """Test that _query_single_peer_for_closest deduplicates peers.""" + target_key = b"target_key" + peer_id = create_valid_peer_id("test") + duplicate_peer = create_valid_peer_id("duplicate") + new_peers = [duplicate_peer] # Pre-existing peer + + # Mock query to return the same peer + mock_result = [duplicate_peer, create_valid_peer_id("new")] + with patch.object( + peer_routing, "_query_peer_for_closest", return_value=mock_result + ): + await peer_routing._query_single_peer_for_closest( + peer_id, target_key, new_peers + ) + + # Should not add duplicate + assert len(new_peers) == 2 # Original + 1 new peer + assert new_peers.count(duplicate_peer) == 1 + + def test_constants(self): + """Test that important constants are properly defined.""" + assert ALPHA == 3 + assert MAX_PEER_LOOKUP_ROUNDS == 20 + assert PROTOCOL_ID == "/ipfs/kad/1.0.0" + + @pytest.mark.trio + async def test_edge_case_max_rounds_reached(self, peer_routing): + """Test that lookup stops after maximum rounds.""" + target_key = b"target_key" + initial_peers = [create_valid_peer_id("peer1")] + + # Mock to always return new peers to force max rounds + def mock_query_side_effect(peer, key): + return [create_valid_peer_id(f"new_peer_{time.time()}")] + + with patch.object( + peer_routing.routing_table, + "find_local_closest_peers", + return_value=initial_peers, + ): + with patch.object( + peer_routing, + "_query_peer_for_closest", + side_effect=mock_query_side_effect, + ): + with patch( + "libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance" + ) as mock_sort: + # Always return different peers to prevent convergence + mock_sort.side_effect = lambda key, peers: peers[:20] + + result = await peer_routing.find_closest_peers_network(target_key) + + # Should stop after max rounds, not infinite loop + assert isinstance(result, list) diff --git a/tests/core/kad_dht/test_unit_provider_store.py b/tests/core/kad_dht/test_unit_provider_store.py new file mode 100644 index 00000000..560c56e5 --- /dev/null +++ b/tests/core/kad_dht/test_unit_provider_store.py @@ -0,0 +1,805 @@ +""" +Unit tests for the ProviderStore and ProviderRecord classes in Kademlia DHT. + +This module tests the core functionality of provider record management including: +- ProviderRecord creation, expiration, and republish logic +- ProviderStore operations (add, get, cleanup) +- Expiration and TTL handling +- Network operations (mocked) +- Edge cases and error conditions +""" + +import time +from unittest.mock import ( + AsyncMock, + Mock, + patch, +) + +import pytest +from multiaddr import ( + Multiaddr, +) + +from libp2p.kad_dht.provider_store import ( + PROVIDER_ADDRESS_TTL, + PROVIDER_RECORD_EXPIRATION_INTERVAL, + PROVIDER_RECORD_REPUBLISH_INTERVAL, + ProviderRecord, + ProviderStore, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) + +mock_host = Mock() + + +class TestProviderRecord: + """Test suite for ProviderRecord class.""" + + def test_init_with_default_timestamp(self): + """Test ProviderRecord initialization with default timestamp.""" + peer_id = ID.from_base58("QmTest123") + addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + peer_info = PeerInfo(peer_id, addresses) + + start_time = time.time() + record = ProviderRecord(peer_info) + end_time = time.time() + + assert record.provider_info == peer_info + assert start_time <= record.timestamp <= end_time + assert record.peer_id == peer_id + assert record.addresses == addresses + + def test_init_with_custom_timestamp(self): + """Test ProviderRecord initialization with custom timestamp.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + custom_timestamp = time.time() - 3600 # 1 hour ago + + record = ProviderRecord(peer_info, timestamp=custom_timestamp) + + assert record.timestamp == custom_timestamp + + def test_is_expired_fresh_record(self): + """Test that fresh records are not expired.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + record = ProviderRecord(peer_info) + + assert not record.is_expired() + + def test_is_expired_old_record(self): + """Test that old records are expired.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + old_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1 + record = ProviderRecord(peer_info, timestamp=old_timestamp) + + assert record.is_expired() + + def test_is_expired_boundary_condition(self): + """Test expiration at exact boundary.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + boundary_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL + record = ProviderRecord(peer_info, timestamp=boundary_timestamp) + + # At the exact boundary, should be expired (implementation uses >) + assert record.is_expired() + + def test_should_republish_fresh_record(self): + """Test that fresh records don't need republishing.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + record = ProviderRecord(peer_info) + + assert not record.should_republish() + + def test_should_republish_old_record(self): + """Test that old records need republishing.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + old_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1 + record = ProviderRecord(peer_info, timestamp=old_timestamp) + + assert record.should_republish() + + def test_should_republish_boundary_condition(self): + """Test republish at exact boundary.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + boundary_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL + record = ProviderRecord(peer_info, timestamp=boundary_timestamp) + + # At the exact boundary, should need republishing (implementation uses >) + assert record.should_republish() + + def test_properties(self): + """Test peer_id and addresses properties.""" + peer_id = ID.from_base58("QmTest123") + addresses = [ + Multiaddr("/ip4/127.0.0.1/tcp/8000"), + Multiaddr("/ip6/::1/tcp/8001"), + ] + peer_info = PeerInfo(peer_id, addresses) + record = ProviderRecord(peer_info) + + assert record.peer_id == peer_id + assert record.addresses == addresses + + def test_empty_addresses(self): + """Test ProviderRecord with empty address list.""" + peer_id = ID.from_base58("QmTest123") + peer_info = PeerInfo(peer_id, []) + record = ProviderRecord(peer_info) + + assert record.addresses == [] + + +class TestProviderStore: + """Test suite for ProviderStore class.""" + + def test_init_empty_store(self): + """Test that a new ProviderStore is initialized empty.""" + store = ProviderStore(host=mock_host) + + assert len(store.providers) == 0 + assert store.peer_routing is None + assert len(store.providing_keys) == 0 + + def test_init_with_host(self): + """Test initialization with host.""" + mock_host = Mock() + mock_peer_id = ID.from_base58("QmTest123") + mock_host.get_id.return_value = mock_peer_id + + store = ProviderStore(host=mock_host) + + assert store.host == mock_host + assert store.local_peer_id == mock_peer_id + assert len(store.providers) == 0 + + def test_init_with_host_and_peer_routing(self): + """Test initialization with both host and peer routing.""" + mock_host = Mock() + mock_peer_routing = Mock() + mock_peer_id = ID.from_base58("QmTest123") + mock_host.get_id.return_value = mock_peer_id + + store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing) + + assert store.host == mock_host + assert store.peer_routing == mock_peer_routing + assert store.local_peer_id == mock_peer_id + + def test_add_provider_new_key(self): + """Test adding a provider for a new key.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + peer_id = ID.from_base58("QmTest123") + addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + provider = PeerInfo(peer_id, addresses) + + store.add_provider(key, provider) + + assert key in store.providers + assert str(peer_id) in store.providers[key] + + record = store.providers[key][str(peer_id)] + assert record.provider_info == provider + assert isinstance(record.timestamp, float) + + def test_add_provider_existing_key(self): + """Test adding multiple providers for the same key.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + + # Add first provider + peer_id1 = ID.from_base58("QmTest123") + provider1 = PeerInfo(peer_id1, []) + store.add_provider(key, provider1) + + # Add second provider + peer_id2 = ID.from_base58("QmTest456") + provider2 = PeerInfo(peer_id2, []) + store.add_provider(key, provider2) + + assert len(store.providers[key]) == 2 + assert str(peer_id1) in store.providers[key] + assert str(peer_id2) in store.providers[key] + + def test_add_provider_update_existing(self): + """Test updating an existing provider.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + peer_id = ID.from_base58("QmTest123") + + # Add initial provider + provider1 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8000")]) + store.add_provider(key, provider1) + first_timestamp = store.providers[key][str(peer_id)].timestamp + + # Small delay to ensure timestamp difference + time.sleep(0.001) + + # Update provider + provider2 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8001")]) + store.add_provider(key, provider2) + + # Should have same peer but updated info + assert len(store.providers[key]) == 1 + assert str(peer_id) in store.providers[key] + + record = store.providers[key][str(peer_id)] + assert record.provider_info == provider2 + assert record.timestamp > first_timestamp + + def test_get_providers_empty_key(self): + """Test getting providers for non-existent key.""" + store = ProviderStore(host=mock_host) + key = b"nonexistent_key" + + providers = store.get_providers(key) + + assert providers == [] + + def test_get_providers_valid_records(self): + """Test getting providers with valid records.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + + # Add multiple providers + peer_id1 = ID.from_base58("QmTest123") + peer_id2 = ID.from_base58("QmTest456") + provider1 = PeerInfo(peer_id1, [Multiaddr("/ip4/127.0.0.1/tcp/8000")]) + provider2 = PeerInfo(peer_id2, [Multiaddr("/ip4/127.0.0.1/tcp/8001")]) + + store.add_provider(key, provider1) + store.add_provider(key, provider2) + + providers = store.get_providers(key) + + assert len(providers) == 2 + provider_ids = {p.peer_id for p in providers} + assert peer_id1 in provider_ids + assert peer_id2 in provider_ids + + def test_get_providers_expired_records(self): + """Test that expired records are filtered out and cleaned up.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + + # Add valid provider + peer_id1 = ID.from_base58("QmTest123") + provider1 = PeerInfo(peer_id1, []) + store.add_provider(key, provider1) + + # Add expired provider manually + peer_id2 = ID.from_base58("QmTest456") + provider2 = PeerInfo(peer_id2, []) + expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1 + store.providers[key][str(peer_id2)] = ProviderRecord( + provider2, expired_timestamp + ) + + providers = store.get_providers(key) + + # Should only return valid provider + assert len(providers) == 1 + assert providers[0].peer_id == peer_id1 + + # Expired provider should be cleaned up + assert str(peer_id2) not in store.providers[key] + + def test_get_providers_address_ttl(self): + """Test address TTL handling in get_providers.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + peer_id = ID.from_base58("QmTest123") + addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + provider = PeerInfo(peer_id, addresses) + + # Add provider with old timestamp (addresses expired but record valid) + old_timestamp = time.time() - PROVIDER_ADDRESS_TTL - 1 + store.providers[key] = {str(peer_id): ProviderRecord(provider, old_timestamp)} + + providers = store.get_providers(key) + + # Should return provider but with empty addresses + assert len(providers) == 1 + assert providers[0].peer_id == peer_id + assert providers[0].addrs == [] + + def test_get_providers_cleanup_empty_key(self): + """Test that keys with no valid providers are removed.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + + # Add only expired providers + peer_id = ID.from_base58("QmTest123") + provider = PeerInfo(peer_id, []) + expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1 + store.providers[key] = { + str(peer_id): ProviderRecord(provider, expired_timestamp) + } + + providers = store.get_providers(key) + + assert providers == [] + assert key not in store.providers # Key should be removed + + def test_cleanup_expired_no_expired_records(self): + """Test cleanup when there are no expired records.""" + store = ProviderStore(host=mock_host) + key1 = b"key1" + key2 = b"key2" + + # Add valid providers + peer_id1 = ID.from_base58("QmTest123") + peer_id2 = ID.from_base58("QmTest456") + provider1 = PeerInfo(peer_id1, []) + provider2 = PeerInfo(peer_id2, []) + + store.add_provider(key1, provider1) + store.add_provider(key2, provider2) + + initial_size = store.size() + store.cleanup_expired() + + assert store.size() == initial_size + assert key1 in store.providers + assert key2 in store.providers + + def test_cleanup_expired_with_expired_records(self): + """Test cleanup removes expired records.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + + # Add valid provider + peer_id1 = ID.from_base58("QmTest123") + provider1 = PeerInfo(peer_id1, []) + store.add_provider(key, provider1) + + # Add expired provider + peer_id2 = ID.from_base58("QmTest456") + provider2 = PeerInfo(peer_id2, []) + expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1 + store.providers[key][str(peer_id2)] = ProviderRecord( + provider2, expired_timestamp + ) + + assert store.size() == 2 + store.cleanup_expired() + + assert store.size() == 1 + assert str(peer_id1) in store.providers[key] + assert str(peer_id2) not in store.providers[key] + + def test_cleanup_expired_remove_empty_keys(self): + """Test that keys with only expired providers are removed.""" + store = ProviderStore(host=mock_host) + key1 = b"key1" + key2 = b"key2" + + # Add valid provider to key1 + peer_id1 = ID.from_base58("QmTest123") + provider1 = PeerInfo(peer_id1, []) + store.add_provider(key1, provider1) + + # Add only expired provider to key2 + peer_id2 = ID.from_base58("QmTest456") + provider2 = PeerInfo(peer_id2, []) + expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1 + store.providers[key2] = { + str(peer_id2): ProviderRecord(provider2, expired_timestamp) + } + + store.cleanup_expired() + + assert key1 in store.providers + assert key2 not in store.providers + + def test_get_provided_keys_empty_store(self): + """Test get_provided_keys with empty store.""" + store = ProviderStore(host=mock_host) + peer_id = ID.from_base58("QmTest123") + + keys = store.get_provided_keys(peer_id) + + assert keys == [] + + def test_get_provided_keys_single_peer(self): + """Test get_provided_keys for a specific peer.""" + store = ProviderStore(host=mock_host) + peer_id1 = ID.from_base58("QmTest123") + peer_id2 = ID.from_base58("QmTest456") + + key1 = b"key1" + key2 = b"key2" + key3 = b"key3" + + provider1 = PeerInfo(peer_id1, []) + provider2 = PeerInfo(peer_id2, []) + + # peer_id1 provides key1 and key2 + store.add_provider(key1, provider1) + store.add_provider(key2, provider1) + + # peer_id2 provides key2 and key3 + store.add_provider(key2, provider2) + store.add_provider(key3, provider2) + + keys1 = store.get_provided_keys(peer_id1) + keys2 = store.get_provided_keys(peer_id2) + + assert len(keys1) == 2 + assert key1 in keys1 + assert key2 in keys1 + + assert len(keys2) == 2 + assert key2 in keys2 + assert key3 in keys2 + + def test_get_provided_keys_nonexistent_peer(self): + """Test get_provided_keys for peer that provides nothing.""" + store = ProviderStore(host=mock_host) + peer_id1 = ID.from_base58("QmTest123") + peer_id2 = ID.from_base58("QmTest456") + + # Add provider for peer_id1 only + key = b"key" + provider = PeerInfo(peer_id1, []) + store.add_provider(key, provider) + + # Query for peer_id2 (provides nothing) + keys = store.get_provided_keys(peer_id2) + + assert keys == [] + + def test_size_empty_store(self): + """Test size() with empty store.""" + store = ProviderStore(host=mock_host) + + assert store.size() == 0 + + def test_size_with_providers(self): + """Test size() with multiple providers.""" + store = ProviderStore(host=mock_host) + + # Add providers + key1 = b"key1" + key2 = b"key2" + peer_id1 = ID.from_base58("QmTest123") + peer_id2 = ID.from_base58("QmTest456") + peer_id3 = ID.from_base58("QmTest789") + + provider1 = PeerInfo(peer_id1, []) + provider2 = PeerInfo(peer_id2, []) + provider3 = PeerInfo(peer_id3, []) + + store.add_provider(key1, provider1) + store.add_provider(key1, provider2) # 2 providers for key1 + store.add_provider(key2, provider3) # 1 provider for key2 + + assert store.size() == 3 + + @pytest.mark.trio + async def test_provide_no_host(self): + """Test provide() returns False when no host is configured.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + + result = await store.provide(key) + + assert result is False + + @pytest.mark.trio + async def test_provide_no_peer_routing(self): + """Test provide() returns False when no peer routing is configured.""" + mock_host = Mock() + store = ProviderStore(host=mock_host) + key = b"test_key" + + result = await store.provide(key) + + assert result is False + + @pytest.mark.trio + async def test_provide_success(self): + """Test successful provide operation.""" + # Setup mocks + mock_host = Mock() + mock_peer_routing = AsyncMock() + peer_id = ID.from_base58("QmTest123") + + mock_host.get_id.return_value = peer_id + mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + + # Mock finding closest peers + closest_peers = [ID.from_base58("QmPeer1"), ID.from_base58("QmPeer2")] + mock_peer_routing.find_closest_peers_network.return_value = closest_peers + + store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing) + + # Mock _send_add_provider to return success + with patch.object(store, "_send_add_provider", return_value=True) as mock_send: + key = b"test_key" + result = await store.provide(key) + + assert result is True + assert key in store.providing_keys + assert key in store.providers + + # Should have called _send_add_provider for each peer + assert mock_send.call_count == len(closest_peers) + + @pytest.mark.trio + async def test_provide_skip_local_peer(self): + """Test that provide() skips sending to local peer.""" + # Setup mocks + mock_host = Mock() + mock_peer_routing = AsyncMock() + peer_id = ID.from_base58("QmTest123") + + mock_host.get_id.return_value = peer_id + mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + + # Include local peer in closest peers + closest_peers = [peer_id, ID.from_base58("QmPeer1")] + mock_peer_routing.find_closest_peers_network.return_value = closest_peers + + store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing) + + with patch.object(store, "_send_add_provider", return_value=True) as mock_send: + key = b"test_key" + result = await store.provide(key) + + assert result is True + # Should only call _send_add_provider once (skip local peer) + assert mock_send.call_count == 1 + + @pytest.mark.trio + async def test_find_providers_no_host(self): + """Test find_providers() returns empty list when no host.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + + result = await store.find_providers(key) + + assert result == [] + + @pytest.mark.trio + async def test_find_providers_local_only(self): + """Test find_providers() returns local providers.""" + mock_host = Mock() + mock_peer_routing = Mock() + store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing) + + # Add local providers + key = b"test_key" + peer_id = ID.from_base58("QmTest123") + provider = PeerInfo(peer_id, []) + store.add_provider(key, provider) + + result = await store.find_providers(key) + + assert len(result) == 1 + assert result[0].peer_id == peer_id + + @pytest.mark.trio + async def test_find_providers_network_search(self): + """Test find_providers() searches network when no local providers.""" + mock_host = Mock() + mock_peer_routing = AsyncMock() + store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing) + + # Mock network search + closest_peers = [ID.from_base58("QmPeer1")] + mock_peer_routing.find_closest_peers_network.return_value = closest_peers + + # Mock provider response + remote_peer_id = ID.from_base58("QmRemote123") + remote_providers = [PeerInfo(remote_peer_id, [])] + + with patch.object( + store, "_get_providers_from_peer", return_value=remote_providers + ): + key = b"test_key" + result = await store.find_providers(key) + + assert len(result) == 1 + assert result[0].peer_id == remote_peer_id + + @pytest.mark.trio + async def test_get_providers_from_peer_no_host(self): + """Test _get_providers_from_peer without host.""" + store = ProviderStore(host=mock_host) + peer_id = ID.from_base58("QmTest123") + key = b"test_key" + + # Should handle missing host gracefully + result = await store._get_providers_from_peer(peer_id, key) + assert result == [] + + def test_edge_case_empty_key(self): + """Test handling of empty key.""" + store = ProviderStore(host=mock_host) + key = b"" + peer_id = ID.from_base58("QmTest123") + provider = PeerInfo(peer_id, []) + + store.add_provider(key, provider) + providers = store.get_providers(key) + + assert len(providers) == 1 + assert providers[0].peer_id == peer_id + + def test_edge_case_large_key(self): + """Test handling of large key.""" + store = ProviderStore(host=mock_host) + key = b"x" * 10000 # 10KB key + peer_id = ID.from_base58("QmTest123") + provider = PeerInfo(peer_id, []) + + store.add_provider(key, provider) + providers = store.get_providers(key) + + assert len(providers) == 1 + assert providers[0].peer_id == peer_id + + def test_concurrent_operations(self): + """Test multiple concurrent operations.""" + store = ProviderStore(host=mock_host) + + # Add many providers + num_keys = 100 + num_providers_per_key = 5 + + for i in range(num_keys): + _key = f"key_{i}".encode() + for j in range(num_providers_per_key): + # Generate unique valid Base58 peer IDs + # Use a different approach that ensures uniqueness + unique_id = i * num_providers_per_key + j + 1 # Ensure > 0 + _peer_id_str = f"QmPeer{unique_id:06d}".replace("0", "A") + "1" * 38 + peer_id = ID.from_base58(_peer_id_str) + provider = PeerInfo(peer_id, []) + store.add_provider(_key, provider) + + # Verify total size + expected_size = num_keys * num_providers_per_key + assert store.size() == expected_size + + # Verify individual keys + for i in range(num_keys): + _key = f"key_{i}".encode() + providers = store.get_providers(_key) + assert len(providers) == num_providers_per_key + + def test_memory_efficiency_large_dataset(self): + """Test memory behavior with large datasets.""" + store = ProviderStore(host=mock_host) + + # Add large number of providers + num_entries = 1000 + for i in range(num_entries): + _key = f"key_{i:05d}".encode() + # Generate valid Base58 peer IDs (replace 0 with valid characters) + peer_str = f"QmPeer{i:05d}".replace("0", "1") + "1" * 35 + peer_id = ID.from_base58(peer_str) + provider = PeerInfo(peer_id, []) + store.add_provider(_key, provider) + + assert store.size() == num_entries + + # Clean up all entries by making them expired + current_time = time.time() + for _key, providers in store.providers.items(): + for _peer_id_str, record in providers.items(): + record.timestamp = ( + current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1 + ) + + store.cleanup_expired() + assert store.size() == 0 + assert len(store.providers) == 0 + + def test_unicode_key_handling(self): + """Test handling of unicode content in keys.""" + store = ProviderStore(host=mock_host) + + # Test various unicode keys + unicode_keys = [ + b"hello", + "héllo".encode(), + "🔑".encode(), + "ключ".encode(), # Russian + "键".encode(), # Chinese + ] + + for i, key in enumerate(unicode_keys): + # Generate valid Base58 peer IDs + peer_id = ID.from_base58(f"QmPeer{i + 1}" + "1" * 42) # Valid base58 + provider = PeerInfo(peer_id, []) + store.add_provider(key, provider) + + providers = store.get_providers(key) + assert len(providers) == 1 + assert providers[0].peer_id == peer_id + + def test_multiple_addresses_per_provider(self): + """Test providers with multiple addresses.""" + store = ProviderStore(host=mock_host) + key = b"test_key" + peer_id = ID.from_base58("QmTest123") + + addresses = [ + Multiaddr("/ip4/127.0.0.1/tcp/8000"), + Multiaddr("/ip6/::1/tcp/8001"), + Multiaddr("/ip4/192.168.1.100/tcp/8002"), + ] + provider = PeerInfo(peer_id, addresses) + + store.add_provider(key, provider) + providers = store.get_providers(key) + + assert len(providers) == 1 + assert providers[0].peer_id == peer_id + assert len(providers[0].addrs) == len(addresses) + assert all(addr in providers[0].addrs for addr in addresses) + + @pytest.mark.trio + async def test_republish_provider_records_no_keys(self): + """Test _republish_provider_records with no providing keys.""" + store = ProviderStore(host=mock_host) + + # Should complete without error even with no providing keys + await store._republish_provider_records() + + assert len(store.providing_keys) == 0 + + def test_expiration_boundary_conditions(self): + """Test expiration around boundary conditions.""" + store = ProviderStore(host=mock_host) + peer_id = ID.from_base58("QmTest123") + provider = PeerInfo(peer_id, []) + + current_time = time.time() + + # Test records at various timestamps + timestamps = [ + current_time, # Fresh + current_time - PROVIDER_ADDRESS_TTL + 1, # Addresses valid + current_time - PROVIDER_ADDRESS_TTL - 1, # Addresses expired + current_time + - PROVIDER_RECORD_REPUBLISH_INTERVAL + + 1, # No republish needed + current_time - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1, # Republish needed + current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL + 1, # Not expired + current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1, # Expired + ] + + for i, timestamp in enumerate(timestamps): + test_key = f"key_{i}".encode() + record = ProviderRecord(provider, timestamp) + store.providers[test_key] = {str(peer_id): record} + + # Test various operations + for i, timestamp in enumerate(timestamps): + test_key = f"key_{i}".encode() + providers = store.get_providers(test_key) + + if timestamp <= current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL: + # Should be expired and removed + assert len(providers) == 0 + assert test_key not in store.providers + else: + # Should be present + assert len(providers) == 1 + assert providers[0].peer_id == peer_id diff --git a/tests/core/kad_dht/test_unit_routing_table.py b/tests/core/kad_dht/test_unit_routing_table.py new file mode 100644 index 00000000..af77eda5 --- /dev/null +++ b/tests/core/kad_dht/test_unit_routing_table.py @@ -0,0 +1,371 @@ +""" +Unit tests for the RoutingTable and KBucket classes in Kademlia DHT. + +This module tests the core functionality of the routing table including: +- KBucket operations (add, remove, split, ping) +- RoutingTable management (peer addition, closest peer finding) +- Distance calculations and peer ordering +- Bucket splitting and range management +""" + +import time +from unittest.mock import ( + AsyncMock, + Mock, + patch, +) + +import pytest +from multiaddr import ( + Multiaddr, +) +import trio + +from libp2p.crypto.secp256k1 import ( + create_new_key_pair, +) +from libp2p.kad_dht.routing_table import ( + BUCKET_SIZE, + KBucket, + RoutingTable, +) +from libp2p.kad_dht.utils import ( + create_key_from_binary, + xor_distance, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.peer.peerinfo import ( + PeerInfo, +) + + +def create_valid_peer_id(name: str) -> ID: + """Create a valid peer ID for testing.""" + # Use crypto to generate valid peer IDs + key_pair = create_new_key_pair() + return ID.from_pubkey(key_pair.public_key) + + +class TestKBucket: + """Test suite for KBucket class.""" + + @pytest.fixture + def mock_host(self): + """Create a mock host for testing.""" + host = Mock() + host.get_peerstore.return_value = Mock() + host.new_stream = AsyncMock() + return host + + @pytest.fixture + def sample_peer_info(self): + """Create sample peer info for testing.""" + peer_id = create_valid_peer_id("test") + addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + return PeerInfo(peer_id, addresses) + + def test_init_default_parameters(self, mock_host): + """Test KBucket initialization with default parameters.""" + bucket = KBucket(mock_host) + + assert bucket.bucket_size == BUCKET_SIZE + assert bucket.host == mock_host + assert bucket.min_range == 0 + assert bucket.max_range == 2**256 + assert len(bucket.peers) == 0 + + def test_peer_operations(self, mock_host, sample_peer_info): + """Test basic peer operations: add, check, and remove.""" + bucket = KBucket(mock_host) + + # Test empty bucket + assert bucket.peer_ids() == [] + assert bucket.size() == 0 + assert not bucket.has_peer(sample_peer_info.peer_id) + + # Add peer manually + bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time()) + + # Test with peer + assert len(bucket.peer_ids()) == 1 + assert sample_peer_info.peer_id in bucket.peer_ids() + assert bucket.size() == 1 + assert bucket.has_peer(sample_peer_info.peer_id) + assert bucket.get_peer_info(sample_peer_info.peer_id) == sample_peer_info + + # Remove peer + result = bucket.remove_peer(sample_peer_info.peer_id) + assert result is True + assert bucket.size() == 0 + assert not bucket.has_peer(sample_peer_info.peer_id) + + @pytest.mark.trio + async def test_add_peer_functionality(self, mock_host): + """Test add_peer method with different scenarios.""" + bucket = KBucket(mock_host, bucket_size=2) # Small bucket for testing + + # Add first peer + peer1 = PeerInfo(create_valid_peer_id("peer1"), []) + result = await bucket.add_peer(peer1) + assert result is True + assert bucket.size() == 1 + + # Add second peer + peer2 = PeerInfo(create_valid_peer_id("peer2"), []) + result = await bucket.add_peer(peer2) + assert result is True + assert bucket.size() == 2 + + # Add same peer again (should update timestamp) + await trio.sleep(0.001) + result = await bucket.add_peer(peer1) + assert result is True + assert bucket.size() == 2 # Still 2 peers + + # Try to add third peer when bucket is full + peer3 = PeerInfo(create_valid_peer_id("peer3"), []) + with patch.object(bucket, "_ping_peer", return_value=True): + result = await bucket.add_peer(peer3) + assert result is False # Should fail if oldest peer responds + + def test_get_oldest_peer(self, mock_host): + """Test get_oldest_peer method.""" + bucket = KBucket(mock_host) + + # Empty bucket + assert bucket.get_oldest_peer() is None + + # Add peers with different timestamps + peer1 = PeerInfo(create_valid_peer_id("peer1"), []) + peer2 = PeerInfo(create_valid_peer_id("peer2"), []) + + current_time = time.time() + bucket.peers[peer1.peer_id] = (peer1, current_time - 300) # Older + bucket.peers[peer2.peer_id] = (peer2, current_time) # Newer + + oldest = bucket.get_oldest_peer() + assert oldest == peer1.peer_id + + def test_stale_peers(self, mock_host): + """Test stale peer identification.""" + bucket = KBucket(mock_host) + + current_time = time.time() + fresh_peer = PeerInfo(create_valid_peer_id("fresh"), []) + stale_peer = PeerInfo(create_valid_peer_id("stale"), []) + + bucket.peers[fresh_peer.peer_id] = (fresh_peer, current_time) + bucket.peers[stale_peer.peer_id] = ( + stale_peer, + current_time - 7200, + ) # 2 hours ago + + stale_peers = bucket.get_stale_peers(3600) # 1 hour threshold + assert len(stale_peers) == 1 + assert stale_peer.peer_id in stale_peers + + def test_key_in_range(self, mock_host): + """Test key_in_range method.""" + bucket = KBucket(mock_host, min_range=100, max_range=200) + + # Test keys within range + key_in_range = (150).to_bytes(32, byteorder="big") + assert bucket.key_in_range(key_in_range) is True + + # Test keys outside range + key_below = (50).to_bytes(32, byteorder="big") + assert bucket.key_in_range(key_below) is False + + key_above = (250).to_bytes(32, byteorder="big") + assert bucket.key_in_range(key_above) is False + + # Test boundary conditions + key_min = (100).to_bytes(32, byteorder="big") + assert bucket.key_in_range(key_min) is True + + key_max = (200).to_bytes(32, byteorder="big") + assert bucket.key_in_range(key_max) is False + + def test_split_bucket(self, mock_host): + """Test bucket splitting functionality.""" + bucket = KBucket(mock_host, min_range=0, max_range=256) + + lower_bucket, upper_bucket = bucket.split() + + # Check ranges + assert lower_bucket.min_range == 0 + assert lower_bucket.max_range == 128 + assert upper_bucket.min_range == 128 + assert upper_bucket.max_range == 256 + + # Check properties + assert lower_bucket.bucket_size == bucket.bucket_size + assert upper_bucket.bucket_size == bucket.bucket_size + assert lower_bucket.host == mock_host + assert upper_bucket.host == mock_host + + @pytest.mark.trio + async def test_ping_peer_scenarios(self, mock_host, sample_peer_info): + """Test different ping scenarios.""" + bucket = KBucket(mock_host) + bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time()) + + # Test ping peer not in bucket + other_peer_id = create_valid_peer_id("other") + with pytest.raises(ValueError, match="Peer .* not in bucket"): + await bucket._ping_peer(other_peer_id) + + # Test ping failure due to stream error + mock_host.new_stream.side_effect = Exception("Stream failed") + result = await bucket._ping_peer(sample_peer_info.peer_id) + assert result is False + + +class TestRoutingTable: + """Test suite for RoutingTable class.""" + + @pytest.fixture + def mock_host(self): + """Create a mock host for testing.""" + host = Mock() + host.get_peerstore.return_value = Mock() + return host + + @pytest.fixture + def local_peer_id(self): + """Create a local peer ID for testing.""" + return create_valid_peer_id("local") + + @pytest.fixture + def sample_peer_info(self): + """Create sample peer info for testing.""" + peer_id = create_valid_peer_id("sample") + addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] + return PeerInfo(peer_id, addresses) + + def test_init_routing_table(self, mock_host, local_peer_id): + """Test RoutingTable initialization.""" + routing_table = RoutingTable(local_peer_id, mock_host) + + assert routing_table.local_id == local_peer_id + assert routing_table.host == mock_host + assert len(routing_table.buckets) == 1 + assert isinstance(routing_table.buckets[0], KBucket) + + @pytest.mark.trio + async def test_add_peer_operations( + self, mock_host, local_peer_id, sample_peer_info + ): + """Test adding peers to routing table.""" + routing_table = RoutingTable(local_peer_id, mock_host) + + # Test adding peer with PeerInfo + result = await routing_table.add_peer(sample_peer_info) + assert result is True + assert routing_table.size() == 1 + assert routing_table.peer_in_table(sample_peer_info.peer_id) + + # Test adding peer with just ID + peer_id = create_valid_peer_id("test") + mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8001")] + mock_host.get_peerstore().addrs.return_value = mock_addrs + + result = await routing_table.add_peer(peer_id) + assert result is True + assert routing_table.size() == 2 + + # Test adding peer with no addresses + no_addr_peer_id = create_valid_peer_id("no_addr") + mock_host.get_peerstore().addrs.return_value = [] + + result = await routing_table.add_peer(no_addr_peer_id) + assert result is False + assert routing_table.size() == 2 + + # Test adding local peer (should be ignored) + result = await routing_table.add_peer(local_peer_id) + assert result is False + assert routing_table.size() == 2 + + def test_find_bucket(self, mock_host, local_peer_id): + """Test finding appropriate bucket for peers.""" + routing_table = RoutingTable(local_peer_id, mock_host) + + # Test with peer ID + peer_id = create_valid_peer_id("test") + bucket = routing_table.find_bucket(peer_id) + assert isinstance(bucket, KBucket) + + def test_peer_management(self, mock_host, local_peer_id, sample_peer_info): + """Test peer management operations.""" + routing_table = RoutingTable(local_peer_id, mock_host) + + # Add peer manually + bucket = routing_table.find_bucket(sample_peer_info.peer_id) + bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time()) + + # Test peer queries + assert routing_table.peer_in_table(sample_peer_info.peer_id) + assert routing_table.get_peer_info(sample_peer_info.peer_id) == sample_peer_info + assert routing_table.size() == 1 + assert len(routing_table.get_peer_ids()) == 1 + + # Test remove peer + result = routing_table.remove_peer(sample_peer_info.peer_id) + assert result is True + assert not routing_table.peer_in_table(sample_peer_info.peer_id) + assert routing_table.size() == 0 + + def test_find_closest_peers(self, mock_host, local_peer_id): + """Test finding closest peers.""" + routing_table = RoutingTable(local_peer_id, mock_host) + + # Empty table + target_key = create_key_from_binary(b"target_key") + closest_peers = routing_table.find_local_closest_peers(target_key, 5) + assert closest_peers == [] + + # Add some peers + bucket = routing_table.buckets[0] + test_peers = [] + for i in range(5): + peer = PeerInfo(create_valid_peer_id(f"peer{i}"), []) + test_peers.append(peer) + bucket.peers[peer.peer_id] = (peer, time.time()) + + closest_peers = routing_table.find_local_closest_peers(target_key, 3) + assert len(closest_peers) <= 3 + assert len(closest_peers) <= len(test_peers) + assert all(isinstance(peer_id, ID) for peer_id in closest_peers) + + def test_distance_calculation(self, mock_host, local_peer_id): + """Test XOR distance calculation.""" + # Test same keys + key = b"\x42" * 32 + distance = xor_distance(key, key) + assert distance == 0 + + # Test different keys + key1 = b"\x00" * 32 + key2 = b"\xff" * 32 + distance = xor_distance(key1, key2) + expected = int.from_bytes(b"\xff" * 32, byteorder="big") + assert distance == expected + + def test_edge_cases(self, mock_host, local_peer_id): + """Test various edge cases.""" + routing_table = RoutingTable(local_peer_id, mock_host) + + # Test with invalid peer ID + nonexistent_peer_id = create_valid_peer_id("nonexistent") + assert not routing_table.peer_in_table(nonexistent_peer_id) + assert routing_table.get_peer_info(nonexistent_peer_id) is None + assert routing_table.remove_peer(nonexistent_peer_id) is False + + # Test bucket splitting scenario + assert len(routing_table.buckets) == 1 + initial_bucket = routing_table.buckets[0] + assert initial_bucket.min_range == 0 + assert initial_bucket.max_range == 2**256 diff --git a/tests/core/kad_dht/test_unit_value_store.py b/tests/core/kad_dht/test_unit_value_store.py new file mode 100644 index 00000000..b287b5e2 --- /dev/null +++ b/tests/core/kad_dht/test_unit_value_store.py @@ -0,0 +1,504 @@ +""" +Unit tests for the ValueStore class in Kademlia DHT. + +This module tests the core functionality of the ValueStore including: +- Basic storage and retrieval operations +- Expiration and TTL handling +- Edge cases and error conditions +- Store management operations +""" + +import time +from unittest.mock import ( + Mock, +) + +import pytest + +from libp2p.kad_dht.value_store import ( + DEFAULT_TTL, + ValueStore, +) +from libp2p.peer.id import ( + ID, +) + +mock_host = Mock() +peer_id = ID.from_base58("QmTest123") + + +class TestValueStore: + """Test suite for ValueStore class.""" + + def test_init_empty_store(self): + """Test that a new ValueStore is initialized empty.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + assert len(store.store) == 0 + + def test_init_with_host_and_peer_id(self): + """Test initialization with host and local peer ID.""" + mock_host = Mock() + peer_id = ID.from_base58("QmTest123") + + store = ValueStore(host=mock_host, local_peer_id=peer_id) + assert store.host == mock_host + assert store.local_peer_id == peer_id + assert len(store.store) == 0 + + def test_put_basic(self): + """Test basic put operation.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value = b"test_value" + + store.put(key, value) + + assert key in store.store + stored_value, validity = store.store[key] + assert stored_value == value + assert validity is not None + assert validity > time.time() # Should be in the future + + def test_put_with_custom_validity(self): + """Test put operation with custom validity time.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value = b"test_value" + custom_validity = time.time() + 3600 # 1 hour from now + + store.put(key, value, validity=custom_validity) + + stored_value, validity = store.store[key] + assert stored_value == value + assert validity == custom_validity + + def test_put_overwrite_existing(self): + """Test that put overwrites existing values.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value1 = b"value1" + value2 = b"value2" + + store.put(key, value1) + store.put(key, value2) + + assert len(store.store) == 1 + stored_value, _ = store.store[key] + assert stored_value == value2 + + def test_get_existing_valid_value(self): + """Test retrieving an existing, non-expired value.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value = b"test_value" + + store.put(key, value) + retrieved_value = store.get(key) + + assert retrieved_value == value + + def test_get_nonexistent_key(self): + """Test retrieving a non-existent key returns None.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"nonexistent_key" + + retrieved_value = store.get(key) + + assert retrieved_value is None + + def test_get_expired_value(self): + """Test that expired values are automatically removed and return None.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value = b"test_value" + expired_validity = time.time() - 1 # 1 second ago + + # Manually insert expired value + store.store[key] = (value, expired_validity) + + retrieved_value = store.get(key) + + assert retrieved_value is None + assert key not in store.store # Should be removed + + def test_remove_existing_key(self): + """Test removing an existing key.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value = b"test_value" + + store.put(key, value) + result = store.remove(key) + + assert result is True + assert key not in store.store + + def test_remove_nonexistent_key(self): + """Test removing a non-existent key returns False.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"nonexistent_key" + + result = store.remove(key) + + assert result is False + + def test_has_existing_valid_key(self): + """Test has() returns True for existing, valid keys.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value = b"test_value" + + store.put(key, value) + result = store.has(key) + + assert result is True + + def test_has_nonexistent_key(self): + """Test has() returns False for non-existent keys.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"nonexistent_key" + + result = store.has(key) + + assert result is False + + def test_has_expired_key(self): + """Test has() returns False for expired keys and removes them.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"test_key" + value = b"test_value" + expired_validity = time.time() - 1 + + # Manually insert expired value + store.store[key] = (value, expired_validity) + + result = store.has(key) + + assert result is False + assert key not in store.store # Should be removed + + def test_cleanup_expired_no_expired_values(self): + """Test cleanup when there are no expired values.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key1 = b"key1" + key2 = b"key2" + value = b"value" + + store.put(key1, value) + store.put(key2, value) + + expired_count = store.cleanup_expired() + + assert expired_count == 0 + assert len(store.store) == 2 + + def test_cleanup_expired_with_expired_values(self): + """Test cleanup removes expired values.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key1 = b"valid_key" + key2 = b"expired_key1" + key3 = b"expired_key2" + value = b"value" + expired_validity = time.time() - 1 + + store.put(key1, value) # Valid + store.store[key2] = (value, expired_validity) # Expired + store.store[key3] = (value, expired_validity) # Expired + + expired_count = store.cleanup_expired() + + assert expired_count == 2 + assert len(store.store) == 1 + assert key1 in store.store + assert key2 not in store.store + assert key3 not in store.store + + def test_cleanup_expired_mixed_validity_types(self): + """Test cleanup with mix of values with and without expiration.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key1 = b"no_expiry" + key2 = b"valid_expiry" + key3 = b"expired" + value = b"value" + + # No expiration (None validity) + store.put(key1, value) + # Valid expiration + store.put(key2, value, validity=time.time() + 3600) + # Expired + store.store[key3] = (value, time.time() - 1) + + expired_count = store.cleanup_expired() + + assert expired_count == 1 + assert len(store.store) == 2 + assert key1 in store.store + assert key2 in store.store + assert key3 not in store.store + + def test_get_keys_empty_store(self): + """Test get_keys() returns empty list for empty store.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + + keys = store.get_keys() + + assert keys == [] + + def test_get_keys_with_valid_values(self): + """Test get_keys() returns all non-expired keys.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key1 = b"key1" + key2 = b"key2" + key3 = b"expired_key" + value = b"value" + + store.put(key1, value) + store.put(key2, value) + store.store[key3] = (value, time.time() - 1) # Expired + + keys = store.get_keys() + + assert len(keys) == 2 + assert key1 in keys + assert key2 in keys + assert key3 not in keys + + def test_size_empty_store(self): + """Test size() returns 0 for empty store.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + + size = store.size() + + assert size == 0 + + def test_size_with_valid_values(self): + """Test size() returns correct count after cleaning expired values.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key1 = b"key1" + key2 = b"key2" + key3 = b"expired_key" + value = b"value" + + store.put(key1, value) + store.put(key2, value) + store.store[key3] = (value, time.time() - 1) # Expired + + size = store.size() + + assert size == 2 + + def test_edge_case_empty_key(self): + """Test handling of empty key.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"" + value = b"value" + + store.put(key, value) + retrieved_value = store.get(key) + + assert retrieved_value == value + + def test_edge_case_empty_value(self): + """Test handling of empty value.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"key" + value = b"" + + store.put(key, value) + retrieved_value = store.get(key) + + assert retrieved_value == value + + def test_edge_case_large_key_value(self): + """Test handling of large keys and values.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"x" * 10000 # 10KB key + value = b"y" * 100000 # 100KB value + + store.put(key, value) + retrieved_value = store.get(key) + + assert retrieved_value == value + + def test_edge_case_negative_validity(self): + """Test handling of negative validity time.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"key" + value = b"value" + + store.put(key, value, validity=-1) + + # Should be expired + retrieved_value = store.get(key) + assert retrieved_value is None + + def test_default_ttl_calculation(self): + """Test that default TTL is correctly applied.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"key" + value = b"value" + start_time = time.time() + + store.put(key, value) + + _, validity = store.store[key] + expected_validity = start_time + DEFAULT_TTL + + # Allow small time difference for execution + assert abs(validity - expected_validity) < 1 + + def test_concurrent_operations(self): + """Test that multiple operations don't interfere with each other.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + + # Add multiple key-value pairs + for i in range(100): + key = f"key_{i}".encode() + value = f"value_{i}".encode() + store.put(key, value) + + # Verify all are stored + assert store.size() == 100 + + # Remove every other key + for i in range(0, 100, 2): + key = f"key_{i}".encode() + store.remove(key) + + # Verify correct count + assert store.size() == 50 + + # Verify remaining keys are correct + for i in range(1, 100, 2): + key = f"key_{i}".encode() + assert store.has(key) + + def test_expiration_boundary_conditions(self): + """Test expiration around current time boundary.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key1 = b"key1" + key2 = b"key2" + key3 = b"key3" + value = b"value" + current_time = time.time() + + # Just expired + store.store[key1] = (value, current_time - 0.001) + # Valid for a longer time to account for test execution time + store.store[key2] = (value, current_time + 1.0) + # Exactly current time (should be expired) + store.store[key3] = (value, current_time) + + # Small delay to ensure time has passed + time.sleep(0.002) + + assert not store.has(key1) # Should be expired + assert store.has(key2) # Should be valid + assert not store.has(key3) # Should be expired (exactly at current time) + + def test_store_internal_structure(self): + """Test that internal store structure is maintained correctly.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"key" + value = b"value" + validity = time.time() + 3600 + + store.put(key, value, validity=validity) + + # Verify internal structure + assert isinstance(store.store, dict) + assert key in store.store + stored_tuple = store.store[key] + assert isinstance(stored_tuple, tuple) + assert len(stored_tuple) == 2 + assert stored_tuple[0] == value + assert stored_tuple[1] == validity + + @pytest.mark.trio + async def test_store_at_peer_local_peer(self): + """Test _store_at_peer returns True when storing at local peer.""" + mock_host = Mock() + peer_id = ID.from_base58("QmTest123") + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"key" + value = b"value" + + result = await store._store_at_peer(peer_id, key, value) + + assert result is True + + @pytest.mark.trio + async def test_get_from_peer_local_peer(self): + """Test _get_from_peer returns None when querying local peer.""" + mock_host = Mock() + peer_id = ID.from_base58("QmTest123") + store = ValueStore(host=mock_host, local_peer_id=peer_id) + key = b"key" + + result = await store._get_from_peer(peer_id, key) + + assert result is None + + def test_memory_efficiency_large_dataset(self): + """Test memory behavior with large datasets.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + + # Add a large number of entries + num_entries = 10000 + for i in range(num_entries): + key = f"key_{i:05d}".encode() + value = f"value_{i:05d}".encode() + store.put(key, value) + + assert store.size() == num_entries + + # Clean up all entries + for i in range(num_entries): + key = f"key_{i:05d}".encode() + store.remove(key) + + assert store.size() == 0 + assert len(store.store) == 0 + + def test_key_collision_resistance(self): + """Test that similar keys don't collide.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + + # Test keys that might cause collisions + keys = [ + b"key", + b"key\x00", + b"key1", + b"Key", # Different case + b"key ", # With space + b" key", # Leading space + ] + + for i, key in enumerate(keys): + value = f"value_{i}".encode() + store.put(key, value) + + # Verify all keys are stored separately + assert store.size() == len(keys) + + for i, key in enumerate(keys): + expected_value = f"value_{i}".encode() + assert store.get(key) == expected_value + + def test_unicode_key_handling(self): + """Test handling of unicode content in keys.""" + store = ValueStore(host=mock_host, local_peer_id=peer_id) + + # Test various unicode keys + unicode_keys = [ + b"hello", + "héllo".encode(), + "🔑".encode(), + "ключ".encode(), # Russian + "键".encode(), # Chinese + ] + + for i, key in enumerate(unicode_keys): + value = f"value_{i}".encode() + store.put(key, value) + assert store.get(key) == value