Kademlia DHT implementation in py-libp2p (#579)

* initialise the module

* added content routing

* added routing module

* added peer routing

* added value store

* added utilities functions

* added main kademlia file

* fixed create_key_from_binary function

* example to test kademlia dht

* added protocol ID and enhanced logging for peer store size in provider and consumer nodes

* refactor: specify stream type in handle_stream method and add peer in routing table

* removed content routing

* added default value of count for finding closest peers

* added functions to find close peers

* refactor: remove content routing and enhance peer discovery

* added put value function

* added get value function

* fix: improve logging and handle key encoding in get_value method

* refactor: remove ContentRouting import from __init__.py

* refactor: improved basic kademlia example

* added protobuf files

* replaced json with protobuf

* refactor: enhance peer discovery and routing logic in KadDHT

* refactor: enhance Kademlia routing table to use PeerInfo objects and improve peer management

* refactor: enhance peer addition logic to utilize PeerInfo objects in routing table

* feat: implement content provider functionality in Kademlia DHT

* refactor: update value store to use datetime for validity management

* refactor: update RoutingTable initialization to include host reference

* refactor: enhance KBucket and RoutingTable for improved peer management and functionality

* refactor: streamline peer discovery and value storage methods in KadDHT

* refactor: update KadDHT and related classes for async peer management and enhanced value storage

* refactor: enhance ProviderStore initialization and improve peer routing integration

* test: add tests for Kademlia DHT functionality

* fix linting issues

* pydocstyle issues fixed

* CICD pipeline issues solved

* fix: update docstring format for find_peer method

* refactor: improve logging and remove unused code in DHT implementation

* refactor: clean up logging and remove unused imports in DHT and test files

* Refactor logging setup and improve DHT stream handling with varint length prefixes

* Update bootstrap peer handling in basic_dht example and refactor peer routing to accept string addresses

* Enhance peer querying in Kademlia DHT by implementing parallel queries using Trio.

* Enhance peer querying by adding deduplication checks

* Refactor DHT implementation to use varint for length prefixes and enhance logging for better traceability

* Add base58 encoding for value storage and enhance logging in basic_dht example

* Refactor Kademlia DHT to support server/client modes

* Added unit tests

* Refactor documentation to fixsome warning

* Add unit tests and remove outdated tests

* Fixed precommit errora

* Refactor error handling test to raise StringParseError for invalid bootstrap addresses

* Add libp2p.kad_dht to the list of subpackages in documentation

* Fix expiration and republish checks to use inclusive comparison

* Add __init__.py file to libp2p.kad_dht.pb package

* Refactor get value and put value to run in parallel with query timeout

* Refactor provider message handling to use parallel processing with timeout

* Add methods for provider store in KadDHT class

* Refactor KadDHT and ProviderStore methods to improve type hints and enhance parallel processing

* Add documentation for libp2p.kad_dht.pb module.

* Update documentation for libp2p.kad_dht package to include subpackages and correct formatting

* Fix formatting in documentation for libp2p.kad_dht package by correcting the subpackage reference

* Fix header formatting in libp2p.kad_dht.pb documentation

* Change log level from info to debug for various logging statements.

* fix CICD issues (post revamp)

* fixed value store unit test

* Refactored kademlia example

* Refactor Kademlia example: enhance logging, improve bootstrap node connection, and streamline server address handling

* removed bootstrap module

* Refactor Kademlia DHT example and core modules: enhance logging, remove unused code, and improve peer handling

* Added docs of kad dht example

* Update server address log file path to use the script's directory

* Refactor: Introduce DHTMode enum for clearer mode management

* moved xor_distance function to utils.py

* Enhance logging in ValueStore and KadDHT: include decoded value in debug logs and update parameter description for validity

* Add handling for closest peers in GET_VALUE response when value is not found

* Handled failure scenario for PUT_VALUE

* Remove kademlia demo from project scripts and contributing documentation

* spelling and logging

---------

Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com>
This commit is contained in:
Sumanjeet
2025-06-17 02:16:40 +05:30
committed by GitHub
parent 733ef86e62
commit d61bca78ab
24 changed files with 5790 additions and 1 deletions

View File

@ -58,7 +58,9 @@ PB = libp2p/crypto/pb/crypto.proto \
libp2p/security/secio/pb/spipe.proto \
libp2p/security/noise/pb/noise.proto \
libp2p/identity/identify/pb/identify.proto \
libp2p/host/autonat/pb/autonat.proto
libp2p/host/autonat/pb/autonat.proto \
libp2p/kad_dht/pb/kademlia.proto
PY = $(PB:.proto=_pb2.py)
PYI = $(PB:.proto=_pb2.pyi)

124
docs/examples.kademlia.rst Normal file
View File

@ -0,0 +1,124 @@
Kademlia DHT Demo
=================
This example demonstrates a Kademlia Distributed Hash Table (DHT) implementation with both value storage/retrieval and content provider advertisement/discovery functionality.
.. code-block:: console
$ python -m pip install libp2p
Collecting libp2p
...
Successfully installed libp2p-x.x.x
$ cd examples/kademlia
$ python kademlia.py --mode server
2025-06-13 19:51:25,424 - kademlia-example - INFO - Running in server mode on port 0
2025-06-13 19:51:25,426 - kademlia-example - INFO - Connected to bootstrap nodes: []
2025-06-13 19:51:25,426 - kademlia-example - INFO - To connect to this node, use: --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-06-13 19:51:25,426 - kademlia-example - INFO - Saved server address to log: /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-06-13 19:51:25,427 - kademlia-example - INFO - DHT service started in SERVER mode
2025-06-13 19:51:25,427 - kademlia-example - INFO - Stored value 'Hello message from Sumanjeet' with key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
2025-06-13 19:51:25,427 - kademlia-example - INFO - Successfully advertised as server for content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
Copy the line that starts with ``--bootstrap``, open a new terminal in the same folder and run the client:
.. code-block:: console
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
Alternatively, if you run the server first, the client can automatically extract the bootstrap address from the server log file:
.. code-block:: console
$ python kademlia.py --mode client
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
The demo showcases key DHT operations:
- **Value Storage & Retrieval**: The server stores a value, and the client retrieves it
- **Content Provider Discovery**: The server advertises content, and the client finds providers
- **Peer Discovery**: Automatic bootstrap and peer routing using the Kademlia algorithm
- **Network Resilience**: Distributed storage across multiple nodes (when available)
Command Line Options
--------------------
The Kademlia demo supports several command line options for customization:
.. code-block:: console
$ python kademlia.py --help
usage: kademlia.py [-h] [--mode MODE] [--port PORT] [--bootstrap [BOOTSTRAP ...]] [--verbose]
Kademlia DHT example with content server functionality
options:
-h, --help show this help message and exit
--mode MODE Run as a server or client node (default: server)
--port PORT Port to listen on (0 for random) (default: 0)
--bootstrap [BOOTSTRAP ...]
Multiaddrs of bootstrap nodes. Provide a space-separated list of addresses.
This is required for client mode.
--verbose Enable verbose logging
**Examples:**
Start server on a specific port:
.. code-block:: console
$ python kademlia.py --mode server --port 8000
Start client with verbose logging:
.. code-block:: console
$ python kademlia.py --mode client --verbose
Connect to multiple bootstrap nodes:
.. code-block:: console
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/8000/p2p/... /ip4/127.0.0.1/tcp/8001/p2p/...
How It Works
------------
The Kademlia DHT implementation demonstrates several key concepts:
**Server Mode:**
- Stores key-value pairs in the distributed hash table
- Advertises itself as a content provider for specific content
- Handles incoming DHT requests from other nodes
- Maintains routing table with known peers
**Client Mode:**
- Connects to bootstrap nodes to join the network
- Retrieves values by their keys from the DHT
- Discovers content providers for specific content
- Performs network lookups using the Kademlia algorithm
**Key Components:**
- **Routing Table**: Organizes peers in k-buckets based on XOR distance
- **Value Store**: Manages key-value storage with TTL (time-to-live)
- **Provider Store**: Tracks which peers provide specific content
- **Peer Routing**: Implements iterative lookups to find closest peers
The full source code for this example is below:
.. literalinclude:: ../examples/kademlia/kademlia.py
:language: python
:linenos:

View File

@ -11,3 +11,4 @@ Examples
examples.echo
examples.ping
examples.pubsub
examples.kademlia

View File

@ -0,0 +1,22 @@
libp2p.kad\_dht.pb package
==========================
Submodules
----------
libp2p.kad_dht.pb.kademlia_pb2 module
-------------------------------------
.. automodule:: libp2p.kad_dht.pb.kademlia_pb2
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.kad_dht.pb
:no-index:
:members:
:undoc-members:
:show-inheritance:

77
docs/libp2p.kad_dht.rst Normal file
View File

@ -0,0 +1,77 @@
libp2p.kad\_dht package
=======================
Subpackages
-----------
.. toctree::
:maxdepth: 4
libp2p.kad_dht.pb
Submodules
----------
libp2p.kad\_dht.kad\_dht module
-------------------------------
.. automodule:: libp2p.kad_dht.kad_dht
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.peer\_routing module
------------------------------------
.. automodule:: libp2p.kad_dht.peer_routing
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.provider\_store module
--------------------------------------
.. automodule:: libp2p.kad_dht.provider_store
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.routing\_table module
-------------------------------------
.. automodule:: libp2p.kad_dht.routing_table
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.utils module
----------------------------
.. automodule:: libp2p.kad_dht.utils
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.value\_store module
-----------------------------------
.. automodule:: libp2p.kad_dht.value_store
:members:
:undoc-members:
:show-inheritance:
libp2p.kad\_dht.pb
------------------
.. automodule:: libp2p.kad_dht.pb
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.kad_dht
:members:
:undoc-members:
:show-inheritance:

View File

@ -11,6 +11,7 @@ Subpackages
libp2p.host
libp2p.identity
libp2p.io
libp2p.kad_dht
libp2p.network
libp2p.peer
libp2p.protocol_muxer

View File

@ -0,0 +1,300 @@
#!/usr/bin/env python
"""
A basic example of using the Kademlia DHT implementation, with all setup logic inlined.
This example demonstrates both value storage/retrieval and content server
advertisement/discovery.
"""
import argparse
import logging
import os
import random
import secrets
import sys
import base58
from multiaddr import (
Multiaddr,
)
import trio
from libp2p import (
new_host,
)
from libp2p.abc import (
IHost,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.kad_dht.kad_dht import (
DHTMode,
KadDHT,
)
from libp2p.kad_dht.utils import (
create_key_from_binary,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.utils import (
info_from_p2p_addr,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger("kademlia-example")
# Configure DHT module loggers to inherit from the parent logger
# This ensures all kademlia-example.* loggers use the same configuration
# Get the directory where this script is located
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt")
# Set the level for all child loggers
for module in [
"kad_dht",
"value_store",
"peer_routing",
"routing_table",
"provider_store",
]:
child_logger = logging.getLogger(f"kademlia-example.{module}")
child_logger.setLevel(logging.INFO)
child_logger.propagate = True # Allow propagation to parent
# File to store node information
bootstrap_nodes = []
# function to take bootstrap_nodes as input and connects to them
async def connect_to_bootstrap_nodes(host: IHost, bootstrap_addrs: list[str]) -> None:
"""
Connect to the bootstrap nodes provided in the list.
params: host: The host instance to connect to
bootstrap_addrs: List of bootstrap node addresses
Returns
-------
None
"""
for addr in bootstrap_addrs:
try:
peerInfo = info_from_p2p_addr(Multiaddr(addr))
host.get_peerstore().add_addrs(peerInfo.peer_id, peerInfo.addrs, 3600)
await host.connect(peerInfo)
except Exception as e:
logger.error(f"Failed to connect to bootstrap node {addr}: {e}")
def save_server_addr(addr: str) -> None:
"""Append the server's multiaddress to the log file."""
try:
with open(SERVER_ADDR_LOG, "w") as f:
f.write(addr + "\n")
logger.info(f"Saved server address to log: {addr}")
except Exception as e:
logger.error(f"Failed to save server address: {e}")
def load_server_addrs() -> list[str]:
"""Load all server multiaddresses from the log file."""
if not os.path.exists(SERVER_ADDR_LOG):
return []
try:
with open(SERVER_ADDR_LOG) as f:
return [line.strip() for line in f if line.strip()]
except Exception as e:
logger.error(f"Failed to load server addresses: {e}")
return []
async def run_node(
port: int, mode: str, bootstrap_addrs: list[str] | None = None
) -> None:
"""Run a node that serves content in the DHT with setup inlined."""
try:
if port <= 0:
port = random.randint(10000, 60000)
logger.debug(f"Using port: {port}")
# Convert string mode to DHTMode enum
if mode is None or mode.upper() == "CLIENT":
dht_mode = DHTMode.CLIENT
elif mode.upper() == "SERVER":
dht_mode = DHTMode.SERVER
else:
logger.error(f"Invalid mode: {mode}. Must be 'client' or 'server'")
sys.exit(1)
# Load server addresses for client mode
if dht_mode == DHTMode.CLIENT:
server_addrs = load_server_addrs()
if server_addrs:
logger.info(f"Loaded {len(server_addrs)} server addresses from log")
bootstrap_nodes.append(server_addrs[0]) # Use the first server address
else:
logger.warning("No server addresses found in log file")
if bootstrap_addrs:
for addr in bootstrap_addrs:
bootstrap_nodes.append(addr)
key_pair = create_new_key_pair(secrets.token_bytes(32))
host = new_host(key_pair=key_pair)
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
async with host.run(listen_addrs=[listen_addr]):
peer_id = host.get_id().pretty()
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
await connect_to_bootstrap_nodes(host, bootstrap_nodes)
dht = KadDHT(host, dht_mode)
# take all peer ids from the host and add them to the dht
for peer_id in host.get_peerstore().peer_ids():
await dht.routing_table.add_peer(peer_id)
logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}")
bootstrap_cmd = f"--bootstrap {addr_str}"
logger.info("To connect to this node, use: %s", bootstrap_cmd)
# Save server address in server mode
if dht_mode == DHTMode.SERVER:
save_server_addr(addr_str)
# Start the DHT service
async with background_trio_service(dht):
logger.info(f"DHT service started in {dht_mode.value} mode")
val_key = create_key_from_binary(b"py-libp2p kademlia example value")
content = b"Hello from python node "
content_key = create_key_from_binary(content)
if dht_mode == DHTMode.SERVER:
# Store a value in the DHT
msg = "Hello message from Sumanjeet"
val_data = msg.encode()
await dht.put_value(val_key, val_data)
logger.info(
f"Stored value '{val_data.decode()}'"
f"with key: {base58.b58encode(val_key).decode()}"
)
# Advertise as content server
success = await dht.provider_store.provide(content_key)
if success:
logger.info(
"Successfully advertised as server"
f"for content: {content_key.hex()}"
)
else:
logger.warning("Failed to advertise as content server")
else:
# retrieve the value
logger.info(
"Looking up key: %s", base58.b58encode(val_key).decode()
)
val_data = await dht.get_value(val_key)
if val_data:
try:
logger.info(f"Retrieved value: {val_data.decode()}")
except UnicodeDecodeError:
logger.info(f"Retrieved value (bytes): {val_data!r}")
else:
logger.warning("Failed to retrieve value")
# Also check if we can find servers for our own content
logger.info("Looking for servers of content: %s", content_key.hex())
providers = await dht.provider_store.find_providers(content_key)
if providers:
logger.info(
"Found %d servers for content: %s",
len(providers),
[p.peer_id.pretty() for p in providers],
)
else:
logger.warning(
"No servers found for content %s", content_key.hex()
)
# Keep the node running
while True:
logger.debug(
"Status - Connected peers: %d,"
"Peers in store: %d, Values in store: %d",
len(dht.host.get_connected_peers()),
len(dht.host.get_peerstore().peer_ids()),
len(dht.value_store.store),
)
await trio.sleep(10)
except Exception as e:
logger.error(f"Server node error: {e}", exc_info=True)
sys.exit(1)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Kademlia DHT example with content server functionality"
)
parser.add_argument(
"--mode",
default="server",
help="Run as a server or client node",
)
parser.add_argument(
"--port",
type=int,
default=0,
help="Port to listen on (0 for random)",
)
parser.add_argument(
"--bootstrap",
type=str,
nargs="*",
help=(
"Multiaddrs of bootstrap nodes. "
"Provide a space-separated list of addresses. "
"This is required for client mode."
),
)
# add option to use verbose logging
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging",
)
args = parser.parse_args()
# Set logging level based on verbosity
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
else:
logging.getLogger().setLevel(logging.INFO)
return args
def main():
"""Main entry point for the kademlia demo."""
try:
args = parse_args()
logger.info(
"Running in %s mode on port %d",
args.mode,
args.port,
)
trio.run(run_node, args.port, args.mode, args.bootstrap)
except Exception as e:
logger.critical(f"Script failed: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,30 @@
"""
Kademlia DHT implementation for py-libp2p.
This module provides a Distributed Hash Table (DHT) implementation
based on the Kademlia protocol.
"""
from .kad_dht import (
KadDHT,
)
from .peer_routing import (
PeerRouting,
)
from .routing_table import (
RoutingTable,
)
from .utils import (
create_key_from_binary,
)
from .value_store import (
ValueStore,
)
__all__ = [
"KadDHT",
"RoutingTable",
"PeerRouting",
"ValueStore",
"create_key_from_binary",
]

616
libp2p/kad_dht/kad_dht.py Normal file
View File

@ -0,0 +1,616 @@
"""
Kademlia DHT implementation for py-libp2p.
This module provides a complete Distributed Hash Table (DHT)
implementation based on the Kademlia algorithm and protocol.
"""
from enum import Enum
import logging
import time
from multiaddr import (
Multiaddr,
)
import trio
import varint
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.tools.async_service import (
Service,
)
from .pb.kademlia_pb2 import (
Message,
)
from .peer_routing import (
PeerRouting,
)
from .provider_store import (
ProviderStore,
)
from .routing_table import (
RoutingTable,
)
from .value_store import (
ValueStore,
)
logger = logging.getLogger("kademlia-example.kad_dht")
# logger = logging.getLogger("libp2p.kademlia")
# Default parameters
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
ROUTING_TABLE_REFRESH_INTERVAL = 1 * 60 # 1 min in seconds for testing
TTL = 24 * 60 * 60 # 24 hours in seconds
ALPHA = 3
QUERY_TIMEOUT = 10 # seconds
class DHTMode(Enum):
"""DHT operation modes."""
CLIENT = "CLIENT"
SERVER = "SERVER"
class KadDHT(Service):
"""
Kademlia DHT implementation for libp2p.
This class provides a DHT implementation that combines routing table management,
peer discovery, content routing, and value storage.
"""
def __init__(self, host: IHost, mode: DHTMode):
"""
Initialize a new Kademlia DHT node.
:param host: The libp2p host.
:param mode: The mode of host (Client or Server) - must be DHTMode enum
"""
super().__init__()
self.host = host
self.local_peer_id = host.get_id()
# Validate that mode is a DHTMode enum
if not isinstance(mode, DHTMode):
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
self.mode = mode
# Initialize the routing table
self.routing_table = RoutingTable(self.local_peer_id, self.host)
# Initialize peer routing
self.peer_routing = PeerRouting(host, self.routing_table)
# Initialize value store
self.value_store = ValueStore(host=host, local_peer_id=self.local_peer_id)
# Initialize provider store with host and peer_routing references
self.provider_store = ProviderStore(host=host, peer_routing=self.peer_routing)
# Last time we republished provider records
self._last_provider_republish = time.time()
# Set protocol handlers
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
async def run(self) -> None:
"""Run the DHT service."""
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
# Main service loop
while self.manager.is_running:
# Periodically refresh the routing table
await self.refresh_routing_table()
# Check if it's time to republish provider records
current_time = time.time()
# await self._republish_provider_records()
self._last_provider_republish = current_time
# Clean up expired values and provider records
expired_values = self.value_store.cleanup_expired()
if expired_values > 0:
logger.debug(f"Cleaned up {expired_values} expired values")
self.provider_store.cleanup_expired()
# Wait before next maintenance cycle
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
"""
Switch the DHT mode.
:param new_mode: The new mode - must be DHTMode enum
:return: The new mode as DHTMode enum
"""
# Validate that new_mode is a DHTMode enum
if not isinstance(new_mode, DHTMode):
raise TypeError(f"new_mode must be DHTMode enum, got {type(new_mode)}")
if new_mode == DHTMode.CLIENT:
self.routing_table.cleanup_routing_table()
self.mode = new_mode
logger.info(f"Switched to {new_mode.value} mode")
return self.mode
async def handle_stream(self, stream: INetStream) -> None:
"""
Handle an incoming DHT stream using varint length prefixes.
"""
if self.mode == DHTMode.CLIENT:
stream.close
return
peer_id = stream.muxed_conn.peer_id
logger.debug(f"Received DHT stream from peer {peer_id}")
await self.add_peer(peer_id)
logger.debug(f"Added peer {peer_id} to routing table")
try:
# Read varint-prefixed length for the message
length_prefix = b""
while True:
byte = await stream.read(1)
if not byte:
logger.warning("Stream closed while reading varint length")
await stream.close()
return
length_prefix += byte
if byte[0] & 0x80 == 0:
break
msg_length = varint.decode_bytes(length_prefix)
# Read the message bytes
msg_bytes = await stream.read(msg_length)
if len(msg_bytes) < msg_length:
logger.warning("Failed to read full message from stream")
await stream.close()
return
try:
# Parse as protobuf
message = Message()
message.ParseFromString(msg_bytes)
logger.debug(
f"Received DHT message from {peer_id}, type: {message.type}"
)
# Handle FIND_NODE message
if message.type == Message.MessageType.FIND_NODE:
# Get target key directly from protobuf
target_key = message.key
# Find closest peers to the target key
closest_peers = self.routing_table.find_local_closest_peers(
target_key, 20
)
logger.debug(f"Found {len(closest_peers)} peers close to target")
# Build response message with protobuf
response = Message()
response.type = Message.MessageType.FIND_NODE
# Add closest peers to response
for peer in closest_peers:
# Skip if the peer is the requester
if peer == peer_id:
continue
# Add peer to closerPeers field
peer_proto = response.closerPeers.add()
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
if addrs:
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug(
f"Sent FIND_NODE response with{len(response.closerPeers)} peers"
)
# Handle ADD_PROVIDER message
elif message.type == Message.MessageType.ADD_PROVIDER:
# Process ADD_PROVIDER
key = message.key
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
# Extract provider information
for provider_proto in message.providerPeers:
try:
# Validate that the provider is the sender
provider_id = ID(provider_proto.id)
if provider_id != peer_id:
logger.warning(
f"Provider ID {provider_id} doesn't"
f"match sender {peer_id}, ignoring"
)
continue
# Convert addresses to Multiaddr
addrs = []
for addr_bytes in provider_proto.addrs:
try:
addrs.append(Multiaddr(addr_bytes))
except Exception as e:
logger.warning(f"Failed to parse address: {e}")
# Add to provider store
provider_info = PeerInfo(provider_id, addrs)
self.provider_store.add_provider(key, provider_info)
logger.debug(
f"Added provider {provider_id} for key {key.hex()}"
)
except Exception as e:
logger.warning(f"Failed to process provider info: {e}")
# Send acknowledgement
response = Message()
response.type = Message.MessageType.ADD_PROVIDER
response.key = key
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent ADD_PROVIDER acknowledgement")
# Handle GET_PROVIDERS message
elif message.type == Message.MessageType.GET_PROVIDERS:
# Process GET_PROVIDERS
key = message.key
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
# Find providers for the key
providers = self.provider_store.get_providers(key)
logger.debug(
f"Found {len(providers)} providers for key {key.hex()}"
)
# Create response
response = Message()
response.type = Message.MessageType.GET_PROVIDERS
response.key = key
# Add provider information to response
for provider_info in providers:
provider_proto = response.providerPeers.add()
provider_proto.id = provider_info.peer_id.to_bytes()
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
for addr in provider_info.addrs:
provider_proto.addrs.append(addr.to_bytes())
# Also include closest peers if we don't have providers
if not providers:
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
)
logger.debug(
f"No providers found, including {len(closest_peers)}"
"closest peers"
)
for peer in closest_peers:
# Skip if peer is the requester
if peer == peer_id:
continue
peer_proto = response.closerPeers.add()
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent GET_PROVIDERS response")
# Handle GET_VALUE message
elif message.type == Message.MessageType.GET_VALUE:
# Process GET_VALUE
key = message.key
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
value = self.value_store.get(key)
if value:
logger.debug(f"Found value for key {key.hex()}")
# Create response using protobuf
response = Message()
response.type = Message.MessageType.GET_VALUE
# Create record
response.key = key
response.record.key = key
response.record.value = value
response.record.timeReceived = str(time.time())
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent GET_VALUE response")
else:
logger.debug(f"No value found for key {key.hex()}")
# Create response with closest peers when no value is found
response = Message()
response.type = Message.MessageType.GET_VALUE
response.key = key
# Add closest peers to key
closest_peers = self.routing_table.find_local_closest_peers(
key, 20
)
logger.debug(
"No value found,"
f"including {len(closest_peers)} closest peers"
)
for peer in closest_peers:
# Skip if peer is the requester
if peer == peer_id:
continue
peer_proto = response.closerPeers.add()
peer_proto.id = peer.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer)
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Serialize and send response
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent GET_VALUE response with closest peers")
# Handle PUT_VALUE message
elif message.type == Message.MessageType.PUT_VALUE and message.HasField(
"record"
):
# Process PUT_VALUE
key = message.record.key
value = message.record.value
success = False
try:
if not (key and value):
raise ValueError(
"Missing key or value in PUT_VALUE message"
)
self.value_store.put(key, value)
logger.debug(f"Stored value {value.hex()} for key {key.hex()}")
success = True
except Exception as e:
logger.warning(
f"Failed to store value {value.hex()} for key "
f"{key.hex()}: {e}"
)
finally:
# Send acknowledgement
response = Message()
response.type = Message.MessageType.PUT_VALUE
if success:
response.key = key
response_bytes = response.SerializeToString()
await stream.write(varint.encode(len(response_bytes)))
await stream.write(response_bytes)
logger.debug("Sent PUT_VALUE acknowledgement")
except Exception as proto_err:
logger.warning(f"Failed to parse protobuf message: {proto_err}")
await stream.close()
except Exception as e:
logger.error(f"Error handling DHT stream: {e}")
await stream.close()
async def refresh_routing_table(self) -> None:
"""Refresh the routing table."""
logger.debug("Refreshing routing table")
await self.peer_routing.refresh_routing_table()
# Peer routing methods
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
"""
Find a peer with the given ID.
"""
logger.debug(f"Finding peer: {peer_id}")
return await self.peer_routing.find_peer(peer_id)
# Value storage and retrieval methods
async def put_value(self, key: bytes, value: bytes) -> None:
"""
Store a value in the DHT.
"""
logger.debug(f"Storing value for key {key.hex()}")
# 1. Store locally first
self.value_store.put(key, value)
try:
decoded_value = value.decode("utf-8")
except UnicodeDecodeError:
decoded_value = value.hex()
logger.debug(
f"Stored value locally for key {key.hex()} with value {decoded_value}"
)
# 2. Get closest peers, excluding self
closest_peers = [
peer
for peer in self.routing_table.find_local_closest_peers(key)
if peer != self.local_peer_id
]
logger.debug(f"Found {len(closest_peers)} peers to store value at")
# 3. Store at remote peers in batches of ALPHA, in parallel
stored_count = 0
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
batch_results = [False] * len(batch)
async def store_one(idx: int, peer: ID) -> None:
try:
with trio.move_on_after(QUERY_TIMEOUT):
success = await self.value_store._store_at_peer(
peer, key, value
)
batch_results[idx] = success
if success:
logger.debug(f"Stored value at peer {peer}")
else:
logger.debug(f"Failed to store value at peer {peer}")
except Exception as e:
logger.debug(f"Error storing value at peer {peer}: {e}")
async with trio.open_nursery() as nursery:
for idx, peer in enumerate(batch):
nursery.start_soon(store_one, idx, peer)
stored_count += sum(batch_results)
logger.info(f"Successfully stored value at {stored_count} peers")
async def get_value(self, key: bytes) -> bytes | None:
logger.debug(f"Getting value for key: {key.hex()}")
# 1. Check local store first
value = self.value_store.get(key)
if value:
logger.debug("Found value locally")
return value
# 2. Get closest peers, excluding self
closest_peers = [
peer
for peer in self.routing_table.find_local_closest_peers(key)
if peer != self.local_peer_id
]
logger.debug(f"Searching {len(closest_peers)} peers for value")
# 3. Query ALPHA peers at a time in parallel
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
found_value = None
async def query_one(peer: ID) -> None:
nonlocal found_value
try:
with trio.move_on_after(QUERY_TIMEOUT):
value = await self.value_store._get_from_peer(peer, key)
if value is not None and found_value is None:
found_value = value
logger.debug(f"Found value at peer {peer}")
except Exception as e:
logger.debug(f"Error querying peer {peer}: {e}")
async with trio.open_nursery() as nursery:
for peer in batch:
nursery.start_soon(query_one, peer)
if found_value is not None:
self.value_store.put(key, found_value)
logger.info("Successfully retrieved value from network")
return found_value
# 4. Not found
logger.warning(f"Value not found for key {key.hex()}")
return None
# Add these methods in the Utility methods section
# Utility methods
async def add_peer(self, peer_id: ID) -> bool:
"""
Add a peer to the routing table.
params: peer_id: The peer ID to add.
Returns
-------
bool
True if peer was added or updated, False otherwise.
"""
return await self.routing_table.add_peer(peer_id)
async def provide(self, key: bytes) -> bool:
"""
Reference to provider_store.provide for convenience.
"""
return await self.provider_store.provide(key)
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
"""
Reference to provider_store.find_providers for convenience.
"""
return await self.provider_store.find_providers(key, count)
def get_routing_table_size(self) -> int:
"""
Get the number of peers in the routing table.
Returns
-------
int
Number of peers.
"""
return self.routing_table.size()
def get_value_store_size(self) -> int:
"""
Get the number of items in the value store.
Returns
-------
int
Number of items.
"""
return self.value_store.size()

