mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Kademlia DHT implementation in py-libp2p (#579)
* initialise the module * added content routing * added routing module * added peer routing * added value store * added utilities functions * added main kademlia file * fixed create_key_from_binary function * example to test kademlia dht * added protocol ID and enhanced logging for peer store size in provider and consumer nodes * refactor: specify stream type in handle_stream method and add peer in routing table * removed content routing * added default value of count for finding closest peers * added functions to find close peers * refactor: remove content routing and enhance peer discovery * added put value function * added get value function * fix: improve logging and handle key encoding in get_value method * refactor: remove ContentRouting import from __init__.py * refactor: improved basic kademlia example * added protobuf files * replaced json with protobuf * refactor: enhance peer discovery and routing logic in KadDHT * refactor: enhance Kademlia routing table to use PeerInfo objects and improve peer management * refactor: enhance peer addition logic to utilize PeerInfo objects in routing table * feat: implement content provider functionality in Kademlia DHT * refactor: update value store to use datetime for validity management * refactor: update RoutingTable initialization to include host reference * refactor: enhance KBucket and RoutingTable for improved peer management and functionality * refactor: streamline peer discovery and value storage methods in KadDHT * refactor: update KadDHT and related classes for async peer management and enhanced value storage * refactor: enhance ProviderStore initialization and improve peer routing integration * test: add tests for Kademlia DHT functionality * fix linting issues * pydocstyle issues fixed * CICD pipeline issues solved * fix: update docstring format for find_peer method * refactor: improve logging and remove unused code in DHT implementation * refactor: clean up logging and remove unused imports in DHT and test files * Refactor logging setup and improve DHT stream handling with varint length prefixes * Update bootstrap peer handling in basic_dht example and refactor peer routing to accept string addresses * Enhance peer querying in Kademlia DHT by implementing parallel queries using Trio. * Enhance peer querying by adding deduplication checks * Refactor DHT implementation to use varint for length prefixes and enhance logging for better traceability * Add base58 encoding for value storage and enhance logging in basic_dht example * Refactor Kademlia DHT to support server/client modes * Added unit tests * Refactor documentation to fixsome warning * Add unit tests and remove outdated tests * Fixed precommit errora * Refactor error handling test to raise StringParseError for invalid bootstrap addresses * Add libp2p.kad_dht to the list of subpackages in documentation * Fix expiration and republish checks to use inclusive comparison * Add __init__.py file to libp2p.kad_dht.pb package * Refactor get value and put value to run in parallel with query timeout * Refactor provider message handling to use parallel processing with timeout * Add methods for provider store in KadDHT class * Refactor KadDHT and ProviderStore methods to improve type hints and enhance parallel processing * Add documentation for libp2p.kad_dht.pb module. * Update documentation for libp2p.kad_dht package to include subpackages and correct formatting * Fix formatting in documentation for libp2p.kad_dht package by correcting the subpackage reference * Fix header formatting in libp2p.kad_dht.pb documentation * Change log level from info to debug for various logging statements. * fix CICD issues (post revamp) * fixed value store unit test * Refactored kademlia example * Refactor Kademlia example: enhance logging, improve bootstrap node connection, and streamline server address handling * removed bootstrap module * Refactor Kademlia DHT example and core modules: enhance logging, remove unused code, and improve peer handling * Added docs of kad dht example * Update server address log file path to use the script's directory * Refactor: Introduce DHTMode enum for clearer mode management * moved xor_distance function to utils.py * Enhance logging in ValueStore and KadDHT: include decoded value in debug logs and update parameter description for validity * Add handling for closest peers in GET_VALUE response when value is not found * Handled failure scenario for PUT_VALUE * Remove kademlia demo from project scripts and contributing documentation * spelling and logging --------- Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com>
This commit is contained in:
4
Makefile
4
Makefile
@ -58,7 +58,9 @@ PB = libp2p/crypto/pb/crypto.proto \
|
||||
libp2p/security/secio/pb/spipe.proto \
|
||||
libp2p/security/noise/pb/noise.proto \
|
||||
libp2p/identity/identify/pb/identify.proto \
|
||||
libp2p/host/autonat/pb/autonat.proto
|
||||
libp2p/host/autonat/pb/autonat.proto \
|
||||
libp2p/kad_dht/pb/kademlia.proto
|
||||
|
||||
PY = $(PB:.proto=_pb2.py)
|
||||
PYI = $(PB:.proto=_pb2.pyi)
|
||||
|
||||
|
||||
124
docs/examples.kademlia.rst
Normal file
124
docs/examples.kademlia.rst
Normal file
@ -0,0 +1,124 @@
|
||||
Kademlia DHT Demo
|
||||
=================
|
||||
|
||||
This example demonstrates a Kademlia Distributed Hash Table (DHT) implementation with both value storage/retrieval and content provider advertisement/discovery functionality.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
Collecting libp2p
|
||||
...
|
||||
Successfully installed libp2p-x.x.x
|
||||
$ cd examples/kademlia
|
||||
$ python kademlia.py --mode server
|
||||
2025-06-13 19:51:25,424 - kademlia-example - INFO - Running in server mode on port 0
|
||||
2025-06-13 19:51:25,426 - kademlia-example - INFO - Connected to bootstrap nodes: []
|
||||
2025-06-13 19:51:25,426 - kademlia-example - INFO - To connect to this node, use: --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-06-13 19:51:25,426 - kademlia-example - INFO - Saved server address to log: /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-06-13 19:51:25,427 - kademlia-example - INFO - DHT service started in SERVER mode
|
||||
2025-06-13 19:51:25,427 - kademlia-example - INFO - Stored value 'Hello message from Sumanjeet' with key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
|
||||
2025-06-13 19:51:25,427 - kademlia-example - INFO - Successfully advertised as server for content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
|
||||
|
||||
|
||||
Copy the line that starts with ``--bootstrap``, open a new terminal in the same folder and run the client:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/28910/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef
|
||||
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
|
||||
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
|
||||
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
|
||||
|
||||
Alternatively, if you run the server first, the client can automatically extract the bootstrap address from the server log file:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client
|
||||
2025-06-13 19:51:37,022 - kademlia-example - INFO - Running in client mode on port 0
|
||||
2025-06-13 19:51:37,026 - kademlia-example - INFO - Connected to bootstrap nodes: [<libp2p.peer.id.ID (16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef)>]
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - DHT service started in CLIENT mode
|
||||
2025-06-13 19:51:37,027 - kademlia-example - INFO - Looking up key: FVDjasarSFDoLPMdgnp1dHSbW2ZAfN8NU2zNbCQeczgP
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Retrieved value: Hello message from Sumanjeet
|
||||
2025-06-13 19:51:37,031 - kademlia-example - INFO - Looking for servers of content: 361f2ed1183bca491b8aec11f0b9e5c06724759b0f7480ae7fb4894901993bc8
|
||||
2025-06-13 19:51:37,035 - kademlia-example - INFO - Found 1 servers for content: ['16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef']
|
||||
|
||||
The demo showcases key DHT operations:
|
||||
|
||||
- **Value Storage & Retrieval**: The server stores a value, and the client retrieves it
|
||||
- **Content Provider Discovery**: The server advertises content, and the client finds providers
|
||||
- **Peer Discovery**: Automatic bootstrap and peer routing using the Kademlia algorithm
|
||||
- **Network Resilience**: Distributed storage across multiple nodes (when available)
|
||||
|
||||
Command Line Options
|
||||
--------------------
|
||||
|
||||
The Kademlia demo supports several command line options for customization:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --help
|
||||
usage: kademlia.py [-h] [--mode MODE] [--port PORT] [--bootstrap [BOOTSTRAP ...]] [--verbose]
|
||||
|
||||
Kademlia DHT example with content server functionality
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--mode MODE Run as a server or client node (default: server)
|
||||
--port PORT Port to listen on (0 for random) (default: 0)
|
||||
--bootstrap [BOOTSTRAP ...]
|
||||
Multiaddrs of bootstrap nodes. Provide a space-separated list of addresses.
|
||||
This is required for client mode.
|
||||
--verbose Enable verbose logging
|
||||
|
||||
**Examples:**
|
||||
|
||||
Start server on a specific port:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode server --port 8000
|
||||
|
||||
Start client with verbose logging:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client --verbose
|
||||
|
||||
Connect to multiple bootstrap nodes:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python kademlia.py --mode client --bootstrap /ip4/127.0.0.1/tcp/8000/p2p/... /ip4/127.0.0.1/tcp/8001/p2p/...
|
||||
|
||||
How It Works
|
||||
------------
|
||||
|
||||
The Kademlia DHT implementation demonstrates several key concepts:
|
||||
|
||||
**Server Mode:**
|
||||
- Stores key-value pairs in the distributed hash table
|
||||
- Advertises itself as a content provider for specific content
|
||||
- Handles incoming DHT requests from other nodes
|
||||
- Maintains routing table with known peers
|
||||
|
||||
**Client Mode:**
|
||||
- Connects to bootstrap nodes to join the network
|
||||
- Retrieves values by their keys from the DHT
|
||||
- Discovers content providers for specific content
|
||||
- Performs network lookups using the Kademlia algorithm
|
||||
|
||||
**Key Components:**
|
||||
- **Routing Table**: Organizes peers in k-buckets based on XOR distance
|
||||
- **Value Store**: Manages key-value storage with TTL (time-to-live)
|
||||
- **Provider Store**: Tracks which peers provide specific content
|
||||
- **Peer Routing**: Implements iterative lookups to find closest peers
|
||||
|
||||
The full source code for this example is below:
|
||||
|
||||
.. literalinclude:: ../examples/kademlia/kademlia.py
|
||||
:language: python
|
||||
:linenos:
|
||||
@ -11,3 +11,4 @@ Examples
|
||||
examples.echo
|
||||
examples.ping
|
||||
examples.pubsub
|
||||
examples.kademlia
|
||||
|
||||
22
docs/libp2p.kad_dht.pb.rst
Normal file
22
docs/libp2p.kad_dht.pb.rst
Normal file
@ -0,0 +1,22 @@
|
||||
libp2p.kad\_dht.pb package
|
||||
==========================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.kad_dht.pb.kademlia_pb2 module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.pb.kademlia_pb2
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.pb
|
||||
:no-index:
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
77
docs/libp2p.kad_dht.rst
Normal file
77
docs/libp2p.kad_dht.rst
Normal file
@ -0,0 +1,77 @@
|
||||
libp2p.kad\_dht package
|
||||
=======================
|
||||
|
||||
Subpackages
|
||||
-----------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.kad_dht.pb
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.kad\_dht.kad\_dht module
|
||||
-------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.kad_dht
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.peer\_routing module
|
||||
------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.peer_routing
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.provider\_store module
|
||||
--------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.provider_store
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.routing\_table module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.routing_table
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.utils module
|
||||
----------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.value\_store module
|
||||
-----------------------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.value_store
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.kad\_dht.pb
|
||||
------------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht.pb
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.kad_dht
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -11,6 +11,7 @@ Subpackages
|
||||
libp2p.host
|
||||
libp2p.identity
|
||||
libp2p.io
|
||||
libp2p.kad_dht
|
||||
libp2p.network
|
||||
libp2p.peer
|
||||
libp2p.protocol_muxer
|
||||
|
||||
300
examples/kademlia/kademlia.py
Normal file
300
examples/kademlia/kademlia.py
Normal file
@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
A basic example of using the Kademlia DHT implementation, with all setup logic inlined.
|
||||
This example demonstrates both value storage/retrieval and content server
|
||||
advertisement/discovery.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
import base58
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.kad_dht.kad_dht import (
|
||||
DHTMode,
|
||||
KadDHT,
|
||||
)
|
||||
from libp2p.kad_dht.utils import (
|
||||
create_key_from_binary,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
logger = logging.getLogger("kademlia-example")
|
||||
|
||||
# Configure DHT module loggers to inherit from the parent logger
|
||||
# This ensures all kademlia-example.* loggers use the same configuration
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt")
|
||||
|
||||
# Set the level for all child loggers
|
||||
for module in [
|
||||
"kad_dht",
|
||||
"value_store",
|
||||
"peer_routing",
|
||||
"routing_table",
|
||||
"provider_store",
|
||||
]:
|
||||
child_logger = logging.getLogger(f"kademlia-example.{module}")
|
||||
child_logger.setLevel(logging.INFO)
|
||||
child_logger.propagate = True # Allow propagation to parent
|
||||
|
||||
# File to store node information
|
||||
bootstrap_nodes = []
|
||||
|
||||
|
||||
# function to take bootstrap_nodes as input and connects to them
|
||||
async def connect_to_bootstrap_nodes(host: IHost, bootstrap_addrs: list[str]) -> None:
|
||||
"""
|
||||
Connect to the bootstrap nodes provided in the list.
|
||||
|
||||
params: host: The host instance to connect to
|
||||
bootstrap_addrs: List of bootstrap node addresses
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
for addr in bootstrap_addrs:
|
||||
try:
|
||||
peerInfo = info_from_p2p_addr(Multiaddr(addr))
|
||||
host.get_peerstore().add_addrs(peerInfo.peer_id, peerInfo.addrs, 3600)
|
||||
await host.connect(peerInfo)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to bootstrap node {addr}: {e}")
|
||||
|
||||
|
||||
def save_server_addr(addr: str) -> None:
|
||||
"""Append the server's multiaddress to the log file."""
|
||||
try:
|
||||
with open(SERVER_ADDR_LOG, "w") as f:
|
||||
f.write(addr + "\n")
|
||||
logger.info(f"Saved server address to log: {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save server address: {e}")
|
||||
|
||||
|
||||
def load_server_addrs() -> list[str]:
|
||||
"""Load all server multiaddresses from the log file."""
|
||||
if not os.path.exists(SERVER_ADDR_LOG):
|
||||
return []
|
||||
try:
|
||||
with open(SERVER_ADDR_LOG) as f:
|
||||
return [line.strip() for line in f if line.strip()]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load server addresses: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def run_node(
|
||||
port: int, mode: str, bootstrap_addrs: list[str] | None = None
|
||||
) -> None:
|
||||
"""Run a node that serves content in the DHT with setup inlined."""
|
||||
try:
|
||||
if port <= 0:
|
||||
port = random.randint(10000, 60000)
|
||||
logger.debug(f"Using port: {port}")
|
||||
|
||||
# Convert string mode to DHTMode enum
|
||||
if mode is None or mode.upper() == "CLIENT":
|
||||
dht_mode = DHTMode.CLIENT
|
||||
elif mode.upper() == "SERVER":
|
||||
dht_mode = DHTMode.SERVER
|
||||
else:
|
||||
logger.error(f"Invalid mode: {mode}. Must be 'client' or 'server'")
|
||||
sys.exit(1)
|
||||
|
||||
# Load server addresses for client mode
|
||||
if dht_mode == DHTMode.CLIENT:
|
||||
server_addrs = load_server_addrs()
|
||||
if server_addrs:
|
||||
logger.info(f"Loaded {len(server_addrs)} server addresses from log")
|
||||
bootstrap_nodes.append(server_addrs[0]) # Use the first server address
|
||||
else:
|
||||
logger.warning("No server addresses found in log file")
|
||||
|
||||
if bootstrap_addrs:
|
||||
for addr in bootstrap_addrs:
|
||||
bootstrap_nodes.append(addr)
|
||||
|
||||
key_pair = create_new_key_pair(secrets.token_bytes(32))
|
||||
host = new_host(key_pair=key_pair)
|
||||
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
peer_id = host.get_id().pretty()
|
||||
addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}"
|
||||
await connect_to_bootstrap_nodes(host, bootstrap_nodes)
|
||||
dht = KadDHT(host, dht_mode)
|
||||
# take all peer ids from the host and add them to the dht
|
||||
for peer_id in host.get_peerstore().peer_ids():
|
||||
await dht.routing_table.add_peer(peer_id)
|
||||
logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}")
|
||||
bootstrap_cmd = f"--bootstrap {addr_str}"
|
||||
logger.info("To connect to this node, use: %s", bootstrap_cmd)
|
||||
|
||||
# Save server address in server mode
|
||||
if dht_mode == DHTMode.SERVER:
|
||||
save_server_addr(addr_str)
|
||||
|
||||
# Start the DHT service
|
||||
async with background_trio_service(dht):
|
||||
logger.info(f"DHT service started in {dht_mode.value} mode")
|
||||
val_key = create_key_from_binary(b"py-libp2p kademlia example value")
|
||||
content = b"Hello from python node "
|
||||
content_key = create_key_from_binary(content)
|
||||
|
||||
if dht_mode == DHTMode.SERVER:
|
||||
# Store a value in the DHT
|
||||
msg = "Hello message from Sumanjeet"
|
||||
val_data = msg.encode()
|
||||
await dht.put_value(val_key, val_data)
|
||||
logger.info(
|
||||
f"Stored value '{val_data.decode()}'"
|
||||
f"with key: {base58.b58encode(val_key).decode()}"
|
||||
)
|
||||
|
||||
# Advertise as content server
|
||||
success = await dht.provider_store.provide(content_key)
|
||||
if success:
|
||||
logger.info(
|
||||
"Successfully advertised as server"
|
||||
f"for content: {content_key.hex()}"
|
||||
)
|
||||
else:
|
||||
logger.warning("Failed to advertise as content server")
|
||||
|
||||
else:
|
||||
# retrieve the value
|
||||
logger.info(
|
||||
"Looking up key: %s", base58.b58encode(val_key).decode()
|
||||
)
|
||||
val_data = await dht.get_value(val_key)
|
||||
if val_data:
|
||||
try:
|
||||
logger.info(f"Retrieved value: {val_data.decode()}")
|
||||
except UnicodeDecodeError:
|
||||
logger.info(f"Retrieved value (bytes): {val_data!r}")
|
||||
else:
|
||||
logger.warning("Failed to retrieve value")
|
||||
|
||||
# Also check if we can find servers for our own content
|
||||
logger.info("Looking for servers of content: %s", content_key.hex())
|
||||
providers = await dht.provider_store.find_providers(content_key)
|
||||
if providers:
|
||||
logger.info(
|
||||
"Found %d servers for content: %s",
|
||||
len(providers),
|
||||
[p.peer_id.pretty() for p in providers],
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No servers found for content %s", content_key.hex()
|
||||
)
|
||||
|
||||
# Keep the node running
|
||||
while True:
|
||||
logger.debug(
|
||||
"Status - Connected peers: %d,"
|
||||
"Peers in store: %d, Values in store: %d",
|
||||
len(dht.host.get_connected_peers()),
|
||||
len(dht.host.get_peerstore().peer_ids()),
|
||||
len(dht.value_store.store),
|
||||
)
|
||||
await trio.sleep(10)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server node error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Kademlia DHT example with content server functionality"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
default="server",
|
||||
help="Run as a server or client node",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port to listen on (0 for random)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bootstrap",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help=(
|
||||
"Multiaddrs of bootstrap nodes. "
|
||||
"Provide a space-separated list of addresses. "
|
||||
"This is required for client mode."
|
||||
),
|
||||
)
|
||||
# add option to use verbose logging
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
# Set logging level based on verbosity
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
else:
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the kademlia demo."""
|
||||
try:
|
||||
args = parse_args()
|
||||
logger.info(
|
||||
"Running in %s mode on port %d",
|
||||
args.mode,
|
||||
args.port,
|
||||
)
|
||||
trio.run(run_node, args.port, args.mode, args.bootstrap)
|
||||
except Exception as e:
|
||||
logger.critical(f"Script failed: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
libp2p/kad_dht/__init__.py
Normal file
30
libp2p/kad_dht/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
Kademlia DHT implementation for py-libp2p.
|
||||
|
||||
This module provides a Distributed Hash Table (DHT) implementation
|
||||
based on the Kademlia protocol.
|
||||
"""
|
||||
|
||||
from .kad_dht import (
|
||||
KadDHT,
|
||||
)
|
||||
from .peer_routing import (
|
||||
PeerRouting,
|
||||
)
|
||||
from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .utils import (
|
||||
create_key_from_binary,
|
||||
)
|
||||
from .value_store import (
|
||||
ValueStore,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KadDHT",
|
||||
"RoutingTable",
|
||||
"PeerRouting",
|
||||
"ValueStore",
|
||||
"create_key_from_binary",
|
||||
]
|
||||
616
libp2p/kad_dht/kad_dht.py
Normal file
616
libp2p/kad_dht/kad_dht.py
Normal file
@ -0,0 +1,616 @@
|
||||
"""
|
||||
Kademlia DHT implementation for py-libp2p.
|
||||
|
||||
This module provides a complete Distributed Hash Table (DHT)
|
||||
implementation based on the Kademlia algorithm and protocol.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
import time
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
from .peer_routing import (
|
||||
PeerRouting,
|
||||
)
|
||||
from .provider_store import (
|
||||
ProviderStore,
|
||||
)
|
||||
from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .value_store import (
|
||||
ValueStore,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("kademlia-example.kad_dht")
|
||||
# logger = logging.getLogger("libp2p.kademlia")
|
||||
# Default parameters
|
||||
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
|
||||
ROUTING_TABLE_REFRESH_INTERVAL = 1 * 60 # 1 min in seconds for testing
|
||||
TTL = 24 * 60 * 60 # 24 hours in seconds
|
||||
ALPHA = 3
|
||||
QUERY_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class DHTMode(Enum):
|
||||
"""DHT operation modes."""
|
||||
|
||||
CLIENT = "CLIENT"
|
||||
SERVER = "SERVER"
|
||||
|
||||
|
||||
class KadDHT(Service):
|
||||
"""
|
||||
Kademlia DHT implementation for libp2p.
|
||||
|
||||
This class provides a DHT implementation that combines routing table management,
|
||||
peer discovery, content routing, and value storage.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, mode: DHTMode):
|
||||
"""
|
||||
Initialize a new Kademlia DHT node.
|
||||
|
||||
:param host: The libp2p host.
|
||||
:param mode: The mode of host (Client or Server) - must be DHTMode enum
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.host = host
|
||||
self.local_peer_id = host.get_id()
|
||||
|
||||
# Validate that mode is a DHTMode enum
|
||||
if not isinstance(mode, DHTMode):
|
||||
raise TypeError(f"mode must be DHTMode enum, got {type(mode)}")
|
||||
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the routing table
|
||||
self.routing_table = RoutingTable(self.local_peer_id, self.host)
|
||||
|
||||
# Initialize peer routing
|
||||
self.peer_routing = PeerRouting(host, self.routing_table)
|
||||
|
||||
# Initialize value store
|
||||
self.value_store = ValueStore(host=host, local_peer_id=self.local_peer_id)
|
||||
|
||||
# Initialize provider store with host and peer_routing references
|
||||
self.provider_store = ProviderStore(host=host, peer_routing=self.peer_routing)
|
||||
|
||||
# Last time we republished provider records
|
||||
self._last_provider_republish = time.time()
|
||||
|
||||
# Set protocol handlers
|
||||
host.set_stream_handler(PROTOCOL_ID, self.handle_stream)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the DHT service."""
|
||||
logger.info(f"Starting Kademlia DHT with peer ID {self.local_peer_id}")
|
||||
|
||||
# Main service loop
|
||||
while self.manager.is_running:
|
||||
# Periodically refresh the routing table
|
||||
await self.refresh_routing_table()
|
||||
|
||||
# Check if it's time to republish provider records
|
||||
current_time = time.time()
|
||||
# await self._republish_provider_records()
|
||||
self._last_provider_republish = current_time
|
||||
|
||||
# Clean up expired values and provider records
|
||||
expired_values = self.value_store.cleanup_expired()
|
||||
if expired_values > 0:
|
||||
logger.debug(f"Cleaned up {expired_values} expired values")
|
||||
|
||||
self.provider_store.cleanup_expired()
|
||||
|
||||
# Wait before next maintenance cycle
|
||||
await trio.sleep(ROUTING_TABLE_REFRESH_INTERVAL)
|
||||
|
||||
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
|
||||
"""
|
||||
Switch the DHT mode.
|
||||
|
||||
:param new_mode: The new mode - must be DHTMode enum
|
||||
:return: The new mode as DHTMode enum
|
||||
"""
|
||||
# Validate that new_mode is a DHTMode enum
|
||||
if not isinstance(new_mode, DHTMode):
|
||||
raise TypeError(f"new_mode must be DHTMode enum, got {type(new_mode)}")
|
||||
|
||||
if new_mode == DHTMode.CLIENT:
|
||||
self.routing_table.cleanup_routing_table()
|
||||
self.mode = new_mode
|
||||
logger.info(f"Switched to {new_mode.value} mode")
|
||||
return self.mode
|
||||
|
||||
async def handle_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle an incoming DHT stream using varint length prefixes.
|
||||
"""
|
||||
if self.mode == DHTMode.CLIENT:
|
||||
stream.close
|
||||
return
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
logger.debug(f"Received DHT stream from peer {peer_id}")
|
||||
await self.add_peer(peer_id)
|
||||
logger.debug(f"Added peer {peer_id} to routing table")
|
||||
|
||||
try:
|
||||
# Read varint-prefixed length for the message
|
||||
length_prefix = b""
|
||||
while True:
|
||||
byte = await stream.read(1)
|
||||
if not byte:
|
||||
logger.warning("Stream closed while reading varint length")
|
||||
await stream.close()
|
||||
return
|
||||
length_prefix += byte
|
||||
if byte[0] & 0x80 == 0:
|
||||
break
|
||||
msg_length = varint.decode_bytes(length_prefix)
|
||||
|
||||
# Read the message bytes
|
||||
msg_bytes = await stream.read(msg_length)
|
||||
if len(msg_bytes) < msg_length:
|
||||
logger.warning("Failed to read full message from stream")
|
||||
await stream.close()
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse as protobuf
|
||||
message = Message()
|
||||
message.ParseFromString(msg_bytes)
|
||||
logger.debug(
|
||||
f"Received DHT message from {peer_id}, type: {message.type}"
|
||||
)
|
||||
|
||||
# Handle FIND_NODE message
|
||||
if message.type == Message.MessageType.FIND_NODE:
|
||||
# Get target key directly from protobuf
|
||||
target_key = message.key
|
||||
|
||||
# Find closest peers to the target key
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
target_key, 20
|
||||
)
|
||||
logger.debug(f"Found {len(closest_peers)} peers close to target")
|
||||
|
||||
# Build response message with protobuf
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Add closest peers to response
|
||||
for peer in closest_peers:
|
||||
# Skip if the peer is the requester
|
||||
if peer == peer_id:
|
||||
continue
|
||||
|
||||
# Add peer to closerPeers field
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
if addrs:
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug(
|
||||
f"Sent FIND_NODE response with{len(response.closerPeers)} peers"
|
||||
)
|
||||
|
||||
# Handle ADD_PROVIDER message
|
||||
elif message.type == Message.MessageType.ADD_PROVIDER:
|
||||
# Process ADD_PROVIDER
|
||||
key = message.key
|
||||
logger.debug(f"Received ADD_PROVIDER for key {key.hex()}")
|
||||
|
||||
# Extract provider information
|
||||
for provider_proto in message.providerPeers:
|
||||
try:
|
||||
# Validate that the provider is the sender
|
||||
provider_id = ID(provider_proto.id)
|
||||
if provider_id != peer_id:
|
||||
logger.warning(
|
||||
f"Provider ID {provider_id} doesn't"
|
||||
f"match sender {peer_id}, ignoring"
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert addresses to Multiaddr
|
||||
addrs = []
|
||||
for addr_bytes in provider_proto.addrs:
|
||||
try:
|
||||
addrs.append(Multiaddr(addr_bytes))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse address: {e}")
|
||||
|
||||
# Add to provider store
|
||||
provider_info = PeerInfo(provider_id, addrs)
|
||||
self.provider_store.add_provider(key, provider_info)
|
||||
logger.debug(
|
||||
f"Added provider {provider_id} for key {key.hex()}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process provider info: {e}")
|
||||
|
||||
# Send acknowledgement
|
||||
response = Message()
|
||||
response.type = Message.MessageType.ADD_PROVIDER
|
||||
response.key = key
|
||||
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent ADD_PROVIDER acknowledgement")
|
||||
|
||||
# Handle GET_PROVIDERS message
|
||||
elif message.type == Message.MessageType.GET_PROVIDERS:
|
||||
# Process GET_PROVIDERS
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}")
|
||||
|
||||
# Find providers for the key
|
||||
providers = self.provider_store.get_providers(key)
|
||||
logger.debug(
|
||||
f"Found {len(providers)} providers for key {key.hex()}"
|
||||
)
|
||||
|
||||
# Create response
|
||||
response = Message()
|
||||
response.type = Message.MessageType.GET_PROVIDERS
|
||||
response.key = key
|
||||
|
||||
# Add provider information to response
|
||||
for provider_info in providers:
|
||||
provider_proto = response.providerPeers.add()
|
||||
provider_proto.id = provider_info.peer_id.to_bytes()
|
||||
provider_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
for addr in provider_info.addrs:
|
||||
provider_proto.addrs.append(addr.to_bytes())
|
||||
|
||||
# Also include closest peers if we don't have providers
|
||||
if not providers:
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
key, 20
|
||||
)
|
||||
logger.debug(
|
||||
f"No providers found, including {len(closest_peers)}"
|
||||
"closest peers"
|
||||
)
|
||||
|
||||
for peer in closest_peers:
|
||||
# Skip if peer is the requester
|
||||
if peer == peer_id:
|
||||
continue
|
||||
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent GET_PROVIDERS response")
|
||||
|
||||
# Handle GET_VALUE message
|
||||
elif message.type == Message.MessageType.GET_VALUE:
|
||||
# Process GET_VALUE
|
||||
key = message.key
|
||||
logger.debug(f"Received GET_VALUE request for key {key.hex()}")
|
||||
|
||||
value = self.value_store.get(key)
|
||||
if value:
|
||||
logger.debug(f"Found value for key {key.hex()}")
|
||||
|
||||
# Create response using protobuf
|
||||
response = Message()
|
||||
response.type = Message.MessageType.GET_VALUE
|
||||
|
||||
# Create record
|
||||
response.key = key
|
||||
response.record.key = key
|
||||
response.record.value = value
|
||||
response.record.timeReceived = str(time.time())
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent GET_VALUE response")
|
||||
else:
|
||||
logger.debug(f"No value found for key {key.hex()}")
|
||||
|
||||
# Create response with closest peers when no value is found
|
||||
response = Message()
|
||||
response.type = Message.MessageType.GET_VALUE
|
||||
response.key = key
|
||||
|
||||
# Add closest peers to key
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
key, 20
|
||||
)
|
||||
logger.debug(
|
||||
"No value found,"
|
||||
f"including {len(closest_peers)} closest peers"
|
||||
)
|
||||
|
||||
for peer in closest_peers:
|
||||
# Skip if peer is the requester
|
||||
if peer == peer_id:
|
||||
continue
|
||||
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Serialize and send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent GET_VALUE response with closest peers")
|
||||
|
||||
# Handle PUT_VALUE message
|
||||
elif message.type == Message.MessageType.PUT_VALUE and message.HasField(
|
||||
"record"
|
||||
):
|
||||
# Process PUT_VALUE
|
||||
key = message.record.key
|
||||
value = message.record.value
|
||||
success = False
|
||||
try:
|
||||
if not (key and value):
|
||||
raise ValueError(
|
||||
"Missing key or value in PUT_VALUE message"
|
||||
)
|
||||
|
||||
self.value_store.put(key, value)
|
||||
logger.debug(f"Stored value {value.hex()} for key {key.hex()}")
|
||||
success = True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to store value {value.hex()} for key "
|
||||
f"{key.hex()}: {e}"
|
||||
)
|
||||
finally:
|
||||
# Send acknowledgement
|
||||
response = Message()
|
||||
response.type = Message.MessageType.PUT_VALUE
|
||||
if success:
|
||||
response.key = key
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(varint.encode(len(response_bytes)))
|
||||
await stream.write(response_bytes)
|
||||
logger.debug("Sent PUT_VALUE acknowledgement")
|
||||
|
||||
except Exception as proto_err:
|
||||
logger.warning(f"Failed to parse protobuf message: {proto_err}")
|
||||
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling DHT stream: {e}")
|
||||
await stream.close()
|
||||
|
||||
async def refresh_routing_table(self) -> None:
|
||||
"""Refresh the routing table."""
|
||||
logger.debug("Refreshing routing table")
|
||||
await self.peer_routing.refresh_routing_table()
|
||||
|
||||
# Peer routing methods
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""
|
||||
Find a peer with the given ID.
|
||||
"""
|
||||
logger.debug(f"Finding peer: {peer_id}")
|
||||
return await self.peer_routing.find_peer(peer_id)
|
||||
|
||||
# Value storage and retrieval methods
|
||||
|
||||
async def put_value(self, key: bytes, value: bytes) -> None:
|
||||
"""
|
||||
Store a value in the DHT.
|
||||
"""
|
||||
logger.debug(f"Storing value for key {key.hex()}")
|
||||
|
||||
# 1. Store locally first
|
||||
self.value_store.put(key, value)
|
||||
try:
|
||||
decoded_value = value.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
decoded_value = value.hex()
|
||||
logger.debug(
|
||||
f"Stored value locally for key {key.hex()} with value {decoded_value}"
|
||||
)
|
||||
|
||||
# 2. Get closest peers, excluding self
|
||||
closest_peers = [
|
||||
peer
|
||||
for peer in self.routing_table.find_local_closest_peers(key)
|
||||
if peer != self.local_peer_id
|
||||
]
|
||||
logger.debug(f"Found {len(closest_peers)} peers to store value at")
|
||||
|
||||
# 3. Store at remote peers in batches of ALPHA, in parallel
|
||||
stored_count = 0
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
batch_results = [False] * len(batch)
|
||||
|
||||
async def store_one(idx: int, peer: ID) -> None:
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
success = await self.value_store._store_at_peer(
|
||||
peer, key, value
|
||||
)
|
||||
batch_results[idx] = success
|
||||
if success:
|
||||
logger.debug(f"Stored value at peer {peer}")
|
||||
else:
|
||||
logger.debug(f"Failed to store value at peer {peer}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error storing value at peer {peer}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for idx, peer in enumerate(batch):
|
||||
nursery.start_soon(store_one, idx, peer)
|
||||
|
||||
stored_count += sum(batch_results)
|
||||
|
||||
logger.info(f"Successfully stored value at {stored_count} peers")
|
||||
|
||||
async def get_value(self, key: bytes) -> bytes | None:
|
||||
logger.debug(f"Getting value for key: {key.hex()}")
|
||||
|
||||
# 1. Check local store first
|
||||
value = self.value_store.get(key)
|
||||
if value:
|
||||
logger.debug("Found value locally")
|
||||
return value
|
||||
|
||||
# 2. Get closest peers, excluding self
|
||||
closest_peers = [
|
||||
peer
|
||||
for peer in self.routing_table.find_local_closest_peers(key)
|
||||
if peer != self.local_peer_id
|
||||
]
|
||||
logger.debug(f"Searching {len(closest_peers)} peers for value")
|
||||
|
||||
# 3. Query ALPHA peers at a time in parallel
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
found_value = None
|
||||
|
||||
async def query_one(peer: ID) -> None:
|
||||
nonlocal found_value
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
value = await self.value_store._get_from_peer(peer, key)
|
||||
if value is not None and found_value is None:
|
||||
found_value = value
|
||||
logger.debug(f"Found value at peer {peer}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error querying peer {peer}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for peer in batch:
|
||||
nursery.start_soon(query_one, peer)
|
||||
|
||||
if found_value is not None:
|
||||
self.value_store.put(key, found_value)
|
||||
logger.info("Successfully retrieved value from network")
|
||||
return found_value
|
||||
|
||||
# 4. Not found
|
||||
logger.warning(f"Value not found for key {key.hex()}")
|
||||
return None
|
||||
|
||||
# Add these methods in the Utility methods section
|
||||
|
||||
# Utility methods
|
||||
|
||||
async def add_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Add a peer to the routing table.
|
||||
|
||||
params: peer_id: The peer ID to add.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if peer was added or updated, False otherwise.
|
||||
|
||||
"""
|
||||
return await self.routing_table.add_peer(peer_id)
|
||||
|
||||
async def provide(self, key: bytes) -> bool:
|
||||
"""
|
||||
Reference to provider_store.provide for convenience.
|
||||
"""
|
||||
return await self.provider_store.provide(key)
|
||||
|
||||
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
|
||||
"""
|
||||
Reference to provider_store.find_providers for convenience.
|
||||
"""
|
||||
return await self.provider_store.find_providers(key, count)
|
||||
|
||||
def get_routing_table_size(self) -> int:
|
||||
"""
|
||||
Get the number of peers in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of peers.
|
||||
|
||||
"""
|
||||
return self.routing_table.size()
|
||||
|
||||
def get_value_store_size(self) -> int:
|
||||
"""
|
||||
Get the number of items in the value store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of items.
|
||||
|
||||
"""
|
||||
return self.value_store.size()
|
||||
0
libp2p/kad_dht/pb/__init__.py
Normal file
0
libp2p/kad_dht/pb/__init__.py
Normal file
38
libp2p/kad_dht/pb/kademlia.proto
Normal file
38
libp2p/kad_dht/pb/kademlia.proto
Normal file
@ -0,0 +1,38 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Record {
|
||||
bytes key = 1;
|
||||
bytes value = 2;
|
||||
string timeReceived = 5;
|
||||
};
|
||||
|
||||
message Message {
|
||||
enum MessageType {
|
||||
PUT_VALUE = 0;
|
||||
GET_VALUE = 1;
|
||||
ADD_PROVIDER = 2;
|
||||
GET_PROVIDERS = 3;
|
||||
FIND_NODE = 4;
|
||||
PING = 5;
|
||||
}
|
||||
|
||||
enum ConnectionType {
|
||||
NOT_CONNECTED = 0;
|
||||
CONNECTED = 1;
|
||||
CAN_CONNECT = 2;
|
||||
CANNOT_CONNECT = 3;
|
||||
}
|
||||
|
||||
message Peer {
|
||||
bytes id = 1;
|
||||
repeated bytes addrs = 2;
|
||||
ConnectionType connection = 3;
|
||||
}
|
||||
|
||||
MessageType type = 1;
|
||||
int32 clusterLevelRaw = 10;
|
||||
bytes key = 2;
|
||||
Record record = 3;
|
||||
repeated Peer closerPeers = 8;
|
||||
repeated Peer providerPeers = 9;
|
||||
}
|
||||
33
libp2p/kad_dht/pb/kademlia_pb2.py
Normal file
33
libp2p/kad_dht/pb/kademlia_pb2.py
Normal file
@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: libp2p/kad_dht/pb/kademlia.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals)
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_globals['_RECORD']._serialized_start=36
|
||||
_globals['_RECORD']._serialized_end=94
|
||||
_globals['_MESSAGE']._serialized_start=97
|
||||
_globals['_MESSAGE']._serialized_end=555
|
||||
_globals['_MESSAGE_PEER']._serialized_start=281
|
||||
_globals['_MESSAGE_PEER']._serialized_end=359
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_start=361
|
||||
_globals['_MESSAGE_MESSAGETYPE']._serialized_end=466
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=468
|
||||
_globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=555
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
133
libp2p/kad_dht/pb/kademlia_pb2.pyi
Normal file
133
libp2p/kad_dht/pb/kademlia_pb2.pyi
Normal file
@ -0,0 +1,133 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import collections.abc
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.containers
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class Record(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
VALUE_FIELD_NUMBER: builtins.int
|
||||
TIMERECEIVED_FIELD_NUMBER: builtins.int
|
||||
key: builtins.bytes
|
||||
value: builtins.bytes
|
||||
timeReceived: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
key: builtins.bytes = ...,
|
||||
value: builtins.bytes = ...,
|
||||
timeReceived: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ...
|
||||
|
||||
global___Record = Record
|
||||
|
||||
@typing.final
|
||||
class Message(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _MessageType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
PUT_VALUE: Message._MessageType.ValueType # 0
|
||||
GET_VALUE: Message._MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message._MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message._MessageType.ValueType # 3
|
||||
FIND_NODE: Message._MessageType.ValueType # 4
|
||||
PING: Message._MessageType.ValueType # 5
|
||||
|
||||
class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ...
|
||||
PUT_VALUE: Message.MessageType.ValueType # 0
|
||||
GET_VALUE: Message.MessageType.ValueType # 1
|
||||
ADD_PROVIDER: Message.MessageType.ValueType # 2
|
||||
GET_PROVIDERS: Message.MessageType.ValueType # 3
|
||||
FIND_NODE: Message.MessageType.ValueType # 4
|
||||
PING: Message.MessageType.ValueType # 5
|
||||
|
||||
class _ConnectionType:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
NOT_CONNECTED: Message._ConnectionType.ValueType # 0
|
||||
CONNECTED: Message._ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message._ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message._ConnectionType.ValueType # 3
|
||||
|
||||
class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ...
|
||||
NOT_CONNECTED: Message.ConnectionType.ValueType # 0
|
||||
CONNECTED: Message.ConnectionType.ValueType # 1
|
||||
CAN_CONNECT: Message.ConnectionType.ValueType # 2
|
||||
CANNOT_CONNECT: Message.ConnectionType.ValueType # 3
|
||||
|
||||
@typing.final
|
||||
class Peer(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
ID_FIELD_NUMBER: builtins.int
|
||||
ADDRS_FIELD_NUMBER: builtins.int
|
||||
CONNECTION_FIELD_NUMBER: builtins.int
|
||||
id: builtins.bytes
|
||||
connection: global___Message.ConnectionType.ValueType
|
||||
@property
|
||||
def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id: builtins.bytes = ...,
|
||||
addrs: collections.abc.Iterable[builtins.bytes] | None = ...,
|
||||
connection: global___Message.ConnectionType.ValueType = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ...
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int
|
||||
KEY_FIELD_NUMBER: builtins.int
|
||||
RECORD_FIELD_NUMBER: builtins.int
|
||||
CLOSERPEERS_FIELD_NUMBER: builtins.int
|
||||
PROVIDERPEERS_FIELD_NUMBER: builtins.int
|
||||
type: global___Message.MessageType.ValueType
|
||||
clusterLevelRaw: builtins.int
|
||||
key: builtins.bytes
|
||||
@property
|
||||
def record(self) -> global___Record: ...
|
||||
@property
|
||||
def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
@property
|
||||
def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___Message.MessageType.ValueType = ...,
|
||||
clusterLevelRaw: builtins.int = ...,
|
||||
key: builtins.bytes = ...,
|
||||
record: global___Record | None = ...,
|
||||
closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ...
|
||||
|
||||
global___Message = Message
|
||||
418
libp2p/kad_dht/peer_routing.py
Normal file
418
libp2p/kad_dht/peer_routing.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""
|
||||
Peer routing implementation for Kademlia DHT.
|
||||
|
||||
This module implements the peer routing interface using Kademlia's algorithm
|
||||
to efficiently locate peers in a distributed network.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import trio
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
INetStream,
|
||||
IPeerRouting,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
from .routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from .utils import (
|
||||
sort_peer_ids_by_distance,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.peer_routing")
|
||||
logger = logging.getLogger("kademlia-example.peer_routing")
|
||||
|
||||
# Constants for the Kademlia algorithm
|
||||
ALPHA = 3 # Concurrency parameter
|
||||
MAX_PEER_LOOKUP_ROUNDS = 20 # Maximum number of rounds in peer lookup
|
||||
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
|
||||
|
||||
|
||||
class PeerRouting(IPeerRouting):
|
||||
"""
|
||||
Implementation of peer routing using the Kademlia algorithm.
|
||||
|
||||
This class provides methods to find peers in the DHT network
|
||||
and helps maintain the routing table.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, routing_table: RoutingTable):
|
||||
"""
|
||||
Initialize the peer routing service.
|
||||
|
||||
:param host: The libp2p host
|
||||
:param routing_table: The Kademlia routing table
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.routing_table = routing_table
|
||||
self.protocol_id = PROTOCOL_ID
|
||||
|
||||
async def find_peer(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""
|
||||
Find a peer with the given ID.
|
||||
|
||||
:param peer_id: The ID of the peer to find
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[PeerInfo]
|
||||
The peer information if found, None otherwise
|
||||
|
||||
"""
|
||||
# Check if this is actually our peer ID
|
||||
if peer_id == self.host.get_id():
|
||||
try:
|
||||
# Return our own peer info
|
||||
return PeerInfo(peer_id, self.host.get_addrs())
|
||||
except Exception:
|
||||
logger.exception("Error getting our own peer info")
|
||||
return None
|
||||
|
||||
# First check if the peer is in our routing table
|
||||
peer_info = self.routing_table.get_peer_info(peer_id)
|
||||
if peer_info:
|
||||
logger.debug(f"Found peer {peer_id} in routing table")
|
||||
return peer_info
|
||||
|
||||
# Then check if the peer is in our peerstore
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
logger.debug(f"Found peer {peer_id} in peerstore")
|
||||
return PeerInfo(peer_id, addrs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If not found locally, search the network
|
||||
try:
|
||||
closest_peers = await self.find_closest_peers_network(peer_id.to_bytes())
|
||||
logger.info(f"Closest peers found: {closest_peers}")
|
||||
|
||||
# Check if we found the peer we're looking for
|
||||
for found_peer in closest_peers:
|
||||
if found_peer == peer_id:
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(found_peer)
|
||||
if addrs:
|
||||
return PeerInfo(found_peer, addrs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching for peer {peer_id}: {e}")
|
||||
|
||||
# Not found
|
||||
logger.info(f"Peer {peer_id} not found")
|
||||
return None
|
||||
|
||||
async def _query_single_peer_for_closest(
|
||||
self, peer: ID, target_key: bytes, new_peers: list[ID]
|
||||
) -> None:
|
||||
"""
|
||||
Query a single peer for closest peers and append results to the shared list.
|
||||
|
||||
params: peer : ID
|
||||
The peer to query
|
||||
params: target_key : bytes
|
||||
The target key to find closest peers for
|
||||
params: new_peers : list[ID]
|
||||
Shared list to append results to
|
||||
|
||||
"""
|
||||
try:
|
||||
result = await self._query_peer_for_closest(peer, target_key)
|
||||
# Add deduplication to prevent duplicate peers
|
||||
for peer_id in result:
|
||||
if peer_id not in new_peers:
|
||||
new_peers.append(peer_id)
|
||||
logger.debug(
|
||||
"Queried peer %s for closest peers, got %d results (%d unique)",
|
||||
peer,
|
||||
len(result),
|
||||
len([p for p in result if p not in new_peers[: -len(result)]]),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Query to peer {peer} failed: {e}")
|
||||
|
||||
async def find_closest_peers_network(
|
||||
self, target_key: bytes, count: int = 20
|
||||
) -> list[ID]:
|
||||
"""
|
||||
Find the closest peers to a target key in the entire network.
|
||||
|
||||
Performs an iterative lookup by querying peers for their closest peers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
Closest peer IDs
|
||||
|
||||
"""
|
||||
# Start with closest peers from our routing table
|
||||
closest_peers = self.routing_table.find_local_closest_peers(target_key, count)
|
||||
logger.debug("Local closest peers: %d found", len(closest_peers))
|
||||
queried_peers: set[ID] = set()
|
||||
rounds = 0
|
||||
|
||||
# Return early if we have no peers to start with
|
||||
if not closest_peers:
|
||||
logger.warning("No local peers available for network lookup")
|
||||
return []
|
||||
|
||||
# Iterative lookup until convergence
|
||||
while rounds < MAX_PEER_LOOKUP_ROUNDS:
|
||||
rounds += 1
|
||||
logger.debug(f"Lookup round {rounds}/{MAX_PEER_LOOKUP_ROUNDS}")
|
||||
|
||||
# Find peers we haven't queried yet
|
||||
peers_to_query = [p for p in closest_peers if p not in queried_peers]
|
||||
if not peers_to_query:
|
||||
logger.debug("No more unqueried peers available, ending lookup")
|
||||
break # No more peers to query
|
||||
|
||||
# Query these peers for their closest peers to target
|
||||
peers_batch = peers_to_query[:ALPHA] # Limit to ALPHA peers at a time
|
||||
|
||||
# Mark these peers as queried before we actually query them
|
||||
for peer in peers_batch:
|
||||
queried_peers.add(peer)
|
||||
|
||||
# Run queries in parallel for this batch using trio nursery
|
||||
new_peers: list[ID] = [] # Shared array to collect all results
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for peer in peers_batch:
|
||||
nursery.start_soon(
|
||||
self._query_single_peer_for_closest, peer, target_key, new_peers
|
||||
)
|
||||
|
||||
# If we got no new peers, we're done
|
||||
if not new_peers:
|
||||
logger.debug("No new peers discovered in this round, ending lookup")
|
||||
break
|
||||
|
||||
# Update our list of closest peers
|
||||
all_candidates = closest_peers + new_peers
|
||||
old_closest_peers = closest_peers[:]
|
||||
closest_peers = sort_peer_ids_by_distance(target_key, all_candidates)[
|
||||
:count
|
||||
]
|
||||
logger.debug(f"Updated closest peers count: {len(closest_peers)}")
|
||||
|
||||
# Check if we made any progress (found closer peers)
|
||||
if closest_peers == old_closest_peers:
|
||||
logger.debug("No improvement in closest peers, ending lookup")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Network lookup completed after {rounds} rounds, "
|
||||
f"found {len(closest_peers)} peers"
|
||||
)
|
||||
return closest_peers
|
||||
|
||||
async def _query_peer_for_closest(self, peer: ID, target_key: bytes) -> list[ID]:
|
||||
"""
|
||||
Query a peer for their closest peers
|
||||
to the target key using varint length prefix
|
||||
"""
|
||||
stream = None
|
||||
results = []
|
||||
try:
|
||||
# Add the peer to our routing table regardless of query outcome
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer, addrs)
|
||||
await self.routing_table.add_peer(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to add peer {peer} to routing table: {e}")
|
||||
|
||||
# Open a stream to the peer using the Kademlia protocol
|
||||
logger.debug(f"Opening stream to {peer} for closest peers query")
|
||||
try:
|
||||
stream = await self.host.new_stream(peer, [self.protocol_id])
|
||||
logger.debug(f"Stream opened to {peer}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to open stream to {peer}: {e}")
|
||||
return []
|
||||
|
||||
# Create and send FIND_NODE request using protobuf
|
||||
find_node_msg = Message()
|
||||
find_node_msg.type = Message.MessageType.FIND_NODE
|
||||
find_node_msg.key = target_key # Set target key directly as bytes
|
||||
|
||||
# Serialize and send the protobuf message with varint length prefix
|
||||
proto_bytes = find_node_msg.SerializeToString()
|
||||
logger.debug(
|
||||
f"Sending FIND_NODE: {proto_bytes.hex()} (len={len(proto_bytes)})"
|
||||
)
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
|
||||
# Read varint-prefixed response length
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
logger.warning(
|
||||
"Error reading varint length from stream: connection closed"
|
||||
)
|
||||
return []
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
logger.debug(f"Connection closed by peer {peer} while reading data")
|
||||
return []
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse the protobuf response
|
||||
response_msg = Message()
|
||||
response_msg.ParseFromString(response_bytes)
|
||||
logger.debug(
|
||||
"Received response from %s with %d peers",
|
||||
peer,
|
||||
len(response_msg.closerPeers),
|
||||
)
|
||||
|
||||
# Process closest peers from response
|
||||
if response_msg.type == Message.MessageType.FIND_NODE:
|
||||
for peer_data in response_msg.closerPeers:
|
||||
new_peer_id = ID(peer_data.id)
|
||||
if new_peer_id not in results:
|
||||
results.append(new_peer_id)
|
||||
if peer_data.addrs:
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
|
||||
addrs = [Multiaddr(addr) for addr in peer_data.addrs]
|
||||
self.host.get_peerstore().add_addrs(new_peer_id, addrs, 3600)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error querying peer {peer} for closest: {e}")
|
||||
|
||||
finally:
|
||||
if stream:
|
||||
await stream.close()
|
||||
return results
|
||||
|
||||
async def _handle_kad_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle incoming Kademlia protocol streams.
|
||||
|
||||
params: stream: The incoming stream
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
try:
|
||||
# Read message length
|
||||
length_bytes = await stream.read(4)
|
||||
if not length_bytes:
|
||||
return
|
||||
|
||||
message_length = int.from_bytes(length_bytes, byteorder="big")
|
||||
|
||||
# Read message
|
||||
message_bytes = await stream.read(message_length)
|
||||
if not message_bytes:
|
||||
return
|
||||
|
||||
# Parse protobuf message
|
||||
kad_message = Message()
|
||||
try:
|
||||
kad_message.ParseFromString(message_bytes)
|
||||
|
||||
if kad_message.type == Message.MessageType.FIND_NODE:
|
||||
# Get target key directly from protobuf message
|
||||
target_key = kad_message.key
|
||||
|
||||
# Find closest peers to target
|
||||
closest_peers = self.routing_table.find_local_closest_peers(
|
||||
target_key, 20
|
||||
)
|
||||
|
||||
# Create protobuf response
|
||||
response = Message()
|
||||
response.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Add peer information to response
|
||||
for peer_id in closest_peers:
|
||||
peer_proto = response.closerPeers.add()
|
||||
peer_proto.id = peer_id.to_bytes()
|
||||
peer_proto.connection = Message.ConnectionType.CAN_CONNECT
|
||||
|
||||
# Add addresses if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
for addr in addrs:
|
||||
peer_proto.addrs.append(addr.to_bytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Send response
|
||||
response_bytes = response.SerializeToString()
|
||||
await stream.write(len(response_bytes).to_bytes(4, byteorder="big"))
|
||||
await stream.write(response_bytes)
|
||||
|
||||
except Exception as parse_err:
|
||||
logger.error(f"Failed to parse protocol buffer message: {parse_err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error handling Kademlia stream: {e}")
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
async def refresh_routing_table(self) -> None:
|
||||
"""
|
||||
Refresh the routing table by performing lookups for random keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
logger.info("Refreshing routing table")
|
||||
|
||||
# Perform a lookup for ourselves to populate the routing table
|
||||
local_id = self.host.get_id()
|
||||
closest_peers = await self.find_closest_peers_network(local_id.to_bytes())
|
||||
|
||||
# Add discovered peers to routing table
|
||||
for peer_id in closest_peers:
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
await self.routing_table.add_peer(peer_info)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to add discovered peer {peer_id}: {e}")
|
||||
575
libp2p/kad_dht/provider_store.py
Normal file
575
libp2p/kad_dht/provider_store.py
Normal file
@ -0,0 +1,575 @@
|
||||
"""
|
||||
Provider record storage for Kademlia DHT.
|
||||
|
||||
This module implements the storage for content provider records in the Kademlia DHT.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.provider_store")
|
||||
logger = logging.getLogger("kademlia-example.provider_store")
|
||||
|
||||
# Constants for provider records (based on IPFS standards)
|
||||
PROVIDER_RECORD_REPUBLISH_INTERVAL = 22 * 60 * 60 # 22 hours in seconds
|
||||
PROVIDER_RECORD_EXPIRATION_INTERVAL = 48 * 60 * 60 # 48 hours in seconds
|
||||
PROVIDER_ADDRESS_TTL = 30 * 60 # 30 minutes in seconds
|
||||
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
|
||||
ALPHA = 3 # Number of parallel queries/advertisements
|
||||
QUERY_TIMEOUT = 10 # Timeout for each query in seconds
|
||||
|
||||
|
||||
class ProviderRecord:
|
||||
"""
|
||||
A record for a content provider in the DHT.
|
||||
|
||||
Contains the peer information and timestamp.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_info: PeerInfo,
|
||||
timestamp: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a new provider record.
|
||||
|
||||
:param provider_info: The provider's peer information
|
||||
:param timestamp: Time this record was created/updated
|
||||
(defaults to current time)
|
||||
|
||||
"""
|
||||
self.provider_info = provider_info
|
||||
self.timestamp = timestamp or time.time()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
Check if this provider record has expired.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the record has expired
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
return (current_time - self.timestamp) >= PROVIDER_RECORD_EXPIRATION_INTERVAL
|
||||
|
||||
def should_republish(self) -> bool:
|
||||
"""
|
||||
Check if this provider record should be republished.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the record should be republished
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
return (current_time - self.timestamp) >= PROVIDER_RECORD_REPUBLISH_INTERVAL
|
||||
|
||||
@property
|
||||
def peer_id(self) -> ID:
|
||||
"""Get the provider's peer ID."""
|
||||
return self.provider_info.peer_id
|
||||
|
||||
@property
|
||||
def addresses(self) -> list[Multiaddr]:
|
||||
"""Get the provider's addresses."""
|
||||
return self.provider_info.addrs
|
||||
|
||||
|
||||
class ProviderStore:
|
||||
"""
|
||||
Store for content provider records in the Kademlia DHT.
|
||||
|
||||
Maps content keys to provider records, with support for expiration.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, peer_routing: Any = None) -> None:
|
||||
"""
|
||||
Initialize a new provider store.
|
||||
|
||||
:param host: The libp2p host instance (optional)
|
||||
:param peer_routing: The peer routing instance (optional)
|
||||
"""
|
||||
# Maps content keys to a dict of provider records (peer_id -> record)
|
||||
self.providers: dict[bytes, dict[str, ProviderRecord]] = {}
|
||||
self.host = host
|
||||
self.peer_routing = peer_routing
|
||||
self.providing_keys: set[bytes] = set()
|
||||
self.local_peer_id = host.get_id()
|
||||
|
||||
async def _republish_provider_records(self) -> None:
|
||||
"""Republish all provider records for content this node is providing."""
|
||||
# First, republish keys we're actively providing
|
||||
for key in self.providing_keys:
|
||||
logger.debug(f"Republishing provider record for key {key.hex()}")
|
||||
await self.provide(key)
|
||||
|
||||
# Also check for any records that should be republished
|
||||
time.time()
|
||||
for key, providers in self.providers.items():
|
||||
for peer_id_str, record in providers.items():
|
||||
# Only republish records for our own peer
|
||||
if self.local_peer_id and str(self.local_peer_id) == peer_id_str:
|
||||
if record.should_republish():
|
||||
logger.debug(
|
||||
f"Republishing old provider record for key {key.hex()}"
|
||||
)
|
||||
await self.provide(key)
|
||||
|
||||
async def provide(self, key: bytes) -> bool:
|
||||
"""
|
||||
Advertise that this node can provide a piece of content.
|
||||
|
||||
Finds the k closest peers to the key and sends them ADD_PROVIDER messages.
|
||||
|
||||
:param key: The content key (multihash) to advertise
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the advertisement was successful
|
||||
|
||||
"""
|
||||
if not self.host or not self.peer_routing:
|
||||
logger.error("Host or peer_routing not initialized, cannot provide content")
|
||||
return False
|
||||
|
||||
# Add to local provider store
|
||||
local_addrs = []
|
||||
for addr in self.host.get_addrs():
|
||||
local_addrs.append(addr)
|
||||
|
||||
local_peer_info = PeerInfo(self.host.get_id(), local_addrs)
|
||||
self.add_provider(key, local_peer_info)
|
||||
|
||||
# Track that we're providing this key
|
||||
self.providing_keys.add(key)
|
||||
|
||||
# Find the k closest peers to the key
|
||||
closest_peers = await self.peer_routing.find_closest_peers_network(key)
|
||||
logger.debug(
|
||||
"Found %d peers close to key %s for provider advertisement",
|
||||
len(closest_peers),
|
||||
key.hex(),
|
||||
)
|
||||
|
||||
# Send ADD_PROVIDER messages to these ALPHA peers in parallel.
|
||||
success_count = 0
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
results: list[bool] = [False] * len(batch)
|
||||
|
||||
async def send_one(
|
||||
idx: int, peer_id: ID, results: list[bool] = results
|
||||
) -> None:
|
||||
if peer_id == self.local_peer_id:
|
||||
return
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
success = await self._send_add_provider(peer_id, key)
|
||||
results[idx] = success
|
||||
if not success:
|
||||
logger.warning(f"Failed to send ADD_PROVIDER to {peer_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for idx, peer_id in enumerate(batch):
|
||||
nursery.start_soon(send_one, idx, peer_id, results)
|
||||
success_count += sum(results)
|
||||
|
||||
logger.info(f"Successfully advertised to {success_count} peers")
|
||||
return success_count > 0
|
||||
|
||||
async def _send_add_provider(self, peer_id: ID, key: bytes) -> bool:
|
||||
"""
|
||||
Send ADD_PROVIDER message to a specific peer.
|
||||
|
||||
:param peer_id: The peer to send the message to
|
||||
:param key: The content key being provided
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the message was successfully sent and acknowledged
|
||||
|
||||
"""
|
||||
try:
|
||||
result = False
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
|
||||
|
||||
# Get our addresses to include in the message
|
||||
addrs = []
|
||||
for addr in self.host.get_addrs():
|
||||
addrs.append(addr.to_bytes())
|
||||
|
||||
# Create the ADD_PROVIDER message
|
||||
message = Message()
|
||||
message.type = Message.MessageType.ADD_PROVIDER
|
||||
message.key = key
|
||||
|
||||
# Add our provider info
|
||||
provider = message.providerPeers.add()
|
||||
provider.id = self.local_peer_id.to_bytes()
|
||||
provider.addrs.extend(addrs)
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
logger.debug(f"Sent ADD_PROVIDER to {peer_id} for key {key.hex()}")
|
||||
# Read response length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
logger.debug("Reading response length prefix in add provider")
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
return False
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
return False
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse response
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check response type
|
||||
response.type == Message.MessageType.ADD_PROVIDER
|
||||
if response.type:
|
||||
result = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}")
|
||||
|
||||
finally:
|
||||
await stream.close()
|
||||
return result
|
||||
|
||||
async def find_providers(self, key: bytes, count: int = 20) -> list[PeerInfo]:
|
||||
"""
|
||||
Find content providers for a given key.
|
||||
|
||||
:param key: The content key to look for
|
||||
:param count: Maximum number of providers to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]
|
||||
List of content providers
|
||||
|
||||
"""
|
||||
if not self.host or not self.peer_routing:
|
||||
logger.error("Host or peer_routing not initialized, cannot find providers")
|
||||
return []
|
||||
|
||||
# Check local provider store first
|
||||
local_providers = self.get_providers(key)
|
||||
if local_providers:
|
||||
logger.debug(
|
||||
f"Found {len(local_providers)} providers locally for {key.hex()}"
|
||||
)
|
||||
return local_providers[:count]
|
||||
logger.debug("local providers are %s", local_providers)
|
||||
|
||||
# Find the closest peers to the key
|
||||
closest_peers = await self.peer_routing.find_closest_peers_network(key)
|
||||
logger.debug(
|
||||
f"Searching {len(closest_peers)} peers for providers of {key.hex()}"
|
||||
)
|
||||
|
||||
# Query these peers for providers in batches of ALPHA, in parallel, with timeout
|
||||
all_providers = []
|
||||
for i in range(0, len(closest_peers), ALPHA):
|
||||
batch = closest_peers[i : i + ALPHA]
|
||||
batch_results: list[list[PeerInfo]] = [[] for _ in batch]
|
||||
|
||||
async def get_one(
|
||||
idx: int,
|
||||
peer_id: ID,
|
||||
batch_results: list[list[PeerInfo]] = batch_results,
|
||||
) -> None:
|
||||
if peer_id == self.local_peer_id:
|
||||
return
|
||||
try:
|
||||
with trio.move_on_after(QUERY_TIMEOUT):
|
||||
providers = await self._get_providers_from_peer(peer_id, key)
|
||||
if providers:
|
||||
for provider in providers:
|
||||
self.add_provider(key, provider)
|
||||
batch_results[idx] = providers
|
||||
else:
|
||||
logger.debug(f"No providers found at peer {peer_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get providers from {peer_id}: {e}")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for idx, peer_id in enumerate(batch):
|
||||
nursery.start_soon(get_one, idx, peer_id, batch_results)
|
||||
|
||||
for providers in batch_results:
|
||||
all_providers.extend(providers)
|
||||
if len(all_providers) >= count:
|
||||
return all_providers[:count]
|
||||
|
||||
return all_providers[:count]
|
||||
|
||||
async def _get_providers_from_peer(self, peer_id: ID, key: bytes) -> list[PeerInfo]:
|
||||
"""
|
||||
Get content providers from a specific peer.
|
||||
|
||||
:param peer_id: The peer to query
|
||||
:param key: The content key to look for
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]
|
||||
List of provider information
|
||||
|
||||
"""
|
||||
providers: list[PeerInfo] = []
|
||||
try:
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
|
||||
|
||||
try:
|
||||
# Create the GET_PROVIDERS message
|
||||
message = Message()
|
||||
message.type = Message.MessageType.GET_PROVIDERS
|
||||
message.key = key
|
||||
|
||||
# Serialize and send the message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
|
||||
# Read response length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
return []
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
return []
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse response
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check response type
|
||||
if response.type != Message.MessageType.GET_PROVIDERS:
|
||||
return []
|
||||
|
||||
# Extract provider information
|
||||
providers = []
|
||||
for provider_proto in response.providerPeers:
|
||||
try:
|
||||
# Create peer ID from bytes
|
||||
provider_id = ID(provider_proto.id)
|
||||
|
||||
# Convert addresses to Multiaddr
|
||||
addrs = []
|
||||
for addr_bytes in provider_proto.addrs:
|
||||
try:
|
||||
addrs.append(Multiaddr(addr_bytes))
|
||||
except Exception:
|
||||
pass # Skip invalid addresses
|
||||
|
||||
# Create PeerInfo and add to result
|
||||
providers.append(PeerInfo(provider_id, addrs))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse provider info: {e}")
|
||||
|
||||
finally:
|
||||
await stream.close()
|
||||
return providers
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting providers from {peer_id}: {e}")
|
||||
return []
|
||||
|
||||
def add_provider(self, key: bytes, provider: PeerInfo) -> None:
|
||||
"""
|
||||
Add a provider for a given content key.
|
||||
|
||||
:param key: The content key
|
||||
:param provider: The provider's peer information
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
# Initialize providers for this key if needed
|
||||
if key not in self.providers:
|
||||
self.providers[key] = {}
|
||||
|
||||
# Add or update the provider record
|
||||
peer_id_str = str(provider.peer_id) # Use string representation as dict key
|
||||
self.providers[key][peer_id_str] = ProviderRecord(
|
||||
provider_info=provider, timestamp=time.time()
|
||||
)
|
||||
logger.debug(f"Added provider {provider.peer_id} for key {key.hex()}")
|
||||
|
||||
def get_providers(self, key: bytes) -> list[PeerInfo]:
|
||||
"""
|
||||
Get all providers for a given content key.
|
||||
|
||||
:param key: The content key
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[PeerInfo]
|
||||
List of providers for the key
|
||||
|
||||
"""
|
||||
if key not in self.providers:
|
||||
return []
|
||||
|
||||
# Collect valid provider records (not expired)
|
||||
result = []
|
||||
current_time = time.time()
|
||||
expired_peers = []
|
||||
|
||||
for peer_id_str, record in self.providers[key].items():
|
||||
# Check if the record has expired
|
||||
if current_time - record.timestamp > PROVIDER_RECORD_EXPIRATION_INTERVAL:
|
||||
expired_peers.append(peer_id_str)
|
||||
continue
|
||||
|
||||
# Use addresses only if they haven't expired
|
||||
addresses = []
|
||||
if current_time - record.timestamp <= PROVIDER_ADDRESS_TTL:
|
||||
addresses = record.addresses
|
||||
|
||||
# Create PeerInfo and add to results
|
||||
result.append(PeerInfo(record.peer_id, addresses))
|
||||
|
||||
# Clean up expired records
|
||||
for peer_id in expired_peers:
|
||||
del self.providers[key][peer_id]
|
||||
|
||||
# Remove the key if no providers left
|
||||
if not self.providers[key]:
|
||||
del self.providers[key]
|
||||
|
||||
return result
|
||||
|
||||
def cleanup_expired(self) -> None:
|
||||
"""Remove expired provider records."""
|
||||
current_time = time.time()
|
||||
expired_keys = []
|
||||
|
||||
for key, providers in self.providers.items():
|
||||
expired_providers = []
|
||||
|
||||
for peer_id_str, record in providers.items():
|
||||
if (
|
||||
current_time - record.timestamp
|
||||
> PROVIDER_RECORD_EXPIRATION_INTERVAL
|
||||
):
|
||||
expired_providers.append(peer_id_str)
|
||||
logger.debug(
|
||||
f"Removing expired provider {peer_id_str} for key {key.hex()}"
|
||||
)
|
||||
|
||||
# Remove expired providers
|
||||
for peer_id in expired_providers:
|
||||
del providers[peer_id]
|
||||
|
||||
# Track empty keys for removal
|
||||
if not providers:
|
||||
expired_keys.append(key)
|
||||
|
||||
# Remove empty keys
|
||||
for key in expired_keys:
|
||||
del self.providers[key]
|
||||
logger.debug(f"Removed key with no providers: {key.hex()}")
|
||||
|
||||
def get_provided_keys(self, peer_id: ID) -> list[bytes]:
|
||||
"""
|
||||
Get all content keys provided by a specific peer.
|
||||
|
||||
:param peer_id: The peer ID to look for
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[bytes]
|
||||
List of content keys provided by the peer
|
||||
|
||||
"""
|
||||
peer_id_str = str(peer_id)
|
||||
result = []
|
||||
|
||||
for key, providers in self.providers.items():
|
||||
if peer_id_str in providers:
|
||||
result.append(key)
|
||||
|
||||
return result
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get the total number of provider records in the store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Total number of provider records across all keys
|
||||
|
||||
"""
|
||||
total = 0
|
||||
for providers in self.providers.values():
|
||||
total += len(providers)
|
||||
return total
|
||||
601
libp2p/kad_dht/routing_table.py
Normal file
601
libp2p/kad_dht/routing_table.py
Normal file
@ -0,0 +1,601 @@
|
||||
"""
|
||||
Kademlia DHT routing table implementation.
|
||||
"""
|
||||
|
||||
from collections import (
|
||||
OrderedDict,
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.kad_dht.utils import xor_distance
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.routing_table")
|
||||
logger = logging.getLogger("kademlia-example.routing_table")
|
||||
|
||||
# Default parameters
|
||||
BUCKET_SIZE = 20 # k in the Kademlia paper
|
||||
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
|
||||
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
|
||||
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
|
||||
|
||||
|
||||
class KBucket:
|
||||
"""
|
||||
A k-bucket implementation for the Kademlia DHT.
|
||||
|
||||
Each k-bucket stores up to k (BUCKET_SIZE) peers, sorted by least-recently seen.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
bucket_size: int = BUCKET_SIZE,
|
||||
min_range: int = 0,
|
||||
max_range: int = 2**256,
|
||||
):
|
||||
"""
|
||||
Initialize a new k-bucket.
|
||||
|
||||
:param host: The host this bucket belongs to
|
||||
:param bucket_size: Maximum number of peers to store in the bucket
|
||||
:param min_range: Lower boundary of the bucket's key range (inclusive)
|
||||
:param max_range: Upper boundary of the bucket's key range (exclusive)
|
||||
|
||||
"""
|
||||
self.bucket_size = bucket_size
|
||||
self.host = host
|
||||
self.min_range = min_range
|
||||
self.max_range = max_range
|
||||
# Store PeerInfo objects along with last-seen timestamp
|
||||
self.peers: OrderedDict[ID, tuple[PeerInfo, float]] = OrderedDict()
|
||||
|
||||
def peer_ids(self) -> list[ID]:
|
||||
"""Get all peer IDs in the bucket."""
|
||||
return list(self.peers.keys())
|
||||
|
||||
def peer_infos(self) -> list[PeerInfo]:
|
||||
"""Get all PeerInfo objects in the bucket."""
|
||||
return [info for info, _ in self.peers.values()]
|
||||
|
||||
def get_oldest_peer(self) -> ID | None:
|
||||
"""Get the least-recently seen peer."""
|
||||
if not self.peers:
|
||||
return None
|
||||
return next(iter(self.peers.keys()))
|
||||
|
||||
async def add_peer(self, peer_info: PeerInfo) -> bool:
|
||||
"""
|
||||
Add a peer to the bucket. Returns True if the peer was added or updated,
|
||||
False if the bucket is full.
|
||||
"""
|
||||
current_time = time.time()
|
||||
peer_id = peer_info.peer_id
|
||||
|
||||
# If peer is already in the bucket, move it to the end (most recently seen)
|
||||
if peer_id in self.peers:
|
||||
self.refresh_peer_last_seen(peer_id)
|
||||
return True
|
||||
|
||||
# If bucket has space, add the peer
|
||||
if len(self.peers) < self.bucket_size:
|
||||
self.peers[peer_id] = (peer_info, current_time)
|
||||
return True
|
||||
|
||||
# If bucket is full, we need to replace the least-recently seen peer
|
||||
# Get the least-recently seen peer
|
||||
oldest_peer_id = self.get_oldest_peer()
|
||||
if oldest_peer_id is None:
|
||||
logger.warning("No oldest peer found when bucket is full")
|
||||
return False
|
||||
|
||||
# Check if the old peer is responsive to ping request
|
||||
try:
|
||||
# Try to ping the oldest peer, not the new peer
|
||||
response = await self._ping_peer(oldest_peer_id)
|
||||
if response:
|
||||
# If the old peer is still alive, we will not add the new peer
|
||||
logger.debug(
|
||||
"Old peer %s is still alive, cannot add new peer %s",
|
||||
oldest_peer_id,
|
||||
peer_id,
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
# If the old peer is unresponsive, we can replace it with the new peer
|
||||
logger.debug(
|
||||
"Old peer %s is unresponsive, replacing with new peer %s: %s",
|
||||
oldest_peer_id,
|
||||
peer_id,
|
||||
str(e),
|
||||
)
|
||||
self.peers.popitem(last=False) # Remove oldest peer
|
||||
self.peers[peer_id] = (peer_info, current_time)
|
||||
return True
|
||||
|
||||
# If we got here, the oldest peer responded but we couldn't add the new peer
|
||||
return False
|
||||
|
||||
def remove_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Remove a peer from the bucket.
|
||||
Returns True if the peer was in the bucket, False otherwise.
|
||||
"""
|
||||
if peer_id in self.peers:
|
||||
del self.peers[peer_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def has_peer(self, peer_id: ID) -> bool:
|
||||
"""Check if the peer is in the bucket."""
|
||||
return peer_id in self.peers
|
||||
|
||||
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""Get the PeerInfo for a given peer ID if it exists in the bucket."""
|
||||
if peer_id in self.peers:
|
||||
return self.peers[peer_id][0]
|
||||
return None
|
||||
|
||||
def size(self) -> int:
|
||||
"""Get the number of peers in the bucket."""
|
||||
return len(self.peers)
|
||||
|
||||
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
|
||||
"""
|
||||
Get peers that haven't been pinged recently.
|
||||
|
||||
params: stale_threshold_seconds: Time in seconds
|
||||
params: after which a peer is considered stale
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
List of peer IDs that need to be refreshed
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
stale_peers = []
|
||||
|
||||
for peer_id, (_, last_seen) in self.peers.items():
|
||||
if current_time - last_seen > stale_threshold_seconds:
|
||||
stale_peers.append(peer_id)
|
||||
|
||||
return stale_peers
|
||||
|
||||
async def _periodic_peer_refresh(self) -> None:
|
||||
"""Background task to periodically refresh peers"""
|
||||
try:
|
||||
while True:
|
||||
await trio.sleep(PEER_REFRESH_INTERVAL) # Check every minute
|
||||
|
||||
# Find stale peers (not pinged in last hour)
|
||||
stale_peers = self.get_stale_peers(
|
||||
stale_threshold_seconds=STALE_PEER_THRESHOLD
|
||||
)
|
||||
if stale_peers:
|
||||
logger.debug(f"Found {len(stale_peers)} stale peers to refresh")
|
||||
|
||||
for peer_id in stale_peers:
|
||||
try:
|
||||
# Try to ping the peer
|
||||
logger.debug("Pinging stale peer %s", peer_id)
|
||||
responce = await self._ping_peer(peer_id)
|
||||
if responce:
|
||||
# Update the last seen time
|
||||
self.refresh_peer_last_seen(peer_id)
|
||||
logger.debug(f"Refreshed peer {peer_id}")
|
||||
else:
|
||||
# If ping fails, remove the peer
|
||||
logger.debug(f"Failed to ping peer {peer_id}")
|
||||
self.remove_peer(peer_id)
|
||||
logger.info(f"Removed unresponsive peer {peer_id}")
|
||||
|
||||
logger.debug(f"Successfully refreshed peer {peer_id}")
|
||||
except Exception as e:
|
||||
# If ping fails, remove the peer
|
||||
logger.debug(
|
||||
"Failed to ping peer %s: %s",
|
||||
peer_id,
|
||||
e,
|
||||
)
|
||||
self.remove_peer(peer_id)
|
||||
logger.info(f"Removed unresponsive peer {peer_id}")
|
||||
except trio.Cancelled:
|
||||
logger.debug("Peer refresh task cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in peer refresh task: {e}", exc_info=True)
|
||||
|
||||
async def _ping_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Ping a peer using protobuf message to check
|
||||
if it's still alive and update last seen time.
|
||||
|
||||
params: peer_id: The ID of the peer to ping
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if ping successful, False otherwise
|
||||
|
||||
"""
|
||||
result = False
|
||||
# Get peer info directly from the bucket
|
||||
peer_info = self.get_peer_info(peer_id)
|
||||
if not peer_info:
|
||||
raise ValueError(f"Peer {peer_id} not in bucket")
|
||||
|
||||
# Default protocol ID for Kademlia DHT
|
||||
protocol_id = TProtocol("/ipfs/kad/1.0.0")
|
||||
|
||||
try:
|
||||
# Open a stream to the peer with the DHT protocol
|
||||
stream = await self.host.new_stream(peer_id, [protocol_id])
|
||||
|
||||
try:
|
||||
# Create ping protobuf message
|
||||
ping_msg = Message()
|
||||
ping_msg.type = Message.PING # Use correct enum
|
||||
|
||||
# Serialize and send with length prefix (4 bytes big-endian)
|
||||
msg_bytes = ping_msg.SerializeToString()
|
||||
logger.debug(
|
||||
f"Sending PING message to {peer_id}, size: {len(msg_bytes)} bytes"
|
||||
)
|
||||
await stream.write(len(msg_bytes).to_bytes(4, byteorder="big"))
|
||||
await stream.write(msg_bytes)
|
||||
|
||||
# Wait for response with timeout
|
||||
with trio.move_on_after(2): # 2 second timeout
|
||||
# Read response length (4 bytes)
|
||||
length_bytes = await stream.read(4)
|
||||
if not length_bytes or len(length_bytes) < 4:
|
||||
logger.warning(f"Peer {peer_id} disconnected during ping")
|
||||
return False
|
||||
|
||||
msg_len = int.from_bytes(length_bytes, byteorder="big")
|
||||
if (
|
||||
msg_len <= 0 or msg_len > 1024 * 1024
|
||||
): # Sanity check on message size
|
||||
logger.warning(
|
||||
f"Invalid message length from {peer_id}: {msg_len}"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.debug(
|
||||
f"Receiving response from {peer_id}, size: {msg_len} bytes"
|
||||
)
|
||||
|
||||
# Read full message
|
||||
response_bytes = await stream.read(msg_len)
|
||||
if not response_bytes:
|
||||
logger.warning(f"Failed to read response from {peer_id}")
|
||||
return False
|
||||
|
||||
# Parse protobuf response
|
||||
response = Message()
|
||||
try:
|
||||
response.ParseFromString(response_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse protobuf response from {peer_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
if response.type == Message.PING:
|
||||
# Update the last seen timestamp for this peer
|
||||
logger.debug(f"Successfully pinged peer {peer_id}")
|
||||
result = True
|
||||
return result
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected response type from {peer_id}: {response.type}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If we get here, the ping timed out
|
||||
logger.warning(f"Ping to peer {peer_id} timed out")
|
||||
return False
|
||||
|
||||
finally:
|
||||
await stream.close()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error pinging peer {peer_id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def refresh_peer_last_seen(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Update the last-seen timestamp for a peer in the bucket.
|
||||
|
||||
params: peer_id: The ID of the peer to refresh
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the peer was found and refreshed, False otherwise
|
||||
|
||||
"""
|
||||
if peer_id in self.peers:
|
||||
# Get current peer info and update the timestamp
|
||||
peer_info, _ = self.peers[peer_id]
|
||||
current_time = time.time()
|
||||
self.peers[peer_id] = (peer_info, current_time)
|
||||
# Move to end of ordered dict to mark as most recently seen
|
||||
self.peers.move_to_end(peer_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def key_in_range(self, key: bytes) -> bool:
|
||||
"""
|
||||
Check if a key is in the range of this bucket.
|
||||
|
||||
params: key: The key to check (bytes)
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the key is in range, False otherwise
|
||||
|
||||
"""
|
||||
key_int = int.from_bytes(key, byteorder="big")
|
||||
return self.min_range <= key_int < self.max_range
|
||||
|
||||
def split(self) -> tuple["KBucket", "KBucket"]:
|
||||
"""
|
||||
Split the bucket into two buckets.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
(lower_bucket, upper_bucket)
|
||||
|
||||
"""
|
||||
midpoint = (self.min_range + self.max_range) // 2
|
||||
lower_bucket = KBucket(self.host, self.bucket_size, self.min_range, midpoint)
|
||||
upper_bucket = KBucket(self.host, self.bucket_size, midpoint, self.max_range)
|
||||
|
||||
# Redistribute peers
|
||||
for peer_id, (peer_info, timestamp) in self.peers.items():
|
||||
peer_key = int.from_bytes(peer_id.to_bytes(), byteorder="big")
|
||||
if peer_key < midpoint:
|
||||
lower_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
else:
|
||||
upper_bucket.peers[peer_id] = (peer_info, timestamp)
|
||||
|
||||
return lower_bucket, upper_bucket
|
||||
|
||||
|
||||
class RoutingTable:
|
||||
"""
|
||||
The Kademlia routing table maintains information on which peers to contact for any
|
||||
given peer ID in the network.
|
||||
"""
|
||||
|
||||
def __init__(self, local_id: ID, host: IHost) -> None:
|
||||
"""
|
||||
Initialize the routing table.
|
||||
|
||||
:param local_id: The ID of the local node.
|
||||
:param host: The host this routing table belongs to.
|
||||
|
||||
"""
|
||||
self.local_id = local_id
|
||||
self.host = host
|
||||
self.buckets = [KBucket(host, BUCKET_SIZE)]
|
||||
|
||||
async def add_peer(self, peer_obj: PeerInfo | ID) -> bool:
|
||||
"""
|
||||
Add a peer to the routing table.
|
||||
|
||||
:param peer_obj: Either PeerInfo object or peer ID to add
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if the peer was added or updated, False otherwise
|
||||
|
||||
"""
|
||||
peer_id = None
|
||||
peer_info = None
|
||||
|
||||
try:
|
||||
# Handle different types of input
|
||||
if isinstance(peer_obj, PeerInfo):
|
||||
# Already have PeerInfo object
|
||||
peer_info = peer_obj
|
||||
peer_id = peer_obj.peer_id
|
||||
else:
|
||||
# Assume it's a peer ID
|
||||
peer_id = peer_obj
|
||||
# Try to get addresses from the peerstore if available
|
||||
try:
|
||||
addrs = self.host.get_peerstore().addrs(peer_id)
|
||||
if addrs:
|
||||
# Create PeerInfo object
|
||||
peer_info = PeerInfo(peer_id, addrs)
|
||||
else:
|
||||
logger.debug(
|
||||
"No addresses found for peer %s in peerstore, skipping",
|
||||
peer_id,
|
||||
)
|
||||
return False
|
||||
except Exception as peerstore_error:
|
||||
# Handle case where peer is not in peerstore yet
|
||||
logger.debug(
|
||||
"Peer %s not found in peerstore: %s, skipping",
|
||||
peer_id,
|
||||
str(peerstore_error),
|
||||
)
|
||||
return False
|
||||
|
||||
# Don't add ourselves
|
||||
if peer_id == self.local_id:
|
||||
return False
|
||||
|
||||
# Find the right bucket for this peer
|
||||
bucket = self.find_bucket(peer_id)
|
||||
|
||||
# Try to add to the bucket
|
||||
success = await bucket.add_peer(peer_info)
|
||||
if success:
|
||||
logger.debug(f"Successfully added peer {peer_id} to routing table")
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding peer {peer_obj} to routing table: {e}")
|
||||
return False
|
||||
|
||||
def remove_peer(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Remove a peer from the routing table.
|
||||
|
||||
:param peer_id: The ID of the peer to remove
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if the peer was removed, False otherwise
|
||||
|
||||
"""
|
||||
bucket = self.find_bucket(peer_id)
|
||||
return bucket.remove_peer(peer_id)
|
||||
|
||||
def find_bucket(self, peer_id: ID) -> KBucket:
|
||||
"""
|
||||
Find the bucket that would contain the given peer ID or PeerInfo.
|
||||
|
||||
:param peer_obj: Either a peer ID or a PeerInfo object
|
||||
|
||||
Returns
|
||||
-------
|
||||
KBucket: The bucket for this peer
|
||||
|
||||
"""
|
||||
for bucket in self.buckets:
|
||||
if bucket.key_in_range(peer_id.to_bytes()):
|
||||
return bucket
|
||||
|
||||
return self.buckets[0]
|
||||
|
||||
def find_local_closest_peers(self, key: bytes, count: int = 20) -> list[ID]:
|
||||
"""
|
||||
Find the closest peers to a given key.
|
||||
|
||||
:param key: The key to find closest peers to (bytes)
|
||||
:param count: Maximum number of peers to return
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ID]: List of peer IDs closest to the key
|
||||
|
||||
"""
|
||||
# Get all peers from all buckets
|
||||
all_peers = []
|
||||
for bucket in self.buckets:
|
||||
all_peers.extend(bucket.peer_ids())
|
||||
|
||||
# Sort by XOR distance to the key
|
||||
all_peers.sort(key=lambda p: xor_distance(p.to_bytes(), key))
|
||||
|
||||
return all_peers[:count]
|
||||
|
||||
def get_peer_ids(self) -> list[ID]:
|
||||
"""
|
||||
Get all peer IDs in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:param List[ID]: List of all peer IDs
|
||||
|
||||
"""
|
||||
peers = []
|
||||
for bucket in self.buckets:
|
||||
peers.extend(bucket.peer_ids())
|
||||
return peers
|
||||
|
||||
def get_peer_info(self, peer_id: ID) -> PeerInfo | None:
|
||||
"""
|
||||
Get the peer info for a specific peer.
|
||||
|
||||
:param peer_id: The ID of the peer to get info for
|
||||
|
||||
Returns
|
||||
-------
|
||||
PeerInfo: The peer info, or None if not found
|
||||
|
||||
"""
|
||||
bucket = self.find_bucket(peer_id)
|
||||
return bucket.get_peer_info(peer_id)
|
||||
|
||||
def peer_in_table(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer is in the routing table.
|
||||
|
||||
:param peer_id: The ID of the peer to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if the peer is in the routing table, False otherwise
|
||||
|
||||
"""
|
||||
bucket = self.find_bucket(peer_id)
|
||||
return bucket.has_peer(peer_id)
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get the number of peers in the routing table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int: Number of peers
|
||||
|
||||
"""
|
||||
count = 0
|
||||
for bucket in self.buckets:
|
||||
count += bucket.size()
|
||||
return count
|
||||
|
||||
def get_stale_peers(self, stale_threshold_seconds: int = 3600) -> list[ID]:
|
||||
"""
|
||||
Get all stale peers from all buckets
|
||||
|
||||
params: stale_threshold_seconds:
|
||||
Time in seconds after which a peer is considered stale
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
List of stale peer IDs
|
||||
|
||||
"""
|
||||
stale_peers = []
|
||||
for bucket in self.buckets:
|
||||
stale_peers.extend(bucket.get_stale_peers(stale_threshold_seconds))
|
||||
return stale_peers
|
||||
|
||||
def cleanup_routing_table(self) -> None:
|
||||
"""
|
||||
Cleanup the routing table by removing all data.
|
||||
This is useful for resetting the routing table during tests or reinitialization.
|
||||
"""
|
||||
self.buckets = [KBucket(self.host, BUCKET_SIZE)]
|
||||
logger.info("Routing table cleaned up, all data removed.")
|
||||
117
libp2p/kad_dht/utils.py
Normal file
117
libp2p/kad_dht/utils.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""
|
||||
Utility functions for Kademlia DHT implementation.
|
||||
"""
|
||||
|
||||
import base58
|
||||
import multihash
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
|
||||
def create_key_from_binary(binary_data: bytes) -> bytes:
|
||||
"""
|
||||
Creates a key for the DHT by hashing binary data with SHA-256.
|
||||
|
||||
params: binary_data: The binary data to hash.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bytes: The resulting key.
|
||||
|
||||
"""
|
||||
return multihash.digest(binary_data, "sha2-256").digest
|
||||
|
||||
|
||||
def xor_distance(key1: bytes, key2: bytes) -> int:
|
||||
"""
|
||||
Calculate the XOR distance between two keys.
|
||||
|
||||
params: key1: First key (bytes)
|
||||
params: key2: Second key (bytes)
|
||||
|
||||
Returns
|
||||
-------
|
||||
int: The XOR distance between the keys
|
||||
|
||||
"""
|
||||
# Ensure the inputs are bytes
|
||||
if not isinstance(key1, bytes) or not isinstance(key2, bytes):
|
||||
raise TypeError("Both key1 and key2 must be bytes objects")
|
||||
|
||||
# Convert to integers
|
||||
k1 = int.from_bytes(key1, byteorder="big")
|
||||
k2 = int.from_bytes(key2, byteorder="big")
|
||||
|
||||
# Calculate XOR distance
|
||||
return k1 ^ k2
|
||||
|
||||
|
||||
def bytes_to_base58(data: bytes) -> str:
|
||||
"""
|
||||
Convert bytes to base58 encoded string.
|
||||
|
||||
params: data: Input bytes
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: Base58 encoded string
|
||||
|
||||
"""
|
||||
return base58.b58encode(data).decode("utf-8")
|
||||
|
||||
|
||||
def sort_peer_ids_by_distance(target_key: bytes, peer_ids: list[ID]) -> list[ID]:
|
||||
"""
|
||||
Sort a list of peer IDs by their distance to the target key.
|
||||
|
||||
params: target_key: The target key to measure distance from
|
||||
params: peer_ids: List of peer IDs to sort
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ID]: Sorted list of peer IDs from closest to furthest
|
||||
|
||||
"""
|
||||
|
||||
def get_distance(peer_id: ID) -> int:
|
||||
# Hash the peer ID bytes to get a key for distance calculation
|
||||
peer_hash = multihash.digest(peer_id.to_bytes(), "sha2-256").digest
|
||||
return xor_distance(target_key, peer_hash)
|
||||
|
||||
return sorted(peer_ids, key=get_distance)
|
||||
|
||||
|
||||
def shared_prefix_len(first: bytes, second: bytes) -> int:
|
||||
"""
|
||||
Calculate the number of prefix bits shared by two byte sequences.
|
||||
|
||||
params: first: First byte sequence
|
||||
params: second: Second byte sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
int: Number of shared prefix bits
|
||||
|
||||
"""
|
||||
# Compare each byte to find the first bit difference
|
||||
common_length = 0
|
||||
for i in range(min(len(first), len(second))):
|
||||
byte_first = first[i]
|
||||
byte_second = second[i]
|
||||
|
||||
if byte_first == byte_second:
|
||||
common_length += 8
|
||||
else:
|
||||
# Find specific bit where they differ
|
||||
xor = byte_first ^ byte_second
|
||||
# Count leading zeros in the xor result
|
||||
for j in range(7, -1, -1):
|
||||
if (xor >> j) & 1 == 1:
|
||||
return common_length + (7 - j)
|
||||
|
||||
# This shouldn't be reached if xor != 0
|
||||
return common_length + 8
|
||||
|
||||
return common_length
|
||||
393
libp2p/kad_dht/value_store.py
Normal file
393
libp2p/kad_dht/value_store.py
Normal file
@ -0,0 +1,393 @@
|
||||
"""
|
||||
Value store implementation for Kademlia DHT.
|
||||
|
||||
Provides a way to store and retrieve key-value pairs with optional expiration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import varint
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
from .pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.kademlia.value_store")
|
||||
logger = logging.getLogger("kademlia-example.value_store")
|
||||
|
||||
# Default time to live for values in seconds (24 hours)
|
||||
DEFAULT_TTL = 24 * 60 * 60
|
||||
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
|
||||
|
||||
|
||||
class ValueStore:
|
||||
"""
|
||||
Store for key-value pairs in a Kademlia DHT.
|
||||
|
||||
Values are stored with a timestamp and optional expiration time.
|
||||
"""
|
||||
|
||||
def __init__(self, host: IHost, local_peer_id: ID):
|
||||
"""
|
||||
Initialize an empty value store.
|
||||
|
||||
:param host: The libp2p host instance.
|
||||
:param local_peer_id: The local peer ID to ignore in peer requests.
|
||||
|
||||
"""
|
||||
# Store format: {key: (value, validity)}
|
||||
self.store: dict[bytes, tuple[bytes, float]] = {}
|
||||
# Store references to the host and local peer ID for making requests
|
||||
self.host = host
|
||||
self.local_peer_id = local_peer_id
|
||||
|
||||
def put(self, key: bytes, value: bytes, validity: float = 0.0) -> None:
|
||||
"""
|
||||
Store a value in the DHT.
|
||||
|
||||
:param key: The key to store the value under
|
||||
:param value: The value to store
|
||||
:param validity: validity in seconds before the value expires.
|
||||
Defaults to `DEFAULT_TTL` if set to 0.0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
"""
|
||||
if validity == 0.0:
|
||||
validity = time.time() + DEFAULT_TTL
|
||||
logger.debug(
|
||||
"Storing value for key %s... with validity %s", key.hex(), validity
|
||||
)
|
||||
self.store[key] = (value, validity)
|
||||
logger.debug(f"Stored value for key {key.hex()}")
|
||||
|
||||
async def _store_at_peer(self, peer_id: ID, key: bytes, value: bytes) -> bool:
|
||||
"""
|
||||
Store a value at a specific peer.
|
||||
|
||||
params: peer_id: The ID of the peer to store the value at
|
||||
params: key: The key to store
|
||||
params: value: The value to store
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the value was successfully stored, False otherwise
|
||||
|
||||
"""
|
||||
result = False
|
||||
stream = None
|
||||
try:
|
||||
# Don't try to store at ourselves
|
||||
if self.local_peer_id and peer_id == self.local_peer_id:
|
||||
result = True
|
||||
return result
|
||||
|
||||
if not self.host:
|
||||
logger.error("Host not initialized, cannot store value at peer")
|
||||
return False
|
||||
|
||||
logger.debug(f"Storing value for key {key.hex()} at peer {peer_id}")
|
||||
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
logger.debug(f"Opened stream to peer {peer_id}")
|
||||
|
||||
# Create the PUT_VALUE message with protobuf
|
||||
message = Message()
|
||||
message.type = Message.MessageType.PUT_VALUE
|
||||
|
||||
# Set message fields
|
||||
message.key = key
|
||||
message.record.key = key
|
||||
message.record.value = value
|
||||
message.record.timeReceived = str(time.time())
|
||||
|
||||
# Serialize and send the protobuf message with length prefix
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
logger.debug("Sent PUT_VALUE protobuf message with varint length")
|
||||
# Read varint-prefixed response length
|
||||
|
||||
length_bytes = b""
|
||||
while True:
|
||||
logger.debug("Reading varint length prefix for response...")
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
logger.warning("Connection closed while reading varint length")
|
||||
return False
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
logger.debug(f"Received varint length bytes: {length_bytes.hex()}")
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
logger.debug("Response length: %d bytes", response_length)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
logger.debug(
|
||||
f"Connection closed by peer {peer_id} while reading data"
|
||||
)
|
||||
return False
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse protobuf response
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check if response is valid
|
||||
if response.type == Message.MessageType.PUT_VALUE:
|
||||
if response.key:
|
||||
result = True
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store value at peer {peer_id}: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
if stream:
|
||||
await stream.close()
|
||||
return result
|
||||
|
||||
def get(self, key: bytes) -> bytes | None:
|
||||
"""
|
||||
Retrieve a value from the DHT.
|
||||
|
||||
params: key: The key to look up
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[bytes]
|
||||
The stored value, or None if not found or expired
|
||||
|
||||
"""
|
||||
logger.debug("Retrieving value for key %s...", key.hex()[:8])
|
||||
if key not in self.store:
|
||||
return None
|
||||
|
||||
value, validity = self.store[key]
|
||||
logger.debug(
|
||||
"Found value for key %s... with validity %s",
|
||||
key.hex(),
|
||||
validity,
|
||||
)
|
||||
# Check if the value has expired
|
||||
if validity is not None and validity < time.time():
|
||||
logger.debug(
|
||||
"Value for key %s... has expired, removing it",
|
||||
key.hex()[:8],
|
||||
)
|
||||
self.remove(key)
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
async def _get_from_peer(self, peer_id: ID, key: bytes) -> bytes | None:
|
||||
"""
|
||||
Retrieve a value from a specific peer.
|
||||
|
||||
params: peer_id: The ID of the peer to retrieve the value from
|
||||
params: key: The key to retrieve
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[bytes]
|
||||
The value if found, None otherwise
|
||||
|
||||
"""
|
||||
stream = None
|
||||
try:
|
||||
# Don't try to get from ourselves
|
||||
if peer_id == self.local_peer_id:
|
||||
return None
|
||||
|
||||
logger.debug(f"Getting value for key {key.hex()} from peer {peer_id}")
|
||||
|
||||
# Open a stream to the peer
|
||||
stream = await self.host.new_stream(peer_id, [TProtocol(PROTOCOL_ID)])
|
||||
logger.debug(f"Opened stream to peer {peer_id} for GET_VALUE")
|
||||
|
||||
# Create the GET_VALUE message using protobuf
|
||||
message = Message()
|
||||
message.type = Message.MessageType.GET_VALUE
|
||||
message.key = key
|
||||
|
||||
# Serialize and send the protobuf message
|
||||
proto_bytes = message.SerializeToString()
|
||||
await stream.write(varint.encode(len(proto_bytes)))
|
||||
await stream.write(proto_bytes)
|
||||
|
||||
# Read response length
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
logger.warning("Connection closed while reading length")
|
||||
return None
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
response_length = varint.decode_bytes(length_bytes)
|
||||
# Read response data
|
||||
response_bytes = b""
|
||||
remaining = response_length
|
||||
while remaining > 0:
|
||||
chunk = await stream.read(remaining)
|
||||
if not chunk:
|
||||
logger.debug(
|
||||
f"Connection closed by peer {peer_id} while reading data"
|
||||
)
|
||||
return None
|
||||
response_bytes += chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# Parse protobuf response
|
||||
try:
|
||||
response = Message()
|
||||
response.ParseFromString(response_bytes)
|
||||
logger.debug(
|
||||
f"Received protobuf response from peer"
|
||||
f" {peer_id}, type: {response.type}"
|
||||
)
|
||||
|
||||
# Process protobuf response
|
||||
if (
|
||||
response.type == Message.MessageType.GET_VALUE
|
||||
and response.HasField("record")
|
||||
and response.record.value
|
||||
):
|
||||
logger.debug(
|
||||
f"Received value for key {key.hex()} from peer {peer_id}"
|
||||
)
|
||||
return response.record.value
|
||||
|
||||
# Handle case where value is not found but peer infos are returned
|
||||
else:
|
||||
logger.debug(
|
||||
f"Value not found for key {key.hex()} from peer {peer_id},"
|
||||
f" received {len(response.closerPeers)} closer peers"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as proto_err:
|
||||
logger.warning(f"Failed to parse as protobuf: {proto_err}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get value from peer {peer_id}: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
if stream:
|
||||
await stream.close()
|
||||
|
||||
def remove(self, key: bytes) -> bool:
|
||||
"""
|
||||
Remove a value from the DHT.
|
||||
|
||||
|
||||
params: key: The key to remove
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the key was found and removed, False otherwise
|
||||
|
||||
"""
|
||||
if key in self.store:
|
||||
del self.store[key]
|
||||
logger.debug(f"Removed value for key {key.hex()[:8]}...")
|
||||
return True
|
||||
return False
|
||||
|
||||
def has(self, key: bytes) -> bool:
|
||||
"""
|
||||
Check if a key exists in the store and hasn't expired.
|
||||
|
||||
params: key: The key to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the key exists and hasn't expired, False otherwise
|
||||
|
||||
"""
|
||||
if key not in self.store:
|
||||
return False
|
||||
|
||||
_, validity = self.store[key]
|
||||
if validity is not None and time.time() > validity:
|
||||
self.remove(key)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""
|
||||
Remove all expired values from the store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of expired values that were removed
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, (_, validity) in self.store.items() if current_time > validity
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.store[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"Cleaned up {len(expired_keys)} expired values")
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
def get_keys(self) -> list[bytes]:
|
||||
"""
|
||||
Get all non-expired keys in the store.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[bytes]
|
||||
List of keys
|
||||
|
||||
"""
|
||||
# Clean up expired values first
|
||||
self.cleanup_expired()
|
||||
return list(self.store.keys())
|
||||
|
||||
def size(self) -> int:
|
||||
"""
|
||||
Get the number of items in the store (after removing expired entries).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of items
|
||||
|
||||
"""
|
||||
self.cleanup_expired()
|
||||
return len(self.store)
|
||||
1
newsfragments/579.feature.rst
Normal file
1
newsfragments/579.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Added support for ``Kademlia DHT`` in py-libp2p.
|
||||
168
tests/core/kad_dht/test_kad_dht.py
Normal file
168
tests/core/kad_dht/test_kad_dht.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""
|
||||
Tests for the Kademlia DHT implementation.
|
||||
|
||||
This module tests core functionality of the Kademlia DHT including:
|
||||
- Node discovery (find_node)
|
||||
- Value storage and retrieval (put_value, get_value)
|
||||
- Content provider advertisement and discovery (provide, find_providers)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.kad_dht.kad_dht import (
|
||||
DHTMode,
|
||||
KadDHT,
|
||||
)
|
||||
from libp2p.kad_dht.utils import (
|
||||
create_key_from_binary,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger("test.kad_dht")
|
||||
|
||||
# Constants for the tests
|
||||
TEST_TIMEOUT = 5 # Timeout in seconds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dht_pair(security_protocol):
|
||||
"""Create a pair of connected DHT nodes for testing."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Get peer info for bootstrapping
|
||||
peer_b_info = PeerInfo(host_b.get_id(), host_b.get_addrs())
|
||||
peer_a_info = PeerInfo(host_a.get_id(), host_a.get_addrs())
|
||||
|
||||
# Create DHT nodes from the hosts with bootstrap peers as multiaddr strings
|
||||
dht_a: KadDHT = KadDHT(host_a, mode=DHTMode.SERVER)
|
||||
dht_b: KadDHT = KadDHT(host_b, mode=DHTMode.SERVER)
|
||||
await dht_a.peer_routing.routing_table.add_peer(peer_b_info)
|
||||
await dht_b.peer_routing.routing_table.add_peer(peer_a_info)
|
||||
|
||||
# Start both DHT services
|
||||
async with background_trio_service(dht_a), background_trio_service(dht_b):
|
||||
# Allow time for bootstrap to complete and connections to establish
|
||||
await trio.sleep(0.1)
|
||||
|
||||
logger.debug(
|
||||
"After bootstrap: Node A peers: %s", dht_a.routing_table.get_peer_ids()
|
||||
)
|
||||
logger.debug(
|
||||
"After bootstrap: Node B peers: %s", dht_b.routing_table.get_peer_ids()
|
||||
)
|
||||
|
||||
# Return the DHT pair
|
||||
yield (dht_a, dht_b)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
"""Test that nodes can find each other in the DHT."""
|
||||
dht_a, dht_b = dht_pair
|
||||
|
||||
# Node A should be able to find Node B
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
found_info = await dht_a.find_peer(dht_b.host.get_id())
|
||||
|
||||
# Verify that the found peer has the correct peer ID
|
||||
assert found_info is not None, "Failed to find the target peer"
|
||||
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
"""Test storing and retrieving values in the DHT."""
|
||||
dht_a, dht_b = dht_pair
|
||||
# dht_a.peer_routing.routing_table.add_peer(dht_b.pe)
|
||||
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
|
||||
# Generate a random key and value
|
||||
key = create_key_from_binary(b"test-key")
|
||||
value = b"test-value"
|
||||
|
||||
# First add the value directly to node A's store to verify storage works
|
||||
dht_a.value_store.put(key, value)
|
||||
logger.debug("Local value store: %s", dht_a.value_store.store)
|
||||
local_value = dht_a.value_store.get(key)
|
||||
assert local_value == value, "Local value storage failed"
|
||||
print("number of nodes in peer store", dht_a.host.get_peerstore().peer_ids())
|
||||
await dht_a.routing_table.add_peer(peer_b_info)
|
||||
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
|
||||
|
||||
# Store the value using the first node (this will also store locally)
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
await dht_a.put_value(key, value)
|
||||
|
||||
# # Log debugging information
|
||||
logger.debug("Put value with key %s...", key.hex()[:10])
|
||||
logger.debug("Node A value store: %s", dht_a.value_store.store)
|
||||
print("hello test")
|
||||
|
||||
# # Allow more time for the value to propagate
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# # Try direct connection between nodes to ensure they're properly linked
|
||||
logger.debug("Node A peers: %s", dht_a.routing_table.get_peer_ids())
|
||||
logger.debug("Node B peers: %s", dht_b.routing_table.get_peer_ids())
|
||||
|
||||
# Retrieve the value using the second node
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
retrieved_value = await dht_b.get_value(key)
|
||||
print("the value stored in node b is", dht_b.get_value_store_size())
|
||||
logger.debug("Retrieved value: %s", retrieved_value)
|
||||
|
||||
# Verify that the retrieved value matches the original
|
||||
assert retrieved_value == value, "Retrieved value does not match the stored value"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
"""Test advertising and finding content providers."""
|
||||
dht_a, dht_b = dht_pair
|
||||
|
||||
# Generate a random content ID
|
||||
content = f"test-content-{uuid.uuid4()}".encode()
|
||||
content_id = hashlib.sha256(content).digest()
|
||||
|
||||
# Store content on the first node
|
||||
dht_a.value_store.put(content_id, content)
|
||||
|
||||
# Advertise the first node as a provider
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
success = await dht_a.provide(content_id)
|
||||
assert success, "Failed to advertise as provider"
|
||||
|
||||
# Allow time for the provider record to propagate
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Find providers using the second node
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
providers = await dht_b.find_providers(content_id)
|
||||
|
||||
# Verify that we found the first node as a provider
|
||||
assert providers, "No providers found"
|
||||
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
|
||||
"Expected provider not found"
|
||||
)
|
||||
|
||||
# Retrieve the content using the provider information
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
retrieved_value = await dht_b.get_value(content_id)
|
||||
assert retrieved_value == content, (
|
||||
"Retrieved content does not match the original"
|
||||
)
|
||||
459
tests/core/kad_dht/test_unit_peer_routing.py
Normal file
459
tests/core/kad_dht/test_unit_peer_routing.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""
|
||||
Unit tests for the PeerRouting class in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of peer routing including:
|
||||
- Peer discovery and lookup
|
||||
- Network queries for closest peers
|
||||
- Protocol message handling
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
Mock,
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import varint
|
||||
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.kad_dht.pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
from libp2p.kad_dht.peer_routing import (
|
||||
ALPHA,
|
||||
MAX_PEER_LOOKUP_ROUNDS,
|
||||
PROTOCOL_ID,
|
||||
PeerRouting,
|
||||
)
|
||||
from libp2p.kad_dht.routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
|
||||
def create_valid_peer_id(name: str) -> ID:
|
||||
"""Create a valid peer ID for testing."""
|
||||
key_pair = create_new_key_pair()
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
class TestPeerRouting:
|
||||
"""Test suite for PeerRouting class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host(self):
|
||||
"""Create a mock host for testing."""
|
||||
host = Mock()
|
||||
host.get_id.return_value = create_valid_peer_id("local")
|
||||
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
host.connect = AsyncMock()
|
||||
return host
|
||||
|
||||
@pytest.fixture
|
||||
def mock_routing_table(self, mock_host):
|
||||
"""Create a mock routing table for testing."""
|
||||
local_id = create_valid_peer_id("local")
|
||||
routing_table = RoutingTable(local_id, mock_host)
|
||||
return routing_table
|
||||
|
||||
@pytest.fixture
|
||||
def peer_routing(self, mock_host, mock_routing_table):
|
||||
"""Create a PeerRouting instance for testing."""
|
||||
return PeerRouting(mock_host, mock_routing_table)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_peer_info(self):
|
||||
"""Create sample peer info for testing."""
|
||||
peer_id = create_valid_peer_id("sample")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
|
||||
return PeerInfo(peer_id, addresses)
|
||||
|
||||
def test_init_peer_routing(self, mock_host, mock_routing_table):
|
||||
"""Test PeerRouting initialization."""
|
||||
peer_routing = PeerRouting(mock_host, mock_routing_table)
|
||||
|
||||
assert peer_routing.host == mock_host
|
||||
assert peer_routing.routing_table == mock_routing_table
|
||||
assert peer_routing.protocol_id == PROTOCOL_ID
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_local_host(self, peer_routing, mock_host):
|
||||
"""Test finding our own peer."""
|
||||
local_id = mock_host.get_id()
|
||||
|
||||
result = await peer_routing.find_peer(local_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.peer_id == local_id
|
||||
assert result.addrs == mock_host.get_addrs()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_in_routing_table(self, peer_routing, sample_peer_info):
|
||||
"""Test finding peer that exists in routing table."""
|
||||
# Add peer to routing table
|
||||
await peer_routing.routing_table.add_peer(sample_peer_info)
|
||||
|
||||
result = await peer_routing.find_peer(sample_peer_info.peer_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.peer_id == sample_peer_info.peer_id
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_in_peerstore(self, peer_routing, mock_host):
|
||||
"""Test finding peer that exists in peerstore."""
|
||||
peer_id = create_valid_peer_id("peerstore")
|
||||
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8002")]
|
||||
|
||||
# Mock peerstore to return addresses
|
||||
mock_host.get_peerstore().addrs.return_value = mock_addrs
|
||||
|
||||
result = await peer_routing.find_peer(peer_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.peer_id == peer_id
|
||||
assert result.addrs == mock_addrs
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_not_found(self, peer_routing, mock_host):
|
||||
"""Test finding peer that doesn't exist anywhere."""
|
||||
peer_id = create_valid_peer_id("nonexistent")
|
||||
|
||||
# Mock peerstore to return no addresses
|
||||
mock_host.get_peerstore().addrs.return_value = []
|
||||
|
||||
# Mock network search to return empty results
|
||||
with patch.object(peer_routing, "find_closest_peers_network", return_value=[]):
|
||||
result = await peer_routing.find_peer(peer_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_closest_peers_network_empty_start(self, peer_routing):
|
||||
"""Test network search with no local peers."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Mock routing table to return empty list
|
||||
with patch.object(
|
||||
peer_routing.routing_table, "find_local_closest_peers", return_value=[]
|
||||
):
|
||||
result = await peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_closest_peers_network_with_peers(self, peer_routing, mock_host):
|
||||
"""Test network search with some initial peers."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create some test peers
|
||||
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(3)]
|
||||
|
||||
# Mock routing table to return initial peers
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
# Mock _query_peer_for_closest to return empty results (no new peers found)
|
||||
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
|
||||
result = await peer_routing.find_closest_peers_network(
|
||||
target_key, count=5
|
||||
)
|
||||
|
||||
assert len(result) <= 5
|
||||
# Should return the initial peers since no new ones were discovered
|
||||
assert all(peer in initial_peers for peer in result)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_closest_peers_convergence(self, peer_routing):
|
||||
"""Test that network search converges properly."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create test peers
|
||||
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(2)]
|
||||
|
||||
# Mock to simulate convergence (no improvement in closest peers)
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
|
||||
with patch(
|
||||
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
result = await peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
assert result == initial_peers
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_peer_for_closest_success(
|
||||
self, peer_routing, mock_host, sample_peer_info
|
||||
):
|
||||
"""Test successful peer query for closest peers."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create mock stream
|
||||
mock_stream = AsyncMock()
|
||||
mock_host.new_stream.return_value = mock_stream
|
||||
|
||||
# Create mock response
|
||||
response_msg = Message()
|
||||
response_msg.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Add a peer to the response
|
||||
peer_proto = response_msg.closerPeers.add()
|
||||
response_peer_id = create_valid_peer_id("response_peer")
|
||||
peer_proto.id = response_peer_id.to_bytes()
|
||||
peer_proto.addrs.append(Multiaddr("/ip4/127.0.0.1/tcp/8003").to_bytes())
|
||||
|
||||
response_bytes = response_msg.SerializeToString()
|
||||
|
||||
# Mock stream reading
|
||||
varint_length = varint.encode(len(response_bytes))
|
||||
mock_stream.read.side_effect = [varint_length, response_bytes]
|
||||
|
||||
# Mock peerstore
|
||||
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
|
||||
mock_host.get_peerstore().add_addrs = Mock()
|
||||
|
||||
result = await peer_routing._query_peer_for_closest(
|
||||
sample_peer_info.peer_id, target_key
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == response_peer_id
|
||||
mock_stream.write.assert_called()
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_peer_for_closest_stream_failure(self, peer_routing, mock_host):
|
||||
"""Test peer query when stream creation fails."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
|
||||
# Mock stream creation failure
|
||||
mock_host.new_stream.side_effect = Exception("Stream failed")
|
||||
mock_host.get_peerstore().addrs.return_value = []
|
||||
|
||||
result = await peer_routing._query_peer_for_closest(peer_id, target_key)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_peer_for_closest_read_failure(
|
||||
self, peer_routing, mock_host, sample_peer_info
|
||||
):
|
||||
"""Test peer query when reading response fails."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create mock stream that fails to read
|
||||
mock_stream = AsyncMock()
|
||||
mock_stream.read.side_effect = [b""] # Empty read simulates connection close
|
||||
mock_host.new_stream.return_value = mock_stream
|
||||
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
|
||||
|
||||
result = await peer_routing._query_peer_for_closest(
|
||||
sample_peer_info.peer_id, target_key
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_refresh_routing_table(self, peer_routing, mock_host):
|
||||
"""Test routing table refresh."""
|
||||
local_id = mock_host.get_id()
|
||||
discovered_peers = [create_valid_peer_id(f"discovered{i}") for i in range(3)]
|
||||
|
||||
# Mock find_closest_peers_network to return discovered peers
|
||||
with patch.object(
|
||||
peer_routing, "find_closest_peers_network", return_value=discovered_peers
|
||||
):
|
||||
# Mock peerstore to return addresses for discovered peers
|
||||
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8003")]
|
||||
mock_host.get_peerstore().addrs.return_value = mock_addrs
|
||||
|
||||
await peer_routing.refresh_routing_table()
|
||||
|
||||
# Should perform lookup for local ID
|
||||
peer_routing.find_closest_peers_network.assert_called_once_with(
|
||||
local_id.to_bytes()
|
||||
)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_kad_stream_find_node(self, peer_routing, mock_host):
|
||||
"""Test handling incoming FIND_NODE requests."""
|
||||
# Create mock stream
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
# Create FIND_NODE request
|
||||
request_msg = Message()
|
||||
request_msg.type = Message.MessageType.FIND_NODE
|
||||
request_msg.key = b"target_key"
|
||||
|
||||
request_bytes = request_msg.SerializeToString()
|
||||
|
||||
# Mock stream reading
|
||||
mock_stream.read.side_effect = [
|
||||
len(request_bytes).to_bytes(4, byteorder="big"),
|
||||
request_bytes,
|
||||
]
|
||||
|
||||
# Mock routing table to return some peers
|
||||
closest_peers = [create_valid_peer_id(f"close{i}") for i in range(2)]
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=closest_peers,
|
||||
):
|
||||
mock_host.get_peerstore().addrs.return_value = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/8004")
|
||||
]
|
||||
|
||||
await peer_routing._handle_kad_stream(mock_stream)
|
||||
|
||||
# Should write response
|
||||
mock_stream.write.assert_called()
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_kad_stream_invalid_message(self, peer_routing):
|
||||
"""Test handling stream with invalid message."""
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
# Mock stream to return invalid data
|
||||
mock_stream.read.side_effect = [
|
||||
(10).to_bytes(4, byteorder="big"),
|
||||
b"invalid_proto_data",
|
||||
]
|
||||
|
||||
# Should handle gracefully without raising exception
|
||||
await peer_routing._handle_kad_stream(mock_stream)
|
||||
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_kad_stream_connection_closed(self, peer_routing):
|
||||
"""Test handling stream when connection is closed early."""
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
# Mock stream to return empty data (connection closed)
|
||||
mock_stream.read.return_value = b""
|
||||
|
||||
await peer_routing._handle_kad_stream(mock_stream)
|
||||
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_single_peer_for_closest_success(self, peer_routing):
|
||||
"""Test _query_single_peer_for_closest method."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
new_peers = []
|
||||
|
||||
# Mock successful query
|
||||
mock_result = [create_valid_peer_id("result1"), create_valid_peer_id("result2")]
|
||||
with patch.object(
|
||||
peer_routing, "_query_peer_for_closest", return_value=mock_result
|
||||
):
|
||||
await peer_routing._query_single_peer_for_closest(
|
||||
peer_id, target_key, new_peers
|
||||
)
|
||||
|
||||
assert len(new_peers) == 2
|
||||
assert all(peer in new_peers for peer in mock_result)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_single_peer_for_closest_failure(self, peer_routing):
|
||||
"""Test _query_single_peer_for_closest when query fails."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
new_peers = []
|
||||
|
||||
# Mock query failure
|
||||
with patch.object(
|
||||
peer_routing,
|
||||
"_query_peer_for_closest",
|
||||
side_effect=Exception("Query failed"),
|
||||
):
|
||||
await peer_routing._query_single_peer_for_closest(
|
||||
peer_id, target_key, new_peers
|
||||
)
|
||||
|
||||
# Should handle exception gracefully
|
||||
assert len(new_peers) == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_single_peer_deduplication(self, peer_routing):
|
||||
"""Test that _query_single_peer_for_closest deduplicates peers."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
duplicate_peer = create_valid_peer_id("duplicate")
|
||||
new_peers = [duplicate_peer] # Pre-existing peer
|
||||
|
||||
# Mock query to return the same peer
|
||||
mock_result = [duplicate_peer, create_valid_peer_id("new")]
|
||||
with patch.object(
|
||||
peer_routing, "_query_peer_for_closest", return_value=mock_result
|
||||
):
|
||||
await peer_routing._query_single_peer_for_closest(
|
||||
peer_id, target_key, new_peers
|
||||
)
|
||||
|
||||
# Should not add duplicate
|
||||
assert len(new_peers) == 2 # Original + 1 new peer
|
||||
assert new_peers.count(duplicate_peer) == 1
|
||||
|
||||
def test_constants(self):
|
||||
"""Test that important constants are properly defined."""
|
||||
assert ALPHA == 3
|
||||
assert MAX_PEER_LOOKUP_ROUNDS == 20
|
||||
assert PROTOCOL_ID == "/ipfs/kad/1.0.0"
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_edge_case_max_rounds_reached(self, peer_routing):
|
||||
"""Test that lookup stops after maximum rounds."""
|
||||
target_key = b"target_key"
|
||||
initial_peers = [create_valid_peer_id("peer1")]
|
||||
|
||||
# Mock to always return new peers to force max rounds
|
||||
def mock_query_side_effect(peer, key):
|
||||
return [create_valid_peer_id(f"new_peer_{time.time()}")]
|
||||
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
with patch.object(
|
||||
peer_routing,
|
||||
"_query_peer_for_closest",
|
||||
side_effect=mock_query_side_effect,
|
||||
):
|
||||
with patch(
|
||||
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance"
|
||||
) as mock_sort:
|
||||
# Always return different peers to prevent convergence
|
||||
mock_sort.side_effect = lambda key, peers: peers[:20]
|
||||
|
||||
result = await peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
# Should stop after max rounds, not infinite loop
|
||||
assert isinstance(result, list)
|
||||
805
tests/core/kad_dht/test_unit_provider_store.py
Normal file
805
tests/core/kad_dht/test_unit_provider_store.py
Normal file
@ -0,0 +1,805 @@
|
||||
"""
|
||||
Unit tests for the ProviderStore and ProviderRecord classes in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of provider record management including:
|
||||
- ProviderRecord creation, expiration, and republish logic
|
||||
- ProviderStore operations (add, get, cleanup)
|
||||
- Expiration and TTL handling
|
||||
- Network operations (mocked)
|
||||
- Edge cases and error conditions
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
Mock,
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
|
||||
from libp2p.kad_dht.provider_store import (
|
||||
PROVIDER_ADDRESS_TTL,
|
||||
PROVIDER_RECORD_EXPIRATION_INTERVAL,
|
||||
PROVIDER_RECORD_REPUBLISH_INTERVAL,
|
||||
ProviderRecord,
|
||||
ProviderStore,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
mock_host = Mock()
|
||||
|
||||
|
||||
class TestProviderRecord:
|
||||
"""Test suite for ProviderRecord class."""
|
||||
|
||||
def test_init_with_default_timestamp(self):
|
||||
"""Test ProviderRecord initialization with default timestamp."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
peer_info = PeerInfo(peer_id, addresses)
|
||||
|
||||
start_time = time.time()
|
||||
record = ProviderRecord(peer_info)
|
||||
end_time = time.time()
|
||||
|
||||
assert record.provider_info == peer_info
|
||||
assert start_time <= record.timestamp <= end_time
|
||||
assert record.peer_id == peer_id
|
||||
assert record.addresses == addresses
|
||||
|
||||
def test_init_with_custom_timestamp(self):
|
||||
"""Test ProviderRecord initialization with custom timestamp."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
custom_timestamp = time.time() - 3600 # 1 hour ago
|
||||
|
||||
record = ProviderRecord(peer_info, timestamp=custom_timestamp)
|
||||
|
||||
assert record.timestamp == custom_timestamp
|
||||
|
||||
def test_is_expired_fresh_record(self):
|
||||
"""Test that fresh records are not expired."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert not record.is_expired()
|
||||
|
||||
def test_is_expired_old_record(self):
|
||||
"""Test that old records are expired."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
old_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
record = ProviderRecord(peer_info, timestamp=old_timestamp)
|
||||
|
||||
assert record.is_expired()
|
||||
|
||||
def test_is_expired_boundary_condition(self):
|
||||
"""Test expiration at exact boundary."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
boundary_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL
|
||||
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
|
||||
|
||||
# At the exact boundary, should be expired (implementation uses >)
|
||||
assert record.is_expired()
|
||||
|
||||
def test_should_republish_fresh_record(self):
|
||||
"""Test that fresh records don't need republishing."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert not record.should_republish()
|
||||
|
||||
def test_should_republish_old_record(self):
|
||||
"""Test that old records need republishing."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
old_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1
|
||||
record = ProviderRecord(peer_info, timestamp=old_timestamp)
|
||||
|
||||
assert record.should_republish()
|
||||
|
||||
def test_should_republish_boundary_condition(self):
|
||||
"""Test republish at exact boundary."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
boundary_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL
|
||||
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
|
||||
|
||||
# At the exact boundary, should need republishing (implementation uses >)
|
||||
assert record.should_republish()
|
||||
|
||||
def test_properties(self):
|
||||
"""Test peer_id and addresses properties."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
|
||||
Multiaddr("/ip6/::1/tcp/8001"),
|
||||
]
|
||||
peer_info = PeerInfo(peer_id, addresses)
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert record.peer_id == peer_id
|
||||
assert record.addresses == addresses
|
||||
|
||||
def test_empty_addresses(self):
|
||||
"""Test ProviderRecord with empty address list."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert record.addresses == []
|
||||
|
||||
|
||||
class TestProviderStore:
|
||||
"""Test suite for ProviderStore class."""
|
||||
|
||||
def test_init_empty_store(self):
|
||||
"""Test that a new ProviderStore is initialized empty."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
assert len(store.providers) == 0
|
||||
assert store.peer_routing is None
|
||||
assert len(store.providing_keys) == 0
|
||||
|
||||
def test_init_with_host(self):
|
||||
"""Test initialization with host."""
|
||||
mock_host = Mock()
|
||||
mock_peer_id = ID.from_base58("QmTest123")
|
||||
mock_host.get_id.return_value = mock_peer_id
|
||||
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
assert store.host == mock_host
|
||||
assert store.local_peer_id == mock_peer_id
|
||||
assert len(store.providers) == 0
|
||||
|
||||
def test_init_with_host_and_peer_routing(self):
|
||||
"""Test initialization with both host and peer routing."""
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = Mock()
|
||||
mock_peer_id = ID.from_base58("QmTest123")
|
||||
mock_host.get_id.return_value = mock_peer_id
|
||||
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
assert store.host == mock_host
|
||||
assert store.peer_routing == mock_peer_routing
|
||||
assert store.local_peer_id == mock_peer_id
|
||||
|
||||
def test_add_provider_new_key(self):
|
||||
"""Test adding a provider for a new key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
provider = PeerInfo(peer_id, addresses)
|
||||
|
||||
store.add_provider(key, provider)
|
||||
|
||||
assert key in store.providers
|
||||
assert str(peer_id) in store.providers[key]
|
||||
|
||||
record = store.providers[key][str(peer_id)]
|
||||
assert record.provider_info == provider
|
||||
assert isinstance(record.timestamp, float)
|
||||
|
||||
def test_add_provider_existing_key(self):
|
||||
"""Test adding multiple providers for the same key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add first provider
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider1)
|
||||
|
||||
# Add second provider
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
store.add_provider(key, provider2)
|
||||
|
||||
assert len(store.providers[key]) == 2
|
||||
assert str(peer_id1) in store.providers[key]
|
||||
assert str(peer_id2) in store.providers[key]
|
||||
|
||||
def test_add_provider_update_existing(self):
|
||||
"""Test updating an existing provider."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
# Add initial provider
|
||||
provider1 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
|
||||
store.add_provider(key, provider1)
|
||||
first_timestamp = store.providers[key][str(peer_id)].timestamp
|
||||
|
||||
# Small delay to ensure timestamp difference
|
||||
time.sleep(0.001)
|
||||
|
||||
# Update provider
|
||||
provider2 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
|
||||
store.add_provider(key, provider2)
|
||||
|
||||
# Should have same peer but updated info
|
||||
assert len(store.providers[key]) == 1
|
||||
assert str(peer_id) in store.providers[key]
|
||||
|
||||
record = store.providers[key][str(peer_id)]
|
||||
assert record.provider_info == provider2
|
||||
assert record.timestamp > first_timestamp
|
||||
|
||||
def test_get_providers_empty_key(self):
|
||||
"""Test getting providers for non-existent key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert providers == []
|
||||
|
||||
def test_get_providers_valid_records(self):
|
||||
"""Test getting providers with valid records."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add multiple providers
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider1 = PeerInfo(peer_id1, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
|
||||
provider2 = PeerInfo(peer_id2, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
|
||||
|
||||
store.add_provider(key, provider1)
|
||||
store.add_provider(key, provider2)
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 2
|
||||
provider_ids = {p.peer_id for p in providers}
|
||||
assert peer_id1 in provider_ids
|
||||
assert peer_id2 in provider_ids
|
||||
|
||||
def test_get_providers_expired_records(self):
|
||||
"""Test that expired records are filtered out and cleaned up."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add valid provider
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider1)
|
||||
|
||||
# Add expired provider manually
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key][str(peer_id2)] = ProviderRecord(
|
||||
provider2, expired_timestamp
|
||||
)
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
# Should only return valid provider
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id1
|
||||
|
||||
# Expired provider should be cleaned up
|
||||
assert str(peer_id2) not in store.providers[key]
|
||||
|
||||
def test_get_providers_address_ttl(self):
|
||||
"""Test address TTL handling in get_providers."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
provider = PeerInfo(peer_id, addresses)
|
||||
|
||||
# Add provider with old timestamp (addresses expired but record valid)
|
||||
old_timestamp = time.time() - PROVIDER_ADDRESS_TTL - 1
|
||||
store.providers[key] = {str(peer_id): ProviderRecord(provider, old_timestamp)}
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
# Should return provider but with empty addresses
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
assert providers[0].addrs == []
|
||||
|
||||
def test_get_providers_cleanup_empty_key(self):
|
||||
"""Test that keys with no valid providers are removed."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add only expired providers
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key] = {
|
||||
str(peer_id): ProviderRecord(provider, expired_timestamp)
|
||||
}
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert providers == []
|
||||
assert key not in store.providers # Key should be removed
|
||||
|
||||
def test_cleanup_expired_no_expired_records(self):
|
||||
"""Test cleanup when there are no expired records."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
|
||||
# Add valid providers
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
|
||||
store.add_provider(key1, provider1)
|
||||
store.add_provider(key2, provider2)
|
||||
|
||||
initial_size = store.size()
|
||||
store.cleanup_expired()
|
||||
|
||||
assert store.size() == initial_size
|
||||
assert key1 in store.providers
|
||||
assert key2 in store.providers
|
||||
|
||||
def test_cleanup_expired_with_expired_records(self):
|
||||
"""Test cleanup removes expired records."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add valid provider
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider1)
|
||||
|
||||
# Add expired provider
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key][str(peer_id2)] = ProviderRecord(
|
||||
provider2, expired_timestamp
|
||||
)
|
||||
|
||||
assert store.size() == 2
|
||||
store.cleanup_expired()
|
||||
|
||||
assert store.size() == 1
|
||||
assert str(peer_id1) in store.providers[key]
|
||||
assert str(peer_id2) not in store.providers[key]
|
||||
|
||||
def test_cleanup_expired_remove_empty_keys(self):
|
||||
"""Test that keys with only expired providers are removed."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
|
||||
# Add valid provider to key1
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key1, provider1)
|
||||
|
||||
# Add only expired provider to key2
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key2] = {
|
||||
str(peer_id2): ProviderRecord(provider2, expired_timestamp)
|
||||
}
|
||||
|
||||
store.cleanup_expired()
|
||||
|
||||
assert key1 in store.providers
|
||||
assert key2 not in store.providers
|
||||
|
||||
def test_get_provided_keys_empty_store(self):
|
||||
"""Test get_provided_keys with empty store."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
keys = store.get_provided_keys(peer_id)
|
||||
|
||||
assert keys == []
|
||||
|
||||
def test_get_provided_keys_single_peer(self):
|
||||
"""Test get_provided_keys for a specific peer."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"key3"
|
||||
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
|
||||
# peer_id1 provides key1 and key2
|
||||
store.add_provider(key1, provider1)
|
||||
store.add_provider(key2, provider1)
|
||||
|
||||
# peer_id2 provides key2 and key3
|
||||
store.add_provider(key2, provider2)
|
||||
store.add_provider(key3, provider2)
|
||||
|
||||
keys1 = store.get_provided_keys(peer_id1)
|
||||
keys2 = store.get_provided_keys(peer_id2)
|
||||
|
||||
assert len(keys1) == 2
|
||||
assert key1 in keys1
|
||||
assert key2 in keys1
|
||||
|
||||
assert len(keys2) == 2
|
||||
assert key2 in keys2
|
||||
assert key3 in keys2
|
||||
|
||||
def test_get_provided_keys_nonexistent_peer(self):
|
||||
"""Test get_provided_keys for peer that provides nothing."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
|
||||
# Add provider for peer_id1 only
|
||||
key = b"key"
|
||||
provider = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider)
|
||||
|
||||
# Query for peer_id2 (provides nothing)
|
||||
keys = store.get_provided_keys(peer_id2)
|
||||
|
||||
assert keys == []
|
||||
|
||||
def test_size_empty_store(self):
|
||||
"""Test size() with empty store."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
assert store.size() == 0
|
||||
|
||||
def test_size_with_providers(self):
|
||||
"""Test size() with multiple providers."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Add providers
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
peer_id3 = ID.from_base58("QmTest789")
|
||||
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
provider3 = PeerInfo(peer_id3, [])
|
||||
|
||||
store.add_provider(key1, provider1)
|
||||
store.add_provider(key1, provider2) # 2 providers for key1
|
||||
store.add_provider(key2, provider3) # 1 provider for key2
|
||||
|
||||
assert store.size() == 3
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_no_host(self):
|
||||
"""Test provide() returns False when no host is configured."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_no_peer_routing(self):
|
||||
"""Test provide() returns False when no peer routing is configured."""
|
||||
mock_host = Mock()
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_success(self):
|
||||
"""Test successful provide operation."""
|
||||
# Setup mocks
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = AsyncMock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
mock_host.get_id.return_value = peer_id
|
||||
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
|
||||
# Mock finding closest peers
|
||||
closest_peers = [ID.from_base58("QmPeer1"), ID.from_base58("QmPeer2")]
|
||||
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
|
||||
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
# Mock _send_add_provider to return success
|
||||
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
|
||||
key = b"test_key"
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is True
|
||||
assert key in store.providing_keys
|
||||
assert key in store.providers
|
||||
|
||||
# Should have called _send_add_provider for each peer
|
||||
assert mock_send.call_count == len(closest_peers)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_skip_local_peer(self):
|
||||
"""Test that provide() skips sending to local peer."""
|
||||
# Setup mocks
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = AsyncMock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
mock_host.get_id.return_value = peer_id
|
||||
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
|
||||
# Include local peer in closest peers
|
||||
closest_peers = [peer_id, ID.from_base58("QmPeer1")]
|
||||
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
|
||||
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
|
||||
key = b"test_key"
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is True
|
||||
# Should only call _send_add_provider once (skip local peer)
|
||||
assert mock_send.call_count == 1
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_providers_no_host(self):
|
||||
"""Test find_providers() returns empty list when no host."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
result = await store.find_providers(key)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_providers_local_only(self):
|
||||
"""Test find_providers() returns local providers."""
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = Mock()
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
# Add local providers
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(key, provider)
|
||||
|
||||
result = await store.find_providers(key)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].peer_id == peer_id
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_providers_network_search(self):
|
||||
"""Test find_providers() searches network when no local providers."""
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = AsyncMock()
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
# Mock network search
|
||||
closest_peers = [ID.from_base58("QmPeer1")]
|
||||
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
|
||||
|
||||
# Mock provider response
|
||||
remote_peer_id = ID.from_base58("QmRemote123")
|
||||
remote_providers = [PeerInfo(remote_peer_id, [])]
|
||||
|
||||
with patch.object(
|
||||
store, "_get_providers_from_peer", return_value=remote_providers
|
||||
):
|
||||
key = b"test_key"
|
||||
result = await store.find_providers(key)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].peer_id == remote_peer_id
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_providers_from_peer_no_host(self):
|
||||
"""Test _get_providers_from_peer without host."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
key = b"test_key"
|
||||
|
||||
# Should handle missing host gracefully
|
||||
result = await store._get_providers_from_peer(peer_id, key)
|
||||
assert result == []
|
||||
|
||||
def test_edge_case_empty_key(self):
|
||||
"""Test handling of empty key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
|
||||
store.add_provider(key, provider)
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
|
||||
def test_edge_case_large_key(self):
|
||||
"""Test handling of large key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"x" * 10000 # 10KB key
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
|
||||
store.add_provider(key, provider)
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
|
||||
def test_concurrent_operations(self):
|
||||
"""Test multiple concurrent operations."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Add many providers
|
||||
num_keys = 100
|
||||
num_providers_per_key = 5
|
||||
|
||||
for i in range(num_keys):
|
||||
_key = f"key_{i}".encode()
|
||||
for j in range(num_providers_per_key):
|
||||
# Generate unique valid Base58 peer IDs
|
||||
# Use a different approach that ensures uniqueness
|
||||
unique_id = i * num_providers_per_key + j + 1 # Ensure > 0
|
||||
_peer_id_str = f"QmPeer{unique_id:06d}".replace("0", "A") + "1" * 38
|
||||
peer_id = ID.from_base58(_peer_id_str)
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(_key, provider)
|
||||
|
||||
# Verify total size
|
||||
expected_size = num_keys * num_providers_per_key
|
||||
assert store.size() == expected_size
|
||||
|
||||
# Verify individual keys
|
||||
for i in range(num_keys):
|
||||
_key = f"key_{i}".encode()
|
||||
providers = store.get_providers(_key)
|
||||
assert len(providers) == num_providers_per_key
|
||||
|
||||
def test_memory_efficiency_large_dataset(self):
|
||||
"""Test memory behavior with large datasets."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Add large number of providers
|
||||
num_entries = 1000
|
||||
for i in range(num_entries):
|
||||
_key = f"key_{i:05d}".encode()
|
||||
# Generate valid Base58 peer IDs (replace 0 with valid characters)
|
||||
peer_str = f"QmPeer{i:05d}".replace("0", "1") + "1" * 35
|
||||
peer_id = ID.from_base58(peer_str)
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(_key, provider)
|
||||
|
||||
assert store.size() == num_entries
|
||||
|
||||
# Clean up all entries by making them expired
|
||||
current_time = time.time()
|
||||
for _key, providers in store.providers.items():
|
||||
for _peer_id_str, record in providers.items():
|
||||
record.timestamp = (
|
||||
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
)
|
||||
|
||||
store.cleanup_expired()
|
||||
assert store.size() == 0
|
||||
assert len(store.providers) == 0
|
||||
|
||||
def test_unicode_key_handling(self):
|
||||
"""Test handling of unicode content in keys."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Test various unicode keys
|
||||
unicode_keys = [
|
||||
b"hello",
|
||||
"héllo".encode(),
|
||||
"🔑".encode(),
|
||||
"ключ".encode(), # Russian
|
||||
"键".encode(), # Chinese
|
||||
]
|
||||
|
||||
for i, key in enumerate(unicode_keys):
|
||||
# Generate valid Base58 peer IDs
|
||||
peer_id = ID.from_base58(f"QmPeer{i + 1}" + "1" * 42) # Valid base58
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(key, provider)
|
||||
|
||||
providers = store.get_providers(key)
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
|
||||
def test_multiple_addresses_per_provider(self):
|
||||
"""Test providers with multiple addresses."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
addresses = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
|
||||
Multiaddr("/ip6/::1/tcp/8001"),
|
||||
Multiaddr("/ip4/192.168.1.100/tcp/8002"),
|
||||
]
|
||||
provider = PeerInfo(peer_id, addresses)
|
||||
|
||||
store.add_provider(key, provider)
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
assert len(providers[0].addrs) == len(addresses)
|
||||
assert all(addr in providers[0].addrs for addr in addresses)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_republish_provider_records_no_keys(self):
|
||||
"""Test _republish_provider_records with no providing keys."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Should complete without error even with no providing keys
|
||||
await store._republish_provider_records()
|
||||
|
||||
assert len(store.providing_keys) == 0
|
||||
|
||||
def test_expiration_boundary_conditions(self):
|
||||
"""Test expiration around boundary conditions."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# Test records at various timestamps
|
||||
timestamps = [
|
||||
current_time, # Fresh
|
||||
current_time - PROVIDER_ADDRESS_TTL + 1, # Addresses valid
|
||||
current_time - PROVIDER_ADDRESS_TTL - 1, # Addresses expired
|
||||
current_time
|
||||
- PROVIDER_RECORD_REPUBLISH_INTERVAL
|
||||
+ 1, # No republish needed
|
||||
current_time - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1, # Republish needed
|
||||
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL + 1, # Not expired
|
||||
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1, # Expired
|
||||
]
|
||||
|
||||
for i, timestamp in enumerate(timestamps):
|
||||
test_key = f"key_{i}".encode()
|
||||
record = ProviderRecord(provider, timestamp)
|
||||
store.providers[test_key] = {str(peer_id): record}
|
||||
|
||||
# Test various operations
|
||||
for i, timestamp in enumerate(timestamps):
|
||||
test_key = f"key_{i}".encode()
|
||||
providers = store.get_providers(test_key)
|
||||
|
||||
if timestamp <= current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL:
|
||||
# Should be expired and removed
|
||||
assert len(providers) == 0
|
||||
assert test_key not in store.providers
|
||||
else:
|
||||
# Should be present
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
371
tests/core/kad_dht/test_unit_routing_table.py
Normal file
371
tests/core/kad_dht/test_unit_routing_table.py
Normal file
@ -0,0 +1,371 @@
|
||||
"""
|
||||
Unit tests for the RoutingTable and KBucket classes in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of the routing table including:
|
||||
- KBucket operations (add, remove, split, ping)
|
||||
- RoutingTable management (peer addition, closest peer finding)
|
||||
- Distance calculations and peer ordering
|
||||
- Bucket splitting and range management
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
Mock,
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.kad_dht.routing_table import (
|
||||
BUCKET_SIZE,
|
||||
KBucket,
|
||||
RoutingTable,
|
||||
)
|
||||
from libp2p.kad_dht.utils import (
|
||||
create_key_from_binary,
|
||||
xor_distance,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
|
||||
def create_valid_peer_id(name: str) -> ID:
|
||||
"""Create a valid peer ID for testing."""
|
||||
# Use crypto to generate valid peer IDs
|
||||
key_pair = create_new_key_pair()
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
class TestKBucket:
|
||||
"""Test suite for KBucket class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host(self):
|
||||
"""Create a mock host for testing."""
|
||||
host = Mock()
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
@pytest.fixture
|
||||
def sample_peer_info(self):
|
||||
"""Create sample peer info for testing."""
|
||||
peer_id = create_valid_peer_id("test")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
return PeerInfo(peer_id, addresses)
|
||||
|
||||
def test_init_default_parameters(self, mock_host):
|
||||
"""Test KBucket initialization with default parameters."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
assert bucket.bucket_size == BUCKET_SIZE
|
||||
assert bucket.host == mock_host
|
||||
assert bucket.min_range == 0
|
||||
assert bucket.max_range == 2**256
|
||||
assert len(bucket.peers) == 0
|
||||
|
||||
def test_peer_operations(self, mock_host, sample_peer_info):
|
||||
"""Test basic peer operations: add, check, and remove."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
# Test empty bucket
|
||||
assert bucket.peer_ids() == []
|
||||
assert bucket.size() == 0
|
||||
assert not bucket.has_peer(sample_peer_info.peer_id)
|
||||
|
||||
# Add peer manually
|
||||
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
|
||||
|
||||
# Test with peer
|
||||
assert len(bucket.peer_ids()) == 1
|
||||
assert sample_peer_info.peer_id in bucket.peer_ids()
|
||||
assert bucket.size() == 1
|
||||
assert bucket.has_peer(sample_peer_info.peer_id)
|
||||
assert bucket.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
|
||||
|
||||
# Remove peer
|
||||
result = bucket.remove_peer(sample_peer_info.peer_id)
|
||||
assert result is True
|
||||
assert bucket.size() == 0
|
||||
assert not bucket.has_peer(sample_peer_info.peer_id)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_add_peer_functionality(self, mock_host):
|
||||
"""Test add_peer method with different scenarios."""
|
||||
bucket = KBucket(mock_host, bucket_size=2) # Small bucket for testing
|
||||
|
||||
# Add first peer
|
||||
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
|
||||
result = await bucket.add_peer(peer1)
|
||||
assert result is True
|
||||
assert bucket.size() == 1
|
||||
|
||||
# Add second peer
|
||||
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
|
||||
result = await bucket.add_peer(peer2)
|
||||
assert result is True
|
||||
assert bucket.size() == 2
|
||||
|
||||
# Add same peer again (should update timestamp)
|
||||
await trio.sleep(0.001)
|
||||
result = await bucket.add_peer(peer1)
|
||||
assert result is True
|
||||
assert bucket.size() == 2 # Still 2 peers
|
||||
|
||||
# Try to add third peer when bucket is full
|
||||
peer3 = PeerInfo(create_valid_peer_id("peer3"), [])
|
||||
with patch.object(bucket, "_ping_peer", return_value=True):
|
||||
result = await bucket.add_peer(peer3)
|
||||
assert result is False # Should fail if oldest peer responds
|
||||
|
||||
def test_get_oldest_peer(self, mock_host):
|
||||
"""Test get_oldest_peer method."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
# Empty bucket
|
||||
assert bucket.get_oldest_peer() is None
|
||||
|
||||
# Add peers with different timestamps
|
||||
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
|
||||
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
|
||||
|
||||
current_time = time.time()
|
||||
bucket.peers[peer1.peer_id] = (peer1, current_time - 300) # Older
|
||||
bucket.peers[peer2.peer_id] = (peer2, current_time) # Newer
|
||||
|
||||
oldest = bucket.get_oldest_peer()
|
||||
assert oldest == peer1.peer_id
|
||||
|
||||
def test_stale_peers(self, mock_host):
|
||||
"""Test stale peer identification."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
current_time = time.time()
|
||||
fresh_peer = PeerInfo(create_valid_peer_id("fresh"), [])
|
||||
stale_peer = PeerInfo(create_valid_peer_id("stale"), [])
|
||||
|
||||
bucket.peers[fresh_peer.peer_id] = (fresh_peer, current_time)
|
||||
bucket.peers[stale_peer.peer_id] = (
|
||||
stale_peer,
|
||||
current_time - 7200,
|
||||
) # 2 hours ago
|
||||
|
||||
stale_peers = bucket.get_stale_peers(3600) # 1 hour threshold
|
||||
assert len(stale_peers) == 1
|
||||
assert stale_peer.peer_id in stale_peers
|
||||
|
||||
def test_key_in_range(self, mock_host):
|
||||
"""Test key_in_range method."""
|
||||
bucket = KBucket(mock_host, min_range=100, max_range=200)
|
||||
|
||||
# Test keys within range
|
||||
key_in_range = (150).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_in_range) is True
|
||||
|
||||
# Test keys outside range
|
||||
key_below = (50).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_below) is False
|
||||
|
||||
key_above = (250).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_above) is False
|
||||
|
||||
# Test boundary conditions
|
||||
key_min = (100).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_min) is True
|
||||
|
||||
key_max = (200).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_max) is False
|
||||
|
||||
def test_split_bucket(self, mock_host):
|
||||
"""Test bucket splitting functionality."""
|
||||
bucket = KBucket(mock_host, min_range=0, max_range=256)
|
||||
|
||||
lower_bucket, upper_bucket = bucket.split()
|
||||
|
||||
# Check ranges
|
||||
assert lower_bucket.min_range == 0
|
||||
assert lower_bucket.max_range == 128
|
||||
assert upper_bucket.min_range == 128
|
||||
assert upper_bucket.max_range == 256
|
||||
|
||||
# Check properties
|
||||
assert lower_bucket.bucket_size == bucket.bucket_size
|
||||
assert upper_bucket.bucket_size == bucket.bucket_size
|
||||
assert lower_bucket.host == mock_host
|
||||
assert upper_bucket.host == mock_host
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_peer_scenarios(self, mock_host, sample_peer_info):
|
||||
"""Test different ping scenarios."""
|
||||
bucket = KBucket(mock_host)
|
||||
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
|
||||
|
||||
# Test ping peer not in bucket
|
||||
other_peer_id = create_valid_peer_id("other")
|
||||
with pytest.raises(ValueError, match="Peer .* not in bucket"):
|
||||
await bucket._ping_peer(other_peer_id)
|
||||
|
||||
# Test ping failure due to stream error
|
||||
mock_host.new_stream.side_effect = Exception("Stream failed")
|
||||
result = await bucket._ping_peer(sample_peer_info.peer_id)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRoutingTable:
|
||||
"""Test suite for RoutingTable class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host(self):
|
||||
"""Create a mock host for testing."""
|
||||
host = Mock()
|
||||
host.get_peerstore.return_value = Mock()
|
||||
return host
|
||||
|
||||
@pytest.fixture
|
||||
def local_peer_id(self):
|
||||
"""Create a local peer ID for testing."""
|
||||
return create_valid_peer_id("local")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_peer_info(self):
|
||||
"""Create sample peer info for testing."""
|
||||
peer_id = create_valid_peer_id("sample")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
return PeerInfo(peer_id, addresses)
|
||||
|
||||
def test_init_routing_table(self, mock_host, local_peer_id):
|
||||
"""Test RoutingTable initialization."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
assert routing_table.local_id == local_peer_id
|
||||
assert routing_table.host == mock_host
|
||||
assert len(routing_table.buckets) == 1
|
||||
assert isinstance(routing_table.buckets[0], KBucket)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_add_peer_operations(
|
||||
self, mock_host, local_peer_id, sample_peer_info
|
||||
):
|
||||
"""Test adding peers to routing table."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Test adding peer with PeerInfo
|
||||
result = await routing_table.add_peer(sample_peer_info)
|
||||
assert result is True
|
||||
assert routing_table.size() == 1
|
||||
assert routing_table.peer_in_table(sample_peer_info.peer_id)
|
||||
|
||||
# Test adding peer with just ID
|
||||
peer_id = create_valid_peer_id("test")
|
||||
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
|
||||
mock_host.get_peerstore().addrs.return_value = mock_addrs
|
||||
|
||||
result = await routing_table.add_peer(peer_id)
|
||||
assert result is True
|
||||
assert routing_table.size() == 2
|
||||
|
||||
# Test adding peer with no addresses
|
||||
no_addr_peer_id = create_valid_peer_id("no_addr")
|
||||
mock_host.get_peerstore().addrs.return_value = []
|
||||
|
||||
result = await routing_table.add_peer(no_addr_peer_id)
|
||||
assert result is False
|
||||
assert routing_table.size() == 2
|
||||
|
||||
# Test adding local peer (should be ignored)
|
||||
result = await routing_table.add_peer(local_peer_id)
|
||||
assert result is False
|
||||
assert routing_table.size() == 2
|
||||
|
||||
def test_find_bucket(self, mock_host, local_peer_id):
|
||||
"""Test finding appropriate bucket for peers."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Test with peer ID
|
||||
peer_id = create_valid_peer_id("test")
|
||||
bucket = routing_table.find_bucket(peer_id)
|
||||
assert isinstance(bucket, KBucket)
|
||||
|
||||
def test_peer_management(self, mock_host, local_peer_id, sample_peer_info):
|
||||
"""Test peer management operations."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Add peer manually
|
||||
bucket = routing_table.find_bucket(sample_peer_info.peer_id)
|
||||
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
|
||||
|
||||
# Test peer queries
|
||||
assert routing_table.peer_in_table(sample_peer_info.peer_id)
|
||||
assert routing_table.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
|
||||
assert routing_table.size() == 1
|
||||
assert len(routing_table.get_peer_ids()) == 1
|
||||
|
||||
# Test remove peer
|
||||
result = routing_table.remove_peer(sample_peer_info.peer_id)
|
||||
assert result is True
|
||||
assert not routing_table.peer_in_table(sample_peer_info.peer_id)
|
||||
assert routing_table.size() == 0
|
||||
|
||||
def test_find_closest_peers(self, mock_host, local_peer_id):
|
||||
"""Test finding closest peers."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Empty table
|
||||
target_key = create_key_from_binary(b"target_key")
|
||||
closest_peers = routing_table.find_local_closest_peers(target_key, 5)
|
||||
assert closest_peers == []
|
||||
|
||||
# Add some peers
|
||||
bucket = routing_table.buckets[0]
|
||||
test_peers = []
|
||||
for i in range(5):
|
||||
peer = PeerInfo(create_valid_peer_id(f"peer{i}"), [])
|
||||
test_peers.append(peer)
|
||||
bucket.peers[peer.peer_id] = (peer, time.time())
|
||||
|
||||
closest_peers = routing_table.find_local_closest_peers(target_key, 3)
|
||||
assert len(closest_peers) <= 3
|
||||
assert len(closest_peers) <= len(test_peers)
|
||||
assert all(isinstance(peer_id, ID) for peer_id in closest_peers)
|
||||
|
||||
def test_distance_calculation(self, mock_host, local_peer_id):
|
||||
"""Test XOR distance calculation."""
|
||||
# Test same keys
|
||||
key = b"\x42" * 32
|
||||
distance = xor_distance(key, key)
|
||||
assert distance == 0
|
||||
|
||||
# Test different keys
|
||||
key1 = b"\x00" * 32
|
||||
key2 = b"\xff" * 32
|
||||
distance = xor_distance(key1, key2)
|
||||
expected = int.from_bytes(b"\xff" * 32, byteorder="big")
|
||||
assert distance == expected
|
||||
|
||||
def test_edge_cases(self, mock_host, local_peer_id):
|
||||
"""Test various edge cases."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Test with invalid peer ID
|
||||
nonexistent_peer_id = create_valid_peer_id("nonexistent")
|
||||
assert not routing_table.peer_in_table(nonexistent_peer_id)
|
||||
assert routing_table.get_peer_info(nonexistent_peer_id) is None
|
||||
assert routing_table.remove_peer(nonexistent_peer_id) is False
|
||||
|
||||
# Test bucket splitting scenario
|
||||
assert len(routing_table.buckets) == 1
|
||||
initial_bucket = routing_table.buckets[0]
|
||||
assert initial_bucket.min_range == 0
|
||||
assert initial_bucket.max_range == 2**256
|
||||
504
tests/core/kad_dht/test_unit_value_store.py
Normal file
504
tests/core/kad_dht/test_unit_value_store.py
Normal file
@ -0,0 +1,504 @@
|
||||
"""
|
||||
Unit tests for the ValueStore class in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of the ValueStore including:
|
||||
- Basic storage and retrieval operations
|
||||
- Expiration and TTL handling
|
||||
- Edge cases and error conditions
|
||||
- Store management operations
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
Mock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.kad_dht.value_store import (
|
||||
DEFAULT_TTL,
|
||||
ValueStore,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
|
||||
class TestValueStore:
|
||||
"""Test suite for ValueStore class."""
|
||||
|
||||
def test_init_empty_store(self):
|
||||
"""Test that a new ValueStore is initialized empty."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
assert len(store.store) == 0
|
||||
|
||||
def test_init_with_host_and_peer_id(self):
|
||||
"""Test initialization with host and local peer ID."""
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
assert store.host == mock_host
|
||||
assert store.local_peer_id == peer_id
|
||||
assert len(store.store) == 0
|
||||
|
||||
def test_put_basic(self):
|
||||
"""Test basic put operation."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
|
||||
assert key in store.store
|
||||
stored_value, validity = store.store[key]
|
||||
assert stored_value == value
|
||||
assert validity is not None
|
||||
assert validity > time.time() # Should be in the future
|
||||
|
||||
def test_put_with_custom_validity(self):
|
||||
"""Test put operation with custom validity time."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
custom_validity = time.time() + 3600 # 1 hour from now
|
||||
|
||||
store.put(key, value, validity=custom_validity)
|
||||
|
||||
stored_value, validity = store.store[key]
|
||||
assert stored_value == value
|
||||
assert validity == custom_validity
|
||||
|
||||
def test_put_overwrite_existing(self):
|
||||
"""Test that put overwrites existing values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value1 = b"value1"
|
||||
value2 = b"value2"
|
||||
|
||||
store.put(key, value1)
|
||||
store.put(key, value2)
|
||||
|
||||
assert len(store.store) == 1
|
||||
stored_value, _ = store.store[key]
|
||||
assert stored_value == value2
|
||||
|
||||
def test_get_existing_valid_value(self):
|
||||
"""Test retrieving an existing, non-expired value."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_get_nonexistent_key(self):
|
||||
"""Test retrieving a non-existent key returns None."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value is None
|
||||
|
||||
def test_get_expired_value(self):
|
||||
"""Test that expired values are automatically removed and return None."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
expired_validity = time.time() - 1 # 1 second ago
|
||||
|
||||
# Manually insert expired value
|
||||
store.store[key] = (value, expired_validity)
|
||||
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value is None
|
||||
assert key not in store.store # Should be removed
|
||||
|
||||
def test_remove_existing_key(self):
|
||||
"""Test removing an existing key."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
result = store.remove(key)
|
||||
|
||||
assert result is True
|
||||
assert key not in store.store
|
||||
|
||||
def test_remove_nonexistent_key(self):
|
||||
"""Test removing a non-existent key returns False."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
result = store.remove(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_has_existing_valid_key(self):
|
||||
"""Test has() returns True for existing, valid keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
result = store.has(key)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_has_nonexistent_key(self):
|
||||
"""Test has() returns False for non-existent keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
result = store.has(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_has_expired_key(self):
|
||||
"""Test has() returns False for expired keys and removes them."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
expired_validity = time.time() - 1
|
||||
|
||||
# Manually insert expired value
|
||||
store.store[key] = (value, expired_validity)
|
||||
|
||||
result = store.has(key)
|
||||
|
||||
assert result is False
|
||||
assert key not in store.store # Should be removed
|
||||
|
||||
def test_cleanup_expired_no_expired_values(self):
|
||||
"""Test cleanup when there are no expired values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
value = b"value"
|
||||
|
||||
store.put(key1, value)
|
||||
store.put(key2, value)
|
||||
|
||||
expired_count = store.cleanup_expired()
|
||||
|
||||
assert expired_count == 0
|
||||
assert len(store.store) == 2
|
||||
|
||||
def test_cleanup_expired_with_expired_values(self):
|
||||
"""Test cleanup removes expired values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"valid_key"
|
||||
key2 = b"expired_key1"
|
||||
key3 = b"expired_key2"
|
||||
value = b"value"
|
||||
expired_validity = time.time() - 1
|
||||
|
||||
store.put(key1, value) # Valid
|
||||
store.store[key2] = (value, expired_validity) # Expired
|
||||
store.store[key3] = (value, expired_validity) # Expired
|
||||
|
||||
expired_count = store.cleanup_expired()
|
||||
|
||||
assert expired_count == 2
|
||||
assert len(store.store) == 1
|
||||
assert key1 in store.store
|
||||
assert key2 not in store.store
|
||||
assert key3 not in store.store
|
||||
|
||||
def test_cleanup_expired_mixed_validity_types(self):
|
||||
"""Test cleanup with mix of values with and without expiration."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"no_expiry"
|
||||
key2 = b"valid_expiry"
|
||||
key3 = b"expired"
|
||||
value = b"value"
|
||||
|
||||
# No expiration (None validity)
|
||||
store.put(key1, value)
|
||||
# Valid expiration
|
||||
store.put(key2, value, validity=time.time() + 3600)
|
||||
# Expired
|
||||
store.store[key3] = (value, time.time() - 1)
|
||||
|
||||
expired_count = store.cleanup_expired()
|
||||
|
||||
assert expired_count == 1
|
||||
assert len(store.store) == 2
|
||||
assert key1 in store.store
|
||||
assert key2 in store.store
|
||||
assert key3 not in store.store
|
||||
|
||||
def test_get_keys_empty_store(self):
|
||||
"""Test get_keys() returns empty list for empty store."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
keys = store.get_keys()
|
||||
|
||||
assert keys == []
|
||||
|
||||
def test_get_keys_with_valid_values(self):
|
||||
"""Test get_keys() returns all non-expired keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"expired_key"
|
||||
value = b"value"
|
||||
|
||||
store.put(key1, value)
|
||||
store.put(key2, value)
|
||||
store.store[key3] = (value, time.time() - 1) # Expired
|
||||
|
||||
keys = store.get_keys()
|
||||
|
||||
assert len(keys) == 2
|
||||
assert key1 in keys
|
||||
assert key2 in keys
|
||||
assert key3 not in keys
|
||||
|
||||
def test_size_empty_store(self):
|
||||
"""Test size() returns 0 for empty store."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
size = store.size()
|
||||
|
||||
assert size == 0
|
||||
|
||||
def test_size_with_valid_values(self):
|
||||
"""Test size() returns correct count after cleaning expired values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"expired_key"
|
||||
value = b"value"
|
||||
|
||||
store.put(key1, value)
|
||||
store.put(key2, value)
|
||||
store.store[key3] = (value, time.time() - 1) # Expired
|
||||
|
||||
size = store.size()
|
||||
|
||||
assert size == 2
|
||||
|
||||
def test_edge_case_empty_key(self):
|
||||
"""Test handling of empty key."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b""
|
||||
value = b"value"
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_edge_case_empty_value(self):
|
||||
"""Test handling of empty value."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b""
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_edge_case_large_key_value(self):
|
||||
"""Test handling of large keys and values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"x" * 10000 # 10KB key
|
||||
value = b"y" * 100000 # 100KB value
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_edge_case_negative_validity(self):
|
||||
"""Test handling of negative validity time."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
|
||||
store.put(key, value, validity=-1)
|
||||
|
||||
# Should be expired
|
||||
retrieved_value = store.get(key)
|
||||
assert retrieved_value is None
|
||||
|
||||
def test_default_ttl_calculation(self):
|
||||
"""Test that default TTL is correctly applied."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
start_time = time.time()
|
||||
|
||||
store.put(key, value)
|
||||
|
||||
_, validity = store.store[key]
|
||||
expected_validity = start_time + DEFAULT_TTL
|
||||
|
||||
# Allow small time difference for execution
|
||||
assert abs(validity - expected_validity) < 1
|
||||
|
||||
def test_concurrent_operations(self):
|
||||
"""Test that multiple operations don't interfere with each other."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Add multiple key-value pairs
|
||||
for i in range(100):
|
||||
key = f"key_{i}".encode()
|
||||
value = f"value_{i}".encode()
|
||||
store.put(key, value)
|
||||
|
||||
# Verify all are stored
|
||||
assert store.size() == 100
|
||||
|
||||
# Remove every other key
|
||||
for i in range(0, 100, 2):
|
||||
key = f"key_{i}".encode()
|
||||
store.remove(key)
|
||||
|
||||
# Verify correct count
|
||||
assert store.size() == 50
|
||||
|
||||
# Verify remaining keys are correct
|
||||
for i in range(1, 100, 2):
|
||||
key = f"key_{i}".encode()
|
||||
assert store.has(key)
|
||||
|
||||
def test_expiration_boundary_conditions(self):
|
||||
"""Test expiration around current time boundary."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"key3"
|
||||
value = b"value"
|
||||
current_time = time.time()
|
||||
|
||||
# Just expired
|
||||
store.store[key1] = (value, current_time - 0.001)
|
||||
# Valid for a longer time to account for test execution time
|
||||
store.store[key2] = (value, current_time + 1.0)
|
||||
# Exactly current time (should be expired)
|
||||
store.store[key3] = (value, current_time)
|
||||
|
||||
# Small delay to ensure time has passed
|
||||
time.sleep(0.002)
|
||||
|
||||
assert not store.has(key1) # Should be expired
|
||||
assert store.has(key2) # Should be valid
|
||||
assert not store.has(key3) # Should be expired (exactly at current time)
|
||||
|
||||
def test_store_internal_structure(self):
|
||||
"""Test that internal store structure is maintained correctly."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
validity = time.time() + 3600
|
||||
|
||||
store.put(key, value, validity=validity)
|
||||
|
||||
# Verify internal structure
|
||||
assert isinstance(store.store, dict)
|
||||
assert key in store.store
|
||||
stored_tuple = store.store[key]
|
||||
assert isinstance(stored_tuple, tuple)
|
||||
assert len(stored_tuple) == 2
|
||||
assert stored_tuple[0] == value
|
||||
assert stored_tuple[1] == validity
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_store_at_peer_local_peer(self):
|
||||
"""Test _store_at_peer returns True when storing at local peer."""
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
|
||||
result = await store._store_at_peer(peer_id, key, value)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_from_peer_local_peer(self):
|
||||
"""Test _get_from_peer returns None when querying local peer."""
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
|
||||
result = await store._get_from_peer(peer_id, key)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_memory_efficiency_large_dataset(self):
|
||||
"""Test memory behavior with large datasets."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Add a large number of entries
|
||||
num_entries = 10000
|
||||
for i in range(num_entries):
|
||||
key = f"key_{i:05d}".encode()
|
||||
value = f"value_{i:05d}".encode()
|
||||
store.put(key, value)
|
||||
|
||||
assert store.size() == num_entries
|
||||
|
||||
# Clean up all entries
|
||||
for i in range(num_entries):
|
||||
key = f"key_{i:05d}".encode()
|
||||
store.remove(key)
|
||||
|
||||
assert store.size() == 0
|
||||
assert len(store.store) == 0
|
||||
|
||||
def test_key_collision_resistance(self):
|
||||
"""Test that similar keys don't collide."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Test keys that might cause collisions
|
||||
keys = [
|
||||
b"key",
|
||||
b"key\x00",
|
||||
b"key1",
|
||||
b"Key", # Different case
|
||||
b"key ", # With space
|
||||
b" key", # Leading space
|
||||
]
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
value = f"value_{i}".encode()
|
||||
store.put(key, value)
|
||||
|
||||
# Verify all keys are stored separately
|
||||
assert store.size() == len(keys)
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
expected_value = f"value_{i}".encode()
|
||||
assert store.get(key) == expected_value
|
||||
|
||||
def test_unicode_key_handling(self):
|
||||
"""Test handling of unicode content in keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Test various unicode keys
|
||||
unicode_keys = [
|
||||
b"hello",
|
||||
"héllo".encode(),
|
||||
"🔑".encode(),
|
||||
"ключ".encode(), # Russian
|
||||
"键".encode(), # Chinese
|
||||
]
|
||||
|
||||
for i, key in enumerate(unicode_keys):
|
||||
value = f"value_{i}".encode()
|
||||
store.put(key, value)
|
||||
assert store.get(key) == value
|
||||
Reference in New Issue
Block a user