View File

View File

@ -0,0 +1,38 @@
syntax = "proto3";
message Record {
bytes key = 1;
bytes value = 2;
string timeReceived = 5;
};
message Message {
enum MessageType {
PUT_VALUE = 0;
GET_VALUE = 1;
ADD_PROVIDER = 2;
GET_PROVIDERS = 3;
FIND_NODE = 4;
PING = 5;
}
enum ConnectionType {
NOT_CONNECTED = 0;
CONNECTED = 1;
CAN_CONNECT = 2;
CANNOT_CONNECT = 3;
}
message Peer {
bytes id = 1;
repeated bytes addrs = 2;
ConnectionType connection = 3;
}
MessageType type = 1;
int32 clusterLevelRaw = 10;
bytes key = 2;
Record record = 3;
repeated Peer closerPeers = 8;
repeated Peer providerPeers = 9;
}

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: libp2p/kad_dht/pb/kademlia.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_RECORD']._serialized_start=36
_globals['_RECORD']._serialized_end=94
_globals['_MESSAGE']._serialized_start=97
_globals['_MESSAGE']._serialized_end=555
_globals['_MESSAGE_PEER']._serialized_start=281
_globals['_MESSAGE_PEER']._serialized_end=359
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=361
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=466
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,133 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import collections.abc
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class Record(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
TIMERECEIVED_FIELD_NUMBER: builtins.int
key: builtins.bytes
value: builtins.bytes
timeReceived: builtins.str
def __init__(
self,
*,
key: builtins.bytes = ...,
value: builtins.bytes = ...,
timeReceived: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
global___Record = Record
@typing.final
class Message(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _MessageType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
PUT_VALUE: Message._MessageType.ValueType # 0
GET_VALUE: Message._MessageType.ValueType # 1
ADD_PROVIDER: Message._MessageType.ValueType # 2
GET_PROVIDERS: Message._MessageType.ValueType # 3
FIND_NODE: Message._MessageType.ValueType # 4
PING: Message._MessageType.ValueType # 5
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
PUT_VALUE: Message.MessageType.ValueType # 0
GET_VALUE: Message.MessageType.ValueType # 1
ADD_PROVIDER: Message.MessageType.ValueType # 2
GET_PROVIDERS: Message.MessageType.ValueType # 3
FIND_NODE: Message.MessageType.ValueType # 4
PING: Message.MessageType.ValueType # 5
class _ConnectionType:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
CONNECTED: Message._ConnectionType.ValueType # 1
CAN_CONNECT: Message._ConnectionType.ValueType # 2
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
CONNECTED: Message.ConnectionType.ValueType # 1
CAN_CONNECT: Message.ConnectionType.ValueType # 2
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
@typing.final
class Peer(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
ID_FIELD_NUMBER: builtins.int
ADDRS_FIELD_NUMBER: builtins.int
CONNECTION_FIELD_NUMBER: builtins.int
id: builtins.bytes
connection: global___Message.ConnectionType.ValueType
@property
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
def __init__(
self,
*,
id: builtins.bytes = ...,
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
connection: global___Message.ConnectionType.ValueType = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
TYPE_FIELD_NUMBER: builtins.int
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
KEY_FIELD_NUMBER: builtins.int
RECORD_FIELD_NUMBER: builtins.int
CLOSERPEERS_FIELD_NUMBER: builtins.int
PROVIDERPEERS_FIELD_NUMBER: builtins.int
type: global___Message.MessageType.ValueType
clusterLevelRaw: builtins.int
key: builtins.bytes
@property
def record(self) -> global___Record: ...
@property
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
@property
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
def __init__(
self,
*,
type: global___Message.MessageType.ValueType = ...,
clusterLevelRaw: builtins.int = ...,
key: builtins.bytes = ...,
record: global___Record | None = ...,
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
global___Message = Message

View File

@ -0,0 +1,418 @@
"""
Peer routing implementation for Kademlia DHT.
This module implements the peer routing interface using Kademlia's algorithm
to efficiently locate peers in a distributed network.
"""
import logging
import trio
import varint
from libp2p.abc import (
IHost,
INetStream,
IPeerRouting,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .pb.kademlia_pb2 import (
Message,
)
from .routing_table import (
RoutingTable,
)
from .utils import (
sort_peer_ids_by_distance,
)
# logger = logging.getLogger("libp2p.kademlia.peer_routing")
logger = logging.getLogger("kademlia-example.peer_routing")
# Constants for the Kademlia algorithm
ALPHA = 3 # Concurrency parameter
MAX_PEER_LOOKUP_ROUNDS = 20 # Maximum number of rounds in peer lookup
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
class PeerRouting(IPeerRouting):
"""
Implementation of peer routing using the Kademlia algorithm.
This class provides methods to find peers in the DHT network
and helps maintain the routing table.
"""
def __init__(self, host: IHost, routing_table: RoutingTable):
"""
Initialize the peer routing service.
:param host: The libp2p host
:param routing_table: The Kademlia routing table
"""
self.host = host
self.routing_table = routing_table
self.protocol_id = PROTOCOL_ID
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
"""
Find a peer with the given ID.
:param peer_id: The ID of the peer to find
Returns
-------
Optional[PeerInfo]
The peer information if found, None otherwise
"""
# Check if this is actually our peer ID
if peer_id == self.host.get_id():
try:
# Return our own peer info
return PeerInfo(peer_id, self.host.get_addrs())
except Exception:
logger.exception("Error getting our own peer info")
return None
# First check if the peer is in our routing table
peer_info = self.routing_table.get_peer_info(peer_id)
if peer_info:
logger.debug(f"Found peer {peer_id} in routing table")
return peer_info
# Then check if the peer is in our peerstore
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
logger.debug(f"Found peer {peer_id} in peerstore")
return PeerInfo(peer_id, addrs)
except Exception:
pass
# If not found locally, search the network
try:
closest_peers = await self.find_closest_peers_network(peer_id.to_bytes())
logger.info(f"Closest peers found: {closest_peers}")
# Check if we found the peer we're looking for
for found_peer in closest_peers:
if found_peer == peer_id:
try:
addrs = self.host.get_peerstore().addrs(found_peer)
if addrs:
return PeerInfo(found_peer, addrs)
except Exception:
pass
except Exception as e:
logger.error(f"Error searching for peer {peer_id}: {e}")
# Not found
logger.info(f"Peer {peer_id} not found")
return None
async def _query_single_peer_for_closest(
self, peer: ID, target_key: bytes, new_peers: list[ID]
) -> None:
"""
Query a single peer for closest peers and append results to the shared list.
params: peer : ID
The peer to query
params: target_key : bytes
The target key to find closest peers for
params: new_peers : list[ID]
Shared list to append results to
"""
try:
result = await self._query_peer_for_closest(peer, target_key)
# Add deduplication to prevent duplicate peers
for peer_id in result:
if peer_id not in new_peers:
new_peers.append(peer_id)
logger.debug(
"Queried peer %s for closest peers, got %d results (%d unique)",
peer,
len(result),
len([p for p in result if p not in new_peers[: -len(result)]]),
)
except Exception as e:
logger.debug(f"Query to peer {peer} failed: {e}")
async def find_closest_peers_network(
self, target_key: bytes, count: int = 20
) -> list[ID]:
"""
Find the closest peers to a target key in the entire network.
Performs an iterative lookup by querying peers for their closest peers.
Returns
-------
list[ID]
Closest peer IDs
"""
# Start with closest peers from our routing table
closest_peers = self.routing_table.find_local_closest_peers(target_key, count)
logger.debug("Local closest peers: %d found", len(closest_peers))
queried_peers: set[ID] = set()
rounds = 0
# Return early if we have no peers to start with
if not closest_peers:
logger.warning("No local peers available for network lookup")
return []
# Iterative lookup until convergence
while rounds < MAX_PEER_LOOKUP_ROUNDS:
rounds += 1
logger.debug(f"Lookup round {rounds}/{MAX_PEER_LOOKUP_ROUNDS}")
# Find peers we haven't queried yet
peers_to_query = [p for p in closest_peers if p not in queried_peers]
if not peers_to_query:
logger.debug("No more unqueried peers available, ending lookup")
break # No more peers to query
# Query these peers for their closest peers to target
peers_batch = peers_to_query[:ALPHA] # Limit to ALPHA peers at a time
# Mark these peers as queried before we actually query them
for peer in peers_batch:
queried_peers.add(peer)
# Run queries in parallel for this batch using trio nursery
new_peers: list[ID] = [] # Shared array to collect all results
async with trio.open_nursery() as nursery:
for peer in peers_batch:
nursery.start_soon(
self._query_single_peer_for_closest, peer, target_key, new_peers
)
# If we got no new peers, we're done
if not new_peers:
logger.debug("No new peers discovered in this round, ending lookup")
break
# Update our list of closest peers
all_candidates = closest_peers + new_peers
old_closest_peers = closest_peers[:]
closest_peers = sort_peer_ids_by_distance(target_key, all_candidates)[
:count
]
logger.debug(f"Updated closest peers count: {len(closest_peers)}")
# Check if we made any progress (found closer peers)
if closest_peers == old_closest_peers:
logger.debug("No improvement in closest peers, ending lookup")
break
logger.info(
f"Network lookup completed after {rounds} rounds, "
f"found {len(closest_peers)} peers"
)
return closest_peers
async def _query_peer_for_closest(self, peer: ID, target_key: bytes) -> list[ID]:
"""
Query a peer for their closest peers
to the target key using varint length prefix
"""
stream = None
results = []
try:
# Add the peer to our routing table regardless of query outcome
try:
addrs = self.host.get_peerstore().addrs(peer)
if addrs:
peer_info = PeerInfo(peer, addrs)
await self.routing_table.add_peer(peer_info)
except Exception as e:
logger.debug(f"Failed to add peer {peer} to routing table: {e}")
# Open a stream to the peer using the Kademlia protocol
logger.debug(f"Opening stream to {peer} for closest peers query")
try:
stream = await self.host.new_stream(peer, [self.protocol_id])
logger.debug(f"Stream opened to {peer}")
except Exception as e:
logger.warning(f"Failed to open stream to {peer}: {e}")
return []
# Create and send FIND_NODE request using protobuf
find_node_msg = Message()
find_node_msg.type = Message.MessageType.FIND_NODE
find_node_msg.key = target_key # Set target key directly as bytes
# Serialize and send the protobuf message with varint length prefix
proto_bytes = find_node_msg.SerializeToString()
logger.debug(
f"Sending FIND_NODE: {proto_bytes.hex()} (len={len(proto_bytes)})"
)
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
# Read varint-prefixed response length
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
logger.warning(
"Error reading varint length from stream: connection closed"
)
return []
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
logger.debug(f"Connection closed by peer {peer} while reading data")
return []
response_bytes += chunk
remaining -= len(chunk)
# Parse the protobuf response
response_msg = Message()
response_msg.ParseFromString(response_bytes)
logger.debug(
"Received response from %s with %d peers",
peer,
len(response_msg.closerPeers),
)
# Process closest peers from response
if response_msg.type == Message.MessageType.FIND_NODE:
for peer_data in response_msg.closerPeers:
new_peer_id = ID(peer_data.id)
if new_peer_id not in results:
results.append(new_peer_id)
if peer_data.addrs:
from multiaddr import (
Multiaddr,
)
addrs = [Multiaddr(addr) for addr in peer_data.addrs]
self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600)
except Exception as e:
logger.debug(f"Error querying peer {peer} for closest: {e}")
finally:
if stream:
await stream.close()
return results
async def _handle_kad_stream(self, stream: INetStream) -> None:
"""
Handle incoming Kademlia protocol streams.
params: stream: The incoming stream
Returns
-------
None
"""
try:
# Read message length
length_bytes = await stream.read(4)
if not length_bytes:
return
message_length = int.from_bytes(length_bytes, byteorder="big")
# Read message
message_bytes = await stream.read(message_length)
if not message_bytes:
return
# Parse protobuf message
kad_message = Message()
try:
kad_message.ParseFromString(message_bytes)
if kad_message.type == Message.MessageType.FIND_NODE:
# Get target key directly from protobuf message
target_key = kad_message.key
# Find closest peers to target
closest_peers = self.routing_table.find_local_closest_peers(
target_key, 20
)
# Create protobuf response
response = Message()
response.type = Message.MessageType.FIND_NODE
# Add peer information to response
for peer_id in closest_peers:
peer_proto = response.closerPeers.add()
peer_proto.id = peer_id.to_bytes()
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
# Add addresses if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
for addr in addrs:
peer_proto.addrs.append(addr.to_bytes())
except Exception:
pass
# Send response
response_bytes = response.SerializeToString()
await stream.write(len(response_bytes).to_bytes(4, byteorder="big"))
await stream.write(response_bytes)
except Exception as parse_err:
logger.error(f"Failed to parse protocol buffer message: {parse_err}")
except Exception as e:
logger.debug(f"Error handling Kademlia stream: {e}")
finally:
await stream.close()
async def refresh_routing_table(self) -> None:
"""
Refresh the routing table by performing lookups for random keys.
Returns
-------
None
"""
logger.info("Refreshing routing table")
# Perform a lookup for ourselves to populate the routing table
local_id = self.host.get_id()
closest_peers = await self.find_closest_peers_network(local_id.to_bytes())
# Add discovered peers to routing table
for peer_id in closest_peers:
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
peer_info = PeerInfo(peer_id, addrs)
await self.routing_table.add_peer(peer_info)
except Exception as e:
logger.debug(f"Failed to add discovered peer {peer_id}: {e}")

View File

@ -0,0 +1,575 @@
"""
Provider record storage for Kademlia DHT.
This module implements the storage for content provider records in the Kademlia DHT.
"""
import logging
import time
from typing import (
Any,
)
from multiaddr import (
Multiaddr,
)
import trio
import varint
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .pb.kademlia_pb2 import (
Message,
)
# logger = logging.getLogger("libp2p.kademlia.provider_store")
logger = logging.getLogger("kademlia-example.provider_store")
# Constants for provider records (based on IPFS standards)
PROVIDER_RECORD_REPUBLISH_INTERVAL = 22 * 60 * 60 # 22 hours in seconds
PROVIDER_RECORD_EXPIRATION_INTERVAL = 48 * 60 * 60 # 48 hours in seconds
PROVIDER_ADDRESS_TTL = 30 * 60 # 30 minutes in seconds
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
ALPHA = 3 # Number of parallel queries/advertisements
QUERY_TIMEOUT = 10 # Timeout for each query in seconds
class ProviderRecord:
"""
A record for a content provider in the DHT.
Contains the peer information and timestamp.
"""
def __init__(
self,
provider_info: PeerInfo,
timestamp: float | None = None,
) -> None:
"""
Initialize a new provider record.
:param provider_info: The provider's peer information
:param timestamp: Time this record was created/updated
(defaults to current time)
"""
self.provider_info = provider_info
self.timestamp = timestamp or time.time()
def is_expired(self) -> bool:
"""
Check if this provider record has expired.
Returns
-------
bool
True if the record has expired
"""
current_time = time.time()
return (current_time - self.timestamp) >= PROVIDER_RECORD_EXPIRATION_INTERVAL
def should_republish(self) -> bool:
"""
Check if this provider record should be republished.
Returns
-------
bool
True if the record should be republished
"""
current_time = time.time()
return (current_time - self.timestamp) >= PROVIDER_RECORD_REPUBLISH_INTERVAL
@property
def peer_id(self) -> ID:
"""Get the provider's peer ID."""
return self.provider_info.peer_id
@property
def addresses(self) -> list[Multiaddr]:
"""Get the provider's addresses."""
return self.provider_info.addrs
class ProviderStore:
"""
Store for content provider records in the Kademlia DHT.
Maps content keys to provider records, with support for expiration.
"""
def __init__(self, host: IHost, peer_routing: Any = None) -> None:
"""
Initialize a new provider store.
:param host: The libp2p host instance (optional)
:param peer_routing: The peer routing instance (optional)
"""
# Maps content keys to a dict of provider records (peer_id -> record)
self.providers: dict[bytes, dict[str, ProviderRecord]] = {}
self.host = host
self.peer_routing = peer_routing
self.providing_keys: set[bytes] = set()
self.local_peer_id = host.get_id()
async def _republish_provider_records(self) -> None:
"""Republish all provider records for content this node is providing."""
# First, republish keys we're actively providing
for key in self.providing_keys:
logger.debug(f"Republishing provider record for key {key.hex()}")
await self.provide(key)
# Also check for any records that should be republished
time.time()
for key, providers in self.providers.items():
for peer_id_str, record in providers.items():
# Only republish records for our own peer
if self.local_peer_id and str(self.local_peer_id) == peer_id_str:
if record.should_republish():
logger.debug(
f"Republishing old provider record for key {key.hex()}"
)
await self.provide(key)
async def provide(self, key: bytes) -> bool:
"""
Advertise that this node can provide a piece of content.
Finds the k closest peers to the key and sends them ADD_PROVIDER messages.
:param key: The content key (multihash) to advertise
Returns
-------
bool
True if the advertisement was successful
"""
if not self.host or not self.peer_routing:
logger.error("Host or peer_routing not initialized, cannot provide content")
return False
# Add to local provider store
local_addrs = []
for addr in self.host.get_addrs():
local_addrs.append(addr)
local_peer_info = PeerInfo(self.host.get_id(), local_addrs)
self.add_provider(key, local_peer_info)
# Track that we're providing this key
self.providing_keys.add(key)
# Find the k closest peers to the key
closest_peers = await self.peer_routing.find_closest_peers_network(key)
logger.debug(
"Found %d peers close to key %s for provider advertisement",
len(closest_peers),
key.hex(),
)
# Send ADD_PROVIDER messages to these ALPHA peers in parallel.
success_count = 0
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
results: list[bool] = [False] * len(batch)
async def send_one(
idx: int, peer_id: ID, results: list[bool] = results
) -> None:
if peer_id == self.local_peer_id:
return
try:
with trio.move_on_after(QUERY_TIMEOUT):
success = await self._send_add_provider(peer_id, key)
results[idx] = success
if not success:
logger.warning(f"Failed to send ADD_PROVIDER to {peer_id}")
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
async with trio.open_nursery() as nursery:
for idx, peer_id in enumerate(batch):
nursery.start_soon(send_one, idx, peer_id, results)
success_count += sum(results)
logger.info(f"Successfully advertised to {success_count} peers")
return success_count > 0
async def _send_add_provider(self, peer_id: ID, key: bytes) -> bool:
"""
Send ADD_PROVIDER message to a specific peer.
:param peer_id: The peer to send the message to
:param key: The content key being provided
Returns
-------
bool
True if the message was successfully sent and acknowledged
"""
try:
result = False
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
# Get our addresses to include in the message
addrs = []
for addr in self.host.get_addrs():
addrs.append(addr.to_bytes())
# Create the ADD_PROVIDER message
message = Message()
message.type = Message.MessageType.ADD_PROVIDER
message.key = key
# Add our provider info
provider = message.providerPeers.add()
provider.id = self.local_peer_id.to_bytes()
provider.addrs.extend(addrs)
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
logger.debug(f"Sent ADD_PROVIDER to {peer_id} for key {key.hex()}")
# Read response length prefix
length_bytes = b""
while True:
logger.debug("Reading response length prefix in add provider")
b = await stream.read(1)
if not b:
return False
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
return False
response_bytes += chunk
remaining -= len(chunk)
# Parse response
response = Message()
response.ParseFromString(response_bytes)
# Check response type
response.type == Message.MessageType.ADD_PROVIDER
if response.type:
result = True
except Exception as e:
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
finally:
await stream.close()
return result
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
"""
Find content providers for a given key.
:param key: The content key to look for
:param count: Maximum number of providers to return
Returns
-------
List[PeerInfo]
List of content providers
"""
if not self.host or not self.peer_routing:
logger.error("Host or peer_routing not initialized, cannot find providers")
return []
# Check local provider store first
local_providers = self.get_providers(key)
if local_providers:
logger.debug(
f"Found {len(local_providers)} providers locally for {key.hex()}"
)
return local_providers[:count]
logger.debug("local providers are %s", local_providers)
# Find the closest peers to the key
closest_peers = await self.peer_routing.find_closest_peers_network(key)
logger.debug(
f"Searching {len(closest_peers)} peers for providers of {key.hex()}"
)
# Query these peers for providers in batches of ALPHA, in parallel, with timeout
all_providers = []
for i in range(0, len(closest_peers), ALPHA):
batch = closest_peers[i : i + ALPHA]
batch_results: list[list[PeerInfo]] = [[] for _ in batch]
async def get_one(
idx: int,
peer_id: ID,
batch_results: list[list[PeerInfo]] = batch_results,
) -> None:
if peer_id == self.local_peer_id:
return
try:
with trio.move_on_after(QUERY_TIMEOUT):
providers = await self._get_providers_from_peer(peer_id, key)
if providers:
for provider in providers:
self.add_provider(key, provider)
batch_results[idx] = providers
else:
logger.debug(f"No providers found at peer {peer_id}")
except Exception as e:
logger.warning(f"Failed to get providers from {peer_id}: {e}")
async with trio.open_nursery() as nursery:
for idx, peer_id in enumerate(batch):
nursery.start_soon(get_one, idx, peer_id, batch_results)
for providers in batch_results:
all_providers.extend(providers)
if len(all_providers) >= count:
return all_providers[:count]
return all_providers[:count]
async def _get_providers_from_peer(self, peer_id: ID, key: bytes) -> list[PeerInfo]:
"""
Get content providers from a specific peer.
:param peer_id: The peer to query
:param key: The content key to look for
Returns
-------
List[PeerInfo]
List of provider information
"""
providers: list[PeerInfo] = []
try:
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
try:
# Create the GET_PROVIDERS message
message = Message()
message.type = Message.MessageType.GET_PROVIDERS
message.key = key
# Serialize and send the message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
# Read response length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
return []
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
return []
response_bytes += chunk
remaining -= len(chunk)
# Parse response
response = Message()
response.ParseFromString(response_bytes)
# Check response type
if response.type != Message.MessageType.GET_PROVIDERS:
return []
# Extract provider information
providers = []
for provider_proto in response.providerPeers:
try:
# Create peer ID from bytes
provider_id = ID(provider_proto.id)
# Convert addresses to Multiaddr
addrs = []
for addr_bytes in provider_proto.addrs:
try:
addrs.append(Multiaddr(addr_bytes))
except Exception:
pass # Skip invalid addresses
# Create PeerInfo and add to result
providers.append(PeerInfo(provider_id, addrs))
except Exception as e:
logger.warning(f"Failed to parse provider info: {e}")
finally:
await stream.close()
return providers
except Exception as e:
logger.warning(f"Error getting providers from {peer_id}: {e}")
return []
def add_provider(self, key: bytes, provider: PeerInfo) -> None:
"""
Add a provider for a given content key.
:param key: The content key
:param provider: The provider's peer information
Returns
-------
None
"""
# Initialize providers for this key if needed
if key not in self.providers:
self.providers[key] = {}
# Add or update the provider record
peer_id_str = str(provider.peer_id) # Use string representation as dict key
self.providers[key][peer_id_str] = ProviderRecord(
provider_info=provider, timestamp=time.time()
)
logger.debug(f"Added provider {provider.peer_id} for key {key.hex()}")
def get_providers(self, key: bytes) -> list[PeerInfo]:
"""
Get all providers for a given content key.
:param key: The content key
Returns
-------
List[PeerInfo]
List of providers for the key
"""
if key not in self.providers:
return []
# Collect valid provider records (not expired)
result = []
current_time = time.time()
expired_peers = []
for peer_id_str, record in self.providers[key].items():
# Check if the record has expired
if current_time - record.timestamp > PROVIDER_RECORD_EXPIRATION_INTERVAL:
expired_peers.append(peer_id_str)
continue
# Use addresses only if they haven't expired
addresses = []
if current_time - record.timestamp <= PROVIDER_ADDRESS_TTL:
addresses = record.addresses
# Create PeerInfo and add to results
result.append(PeerInfo(record.peer_id, addresses))
# Clean up expired records
for peer_id in expired_peers:
del self.providers[key][peer_id]
# Remove the key if no providers left
if not self.providers[key]:
del self.providers[key]
return result
def cleanup_expired(self) -> None:
"""Remove expired provider records."""
current_time = time.time()
expired_keys = []
for key, providers in self.providers.items():
expired_providers = []
for peer_id_str, record in providers.items():
if (
current_time - record.timestamp
> PROVIDER_RECORD_EXPIRATION_INTERVAL
):
expired_providers.append(peer_id_str)
logger.debug(
f"Removing expired provider {peer_id_str} for key {key.hex()}"
)
# Remove expired providers
for peer_id in expired_providers:
del providers[peer_id]
# Track empty keys for removal
if not providers:
expired_keys.append(key)
# Remove empty keys
for key in expired_keys:
del self.providers[key]
logger.debug(f"Removed key with no providers: {key.hex()}")
def get_provided_keys(self, peer_id: ID) -> list[bytes]:
"""
Get all content keys provided by a specific peer.
:param peer_id: The peer ID to look for
Returns
-------
List[bytes]
List of content keys provided by the peer
"""
peer_id_str = str(peer_id)
result = []
for key, providers in self.providers.items():
if peer_id_str in providers:
result.append(key)
return result
def size(self) -> int:
"""
Get the total number of provider records in the store.
Returns
-------
int
Total number of provider records across all keys
"""
total = 0
for providers in self.providers.values():
total += len(providers)
return total

View File

@ -0,0 +1,601 @@
"""
Kademlia DHT routing table implementation.
"""
from collections import (
OrderedDict,
)
import logging
import time
import trio
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.kad_dht.utils import xor_distance
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .pb.kademlia_pb2 import (
Message,
)
# logger = logging.getLogger("libp2p.kademlia.routing_table")
logger = logging.getLogger("kademlia-example.routing_table")
# Default parameters
BUCKET_SIZE = 20 # k in the Kademlia paper
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
class KBucket:
"""
A k-bucket implementation for the Kademlia DHT.
Each k-bucket stores up to k (BUCKET_SIZE) peers, sorted by least-recently seen.
"""
def __init__(
self,
host: IHost,
bucket_size: int = BUCKET_SIZE,
min_range: int = 0,
max_range: int = 2**256,
):
"""
Initialize a new k-bucket.
:param host: The host this bucket belongs to
:param bucket_size: Maximum number of peers to store in the bucket
:param min_range: Lower boundary of the bucket's key range (inclusive)
:param max_range: Upper boundary of the bucket's key range (exclusive)
"""
self.bucket_size = bucket_size
self.host = host
self.min_range = min_range
self.max_range = max_range
# Store PeerInfo objects along with last-seen timestamp
self.peers: OrderedDict[ID, tuple[PeerInfo, float]] = OrderedDict()
def peer_ids(self) -> list[ID]:
"""Get all peer IDs in the bucket."""
return list(self.peers.keys())
def peer_infos(self) -> list[PeerInfo]:
"""Get all PeerInfo objects in the bucket."""
return [info for info, _ in self.peers.values()]
def get_oldest_peer(self) -> ID | None:
"""Get the least-recently seen peer."""
if not self.peers:
return None
return next(iter(self.peers.keys()))
async def add_peer(self, peer_info: PeerInfo) -> bool:
"""
Add a peer to the bucket. Returns True if the peer was added or updated,
False if the bucket is full.
"""
current_time = time.time()
peer_id = peer_info.peer_id
# If peer is already in the bucket, move it to the end (most recently seen)
if peer_id in self.peers:
self.refresh_peer_last_seen(peer_id)
return True
# If bucket has space, add the peer
if len(self.peers) < self.bucket_size:
self.peers[peer_id] = (peer_info, current_time)
return True
# If bucket is full, we need to replace the least-recently seen peer
# Get the least-recently seen peer
oldest_peer_id = self.get_oldest_peer()
if oldest_peer_id is None:
logger.warning("No oldest peer found when bucket is full")
return False
# Check if the old peer is responsive to ping request
try:
# Try to ping the oldest peer, not the new peer
response = await self._ping_peer(oldest_peer_id)
if response:
# If the old peer is still alive, we will not add the new peer
logger.debug(
"Old peer %s is still alive, cannot add new peer %s",
oldest_peer_id,
peer_id,
)
return False
except Exception as e:
# If the old peer is unresponsive, we can replace it with the new peer
logger.debug(
"Old peer %s is unresponsive, replacing with new peer %s: %s",
oldest_peer_id,
peer_id,
str(e),
)
self.peers.popitem(last=False) # Remove oldest peer
self.peers[peer_id] = (peer_info, current_time)
return True
# If we got here, the oldest peer responded but we couldn't add the new peer
return False
def remove_peer(self, peer_id: ID) -> bool:
"""
Remove a peer from the bucket.
Returns True if the peer was in the bucket, False otherwise.
"""
if peer_id in self.peers:
del self.peers[peer_id]
return True
return False
def has_peer(self, peer_id: ID) -> bool:
"""Check if the peer is in the bucket."""
return peer_id in self.peers
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
"""Get the PeerInfo for a given peer ID if it exists in the bucket."""
if peer_id in self.peers:
return self.peers[peer_id][0]
return None
def size(self) -> int:
"""Get the number of peers in the bucket."""
return len(self.peers)
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
"""
Get peers that haven't been pinged recently.
params: stale_threshold_seconds: Time in seconds
params: after which a peer is considered stale
Returns
-------
list[ID]
List of peer IDs that need to be refreshed
"""
current_time = time.time()
stale_peers = []
for peer_id, (_, last_seen) in self.peers.items():
if current_time - last_seen > stale_threshold_seconds:
stale_peers.append(peer_id)
return stale_peers
async def _periodic_peer_refresh(self) -> None:
"""Background task to periodically refresh peers"""
try:
while True:
await trio.sleep(PEER_REFRESH_INTERVAL) # Check every minute
# Find stale peers (not pinged in last hour)
stale_peers = self.get_stale_peers(
stale_threshold_seconds=STALE_PEER_THRESHOLD
)
if stale_peers:
logger.debug(f"Found {len(stale_peers)} stale peers to refresh")
for peer_id in stale_peers:
try:
# Try to ping the peer
logger.debug("Pinging stale peer %s", peer_id)
responce = await self._ping_peer(peer_id)
if responce:
# Update the last seen time
self.refresh_peer_last_seen(peer_id)
logger.debug(f"Refreshed peer {peer_id}")
else:
# If ping fails, remove the peer
logger.debug(f"Failed to ping peer {peer_id}")
self.remove_peer(peer_id)
logger.info(f"Removed unresponsive peer {peer_id}")
logger.debug(f"Successfully refreshed peer {peer_id}")
except Exception as e:
# If ping fails, remove the peer
logger.debug(
"Failed to ping peer %s: %s",
peer_id,
e,
)
self.remove_peer(peer_id)
logger.info(f"Removed unresponsive peer {peer_id}")
except trio.Cancelled:
logger.debug("Peer refresh task cancelled")
except Exception as e:
logger.error(f"Error in peer refresh task: {e}", exc_info=True)
async def _ping_peer(self, peer_id: ID) -> bool:
"""
Ping a peer using protobuf message to check
if it's still alive and update last seen time.
params: peer_id: The ID of the peer to ping
Returns
-------
bool
True if ping successful, False otherwise
"""
result = False
# Get peer info directly from the bucket
peer_info = self.get_peer_info(peer_id)
if not peer_info:
raise ValueError(f"Peer {peer_id} not in bucket")
# Default protocol ID for Kademlia DHT
protocol_id = TProtocol("/ipfs/kad/1.0.0")
try:
# Open a stream to the peer with the DHT protocol
stream = await self.host.new_stream(peer_id, [protocol_id])
try:
# Create ping protobuf message
ping_msg = Message()
ping_msg.type = Message.PING # Use correct enum
# Serialize and send with length prefix (4 bytes big-endian)
msg_bytes = ping_msg.SerializeToString()
logger.debug(
f"Sending PING message to {peer_id}, size: {len(msg_bytes)} bytes"
)
await stream.write(len(msg_bytes).to_bytes(4, byteorder="big"))
await stream.write(msg_bytes)
# Wait for response with timeout
with trio.move_on_after(2): # 2 second timeout
# Read response length (4 bytes)
length_bytes = await stream.read(4)
if not length_bytes or len(length_bytes) < 4:
logger.warning(f"Peer {peer_id} disconnected during ping")
return False
msg_len = int.from_bytes(length_bytes, byteorder="big")
if (
msg_len <= 0 or msg_len > 1024 * 1024
): # Sanity check on message size
logger.warning(
f"Invalid message length from {peer_id}: {msg_len}"
)
return False
logger.debug(
f"Receiving response from {peer_id}, size: {msg_len} bytes"
)
# Read full message
response_bytes = await stream.read(msg_len)
if not response_bytes:
logger.warning(f"Failed to read response from {peer_id}")
return False
# Parse protobuf response
response = Message()
try:
response.ParseFromString(response_bytes)
except Exception as e:
logger.warning(
f"Failed to parse protobuf response from {peer_id}: {e}"
)
return False
if response.type == Message.PING:
# Update the last seen timestamp for this peer
logger.debug(f"Successfully pinged peer {peer_id}")
result = True
return result
else:
logger.warning(
f"Unexpected response type from {peer_id}: {response.type}"
)
return False
# If we get here, the ping timed out
logger.warning(f"Ping to peer {peer_id} timed out")
return False
finally:
await stream.close()
return result
except Exception as e:
logger.error(f"Error pinging peer {peer_id}: {str(e)}")
return False
def refresh_peer_last_seen(self, peer_id: ID) -> bool:
"""
Update the last-seen timestamp for a peer in the bucket.
params: peer_id: The ID of the peer to refresh
Returns
-------
bool
True if the peer was found and refreshed, False otherwise
"""
if peer_id in self.peers:
# Get current peer info and update the timestamp
peer_info, _ = self.peers[peer_id]
current_time = time.time()
self.peers[peer_id] = (peer_info, current_time)
# Move to end of ordered dict to mark as most recently seen
self.peers.move_to_end(peer_id)
return True
return False
def key_in_range(self, key: bytes) -> bool:
"""
Check if a key is in the range of this bucket.
params: key: The key to check (bytes)
Returns
-------
bool
True if the key is in range, False otherwise
"""
key_int = int.from_bytes(key, byteorder="big")
return self.min_range <= key_int < self.max_range
def split(self) -> tuple["KBucket", "KBucket"]:
"""
Split the bucket into two buckets.
Returns
-------
tuple
(lower_bucket, upper_bucket)
"""
midpoint = (self.min_range + self.max_range) // 2
lower_bucket = KBucket(self.host, self.bucket_size, self.min_range, midpoint)
upper_bucket = KBucket(self.host, self.bucket_size, midpoint, self.max_range)
# Redistribute peers
for peer_id, (peer_info, timestamp) in self.peers.items():
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
if peer_key < midpoint:
lower_bucket.peers[peer_id] = (peer_info, timestamp)
else:
upper_bucket.peers[peer_id] = (peer_info, timestamp)
return lower_bucket, upper_bucket
class RoutingTable:
"""
The Kademlia routing table maintains information on which peers to contact for any
given peer ID in the network.
"""
def __init__(self, local_id: ID, host: IHost) -> None:
"""
Initialize the routing table.
:param local_id: The ID of the local node.
:param host: The host this routing table belongs to.
"""
self.local_id = local_id
self.host = host
self.buckets = [KBucket(host, BUCKET_SIZE)]
async def add_peer(self, peer_obj: PeerInfo | ID) -> bool:
"""
Add a peer to the routing table.
:param peer_obj: Either PeerInfo object or peer ID to add
Returns
-------
bool: True if the peer was added or updated, False otherwise
"""
peer_id = None
peer_info = None
try:
# Handle different types of input
if isinstance(peer_obj, PeerInfo):
# Already have PeerInfo object
peer_info = peer_obj
peer_id = peer_obj.peer_id
else:
# Assume it's a peer ID
peer_id = peer_obj
# Try to get addresses from the peerstore if available
try:
addrs = self.host.get_peerstore().addrs(peer_id)
if addrs:
# Create PeerInfo object
peer_info = PeerInfo(peer_id, addrs)
else:
logger.debug(
"No addresses found for peer %s in peerstore, skipping",
peer_id,
)
return False
except Exception as peerstore_error:
# Handle case where peer is not in peerstore yet
logger.debug(
"Peer %s not found in peerstore: %s, skipping",
peer_id,
str(peerstore_error),
)
return False
# Don't add ourselves
if peer_id == self.local_id:
return False
# Find the right bucket for this peer
bucket = self.find_bucket(peer_id)
# Try to add to the bucket
success = await bucket.add_peer(peer_info)
if success:
logger.debug(f"Successfully added peer {peer_id} to routing table")
return success
except Exception as e:
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
return False
def remove_peer(self, peer_id: ID) -> bool:
"""
Remove a peer from the routing table.
:param peer_id: The ID of the peer to remove
Returns
-------
bool: True if the peer was removed, False otherwise
"""
bucket = self.find_bucket(peer_id)
return bucket.remove_peer(peer_id)
def find_bucket(self, peer_id: ID) -> KBucket:
"""
Find the bucket that would contain the given peer ID or PeerInfo.
:param peer_obj: Either a peer ID or a PeerInfo object
Returns
-------
KBucket: The bucket for this peer
"""
for bucket in self.buckets:
if bucket.key_in_range(peer_id.to_bytes()):
return bucket
return self.buckets[0]
def find_local_closest_peers(self, key: bytes, count: int = 20) -> list[ID]:
"""
Find the closest peers to a given key.
:param key: The key to find closest peers to (bytes)
:param count: Maximum number of peers to return
Returns
-------
List[ID]: List of peer IDs closest to the key
"""
# Get all peers from all buckets
all_peers = []
for bucket in self.buckets:
all_peers.extend(bucket.peer_ids())
# Sort by XOR distance to the key
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
return all_peers[:count]
def get_peer_ids(self) -> list[ID]:
"""
Get all peer IDs in the routing table.
Returns
-------
:param List[ID]: List of all peer IDs
"""
peers = []
for bucket in self.buckets:
peers.extend(bucket.peer_ids())
return peers
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
"""
Get the peer info for a specific peer.
:param peer_id: The ID of the peer to get info for
Returns
-------
PeerInfo: The peer info, or None if not found
"""
bucket = self.find_bucket(peer_id)
return bucket.get_peer_info(peer_id)
def peer_in_table(self, peer_id: ID) -> bool:
"""
Check if a peer is in the routing table.
:param peer_id: The ID of the peer to check
Returns
-------
bool: True if the peer is in the routing table, False otherwise
"""
bucket = self.find_bucket(peer_id)
return bucket.has_peer(peer_id)
def size(self) -> int:
"""
Get the number of peers in the routing table.
Returns
-------
int: Number of peers
"""
count = 0
for bucket in self.buckets:
count += bucket.size()
return count
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
"""
Get all stale peers from all buckets
params: stale_threshold_seconds:
Time in seconds after which a peer is considered stale
Returns
-------
list[ID]
List of stale peer IDs
"""
stale_peers = []
for bucket in self.buckets:
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
return stale_peers
def cleanup_routing_table(self) -> None:
"""
Cleanup the routing table by removing all data.
This is useful for resetting the routing table during tests or reinitialization.
"""
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
logger.info("Routing table cleaned up, all data removed.")

117
libp2p/kad_dht/utils.py Normal file
View File

@ -0,0 +1,117 @@
"""
Utility functions for Kademlia DHT implementation.
"""
import base58
import multihash
from libp2p.peer.id import (
ID,
)
def create_key_from_binary(binary_data: bytes) -> bytes:
"""
Creates a key for the DHT by hashing binary data with SHA-256.
params: binary_data: The binary data to hash.
Returns
-------
bytes: The resulting key.
"""
return multihash.digest(binary_data, "sha2-256").digest
def xor_distance(key1: bytes, key2: bytes) -> int:
"""
Calculate the XOR distance between two keys.
params: key1: First key (bytes)
params: key2: Second key (bytes)
Returns
-------
int: The XOR distance between the keys
"""
# Ensure the inputs are bytes
if not isinstance(key1, bytes) or not isinstance(key2, bytes):
raise TypeError("Both key1 and key2 must be bytes objects")
# Convert to integers
k1 = int.from_bytes(key1, byteorder="big")
k2 = int.from_bytes(key2, byteorder="big")
# Calculate XOR distance
return k1 ^ k2
def bytes_to_base58(data: bytes) -> str:
"""
Convert bytes to base58 encoded string.
params: data: Input bytes
Returns
-------
str: Base58 encoded string
"""
return base58.b58encode(data).decode("utf-8")
def sort_peer_ids_by_distance(target_key: bytes, peer_ids: list[ID]) -> list[ID]:
"""
Sort a list of peer IDs by their distance to the target key.
params: target_key: The target key to measure distance from
params: peer_ids: List of peer IDs to sort
Returns
-------
List[ID]: Sorted list of peer IDs from closest to furthest
"""
def get_distance(peer_id: ID) -> int:
# Hash the peer ID bytes to get a key for distance calculation
peer_hash = multihash.digest(peer_id.to_bytes(), "sha2-256").digest
return xor_distance(target_key, peer_hash)
return sorted(peer_ids, key=get_distance)
def shared_prefix_len(first: bytes, second: bytes) -> int:
"""
Calculate the number of prefix bits shared by two byte sequences.
params: first: First byte sequence
params: second: Second byte sequence
Returns
-------
int: Number of shared prefix bits
"""
# Compare each byte to find the first bit difference
common_length = 0
for i in range(min(len(first), len(second))):
byte_first = first[i]
byte_second = second[i]
if byte_first == byte_second:
common_length += 8
else:
# Find specific bit where they differ
xor = byte_first ^ byte_second
# Count leading zeros in the xor result
for j in range(7, -1, -1):
if (xor >> j) & 1 == 1:
return common_length + (7 - j)
# This shouldn't be reached if xor != 0
return common_length + 8
return common_length

View File

@ -0,0 +1,393 @@
"""
Value store implementation for Kademlia DHT.
Provides a way to store and retrieve key-value pairs with optional expiration.
"""
import logging
import time
import varint
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from .pb.kademlia_pb2 import (
Message,
)
# logger = logging.getLogger("libp2p.kademlia.value_store")
logger = logging.getLogger("kademlia-example.value_store")
# Default time to live for values in seconds (24 hours)
DEFAULT_TTL = 24 * 60 * 60
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
class ValueStore:
"""
Store for key-value pairs in a Kademlia DHT.
Values are stored with a timestamp and optional expiration time.
"""
def __init__(self, host: IHost, local_peer_id: ID):
"""
Initialize an empty value store.
:param host: The libp2p host instance.
:param local_peer_id: The local peer ID to ignore in peer requests.
"""
# Store format: {key: (value, validity)}
self.store: dict[bytes, tuple[bytes, float]] = {}
# Store references to the host and local peer ID for making requests
self.host = host
self.local_peer_id = local_peer_id
def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None:
"""
Store a value in the DHT.
:param key: The key to store the value under
:param value: The value to store
:param validity: validity in seconds before the value expires.
Defaults to `DEFAULT_TTL` if set to 0.0.
Returns
-------
None
"""
if validity == 0.0:
validity = time.time() + DEFAULT_TTL
logger.debug(
"Storing value for key %s... with validity %s", key.hex(), validity
)
self.store[key] = (value, validity)
logger.debug(f"Stored value for key {key.hex()}")
async def _store_at_peer(self, peer_id: ID, key: bytes, value: bytes) -> bool:
"""
Store a value at a specific peer.
params: peer_id: The ID of the peer to store the value at
params: key: The key to store
params: value: The value to store
Returns
-------
bool
True if the value was successfully stored, False otherwise
"""
result = False
stream = None
try:
# Don't try to store at ourselves
if self.local_peer_id and peer_id == self.local_peer_id:
result = True
return result
if not self.host:
logger.error("Host not initialized, cannot store value at peer")
return False
logger.debug(f"Storing value for key {key.hex()} at peer {peer_id}")
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
logger.debug(f"Opened stream to peer {peer_id}")
# Create the PUT_VALUE message with protobuf
message = Message()
message.type = Message.MessageType.PUT_VALUE
# Set message fields
message.key = key
message.record.key = key
message.record.value = value
message.record.timeReceived = str(time.time())
# Serialize and send the protobuf message with length prefix
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
logger.debug("Sent PUT_VALUE protobuf message with varint length")
# Read varint-prefixed response length
length_bytes = b""
while True:
logger.debug("Reading varint length prefix for response...")
b = await stream.read(1)
if not b:
logger.warning("Connection closed while reading varint length")
return False
length_bytes += b
if b[0] & 0x80 == 0:
break
logger.debug(f"Received varint length bytes: {length_bytes.hex()}")
response_length = varint.decode_bytes(length_bytes)
logger.debug("Response length: %d bytes", response_length)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
logger.debug(
f"Connection closed by peer {peer_id} while reading data"
)
return False
response_bytes += chunk
remaining -= len(chunk)
# Parse protobuf response
response = Message()
response.ParseFromString(response_bytes)
# Check if response is valid
if response.type == Message.MessageType.PUT_VALUE:
if response.key:
result = True
return result
except Exception as e:
logger.warning(f"Failed to store value at peer {peer_id}: {e}")
return False
finally:
if stream:
await stream.close()
return result
def get(self, key: bytes) -> bytes | None:
"""
Retrieve a value from the DHT.
params: key: The key to look up
Returns
-------
Optional[bytes]
The stored value, or None if not found or expired
"""
logger.debug("Retrieving value for key %s...", key.hex()[:8])
if key not in self.store:
return None
value, validity = self.store[key]
logger.debug(
"Found value for key %s... with validity %s",
key.hex(),
validity,
)
# Check if the value has expired
if validity is not None and validity < time.time():
logger.debug(
"Value for key %s... has expired, removing it",
key.hex()[:8],
)
self.remove(key)
return None
return value
async def _get_from_peer(self, peer_id: ID, key: bytes) -> bytes | None:
"""
Retrieve a value from a specific peer.
params: peer_id: The ID of the peer to retrieve the value from
params: key: The key to retrieve
Returns
-------
Optional[bytes]
The value if found, None otherwise
"""
stream = None
try:
# Don't try to get from ourselves
if peer_id == self.local_peer_id:
return None
logger.debug(f"Getting value for key {key.hex()} from peer {peer_id}")
# Open a stream to the peer
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
logger.debug(f"Opened stream to peer {peer_id} for GET_VALUE")
# Create the GET_VALUE message using protobuf
message = Message()
message.type = Message.MessageType.GET_VALUE
message.key = key
# Serialize and send the protobuf message
proto_bytes = message.SerializeToString()
await stream.write(varint.encode(len(proto_bytes)))
await stream.write(proto_bytes)
# Read response length
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
logger.warning("Connection closed while reading length")
return None
length_bytes += b
if b[0] & 0x80 == 0:
break
response_length = varint.decode_bytes(length_bytes)
# Read response data
response_bytes = b""
remaining = response_length
while remaining > 0:
chunk = await stream.read(remaining)
if not chunk:
logger.debug(
f"Connection closed by peer {peer_id} while reading data"
)
return None
response_bytes += chunk
remaining -= len(chunk)
# Parse protobuf response
try:
response = Message()
response.ParseFromString(response_bytes)
logger.debug(
f"Received protobuf response from peer"
f" {peer_id}, type: {response.type}"
)
# Process protobuf response
if (
response.type == Message.MessageType.GET_VALUE
and response.HasField("record")
and response.record.value
):
logger.debug(
f"Received value for key {key.hex()} from peer {peer_id}"
)
return response.record.value
# Handle case where value is not found but peer infos are returned
else:
logger.debug(
f"Value not found for key {key.hex()} from peer {peer_id},"
f" received {len(response.closerPeers)} closer peers"
)
return None
except Exception as proto_err:
logger.warning(f"Failed to parse as protobuf: {proto_err}")
return None
except Exception as e:
logger.warning(f"Failed to get value from peer {peer_id}: {e}")
return None
finally:
if stream:
await stream.close()
def remove(self, key: bytes) -> bool:
"""
Remove a value from the DHT.
params: key: The key to remove
Returns
-------
bool
True if the key was found and removed, False otherwise
"""
if key in self.store:
del self.store[key]
logger.debug(f"Removed value for key {key.hex()[:8]}...")
return True
return False
def has(self, key: bytes) -> bool:
"""
Check if a key exists in the store and hasn't expired.
params: key: The key to check
Returns
-------
bool
True if the key exists and hasn't expired, False otherwise
"""
if key not in self.store:
return False
_, validity = self.store[key]
if validity is not None and time.time() > validity:
self.remove(key)
return False
return True
def cleanup_expired(self) -> int:
"""
Remove all expired values from the store.
Returns
-------
int
The number of expired values that were removed
"""
current_time = time.time()
expired_keys = [
key for key, (_, validity) in self.store.items() if current_time > validity
]
for key in expired_keys:
del self.store[key]
if expired_keys:
logger.debug(f"Cleaned up {len(expired_keys)} expired values")
return len(expired_keys)
def get_keys(self) -> list[bytes]:
"""
Get all non-expired keys in the store.
Returns
-------
list[bytes]
List of keys
"""
# Clean up expired values first
self.cleanup_expired()
return list(self.store.keys())
def size(self) -> int:
"""
Get the number of items in the store (after removing expired entries).
Returns
-------
int
Number of items
"""
self.cleanup_expired()
return len(self.store)

View File

@ -0,0 +1 @@
Added support for ``Kademlia DHT`` in py-libp2p.

View File

@ -0,0 +1,168 @@
"""
Tests for the Kademlia DHT implementation.
This module tests core functionality of the Kademlia DHT including:
- Node discovery (find_node)
- Value storage and retrieval (put_value, get_value)
- Content provider advertisement and discovery (provide, find_providers)
"""
import hashlib
import logging
import uuid
import pytest
import trio
from libp2p.kad_dht.kad_dht import (
DHTMode,
KadDHT,
)
from libp2p.kad_dht.utils import (
create_key_from_binary,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from tests.utils.factories import (
host_pair_factory,
)
# Configure logger
logger = logging.getLogger("test.kad_dht")
# Constants for the tests
TEST_TIMEOUT = 5 # Timeout in seconds
@pytest.fixture
async def dht_pair(security_protocol):
"""Create a pair of connected DHT nodes for testing."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Get peer info for bootstrapping
peer_b_info = PeerInfo(host_b.get_id(), host_b.get_addrs())
peer_a_info = PeerInfo(host_a.get_id(), host_a.get_addrs())
# Create DHT nodes from the hosts with bootstrap peers as multiaddr strings
dht_a: KadDHT = KadDHT(host_a, mode=DHTMode.SERVER)
dht_b: KadDHT = KadDHT(host_b, mode=DHTMode.SERVER)
await dht_a.peer_routing.routing_table.add_peer(peer_b_info)
await dht_b.peer_routing.routing_table.add_peer(peer_a_info)
# Start both DHT services
async with background_trio_service(dht_a), background_trio_service(dht_b):
# Allow time for bootstrap to complete and connections to establish
await trio.sleep(0.1)
logger.debug(
"After bootstrap: Node A peers: %s", dht_a.routing_table.get_peer_ids()
)
logger.debug(
"After bootstrap: Node B peers: %s", dht_b.routing_table.get_peer_ids()
)
# Return the DHT pair
yield (dht_a, dht_b)
@pytest.mark.trio
async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
"""Test that nodes can find each other in the DHT."""
dht_a, dht_b = dht_pair
# Node A should be able to find Node B
with trio.fail_after(TEST_TIMEOUT):
found_info = await dht_a.find_peer(dht_b.host.get_id())
# Verify that the found peer has the correct peer ID
assert found_info is not None, "Failed to find the target peer"
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
@pytest.mark.trio
async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
"""Test storing and retrieving values in the DHT."""
dht_a, dht_b = dht_pair
# dht_a.peer_routing.routing_table.add_peer(dht_b.pe)
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
# Generate a random key and value
key = create_key_from_binary(b"test-key")
value = b"test-value"
# First add the value directly to node A's store to verify storage works
dht_a.value_store.put(key, value)
logger.debug("Local value store: %s", dht_a.value_store.store)
local_value = dht_a.value_store.get(key)
assert local_value == value, "Local value storage failed"
print("number of nodes in peer store", dht_a.host.get_peerstore().peer_ids())
await dht_a.routing_table.add_peer(peer_b_info)
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
# Store the value using the first node (this will also store locally)
with trio.fail_after(TEST_TIMEOUT):
await dht_a.put_value(key, value)
# # Log debugging information
logger.debug("Put value with key %s...", key.hex()[:10])
logger.debug("Node A value store: %s", dht_a.value_store.store)
print("hello test")
# # Allow more time for the value to propagate
await trio.sleep(0.5)
# # Try direct connection between nodes to ensure they're properly linked
logger.debug("Node A peers: %s", dht_a.routing_table.get_peer_ids())
logger.debug("Node B peers: %s", dht_b.routing_table.get_peer_ids())
# Retrieve the value using the second node
with trio.fail_after(TEST_TIMEOUT):
retrieved_value = await dht_b.get_value(key)
print("the value stored in node b is", dht_b.get_value_store_size())
logger.debug("Retrieved value: %s", retrieved_value)
# Verify that the retrieved value matches the original
assert retrieved_value == value, "Retrieved value does not match the stored value"
@pytest.mark.trio
async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
"""Test advertising and finding content providers."""
dht_a, dht_b = dht_pair
# Generate a random content ID
content = f"test-content-{uuid.uuid4()}".encode()
content_id = hashlib.sha256(content).digest()
# Store content on the first node
dht_a.value_store.put(content_id, content)
# Advertise the first node as a provider
with trio.fail_after(TEST_TIMEOUT):
success = await dht_a.provide(content_id)
assert success, "Failed to advertise as provider"
# Allow time for the provider record to propagate
await trio.sleep(0.1)
# Find providers using the second node
with trio.fail_after(TEST_TIMEOUT):
providers = await dht_b.find_providers(content_id)
# Verify that we found the first node as a provider
assert providers, "No providers found"
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
"Expected provider not found"
)
# Retrieve the content using the provider information
with trio.fail_after(TEST_TIMEOUT):
retrieved_value = await dht_b.get_value(content_id)
assert retrieved_value == content, (
"Retrieved content does not match the original"
)

View File

@ -0,0 +1,459 @@
"""
Unit tests for the PeerRouting class in Kademlia DHT.
This module tests the core functionality of peer routing including:
- Peer discovery and lookup
- Network queries for closest peers
- Protocol message handling
- Error handling and edge cases
"""
import time
from unittest.mock import (
AsyncMock,
Mock,
patch,
)
import pytest
from multiaddr import (
Multiaddr,
)
import varint
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.kad_dht.pb.kademlia_pb2 import (
Message,
)
from libp2p.kad_dht.peer_routing import (
ALPHA,
MAX_PEER_LOOKUP_ROUNDS,
PROTOCOL_ID,
PeerRouting,
)
from libp2p.kad_dht.routing_table import (
RoutingTable,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
def create_valid_peer_id(name: str) -> ID:
"""Create a valid peer ID for testing."""
key_pair = create_new_key_pair()
return ID.from_pubkey(key_pair.public_key)
class TestPeerRouting:
"""Test suite for PeerRouting class."""
@pytest.fixture
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_id.return_value = create_valid_peer_id("local")
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()
host.connect = AsyncMock()
return host
@pytest.fixture
def mock_routing_table(self, mock_host):
"""Create a mock routing table for testing."""
local_id = create_valid_peer_id("local")
routing_table = RoutingTable(local_id, mock_host)
return routing_table
@pytest.fixture
def peer_routing(self, mock_host, mock_routing_table):
"""Create a PeerRouting instance for testing."""
return PeerRouting(mock_host, mock_routing_table)
@pytest.fixture
def sample_peer_info(self):
"""Create sample peer info for testing."""
peer_id = create_valid_peer_id("sample")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
return PeerInfo(peer_id, addresses)
def test_init_peer_routing(self, mock_host, mock_routing_table):
"""Test PeerRouting initialization."""
peer_routing = PeerRouting(mock_host, mock_routing_table)
assert peer_routing.host == mock_host
assert peer_routing.routing_table == mock_routing_table
assert peer_routing.protocol_id == PROTOCOL_ID
@pytest.mark.trio
async def test_find_peer_local_host(self, peer_routing, mock_host):
"""Test finding our own peer."""
local_id = mock_host.get_id()
result = await peer_routing.find_peer(local_id)
assert result is not None
assert result.peer_id == local_id
assert result.addrs == mock_host.get_addrs()
@pytest.mark.trio
async def test_find_peer_in_routing_table(self, peer_routing, sample_peer_info):
"""Test finding peer that exists in routing table."""
# Add peer to routing table
await peer_routing.routing_table.add_peer(sample_peer_info)
result = await peer_routing.find_peer(sample_peer_info.peer_id)
assert result is not None
assert result.peer_id == sample_peer_info.peer_id
@pytest.mark.trio
async def test_find_peer_in_peerstore(self, peer_routing, mock_host):
"""Test finding peer that exists in peerstore."""
peer_id = create_valid_peer_id("peerstore")
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8002")]
# Mock peerstore to return addresses
mock_host.get_peerstore().addrs.return_value = mock_addrs
result = await peer_routing.find_peer(peer_id)
assert result is not None
assert result.peer_id == peer_id
assert result.addrs == mock_addrs
@pytest.mark.trio
async def test_find_peer_not_found(self, peer_routing, mock_host):
"""Test finding peer that doesn't exist anywhere."""
peer_id = create_valid_peer_id("nonexistent")
# Mock peerstore to return no addresses
mock_host.get_peerstore().addrs.return_value = []
# Mock network search to return empty results
with patch.object(peer_routing, "find_closest_peers_network", return_value=[]):
result = await peer_routing.find_peer(peer_id)
assert result is None
@pytest.mark.trio
async def test_find_closest_peers_network_empty_start(self, peer_routing):
"""Test network search with no local peers."""
target_key = b"target_key"
# Mock routing table to return empty list
with patch.object(
peer_routing.routing_table, "find_local_closest_peers", return_value=[]
):
result = await peer_routing.find_closest_peers_network(target_key)
assert result == []
@pytest.mark.trio
async def test_find_closest_peers_network_with_peers(self, peer_routing, mock_host):
"""Test network search with some initial peers."""
target_key = b"target_key"
# Create some test peers
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(3)]
# Mock routing table to return initial peers
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=initial_peers,
):
# Mock _query_peer_for_closest to return empty results (no new peers found)
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
result = await peer_routing.find_closest_peers_network(
target_key, count=5
)
assert len(result) <= 5
# Should return the initial peers since no new ones were discovered
assert all(peer in initial_peers for peer in result)
@pytest.mark.trio
async def test_find_closest_peers_convergence(self, peer_routing):
"""Test that network search converges properly."""
target_key = b"target_key"
# Create test peers
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(2)]
# Mock to simulate convergence (no improvement in closest peers)
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=initial_peers,
):
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
with patch(
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance",
return_value=initial_peers,
):
result = await peer_routing.find_closest_peers_network(target_key)
assert result == initial_peers
@pytest.mark.trio
async def test_query_peer_for_closest_success(
self, peer_routing, mock_host, sample_peer_info
):
"""Test successful peer query for closest peers."""
target_key = b"target_key"
# Create mock stream
mock_stream = AsyncMock()
mock_host.new_stream.return_value = mock_stream
# Create mock response
response_msg = Message()
response_msg.type = Message.MessageType.FIND_NODE
# Add a peer to the response
peer_proto = response_msg.closerPeers.add()
response_peer_id = create_valid_peer_id("response_peer")
peer_proto.id = response_peer_id.to_bytes()
peer_proto.addrs.append(Multiaddr("/ip4/127.0.0.1/tcp/8003").to_bytes())
response_bytes = response_msg.SerializeToString()
# Mock stream reading
varint_length = varint.encode(len(response_bytes))
mock_stream.read.side_effect = [varint_length, response_bytes]
# Mock peerstore
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
mock_host.get_peerstore().add_addrs = Mock()
result = await peer_routing._query_peer_for_closest(
sample_peer_info.peer_id, target_key
)
assert len(result) == 1
assert result[0] == response_peer_id
mock_stream.write.assert_called()
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_query_peer_for_closest_stream_failure(self, peer_routing, mock_host):
"""Test peer query when stream creation fails."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
# Mock stream creation failure
mock_host.new_stream.side_effect = Exception("Stream failed")
mock_host.get_peerstore().addrs.return_value = []
result = await peer_routing._query_peer_for_closest(peer_id, target_key)
assert result == []
@pytest.mark.trio
async def test_query_peer_for_closest_read_failure(
self, peer_routing, mock_host, sample_peer_info
):
"""Test peer query when reading response fails."""
target_key = b"target_key"
# Create mock stream that fails to read
mock_stream = AsyncMock()
mock_stream.read.side_effect = [b""] # Empty read simulates connection close
mock_host.new_stream.return_value = mock_stream
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
result = await peer_routing._query_peer_for_closest(
sample_peer_info.peer_id, target_key
)
assert result == []
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_refresh_routing_table(self, peer_routing, mock_host):
"""Test routing table refresh."""
local_id = mock_host.get_id()
discovered_peers = [create_valid_peer_id(f"discovered{i}") for i in range(3)]
# Mock find_closest_peers_network to return discovered peers
with patch.object(
peer_routing, "find_closest_peers_network", return_value=discovered_peers
):
# Mock peerstore to return addresses for discovered peers
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8003")]
mock_host.get_peerstore().addrs.return_value = mock_addrs
await peer_routing.refresh_routing_table()
# Should perform lookup for local ID
peer_routing.find_closest_peers_network.assert_called_once_with(
local_id.to_bytes()
)
@pytest.mark.trio
async def test_handle_kad_stream_find_node(self, peer_routing, mock_host):
"""Test handling incoming FIND_NODE requests."""
# Create mock stream
mock_stream = AsyncMock()
# Create FIND_NODE request
request_msg = Message()
request_msg.type = Message.MessageType.FIND_NODE
request_msg.key = b"target_key"
request_bytes = request_msg.SerializeToString()
# Mock stream reading
mock_stream.read.side_effect = [
len(request_bytes).to_bytes(4, byteorder="big"),
request_bytes,
]
# Mock routing table to return some peers
closest_peers = [create_valid_peer_id(f"close{i}") for i in range(2)]
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=closest_peers,
):
mock_host.get_peerstore().addrs.return_value = [
Multiaddr("/ip4/127.0.0.1/tcp/8004")
]
await peer_routing._handle_kad_stream(mock_stream)
# Should write response
mock_stream.write.assert_called()
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_handle_kad_stream_invalid_message(self, peer_routing):
"""Test handling stream with invalid message."""
mock_stream = AsyncMock()
# Mock stream to return invalid data
mock_stream.read.side_effect = [
(10).to_bytes(4, byteorder="big"),
b"invalid_proto_data",
]
# Should handle gracefully without raising exception
await peer_routing._handle_kad_stream(mock_stream)
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_handle_kad_stream_connection_closed(self, peer_routing):
"""Test handling stream when connection is closed early."""
mock_stream = AsyncMock()
# Mock stream to return empty data (connection closed)
mock_stream.read.return_value = b""
await peer_routing._handle_kad_stream(mock_stream)
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_query_single_peer_for_closest_success(self, peer_routing):
"""Test _query_single_peer_for_closest method."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
new_peers = []
# Mock successful query
mock_result = [create_valid_peer_id("result1"), create_valid_peer_id("result2")]
with patch.object(
peer_routing, "_query_peer_for_closest", return_value=mock_result
):
await peer_routing._query_single_peer_for_closest(
peer_id, target_key, new_peers
)
assert len(new_peers) == 2
assert all(peer in new_peers for peer in mock_result)
@pytest.mark.trio
async def test_query_single_peer_for_closest_failure(self, peer_routing):
"""Test _query_single_peer_for_closest when query fails."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
new_peers = []
# Mock query failure
with patch.object(
peer_routing,
"_query_peer_for_closest",
side_effect=Exception("Query failed"),
):
await peer_routing._query_single_peer_for_closest(
peer_id, target_key, new_peers
)
# Should handle exception gracefully
assert len(new_peers) == 0
@pytest.mark.trio
async def test_query_single_peer_deduplication(self, peer_routing):
"""Test that _query_single_peer_for_closest deduplicates peers."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
duplicate_peer = create_valid_peer_id("duplicate")
new_peers = [duplicate_peer] # Pre-existing peer
# Mock query to return the same peer
mock_result = [duplicate_peer, create_valid_peer_id("new")]
with patch.object(
peer_routing, "_query_peer_for_closest", return_value=mock_result
):
await peer_routing._query_single_peer_for_closest(
peer_id, target_key, new_peers
)
# Should not add duplicate
assert len(new_peers) == 2 # Original + 1 new peer
assert new_peers.count(duplicate_peer) == 1
def test_constants(self):
"""Test that important constants are properly defined."""
assert ALPHA == 3
assert MAX_PEER_LOOKUP_ROUNDS == 20
assert PROTOCOL_ID == "/ipfs/kad/1.0.0"
@pytest.mark.trio
async def test_edge_case_max_rounds_reached(self, peer_routing):
"""Test that lookup stops after maximum rounds."""
target_key = b"target_key"
initial_peers = [create_valid_peer_id("peer1")]
# Mock to always return new peers to force max rounds
def mock_query_side_effect(peer, key):
return [create_valid_peer_id(f"new_peer_{time.time()}")]
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=initial_peers,
):
with patch.object(
peer_routing,
"_query_peer_for_closest",
side_effect=mock_query_side_effect,
):
with patch(
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance"
) as mock_sort:
# Always return different peers to prevent convergence
mock_sort.side_effect = lambda key, peers: peers[:20]
result = await peer_routing.find_closest_peers_network(target_key)
# Should stop after max rounds, not infinite loop
assert isinstance(result, list)

View File

@ -0,0 +1,805 @@
"""
Unit tests for the ProviderStore and ProviderRecord classes in Kademlia DHT.
This module tests the core functionality of provider record management including:
- ProviderRecord creation, expiration, and republish logic
- ProviderStore operations (add, get, cleanup)
- Expiration and TTL handling
- Network operations (mocked)
- Edge cases and error conditions
"""
import time
from unittest.mock import (
AsyncMock,
Mock,
patch,
)
import pytest
from multiaddr import (
Multiaddr,
)
from libp2p.kad_dht.provider_store import (
PROVIDER_ADDRESS_TTL,
PROVIDER_RECORD_EXPIRATION_INTERVAL,
PROVIDER_RECORD_REPUBLISH_INTERVAL,
ProviderRecord,
ProviderStore,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
mock_host = Mock()
class TestProviderRecord:
"""Test suite for ProviderRecord class."""
def test_init_with_default_timestamp(self):
"""Test ProviderRecord initialization with default timestamp."""
peer_id = ID.from_base58("QmTest123")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
peer_info = PeerInfo(peer_id, addresses)
start_time = time.time()
record = ProviderRecord(peer_info)
end_time = time.time()
assert record.provider_info == peer_info
assert start_time <= record.timestamp <= end_time
assert record.peer_id == peer_id
assert record.addresses == addresses
def test_init_with_custom_timestamp(self):
"""Test ProviderRecord initialization with custom timestamp."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
custom_timestamp = time.time() - 3600 # 1 hour ago
record = ProviderRecord(peer_info, timestamp=custom_timestamp)
assert record.timestamp == custom_timestamp
def test_is_expired_fresh_record(self):
"""Test that fresh records are not expired."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
record = ProviderRecord(peer_info)
assert not record.is_expired()
def test_is_expired_old_record(self):
"""Test that old records are expired."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
old_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
record = ProviderRecord(peer_info, timestamp=old_timestamp)
assert record.is_expired()
def test_is_expired_boundary_condition(self):
"""Test expiration at exact boundary."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
boundary_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
# At the exact boundary, should be expired (implementation uses >)
assert record.is_expired()
def test_should_republish_fresh_record(self):
"""Test that fresh records don't need republishing."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
record = ProviderRecord(peer_info)
assert not record.should_republish()
def test_should_republish_old_record(self):
"""Test that old records need republishing."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
old_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1
record = ProviderRecord(peer_info, timestamp=old_timestamp)
assert record.should_republish()
def test_should_republish_boundary_condition(self):
"""Test republish at exact boundary."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
boundary_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
# At the exact boundary, should need republishing (implementation uses >)
assert record.should_republish()
def test_properties(self):
"""Test peer_id and addresses properties."""
peer_id = ID.from_base58("QmTest123")
addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
Multiaddr("/ip6/::1/tcp/8001"),
]
peer_info = PeerInfo(peer_id, addresses)
record = ProviderRecord(peer_info)
assert record.peer_id == peer_id
assert record.addresses == addresses
def test_empty_addresses(self):
"""Test ProviderRecord with empty address list."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
record = ProviderRecord(peer_info)
assert record.addresses == []
class TestProviderStore:
"""Test suite for ProviderStore class."""
def test_init_empty_store(self):
"""Test that a new ProviderStore is initialized empty."""
store = ProviderStore(host=mock_host)
assert len(store.providers) == 0
assert store.peer_routing is None
assert len(store.providing_keys) == 0
def test_init_with_host(self):
"""Test initialization with host."""
mock_host = Mock()
mock_peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = mock_peer_id
store = ProviderStore(host=mock_host)
assert store.host == mock_host
assert store.local_peer_id == mock_peer_id
assert len(store.providers) == 0
def test_init_with_host_and_peer_routing(self):
"""Test initialization with both host and peer routing."""
mock_host = Mock()
mock_peer_routing = Mock()
mock_peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = mock_peer_id
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
assert store.host == mock_host
assert store.peer_routing == mock_peer_routing
assert store.local_peer_id == mock_peer_id
def test_add_provider_new_key(self):
"""Test adding a provider for a new key."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
provider = PeerInfo(peer_id, addresses)
store.add_provider(key, provider)
assert key in store.providers
assert str(peer_id) in store.providers[key]
record = store.providers[key][str(peer_id)]
assert record.provider_info == provider
assert isinstance(record.timestamp, float)
def test_add_provider_existing_key(self):
"""Test adding multiple providers for the same key."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add first provider
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key, provider1)
# Add second provider
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
store.add_provider(key, provider2)
assert len(store.providers[key]) == 2
assert str(peer_id1) in store.providers[key]
assert str(peer_id2) in store.providers[key]
def test_add_provider_update_existing(self):
"""Test updating an existing provider."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
# Add initial provider
provider1 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
store.add_provider(key, provider1)
first_timestamp = store.providers[key][str(peer_id)].timestamp
# Small delay to ensure timestamp difference
time.sleep(0.001)
# Update provider
provider2 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
store.add_provider(key, provider2)
# Should have same peer but updated info
assert len(store.providers[key]) == 1
assert str(peer_id) in store.providers[key]
record = store.providers[key][str(peer_id)]
assert record.provider_info == provider2
assert record.timestamp > first_timestamp
def test_get_providers_empty_key(self):
"""Test getting providers for non-existent key."""
store = ProviderStore(host=mock_host)
key = b"nonexistent_key"
providers = store.get_providers(key)
assert providers == []
def test_get_providers_valid_records(self):
"""Test getting providers with valid records."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add multiple providers
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
provider1 = PeerInfo(peer_id1, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
provider2 = PeerInfo(peer_id2, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
store.add_provider(key, provider1)
store.add_provider(key, provider2)
providers = store.get_providers(key)
assert len(providers) == 2
provider_ids = {p.peer_id for p in providers}
assert peer_id1 in provider_ids
assert peer_id2 in provider_ids
def test_get_providers_expired_records(self):
"""Test that expired records are filtered out and cleaned up."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add valid provider
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key, provider1)
# Add expired provider manually
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key][str(peer_id2)] = ProviderRecord(
provider2, expired_timestamp
)
providers = store.get_providers(key)
# Should only return valid provider
assert len(providers) == 1
assert providers[0].peer_id == peer_id1
# Expired provider should be cleaned up
assert str(peer_id2) not in store.providers[key]
def test_get_providers_address_ttl(self):
"""Test address TTL handling in get_providers."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
provider = PeerInfo(peer_id, addresses)
# Add provider with old timestamp (addresses expired but record valid)
old_timestamp = time.time() - PROVIDER_ADDRESS_TTL - 1
store.providers[key] = {str(peer_id): ProviderRecord(provider, old_timestamp)}
providers = store.get_providers(key)
# Should return provider but with empty addresses
assert len(providers) == 1
assert providers[0].peer_id == peer_id
assert providers[0].addrs == []
def test_get_providers_cleanup_empty_key(self):
"""Test that keys with no valid providers are removed."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add only expired providers
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key] = {
str(peer_id): ProviderRecord(provider, expired_timestamp)
}
providers = store.get_providers(key)
assert providers == []
assert key not in store.providers # Key should be removed
def test_cleanup_expired_no_expired_records(self):
"""Test cleanup when there are no expired records."""
store = ProviderStore(host=mock_host)
key1 = b"key1"
key2 = b"key2"
# Add valid providers
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
provider1 = PeerInfo(peer_id1, [])
provider2 = PeerInfo(peer_id2, [])
store.add_provider(key1, provider1)
store.add_provider(key2, provider2)
initial_size = store.size()
store.cleanup_expired()
assert store.size() == initial_size
assert key1 in store.providers
assert key2 in store.providers
def test_cleanup_expired_with_expired_records(self):
"""Test cleanup removes expired records."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add valid provider
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key, provider1)
# Add expired provider
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key][str(peer_id2)] = ProviderRecord(
provider2, expired_timestamp
)
assert store.size() == 2
store.cleanup_expired()
assert store.size() == 1
assert str(peer_id1) in store.providers[key]
assert str(peer_id2) not in store.providers[key]
def test_cleanup_expired_remove_empty_keys(self):
"""Test that keys with only expired providers are removed."""
store = ProviderStore(host=mock_host)
key1 = b"key1"
key2 = b"key2"
# Add valid provider to key1
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key1, provider1)
# Add only expired provider to key2
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key2] = {
str(peer_id2): ProviderRecord(provider2, expired_timestamp)
}
store.cleanup_expired()
assert key1 in store.providers
assert key2 not in store.providers
def test_get_provided_keys_empty_store(self):
"""Test get_provided_keys with empty store."""
store = ProviderStore(host=mock_host)
peer_id = ID.from_base58("QmTest123")
keys = store.get_provided_keys(peer_id)
assert keys == []
def test_get_provided_keys_single_peer(self):
"""Test get_provided_keys for a specific peer."""
store = ProviderStore(host=mock_host)
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
key1 = b"key1"
key2 = b"key2"
key3 = b"key3"
provider1 = PeerInfo(peer_id1, [])
provider2 = PeerInfo(peer_id2, [])
# peer_id1 provides key1 and key2
store.add_provider(key1, provider1)
store.add_provider(key2, provider1)
# peer_id2 provides key2 and key3
store.add_provider(key2, provider2)
store.add_provider(key3, provider2)
keys1 = store.get_provided_keys(peer_id1)
keys2 = store.get_provided_keys(peer_id2)
assert len(keys1) == 2
assert key1 in keys1
assert key2 in keys1
assert len(keys2) == 2
assert key2 in keys2
assert key3 in keys2
def test_get_provided_keys_nonexistent_peer(self):
"""Test get_provided_keys for peer that provides nothing."""
store = ProviderStore(host=mock_host)
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
# Add provider for peer_id1 only
key = b"key"
provider = PeerInfo(peer_id1, [])
store.add_provider(key, provider)
# Query for peer_id2 (provides nothing)
keys = store.get_provided_keys(peer_id2)
assert keys == []
def test_size_empty_store(self):
"""Test size() with empty store."""
store = ProviderStore(host=mock_host)
assert store.size() == 0
def test_size_with_providers(self):
"""Test size() with multiple providers."""
store = ProviderStore(host=mock_host)
# Add providers
key1 = b"key1"
key2 = b"key2"
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
peer_id3 = ID.from_base58("QmTest789")
provider1 = PeerInfo(peer_id1, [])
provider2 = PeerInfo(peer_id2, [])
provider3 = PeerInfo(peer_id3, [])
store.add_provider(key1, provider1)
store.add_provider(key1, provider2) # 2 providers for key1
store.add_provider(key2, provider3) # 1 provider for key2
assert store.size() == 3
@pytest.mark.trio
async def test_provide_no_host(self):
"""Test provide() returns False when no host is configured."""
store = ProviderStore(host=mock_host)
key = b"test_key"
result = await store.provide(key)
assert result is False
@pytest.mark.trio
async def test_provide_no_peer_routing(self):
"""Test provide() returns False when no peer routing is configured."""
mock_host = Mock()
store = ProviderStore(host=mock_host)
key = b"test_key"
result = await store.provide(key)
assert result is False
@pytest.mark.trio
async def test_provide_success(self):
"""Test successful provide operation."""
# Setup mocks
mock_host = Mock()
mock_peer_routing = AsyncMock()
peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = peer_id
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
# Mock finding closest peers
closest_peers = [ID.from_base58("QmPeer1"), ID.from_base58("QmPeer2")]
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
# Mock _send_add_provider to return success
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
key = b"test_key"
result = await store.provide(key)
assert result is True
assert key in store.providing_keys
assert key in store.providers
# Should have called _send_add_provider for each peer
assert mock_send.call_count == len(closest_peers)
@pytest.mark.trio
async def test_provide_skip_local_peer(self):
"""Test that provide() skips sending to local peer."""
# Setup mocks
mock_host = Mock()
mock_peer_routing = AsyncMock()
peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = peer_id
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
# Include local peer in closest peers
closest_peers = [peer_id, ID.from_base58("QmPeer1")]
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
key = b"test_key"
result = await store.provide(key)
assert result is True
# Should only call _send_add_provider once (skip local peer)
assert mock_send.call_count == 1
@pytest.mark.trio
async def test_find_providers_no_host(self):
"""Test find_providers() returns empty list when no host."""
store = ProviderStore(host=mock_host)
key = b"test_key"
result = await store.find_providers(key)
assert result == []
@pytest.mark.trio
async def test_find_providers_local_only(self):
"""Test find_providers() returns local providers."""
mock_host = Mock()
mock_peer_routing = Mock()
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
# Add local providers
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
result = await store.find_providers(key)
assert len(result) == 1
assert result[0].peer_id == peer_id
@pytest.mark.trio
async def test_find_providers_network_search(self):
"""Test find_providers() searches network when no local providers."""
mock_host = Mock()
mock_peer_routing = AsyncMock()
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
# Mock network search
closest_peers = [ID.from_base58("QmPeer1")]
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
# Mock provider response
remote_peer_id = ID.from_base58("QmRemote123")
remote_providers = [PeerInfo(remote_peer_id, [])]
with patch.object(
store, "_get_providers_from_peer", return_value=remote_providers
):
key = b"test_key"
result = await store.find_providers(key)
assert len(result) == 1
assert result[0].peer_id == remote_peer_id
@pytest.mark.trio
async def test_get_providers_from_peer_no_host(self):
"""Test _get_providers_from_peer without host."""
store = ProviderStore(host=mock_host)
peer_id = ID.from_base58("QmTest123")
key = b"test_key"
# Should handle missing host gracefully
result = await store._get_providers_from_peer(peer_id, key)
assert result == []
def test_edge_case_empty_key(self):
"""Test handling of empty key."""
store = ProviderStore(host=mock_host)
key = b""
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
def test_edge_case_large_key(self):
"""Test handling of large key."""
store = ProviderStore(host=mock_host)
key = b"x" * 10000 # 10KB key
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
def test_concurrent_operations(self):
"""Test multiple concurrent operations."""
store = ProviderStore(host=mock_host)
# Add many providers
num_keys = 100
num_providers_per_key = 5
for i in range(num_keys):
_key = f"key_{i}".encode()
for j in range(num_providers_per_key):
# Generate unique valid Base58 peer IDs
# Use a different approach that ensures uniqueness
unique_id = i * num_providers_per_key + j + 1 # Ensure > 0
_peer_id_str = f"QmPeer{unique_id:06d}".replace("0", "A") + "1" * 38
peer_id = ID.from_base58(_peer_id_str)
provider = PeerInfo(peer_id, [])
store.add_provider(_key, provider)
# Verify total size
expected_size = num_keys * num_providers_per_key
assert store.size() == expected_size
# Verify individual keys
for i in range(num_keys):
_key = f"key_{i}".encode()
providers = store.get_providers(_key)
assert len(providers) == num_providers_per_key
def test_memory_efficiency_large_dataset(self):
"""Test memory behavior with large datasets."""
store = ProviderStore(host=mock_host)
# Add large number of providers
num_entries = 1000
for i in range(num_entries):
_key = f"key_{i:05d}".encode()
# Generate valid Base58 peer IDs (replace 0 with valid characters)
peer_str = f"QmPeer{i:05d}".replace("0", "1") + "1" * 35
peer_id = ID.from_base58(peer_str)
provider = PeerInfo(peer_id, [])
store.add_provider(_key, provider)
assert store.size() == num_entries
# Clean up all entries by making them expired
current_time = time.time()
for _key, providers in store.providers.items():
for _peer_id_str, record in providers.items():
record.timestamp = (
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
)
store.cleanup_expired()
assert store.size() == 0
assert len(store.providers) == 0
def test_unicode_key_handling(self):
"""Test handling of unicode content in keys."""
store = ProviderStore(host=mock_host)
# Test various unicode keys
unicode_keys = [
b"hello",
"héllo".encode(),
"🔑".encode(),
"ключ".encode(), # Russian
"".encode(), # Chinese
]
for i, key in enumerate(unicode_keys):
# Generate valid Base58 peer IDs
peer_id = ID.from_base58(f"QmPeer{i + 1}" + "1" * 42) # Valid base58
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
def test_multiple_addresses_per_provider(self):
"""Test providers with multiple addresses."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
Multiaddr("/ip6/::1/tcp/8001"),
Multiaddr("/ip4/192.168.1.100/tcp/8002"),
]
provider = PeerInfo(peer_id, addresses)
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
assert len(providers[0].addrs) == len(addresses)
assert all(addr in providers[0].addrs for addr in addresses)
@pytest.mark.trio
async def test_republish_provider_records_no_keys(self):
"""Test _republish_provider_records with no providing keys."""
store = ProviderStore(host=mock_host)
# Should complete without error even with no providing keys
await store._republish_provider_records()
assert len(store.providing_keys) == 0
def test_expiration_boundary_conditions(self):
"""Test expiration around boundary conditions."""
store = ProviderStore(host=mock_host)
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
current_time = time.time()
# Test records at various timestamps
timestamps = [
current_time, # Fresh
current_time - PROVIDER_ADDRESS_TTL + 1, # Addresses valid
current_time - PROVIDER_ADDRESS_TTL - 1, # Addresses expired
current_time
- PROVIDER_RECORD_REPUBLISH_INTERVAL
+ 1, # No republish needed
current_time - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1, # Republish needed
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL + 1, # Not expired
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1, # Expired
]
for i, timestamp in enumerate(timestamps):
test_key = f"key_{i}".encode()
record = ProviderRecord(provider, timestamp)
store.providers[test_key] = {str(peer_id): record}
# Test various operations
for i, timestamp in enumerate(timestamps):
test_key = f"key_{i}".encode()
providers = store.get_providers(test_key)
if timestamp <= current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL:
# Should be expired and removed
assert len(providers) == 0
assert test_key not in store.providers
else:
# Should be present
assert len(providers) == 1
assert providers[0].peer_id == peer_id

View File

@ -0,0 +1,371 @@
"""
Unit tests for the RoutingTable and KBucket classes in Kademlia DHT.
This module tests the core functionality of the routing table including:
- KBucket operations (add, remove, split, ping)
- RoutingTable management (peer addition, closest peer finding)
- Distance calculations and peer ordering
- Bucket splitting and range management
"""
import time
from unittest.mock import (
AsyncMock,
Mock,
patch,
)
import pytest
from multiaddr import (
Multiaddr,
)
import trio
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.kad_dht.routing_table import (
BUCKET_SIZE,
KBucket,
RoutingTable,
)
from libp2p.kad_dht.utils import (
create_key_from_binary,
xor_distance,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
def create_valid_peer_id(name: str) -> ID:
"""Create a valid peer ID for testing."""
# Use crypto to generate valid peer IDs
key_pair = create_new_key_pair()
return ID.from_pubkey(key_pair.public_key)
class TestKBucket:
"""Test suite for KBucket class."""
@pytest.fixture
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()
return host
@pytest.fixture
def sample_peer_info(self):
"""Create sample peer info for testing."""
peer_id = create_valid_peer_id("test")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
return PeerInfo(peer_id, addresses)
def test_init_default_parameters(self, mock_host):
"""Test KBucket initialization with default parameters."""
bucket = KBucket(mock_host)
assert bucket.bucket_size == BUCKET_SIZE
assert bucket.host == mock_host
assert bucket.min_range == 0
assert bucket.max_range == 2**256
assert len(bucket.peers) == 0
def test_peer_operations(self, mock_host, sample_peer_info):
"""Test basic peer operations: add, check, and remove."""
bucket = KBucket(mock_host)
# Test empty bucket
assert bucket.peer_ids() == []
assert bucket.size() == 0
assert not bucket.has_peer(sample_peer_info.peer_id)
# Add peer manually
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
# Test with peer
assert len(bucket.peer_ids()) == 1
assert sample_peer_info.peer_id in bucket.peer_ids()
assert bucket.size() == 1
assert bucket.has_peer(sample_peer_info.peer_id)
assert bucket.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
# Remove peer
result = bucket.remove_peer(sample_peer_info.peer_id)
assert result is True
assert bucket.size() == 0
assert not bucket.has_peer(sample_peer_info.peer_id)
@pytest.mark.trio
async def test_add_peer_functionality(self, mock_host):
"""Test add_peer method with different scenarios."""
bucket = KBucket(mock_host, bucket_size=2) # Small bucket for testing
# Add first peer
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
result = await bucket.add_peer(peer1)
assert result is True
assert bucket.size() == 1
# Add second peer
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
result = await bucket.add_peer(peer2)
assert result is True
assert bucket.size() == 2
# Add same peer again (should update timestamp)
await trio.sleep(0.001)
result = await bucket.add_peer(peer1)
assert result is True
assert bucket.size() == 2 # Still 2 peers
# Try to add third peer when bucket is full
peer3 = PeerInfo(create_valid_peer_id("peer3"), [])
with patch.object(bucket, "_ping_peer", return_value=True):
result = await bucket.add_peer(peer3)
assert result is False # Should fail if oldest peer responds
def test_get_oldest_peer(self, mock_host):
"""Test get_oldest_peer method."""
bucket = KBucket(mock_host)
# Empty bucket
assert bucket.get_oldest_peer() is None
# Add peers with different timestamps
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
current_time = time.time()
bucket.peers[peer1.peer_id] = (peer1, current_time - 300) # Older
bucket.peers[peer2.peer_id] = (peer2, current_time) # Newer
oldest = bucket.get_oldest_peer()
assert oldest == peer1.peer_id
def test_stale_peers(self, mock_host):
"""Test stale peer identification."""
bucket = KBucket(mock_host)
current_time = time.time()
fresh_peer = PeerInfo(create_valid_peer_id("fresh"), [])
stale_peer = PeerInfo(create_valid_peer_id("stale"), [])
bucket.peers[fresh_peer.peer_id] = (fresh_peer, current_time)
bucket.peers[stale_peer.peer_id] = (
stale_peer,
current_time - 7200,
) # 2 hours ago
stale_peers = bucket.get_stale_peers(3600) # 1 hour threshold
assert len(stale_peers) == 1
assert stale_peer.peer_id in stale_peers
def test_key_in_range(self, mock_host):
"""Test key_in_range method."""
bucket = KBucket(mock_host, min_range=100, max_range=200)
# Test keys within range
key_in_range = (150).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_in_range) is True
# Test keys outside range
key_below = (50).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_below) is False
key_above = (250).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_above) is False
# Test boundary conditions
key_min = (100).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_min) is True
key_max = (200).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_max) is False
def test_split_bucket(self, mock_host):
"""Test bucket splitting functionality."""
bucket = KBucket(mock_host, min_range=0, max_range=256)
lower_bucket, upper_bucket = bucket.split()
# Check ranges
assert lower_bucket.min_range == 0
assert lower_bucket.max_range == 128
assert upper_bucket.min_range == 128
assert upper_bucket.max_range == 256
# Check properties
assert lower_bucket.bucket_size == bucket.bucket_size
assert upper_bucket.bucket_size == bucket.bucket_size
assert lower_bucket.host == mock_host
assert upper_bucket.host == mock_host
@pytest.mark.trio
async def test_ping_peer_scenarios(self, mock_host, sample_peer_info):
"""Test different ping scenarios."""
bucket = KBucket(mock_host)
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
# Test ping peer not in bucket
other_peer_id = create_valid_peer_id("other")
with pytest.raises(ValueError, match="Peer .* not in bucket"):
await bucket._ping_peer(other_peer_id)
# Test ping failure due to stream error
mock_host.new_stream.side_effect = Exception("Stream failed")
result = await bucket._ping_peer(sample_peer_info.peer_id)
assert result is False
class TestRoutingTable:
"""Test suite for RoutingTable class."""
@pytest.fixture
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_peerstore.return_value = Mock()
return host
@pytest.fixture
def local_peer_id(self):
"""Create a local peer ID for testing."""
return create_valid_peer_id("local")
@pytest.fixture
def sample_peer_info(self):
"""Create sample peer info for testing."""
peer_id = create_valid_peer_id("sample")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
return PeerInfo(peer_id, addresses)
def test_init_routing_table(self, mock_host, local_peer_id):
"""Test RoutingTable initialization."""
routing_table = RoutingTable(local_peer_id, mock_host)
assert routing_table.local_id == local_peer_id
assert routing_table.host == mock_host
assert len(routing_table.buckets) == 1
assert isinstance(routing_table.buckets[0], KBucket)
@pytest.mark.trio
async def test_add_peer_operations(
self, mock_host, local_peer_id, sample_peer_info
):
"""Test adding peers to routing table."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Test adding peer with PeerInfo
result = await routing_table.add_peer(sample_peer_info)
assert result is True
assert routing_table.size() == 1
assert routing_table.peer_in_table(sample_peer_info.peer_id)
# Test adding peer with just ID
peer_id = create_valid_peer_id("test")
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
mock_host.get_peerstore().addrs.return_value = mock_addrs
result = await routing_table.add_peer(peer_id)
assert result is True
assert routing_table.size() == 2
# Test adding peer with no addresses
no_addr_peer_id = create_valid_peer_id("no_addr")
mock_host.get_peerstore().addrs.return_value = []
result = await routing_table.add_peer(no_addr_peer_id)
assert result is False
assert routing_table.size() == 2
# Test adding local peer (should be ignored)
result = await routing_table.add_peer(local_peer_id)
assert result is False
assert routing_table.size() == 2
def test_find_bucket(self, mock_host, local_peer_id):
"""Test finding appropriate bucket for peers."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Test with peer ID
peer_id = create_valid_peer_id("test")
bucket = routing_table.find_bucket(peer_id)
assert isinstance(bucket, KBucket)
def test_peer_management(self, mock_host, local_peer_id, sample_peer_info):
"""Test peer management operations."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Add peer manually
bucket = routing_table.find_bucket(sample_peer_info.peer_id)
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
# Test peer queries
assert routing_table.peer_in_table(sample_peer_info.peer_id)
assert routing_table.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
assert routing_table.size() == 1
assert len(routing_table.get_peer_ids()) == 1
# Test remove peer
result = routing_table.remove_peer(sample_peer_info.peer_id)
assert result is True
assert not routing_table.peer_in_table(sample_peer_info.peer_id)
assert routing_table.size() == 0
def test_find_closest_peers(self, mock_host, local_peer_id):
"""Test finding closest peers."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Empty table
target_key = create_key_from_binary(b"target_key")
closest_peers = routing_table.find_local_closest_peers(target_key, 5)
assert closest_peers == []
# Add some peers
bucket = routing_table.buckets[0]
test_peers = []
for i in range(5):
peer = PeerInfo(create_valid_peer_id(f"peer{i}"), [])
test_peers.append(peer)
bucket.peers[peer.peer_id] = (peer, time.time())
closest_peers = routing_table.find_local_closest_peers(target_key, 3)
assert len(closest_peers) <= 3
assert len(closest_peers) <= len(test_peers)
assert all(isinstance(peer_id, ID) for peer_id in closest_peers)
def test_distance_calculation(self, mock_host, local_peer_id):
"""Test XOR distance calculation."""
# Test same keys
key = b"\x42" * 32
distance = xor_distance(key, key)
assert distance == 0
# Test different keys
key1 = b"\x00" * 32
key2 = b"\xff" * 32
distance = xor_distance(key1, key2)
expected = int.from_bytes(b"\xff" * 32, byteorder="big")
assert distance == expected
def test_edge_cases(self, mock_host, local_peer_id):
"""Test various edge cases."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Test with invalid peer ID
nonexistent_peer_id = create_valid_peer_id("nonexistent")
assert not routing_table.peer_in_table(nonexistent_peer_id)
assert routing_table.get_peer_info(nonexistent_peer_id) is None
assert routing_table.remove_peer(nonexistent_peer_id) is False
# Test bucket splitting scenario
assert len(routing_table.buckets) == 1
initial_bucket = routing_table.buckets[0]
assert initial_bucket.min_range == 0
assert initial_bucket.max_range == 2**256

View File

@ -0,0 +1,504 @@
"""
Unit tests for the ValueStore class in Kademlia DHT.
This module tests the core functionality of the ValueStore including:
- Basic storage and retrieval operations
- Expiration and TTL handling
- Edge cases and error conditions
- Store management operations
"""
import time
from unittest.mock import (
Mock,
)
import pytest
from libp2p.kad_dht.value_store import (
DEFAULT_TTL,
ValueStore,
)
from libp2p.peer.id import (
ID,
)
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
class TestValueStore:
"""Test suite for ValueStore class."""
def test_init_empty_store(self):
"""Test that a new ValueStore is initialized empty."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
assert len(store.store) == 0
def test_init_with_host_and_peer_id(self):
"""Test initialization with host and local peer ID."""
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
store = ValueStore(host=mock_host, local_peer_id=peer_id)
assert store.host == mock_host
assert store.local_peer_id == peer_id
assert len(store.store) == 0
def test_put_basic(self):
"""Test basic put operation."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
assert key in store.store
stored_value, validity = store.store[key]
assert stored_value == value
assert validity is not None
assert validity > time.time() # Should be in the future
def test_put_with_custom_validity(self):
"""Test put operation with custom validity time."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
custom_validity = time.time() + 3600 # 1 hour from now
store.put(key, value, validity=custom_validity)
stored_value, validity = store.store[key]
assert stored_value == value
assert validity == custom_validity
def test_put_overwrite_existing(self):
"""Test that put overwrites existing values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value1 = b"value1"
value2 = b"value2"
store.put(key, value1)
store.put(key, value2)
assert len(store.store) == 1
stored_value, _ = store.store[key]
assert stored_value == value2
def test_get_existing_valid_value(self):
"""Test retrieving an existing, non-expired value."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_get_nonexistent_key(self):
"""Test retrieving a non-existent key returns None."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"nonexistent_key"
retrieved_value = store.get(key)
assert retrieved_value is None
def test_get_expired_value(self):
"""Test that expired values are automatically removed and return None."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
expired_validity = time.time() - 1 # 1 second ago
# Manually insert expired value
store.store[key] = (value, expired_validity)
retrieved_value = store.get(key)
assert retrieved_value is None
assert key not in store.store # Should be removed
def test_remove_existing_key(self):
"""Test removing an existing key."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
result = store.remove(key)
assert result is True
assert key not in store.store
def test_remove_nonexistent_key(self):
"""Test removing a non-existent key returns False."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"nonexistent_key"
result = store.remove(key)
assert result is False
def test_has_existing_valid_key(self):
"""Test has() returns True for existing, valid keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
result = store.has(key)
assert result is True
def test_has_nonexistent_key(self):
"""Test has() returns False for non-existent keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"nonexistent_key"
result = store.has(key)
assert result is False
def test_has_expired_key(self):
"""Test has() returns False for expired keys and removes them."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
expired_validity = time.time() - 1
# Manually insert expired value
store.store[key] = (value, expired_validity)
result = store.has(key)
assert result is False
assert key not in store.store # Should be removed
def test_cleanup_expired_no_expired_values(self):
"""Test cleanup when there are no expired values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
value = b"value"
store.put(key1, value)
store.put(key2, value)
expired_count = store.cleanup_expired()
assert expired_count == 0
assert len(store.store) == 2
def test_cleanup_expired_with_expired_values(self):
"""Test cleanup removes expired values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"valid_key"
key2 = b"expired_key1"
key3 = b"expired_key2"
value = b"value"
expired_validity = time.time() - 1
store.put(key1, value) # Valid
store.store[key2] = (value, expired_validity) # Expired
store.store[key3] = (value, expired_validity) # Expired
expired_count = store.cleanup_expired()
assert expired_count == 2
assert len(store.store) == 1
assert key1 in store.store
assert key2 not in store.store
assert key3 not in store.store
def test_cleanup_expired_mixed_validity_types(self):
"""Test cleanup with mix of values with and without expiration."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"no_expiry"
key2 = b"valid_expiry"
key3 = b"expired"
value = b"value"
# No expiration (None validity)
store.put(key1, value)
# Valid expiration
store.put(key2, value, validity=time.time() + 3600)
# Expired
store.store[key3] = (value, time.time() - 1)
expired_count = store.cleanup_expired()
assert expired_count == 1
assert len(store.store) == 2
assert key1 in store.store
assert key2 in store.store
assert key3 not in store.store
def test_get_keys_empty_store(self):
"""Test get_keys() returns empty list for empty store."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
keys = store.get_keys()
assert keys == []
def test_get_keys_with_valid_values(self):
"""Test get_keys() returns all non-expired keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
key3 = b"expired_key"
value = b"value"
store.put(key1, value)
store.put(key2, value)
store.store[key3] = (value, time.time() - 1) # Expired
keys = store.get_keys()
assert len(keys) == 2
assert key1 in keys
assert key2 in keys
assert key3 not in keys
def test_size_empty_store(self):
"""Test size() returns 0 for empty store."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
size = store.size()
assert size == 0
def test_size_with_valid_values(self):
"""Test size() returns correct count after cleaning expired values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
key3 = b"expired_key"
value = b"value"
store.put(key1, value)
store.put(key2, value)
store.store[key3] = (value, time.time() - 1) # Expired
size = store.size()
assert size == 2
def test_edge_case_empty_key(self):
"""Test handling of empty key."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b""
value = b"value"
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_edge_case_empty_value(self):
"""Test handling of empty value."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b""
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_edge_case_large_key_value(self):
"""Test handling of large keys and values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"x" * 10000 # 10KB key
value = b"y" * 100000 # 100KB value
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_edge_case_negative_validity(self):
"""Test handling of negative validity time."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
store.put(key, value, validity=-1)
# Should be expired
retrieved_value = store.get(key)
assert retrieved_value is None
def test_default_ttl_calculation(self):
"""Test that default TTL is correctly applied."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
start_time = time.time()
store.put(key, value)
_, validity = store.store[key]
expected_validity = start_time + DEFAULT_TTL
# Allow small time difference for execution
assert abs(validity - expected_validity) < 1
def test_concurrent_operations(self):
"""Test that multiple operations don't interfere with each other."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Add multiple key-value pairs
for i in range(100):
key = f"key_{i}".encode()
value = f"value_{i}".encode()
store.put(key, value)
# Verify all are stored
assert store.size() == 100
# Remove every other key
for i in range(0, 100, 2):
key = f"key_{i}".encode()
store.remove(key)
# Verify correct count
assert store.size() == 50
# Verify remaining keys are correct
for i in range(1, 100, 2):
key = f"key_{i}".encode()
assert store.has(key)
def test_expiration_boundary_conditions(self):
"""Test expiration around current time boundary."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
key3 = b"key3"
value = b"value"
current_time = time.time()
# Just expired
store.store[key1] = (value, current_time - 0.001)
# Valid for a longer time to account for test execution time
store.store[key2] = (value, current_time + 1.0)
# Exactly current time (should be expired)
store.store[key3] = (value, current_time)
# Small delay to ensure time has passed
time.sleep(0.002)
assert not store.has(key1) # Should be expired
assert store.has(key2) # Should be valid
assert not store.has(key3) # Should be expired (exactly at current time)
def test_store_internal_structure(self):
"""Test that internal store structure is maintained correctly."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
validity = time.time() + 3600
store.put(key, value, validity=validity)
# Verify internal structure
assert isinstance(store.store, dict)
assert key in store.store
stored_tuple = store.store[key]
assert isinstance(stored_tuple, tuple)
assert len(stored_tuple) == 2
assert stored_tuple[0] == value
assert stored_tuple[1] == validity
@pytest.mark.trio
async def test_store_at_peer_local_peer(self):
"""Test _store_at_peer returns True when storing at local peer."""
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
result = await store._store_at_peer(peer_id, key, value)
assert result is True
@pytest.mark.trio
async def test_get_from_peer_local_peer(self):
"""Test _get_from_peer returns None when querying local peer."""
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
result = await store._get_from_peer(peer_id, key)
assert result is None
def test_memory_efficiency_large_dataset(self):
"""Test memory behavior with large datasets."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Add a large number of entries
num_entries = 10000
for i in range(num_entries):
key = f"key_{i:05d}".encode()
value = f"value_{i:05d}".encode()
store.put(key, value)
assert store.size() == num_entries
# Clean up all entries
for i in range(num_entries):
key = f"key_{i:05d}".encode()
store.remove(key)
assert store.size() == 0
assert len(store.store) == 0
def test_key_collision_resistance(self):
"""Test that similar keys don't collide."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Test keys that might cause collisions
keys = [
b"key",
b"key\x00",
b"key1",
b"Key", # Different case
b"key ", # With space
b" key", # Leading space
]
for i, key in enumerate(keys):
value = f"value_{i}".encode()
store.put(key, value)
# Verify all keys are stored separately
assert store.size() == len(keys)
for i, key in enumerate(keys):
expected_value = f"value_{i}".encode()
assert store.get(key) == expected_value
def test_unicode_key_handling(self):
"""Test handling of unicode content in keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Test various unicode keys
unicode_keys = [
b"hello",
"héllo".encode(),
"🔑".encode(),
"ключ".encode(), # Russian
"".encode(), # Chinese
]
for i, key in enumerate(unicode_keys):
value = f"value_{i}".encode()
store.put(key, value)
assert store.get(key) == value