Kademlia DHT implementation in py-libp2p (#579)

* initialise the module

* added content routing

* added routing module

* added peer routing

* added value store

* added utilities functions

* added main kademlia file

* fixed create_key_from_binary function

* example to test kademlia dht

* added protocol ID and enhanced logging for peer store size in provider and consumer nodes

* refactor: specify stream type in handle_stream method and add peer in routing table

* removed content routing

* added default value of count for finding closest peers

* added functions to find close peers

* refactor: remove content routing and enhance peer discovery

* added put value function

* added get value function

* fix: improve logging and handle key encoding in get_value method

* refactor: remove ContentRouting import from __init__.py

* refactor: improved basic kademlia example

* added protobuf files

* replaced json with protobuf

* refactor: enhance peer discovery and routing logic in KadDHT

* refactor: enhance Kademlia routing table to use PeerInfo objects and improve peer management

* refactor: enhance peer addition logic to utilize PeerInfo objects in routing table

* feat: implement content provider functionality in Kademlia DHT

* refactor: update value store to use datetime for validity management

* refactor: update RoutingTable initialization to include host reference

* refactor: enhance KBucket and RoutingTable for improved peer management and functionality

* refactor: streamline peer discovery and value storage methods in KadDHT

* refactor: update KadDHT and related classes for async peer management and enhanced value storage

* refactor: enhance ProviderStore initialization and improve peer routing integration

* test: add tests for Kademlia DHT functionality

* fix linting issues

* pydocstyle issues fixed

* CICD pipeline issues solved

* fix: update docstring format for find_peer method

* refactor: improve logging and remove unused code in DHT implementation

* refactor: clean up logging and remove unused imports in DHT and test files

* Refactor logging setup and improve DHT stream handling with varint length prefixes

* Update bootstrap peer handling in basic_dht example and refactor peer routing to accept string addresses

* Enhance peer querying in Kademlia DHT by implementing parallel queries using Trio.

* Enhance peer querying by adding deduplication checks

* Refactor DHT implementation to use varint for length prefixes and enhance logging for better traceability

* Add base58 encoding for value storage and enhance logging in basic_dht example

* Refactor Kademlia DHT to support server/client modes

* Added unit tests

* Refactor documentation to fixsome warning

* Add unit tests and remove outdated tests

* Fixed precommit errora

* Refactor error handling test to raise StringParseError for invalid bootstrap addresses

* Add libp2p.kad_dht to the list of subpackages in documentation

* Fix expiration and republish checks to use inclusive comparison

* Add __init__.py file to libp2p.kad_dht.pb package

* Refactor get value and put value to run in parallel with query timeout

* Refactor provider message handling to use parallel processing with timeout

* Add methods for provider store in KadDHT class

* Refactor KadDHT and ProviderStore methods to improve type hints and enhance parallel processing

* Add documentation for libp2p.kad_dht.pb module.

* Update documentation for libp2p.kad_dht package to include subpackages and correct formatting

* Fix formatting in documentation for libp2p.kad_dht package by correcting the subpackage reference

* Fix header formatting in libp2p.kad_dht.pb documentation

* Change log level from info to debug for various logging statements.

* fix CICD issues (post revamp)

* fixed value store unit test

* Refactored kademlia example

* Refactor Kademlia example: enhance logging, improve bootstrap node connection, and streamline server address handling

* removed bootstrap module

* Refactor Kademlia DHT example and core modules: enhance logging, remove unused code, and improve peer handling

* Added docs of kad dht example

* Update server address log file path to use the script's directory

* Refactor: Introduce DHTMode enum for clearer mode management

* moved xor_distance function to utils.py

* Enhance logging in ValueStore and KadDHT: include decoded value in debug logs and update parameter description for validity

* Add handling for closest peers in GET_VALUE response when value is not found

* Handled failure scenario for PUT_VALUE

* Remove kademlia demo from project scripts and contributing documentation

* spelling and logging

---------

Co-authored-by: pacrob <5199899+pacrob@users.noreply.github.com>
This commit is contained in:
Sumanjeet
2025-06-17 02:16:40 +05:30
committed by GitHub
parent 733ef86e62
commit d61bca78ab
24 changed files with 5790 additions and 1 deletions

View File

@ -0,0 +1,168 @@
"""
Tests for the Kademlia DHT implementation.
This module tests core functionality of the Kademlia DHT including:
- Node discovery (find_node)
- Value storage and retrieval (put_value, get_value)
- Content provider advertisement and discovery (provide, find_providers)
"""
import hashlib
import logging
import uuid
import pytest
import trio
from libp2p.kad_dht.kad_dht import (
DHTMode,
KadDHT,
)
from libp2p.kad_dht.utils import (
create_key_from_binary,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from tests.utils.factories import (
host_pair_factory,
)
# Configure logger
logger = logging.getLogger("test.kad_dht")
# Constants for the tests
TEST_TIMEOUT = 5 # Timeout in seconds
@pytest.fixture
async def dht_pair(security_protocol):
"""Create a pair of connected DHT nodes for testing."""
async with host_pair_factory(security_protocol=security_protocol) as (
host_a,
host_b,
):
# Get peer info for bootstrapping
peer_b_info = PeerInfo(host_b.get_id(), host_b.get_addrs())
peer_a_info = PeerInfo(host_a.get_id(), host_a.get_addrs())
# Create DHT nodes from the hosts with bootstrap peers as multiaddr strings
dht_a: KadDHT = KadDHT(host_a, mode=DHTMode.SERVER)
dht_b: KadDHT = KadDHT(host_b, mode=DHTMode.SERVER)
await dht_a.peer_routing.routing_table.add_peer(peer_b_info)
await dht_b.peer_routing.routing_table.add_peer(peer_a_info)
# Start both DHT services
async with background_trio_service(dht_a), background_trio_service(dht_b):
# Allow time for bootstrap to complete and connections to establish
await trio.sleep(0.1)
logger.debug(
"After bootstrap: Node A peers: %s", dht_a.routing_table.get_peer_ids()
)
logger.debug(
"After bootstrap: Node B peers: %s", dht_b.routing_table.get_peer_ids()
)
# Return the DHT pair
yield (dht_a, dht_b)
@pytest.mark.trio
async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
"""Test that nodes can find each other in the DHT."""
dht_a, dht_b = dht_pair
# Node A should be able to find Node B
with trio.fail_after(TEST_TIMEOUT):
found_info = await dht_a.find_peer(dht_b.host.get_id())
# Verify that the found peer has the correct peer ID
assert found_info is not None, "Failed to find the target peer"
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
@pytest.mark.trio
async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
"""Test storing and retrieving values in the DHT."""
dht_a, dht_b = dht_pair
# dht_a.peer_routing.routing_table.add_peer(dht_b.pe)
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
# Generate a random key and value
key = create_key_from_binary(b"test-key")
value = b"test-value"
# First add the value directly to node A's store to verify storage works
dht_a.value_store.put(key, value)
logger.debug("Local value store: %s", dht_a.value_store.store)
local_value = dht_a.value_store.get(key)
assert local_value == value, "Local value storage failed"
print("number of nodes in peer store", dht_a.host.get_peerstore().peer_ids())
await dht_a.routing_table.add_peer(peer_b_info)
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
# Store the value using the first node (this will also store locally)
with trio.fail_after(TEST_TIMEOUT):
await dht_a.put_value(key, value)
# # Log debugging information
logger.debug("Put value with key %s...", key.hex()[:10])
logger.debug("Node A value store: %s", dht_a.value_store.store)
print("hello test")
# # Allow more time for the value to propagate
await trio.sleep(0.5)
# # Try direct connection between nodes to ensure they're properly linked
logger.debug("Node A peers: %s", dht_a.routing_table.get_peer_ids())
logger.debug("Node B peers: %s", dht_b.routing_table.get_peer_ids())
# Retrieve the value using the second node
with trio.fail_after(TEST_TIMEOUT):
retrieved_value = await dht_b.get_value(key)
print("the value stored in node b is", dht_b.get_value_store_size())
logger.debug("Retrieved value: %s", retrieved_value)
# Verify that the retrieved value matches the original
assert retrieved_value == value, "Retrieved value does not match the stored value"
@pytest.mark.trio
async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
"""Test advertising and finding content providers."""
dht_a, dht_b = dht_pair
# Generate a random content ID
content = f"test-content-{uuid.uuid4()}".encode()
content_id = hashlib.sha256(content).digest()
# Store content on the first node
dht_a.value_store.put(content_id, content)
# Advertise the first node as a provider
with trio.fail_after(TEST_TIMEOUT):
success = await dht_a.provide(content_id)
assert success, "Failed to advertise as provider"
# Allow time for the provider record to propagate
await trio.sleep(0.1)
# Find providers using the second node
with trio.fail_after(TEST_TIMEOUT):
providers = await dht_b.find_providers(content_id)
# Verify that we found the first node as a provider
assert providers, "No providers found"
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
"Expected provider not found"
)
# Retrieve the content using the provider information
with trio.fail_after(TEST_TIMEOUT):
retrieved_value = await dht_b.get_value(content_id)
assert retrieved_value == content, (
"Retrieved content does not match the original"
)

View File

@ -0,0 +1,459 @@
"""
Unit tests for the PeerRouting class in Kademlia DHT.
This module tests the core functionality of peer routing including:
- Peer discovery and lookup
- Network queries for closest peers
- Protocol message handling
- Error handling and edge cases
"""
import time
from unittest.mock import (
AsyncMock,
Mock,
patch,
)
import pytest
from multiaddr import (
Multiaddr,
)
import varint
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.kad_dht.pb.kademlia_pb2 import (
Message,
)
from libp2p.kad_dht.peer_routing import (
ALPHA,
MAX_PEER_LOOKUP_ROUNDS,
PROTOCOL_ID,
PeerRouting,
)
from libp2p.kad_dht.routing_table import (
RoutingTable,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
def create_valid_peer_id(name: str) -> ID:
"""Create a valid peer ID for testing."""
key_pair = create_new_key_pair()
return ID.from_pubkey(key_pair.public_key)
class TestPeerRouting:
"""Test suite for PeerRouting class."""
@pytest.fixture
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_id.return_value = create_valid_peer_id("local")
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()
host.connect = AsyncMock()
return host
@pytest.fixture
def mock_routing_table(self, mock_host):
"""Create a mock routing table for testing."""
local_id = create_valid_peer_id("local")
routing_table = RoutingTable(local_id, mock_host)
return routing_table
@pytest.fixture
def peer_routing(self, mock_host, mock_routing_table):
"""Create a PeerRouting instance for testing."""
return PeerRouting(mock_host, mock_routing_table)
@pytest.fixture
def sample_peer_info(self):
"""Create sample peer info for testing."""
peer_id = create_valid_peer_id("sample")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
return PeerInfo(peer_id, addresses)
def test_init_peer_routing(self, mock_host, mock_routing_table):
"""Test PeerRouting initialization."""
peer_routing = PeerRouting(mock_host, mock_routing_table)
assert peer_routing.host == mock_host
assert peer_routing.routing_table == mock_routing_table
assert peer_routing.protocol_id == PROTOCOL_ID
@pytest.mark.trio
async def test_find_peer_local_host(self, peer_routing, mock_host):
"""Test finding our own peer."""
local_id = mock_host.get_id()
result = await peer_routing.find_peer(local_id)
assert result is not None
assert result.peer_id == local_id
assert result.addrs == mock_host.get_addrs()
@pytest.mark.trio
async def test_find_peer_in_routing_table(self, peer_routing, sample_peer_info):
"""Test finding peer that exists in routing table."""
# Add peer to routing table
await peer_routing.routing_table.add_peer(sample_peer_info)
result = await peer_routing.find_peer(sample_peer_info.peer_id)
assert result is not None
assert result.peer_id == sample_peer_info.peer_id
@pytest.mark.trio
async def test_find_peer_in_peerstore(self, peer_routing, mock_host):
"""Test finding peer that exists in peerstore."""
peer_id = create_valid_peer_id("peerstore")
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8002")]
# Mock peerstore to return addresses
mock_host.get_peerstore().addrs.return_value = mock_addrs
result = await peer_routing.find_peer(peer_id)
assert result is not None
assert result.peer_id == peer_id
assert result.addrs == mock_addrs
@pytest.mark.trio
async def test_find_peer_not_found(self, peer_routing, mock_host):
"""Test finding peer that doesn't exist anywhere."""
peer_id = create_valid_peer_id("nonexistent")
# Mock peerstore to return no addresses
mock_host.get_peerstore().addrs.return_value = []
# Mock network search to return empty results
with patch.object(peer_routing, "find_closest_peers_network", return_value=[]):
result = await peer_routing.find_peer(peer_id)
assert result is None
@pytest.mark.trio
async def test_find_closest_peers_network_empty_start(self, peer_routing):
"""Test network search with no local peers."""
target_key = b"target_key"
# Mock routing table to return empty list
with patch.object(
peer_routing.routing_table, "find_local_closest_peers", return_value=[]
):
result = await peer_routing.find_closest_peers_network(target_key)
assert result == []
@pytest.mark.trio
async def test_find_closest_peers_network_with_peers(self, peer_routing, mock_host):
"""Test network search with some initial peers."""
target_key = b"target_key"
# Create some test peers
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(3)]
# Mock routing table to return initial peers
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=initial_peers,
):
# Mock _query_peer_for_closest to return empty results (no new peers found)
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
result = await peer_routing.find_closest_peers_network(
target_key, count=5
)
assert len(result) <= 5
# Should return the initial peers since no new ones were discovered
assert all(peer in initial_peers for peer in result)
@pytest.mark.trio
async def test_find_closest_peers_convergence(self, peer_routing):
"""Test that network search converges properly."""
target_key = b"target_key"
# Create test peers
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(2)]
# Mock to simulate convergence (no improvement in closest peers)
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=initial_peers,
):
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
with patch(
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance",
return_value=initial_peers,
):
result = await peer_routing.find_closest_peers_network(target_key)
assert result == initial_peers
@pytest.mark.trio
async def test_query_peer_for_closest_success(
self, peer_routing, mock_host, sample_peer_info
):
"""Test successful peer query for closest peers."""
target_key = b"target_key"
# Create mock stream
mock_stream = AsyncMock()
mock_host.new_stream.return_value = mock_stream
# Create mock response
response_msg = Message()
response_msg.type = Message.MessageType.FIND_NODE
# Add a peer to the response
peer_proto = response_msg.closerPeers.add()
response_peer_id = create_valid_peer_id("response_peer")
peer_proto.id = response_peer_id.to_bytes()
peer_proto.addrs.append(Multiaddr("/ip4/127.0.0.1/tcp/8003").to_bytes())
response_bytes = response_msg.SerializeToString()
# Mock stream reading
varint_length = varint.encode(len(response_bytes))
mock_stream.read.side_effect = [varint_length, response_bytes]
# Mock peerstore
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
mock_host.get_peerstore().add_addrs = Mock()
result = await peer_routing._query_peer_for_closest(
sample_peer_info.peer_id, target_key
)
assert len(result) == 1
assert result[0] == response_peer_id
mock_stream.write.assert_called()
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_query_peer_for_closest_stream_failure(self, peer_routing, mock_host):
"""Test peer query when stream creation fails."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
# Mock stream creation failure
mock_host.new_stream.side_effect = Exception("Stream failed")
mock_host.get_peerstore().addrs.return_value = []
result = await peer_routing._query_peer_for_closest(peer_id, target_key)
assert result == []
@pytest.mark.trio
async def test_query_peer_for_closest_read_failure(
self, peer_routing, mock_host, sample_peer_info
):
"""Test peer query when reading response fails."""
target_key = b"target_key"
# Create mock stream that fails to read
mock_stream = AsyncMock()
mock_stream.read.side_effect = [b""] # Empty read simulates connection close
mock_host.new_stream.return_value = mock_stream
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
result = await peer_routing._query_peer_for_closest(
sample_peer_info.peer_id, target_key
)
assert result == []
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_refresh_routing_table(self, peer_routing, mock_host):
"""Test routing table refresh."""
local_id = mock_host.get_id()
discovered_peers = [create_valid_peer_id(f"discovered{i}") for i in range(3)]
# Mock find_closest_peers_network to return discovered peers
with patch.object(
peer_routing, "find_closest_peers_network", return_value=discovered_peers
):
# Mock peerstore to return addresses for discovered peers
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8003")]
mock_host.get_peerstore().addrs.return_value = mock_addrs
await peer_routing.refresh_routing_table()
# Should perform lookup for local ID
peer_routing.find_closest_peers_network.assert_called_once_with(
local_id.to_bytes()
)
@pytest.mark.trio
async def test_handle_kad_stream_find_node(self, peer_routing, mock_host):
"""Test handling incoming FIND_NODE requests."""
# Create mock stream
mock_stream = AsyncMock()
# Create FIND_NODE request
request_msg = Message()
request_msg.type = Message.MessageType.FIND_NODE
request_msg.key = b"target_key"
request_bytes = request_msg.SerializeToString()
# Mock stream reading
mock_stream.read.side_effect = [
len(request_bytes).to_bytes(4, byteorder="big"),
request_bytes,
]
# Mock routing table to return some peers
closest_peers = [create_valid_peer_id(f"close{i}") for i in range(2)]
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=closest_peers,
):
mock_host.get_peerstore().addrs.return_value = [
Multiaddr("/ip4/127.0.0.1/tcp/8004")
]
await peer_routing._handle_kad_stream(mock_stream)
# Should write response
mock_stream.write.assert_called()
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_handle_kad_stream_invalid_message(self, peer_routing):
"""Test handling stream with invalid message."""
mock_stream = AsyncMock()
# Mock stream to return invalid data
mock_stream.read.side_effect = [
(10).to_bytes(4, byteorder="big"),
b"invalid_proto_data",
]
# Should handle gracefully without raising exception
await peer_routing._handle_kad_stream(mock_stream)
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_handle_kad_stream_connection_closed(self, peer_routing):
"""Test handling stream when connection is closed early."""
mock_stream = AsyncMock()
# Mock stream to return empty data (connection closed)
mock_stream.read.return_value = b""
await peer_routing._handle_kad_stream(mock_stream)
mock_stream.close.assert_called_once()
@pytest.mark.trio
async def test_query_single_peer_for_closest_success(self, peer_routing):
"""Test _query_single_peer_for_closest method."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
new_peers = []
# Mock successful query
mock_result = [create_valid_peer_id("result1"), create_valid_peer_id("result2")]
with patch.object(
peer_routing, "_query_peer_for_closest", return_value=mock_result
):
await peer_routing._query_single_peer_for_closest(
peer_id, target_key, new_peers
)
assert len(new_peers) == 2
assert all(peer in new_peers for peer in mock_result)
@pytest.mark.trio
async def test_query_single_peer_for_closest_failure(self, peer_routing):
"""Test _query_single_peer_for_closest when query fails."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
new_peers = []
# Mock query failure
with patch.object(
peer_routing,
"_query_peer_for_closest",
side_effect=Exception("Query failed"),
):
await peer_routing._query_single_peer_for_closest(
peer_id, target_key, new_peers
)
# Should handle exception gracefully
assert len(new_peers) == 0
@pytest.mark.trio
async def test_query_single_peer_deduplication(self, peer_routing):
"""Test that _query_single_peer_for_closest deduplicates peers."""
target_key = b"target_key"
peer_id = create_valid_peer_id("test")
duplicate_peer = create_valid_peer_id("duplicate")
new_peers = [duplicate_peer] # Pre-existing peer
# Mock query to return the same peer
mock_result = [duplicate_peer, create_valid_peer_id("new")]
with patch.object(
peer_routing, "_query_peer_for_closest", return_value=mock_result
):
await peer_routing._query_single_peer_for_closest(
peer_id, target_key, new_peers
)
# Should not add duplicate
assert len(new_peers) == 2 # Original + 1 new peer
assert new_peers.count(duplicate_peer) == 1
def test_constants(self):
"""Test that important constants are properly defined."""
assert ALPHA == 3
assert MAX_PEER_LOOKUP_ROUNDS == 20
assert PROTOCOL_ID == "/ipfs/kad/1.0.0"
@pytest.mark.trio
async def test_edge_case_max_rounds_reached(self, peer_routing):
"""Test that lookup stops after maximum rounds."""
target_key = b"target_key"
initial_peers = [create_valid_peer_id("peer1")]
# Mock to always return new peers to force max rounds
def mock_query_side_effect(peer, key):
return [create_valid_peer_id(f"new_peer_{time.time()}")]
with patch.object(
peer_routing.routing_table,
"find_local_closest_peers",
return_value=initial_peers,
):
with patch.object(
peer_routing,
"_query_peer_for_closest",
side_effect=mock_query_side_effect,
):
with patch(
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance"
) as mock_sort:
# Always return different peers to prevent convergence
mock_sort.side_effect = lambda key, peers: peers[:20]
result = await peer_routing.find_closest_peers_network(target_key)
# Should stop after max rounds, not infinite loop
assert isinstance(result, list)

View File

@ -0,0 +1,805 @@
"""
Unit tests for the ProviderStore and ProviderRecord classes in Kademlia DHT.
This module tests the core functionality of provider record management including:
- ProviderRecord creation, expiration, and republish logic
- ProviderStore operations (add, get, cleanup)
- Expiration and TTL handling
- Network operations (mocked)
- Edge cases and error conditions
"""
import time
from unittest.mock import (
AsyncMock,
Mock,
patch,
)
import pytest
from multiaddr import (
Multiaddr,
)
from libp2p.kad_dht.provider_store import (
PROVIDER_ADDRESS_TTL,
PROVIDER_RECORD_EXPIRATION_INTERVAL,
PROVIDER_RECORD_REPUBLISH_INTERVAL,
ProviderRecord,
ProviderStore,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
mock_host = Mock()
class TestProviderRecord:
"""Test suite for ProviderRecord class."""
def test_init_with_default_timestamp(self):
"""Test ProviderRecord initialization with default timestamp."""
peer_id = ID.from_base58("QmTest123")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
peer_info = PeerInfo(peer_id, addresses)
start_time = time.time()
record = ProviderRecord(peer_info)
end_time = time.time()
assert record.provider_info == peer_info
assert start_time <= record.timestamp <= end_time
assert record.peer_id == peer_id
assert record.addresses == addresses
def test_init_with_custom_timestamp(self):
"""Test ProviderRecord initialization with custom timestamp."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
custom_timestamp = time.time() - 3600 # 1 hour ago
record = ProviderRecord(peer_info, timestamp=custom_timestamp)
assert record.timestamp == custom_timestamp
def test_is_expired_fresh_record(self):
"""Test that fresh records are not expired."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
record = ProviderRecord(peer_info)
assert not record.is_expired()
def test_is_expired_old_record(self):
"""Test that old records are expired."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
old_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
record = ProviderRecord(peer_info, timestamp=old_timestamp)
assert record.is_expired()
def test_is_expired_boundary_condition(self):
"""Test expiration at exact boundary."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
boundary_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
# At the exact boundary, should be expired (implementation uses >)
assert record.is_expired()
def test_should_republish_fresh_record(self):
"""Test that fresh records don't need republishing."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
record = ProviderRecord(peer_info)
assert not record.should_republish()
def test_should_republish_old_record(self):
"""Test that old records need republishing."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
old_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1
record = ProviderRecord(peer_info, timestamp=old_timestamp)
assert record.should_republish()
def test_should_republish_boundary_condition(self):
"""Test republish at exact boundary."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
boundary_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
# At the exact boundary, should need republishing (implementation uses >)
assert record.should_republish()
def test_properties(self):
"""Test peer_id and addresses properties."""
peer_id = ID.from_base58("QmTest123")
addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
Multiaddr("/ip6/::1/tcp/8001"),
]
peer_info = PeerInfo(peer_id, addresses)
record = ProviderRecord(peer_info)
assert record.peer_id == peer_id
assert record.addresses == addresses
def test_empty_addresses(self):
"""Test ProviderRecord with empty address list."""
peer_id = ID.from_base58("QmTest123")
peer_info = PeerInfo(peer_id, [])
record = ProviderRecord(peer_info)
assert record.addresses == []
class TestProviderStore:
"""Test suite for ProviderStore class."""
def test_init_empty_store(self):
"""Test that a new ProviderStore is initialized empty."""
store = ProviderStore(host=mock_host)
assert len(store.providers) == 0
assert store.peer_routing is None
assert len(store.providing_keys) == 0
def test_init_with_host(self):
"""Test initialization with host."""
mock_host = Mock()
mock_peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = mock_peer_id
store = ProviderStore(host=mock_host)
assert store.host == mock_host
assert store.local_peer_id == mock_peer_id
assert len(store.providers) == 0
def test_init_with_host_and_peer_routing(self):
"""Test initialization with both host and peer routing."""
mock_host = Mock()
mock_peer_routing = Mock()
mock_peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = mock_peer_id
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
assert store.host == mock_host
assert store.peer_routing == mock_peer_routing
assert store.local_peer_id == mock_peer_id
def test_add_provider_new_key(self):
"""Test adding a provider for a new key."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
provider = PeerInfo(peer_id, addresses)
store.add_provider(key, provider)
assert key in store.providers
assert str(peer_id) in store.providers[key]
record = store.providers[key][str(peer_id)]
assert record.provider_info == provider
assert isinstance(record.timestamp, float)
def test_add_provider_existing_key(self):
"""Test adding multiple providers for the same key."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add first provider
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key, provider1)
# Add second provider
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
store.add_provider(key, provider2)
assert len(store.providers[key]) == 2
assert str(peer_id1) in store.providers[key]
assert str(peer_id2) in store.providers[key]
def test_add_provider_update_existing(self):
"""Test updating an existing provider."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
# Add initial provider
provider1 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
store.add_provider(key, provider1)
first_timestamp = store.providers[key][str(peer_id)].timestamp
# Small delay to ensure timestamp difference
time.sleep(0.001)
# Update provider
provider2 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
store.add_provider(key, provider2)
# Should have same peer but updated info
assert len(store.providers[key]) == 1
assert str(peer_id) in store.providers[key]
record = store.providers[key][str(peer_id)]
assert record.provider_info == provider2
assert record.timestamp > first_timestamp
def test_get_providers_empty_key(self):
"""Test getting providers for non-existent key."""
store = ProviderStore(host=mock_host)
key = b"nonexistent_key"
providers = store.get_providers(key)
assert providers == []
def test_get_providers_valid_records(self):
"""Test getting providers with valid records."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add multiple providers
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
provider1 = PeerInfo(peer_id1, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
provider2 = PeerInfo(peer_id2, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
store.add_provider(key, provider1)
store.add_provider(key, provider2)
providers = store.get_providers(key)
assert len(providers) == 2
provider_ids = {p.peer_id for p in providers}
assert peer_id1 in provider_ids
assert peer_id2 in provider_ids
def test_get_providers_expired_records(self):
"""Test that expired records are filtered out and cleaned up."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add valid provider
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key, provider1)
# Add expired provider manually
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key][str(peer_id2)] = ProviderRecord(
provider2, expired_timestamp
)
providers = store.get_providers(key)
# Should only return valid provider
assert len(providers) == 1
assert providers[0].peer_id == peer_id1
# Expired provider should be cleaned up
assert str(peer_id2) not in store.providers[key]
def test_get_providers_address_ttl(self):
"""Test address TTL handling in get_providers."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
provider = PeerInfo(peer_id, addresses)
# Add provider with old timestamp (addresses expired but record valid)
old_timestamp = time.time() - PROVIDER_ADDRESS_TTL - 1
store.providers[key] = {str(peer_id): ProviderRecord(provider, old_timestamp)}
providers = store.get_providers(key)
# Should return provider but with empty addresses
assert len(providers) == 1
assert providers[0].peer_id == peer_id
assert providers[0].addrs == []
def test_get_providers_cleanup_empty_key(self):
"""Test that keys with no valid providers are removed."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add only expired providers
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key] = {
str(peer_id): ProviderRecord(provider, expired_timestamp)
}
providers = store.get_providers(key)
assert providers == []
assert key not in store.providers # Key should be removed
def test_cleanup_expired_no_expired_records(self):
"""Test cleanup when there are no expired records."""
store = ProviderStore(host=mock_host)
key1 = b"key1"
key2 = b"key2"
# Add valid providers
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
provider1 = PeerInfo(peer_id1, [])
provider2 = PeerInfo(peer_id2, [])
store.add_provider(key1, provider1)
store.add_provider(key2, provider2)
initial_size = store.size()
store.cleanup_expired()
assert store.size() == initial_size
assert key1 in store.providers
assert key2 in store.providers
def test_cleanup_expired_with_expired_records(self):
"""Test cleanup removes expired records."""
store = ProviderStore(host=mock_host)
key = b"test_key"
# Add valid provider
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key, provider1)
# Add expired provider
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key][str(peer_id2)] = ProviderRecord(
provider2, expired_timestamp
)
assert store.size() == 2
store.cleanup_expired()
assert store.size() == 1
assert str(peer_id1) in store.providers[key]
assert str(peer_id2) not in store.providers[key]
def test_cleanup_expired_remove_empty_keys(self):
"""Test that keys with only expired providers are removed."""
store = ProviderStore(host=mock_host)
key1 = b"key1"
key2 = b"key2"
# Add valid provider to key1
peer_id1 = ID.from_base58("QmTest123")
provider1 = PeerInfo(peer_id1, [])
store.add_provider(key1, provider1)
# Add only expired provider to key2
peer_id2 = ID.from_base58("QmTest456")
provider2 = PeerInfo(peer_id2, [])
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
store.providers[key2] = {
str(peer_id2): ProviderRecord(provider2, expired_timestamp)
}
store.cleanup_expired()
assert key1 in store.providers
assert key2 not in store.providers
def test_get_provided_keys_empty_store(self):
"""Test get_provided_keys with empty store."""
store = ProviderStore(host=mock_host)
peer_id = ID.from_base58("QmTest123")
keys = store.get_provided_keys(peer_id)
assert keys == []
def test_get_provided_keys_single_peer(self):
"""Test get_provided_keys for a specific peer."""
store = ProviderStore(host=mock_host)
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
key1 = b"key1"
key2 = b"key2"
key3 = b"key3"
provider1 = PeerInfo(peer_id1, [])
provider2 = PeerInfo(peer_id2, [])
# peer_id1 provides key1 and key2
store.add_provider(key1, provider1)
store.add_provider(key2, provider1)
# peer_id2 provides key2 and key3
store.add_provider(key2, provider2)
store.add_provider(key3, provider2)
keys1 = store.get_provided_keys(peer_id1)
keys2 = store.get_provided_keys(peer_id2)
assert len(keys1) == 2
assert key1 in keys1
assert key2 in keys1
assert len(keys2) == 2
assert key2 in keys2
assert key3 in keys2
def test_get_provided_keys_nonexistent_peer(self):
"""Test get_provided_keys for peer that provides nothing."""
store = ProviderStore(host=mock_host)
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
# Add provider for peer_id1 only
key = b"key"
provider = PeerInfo(peer_id1, [])
store.add_provider(key, provider)
# Query for peer_id2 (provides nothing)
keys = store.get_provided_keys(peer_id2)
assert keys == []
def test_size_empty_store(self):
"""Test size() with empty store."""
store = ProviderStore(host=mock_host)
assert store.size() == 0
def test_size_with_providers(self):
"""Test size() with multiple providers."""
store = ProviderStore(host=mock_host)
# Add providers
key1 = b"key1"
key2 = b"key2"
peer_id1 = ID.from_base58("QmTest123")
peer_id2 = ID.from_base58("QmTest456")
peer_id3 = ID.from_base58("QmTest789")
provider1 = PeerInfo(peer_id1, [])
provider2 = PeerInfo(peer_id2, [])
provider3 = PeerInfo(peer_id3, [])
store.add_provider(key1, provider1)
store.add_provider(key1, provider2) # 2 providers for key1
store.add_provider(key2, provider3) # 1 provider for key2
assert store.size() == 3
@pytest.mark.trio
async def test_provide_no_host(self):
"""Test provide() returns False when no host is configured."""
store = ProviderStore(host=mock_host)
key = b"test_key"
result = await store.provide(key)
assert result is False
@pytest.mark.trio
async def test_provide_no_peer_routing(self):
"""Test provide() returns False when no peer routing is configured."""
mock_host = Mock()
store = ProviderStore(host=mock_host)
key = b"test_key"
result = await store.provide(key)
assert result is False
@pytest.mark.trio
async def test_provide_success(self):
"""Test successful provide operation."""
# Setup mocks
mock_host = Mock()
mock_peer_routing = AsyncMock()
peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = peer_id
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
# Mock finding closest peers
closest_peers = [ID.from_base58("QmPeer1"), ID.from_base58("QmPeer2")]
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
# Mock _send_add_provider to return success
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
key = b"test_key"
result = await store.provide(key)
assert result is True
assert key in store.providing_keys
assert key in store.providers
# Should have called _send_add_provider for each peer
assert mock_send.call_count == len(closest_peers)
@pytest.mark.trio
async def test_provide_skip_local_peer(self):
"""Test that provide() skips sending to local peer."""
# Setup mocks
mock_host = Mock()
mock_peer_routing = AsyncMock()
peer_id = ID.from_base58("QmTest123")
mock_host.get_id.return_value = peer_id
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
# Include local peer in closest peers
closest_peers = [peer_id, ID.from_base58("QmPeer1")]
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
key = b"test_key"
result = await store.provide(key)
assert result is True
# Should only call _send_add_provider once (skip local peer)
assert mock_send.call_count == 1
@pytest.mark.trio
async def test_find_providers_no_host(self):
"""Test find_providers() returns empty list when no host."""
store = ProviderStore(host=mock_host)
key = b"test_key"
result = await store.find_providers(key)
assert result == []
@pytest.mark.trio
async def test_find_providers_local_only(self):
"""Test find_providers() returns local providers."""
mock_host = Mock()
mock_peer_routing = Mock()
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
# Add local providers
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
result = await store.find_providers(key)
assert len(result) == 1
assert result[0].peer_id == peer_id
@pytest.mark.trio
async def test_find_providers_network_search(self):
"""Test find_providers() searches network when no local providers."""
mock_host = Mock()
mock_peer_routing = AsyncMock()
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
# Mock network search
closest_peers = [ID.from_base58("QmPeer1")]
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
# Mock provider response
remote_peer_id = ID.from_base58("QmRemote123")
remote_providers = [PeerInfo(remote_peer_id, [])]
with patch.object(
store, "_get_providers_from_peer", return_value=remote_providers
):
key = b"test_key"
result = await store.find_providers(key)
assert len(result) == 1
assert result[0].peer_id == remote_peer_id
@pytest.mark.trio
async def test_get_providers_from_peer_no_host(self):
"""Test _get_providers_from_peer without host."""
store = ProviderStore(host=mock_host)
peer_id = ID.from_base58("QmTest123")
key = b"test_key"
# Should handle missing host gracefully
result = await store._get_providers_from_peer(peer_id, key)
assert result == []
def test_edge_case_empty_key(self):
"""Test handling of empty key."""
store = ProviderStore(host=mock_host)
key = b""
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
def test_edge_case_large_key(self):
"""Test handling of large key."""
store = ProviderStore(host=mock_host)
key = b"x" * 10000 # 10KB key
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
def test_concurrent_operations(self):
"""Test multiple concurrent operations."""
store = ProviderStore(host=mock_host)
# Add many providers
num_keys = 100
num_providers_per_key = 5
for i in range(num_keys):
_key = f"key_{i}".encode()
for j in range(num_providers_per_key):
# Generate unique valid Base58 peer IDs
# Use a different approach that ensures uniqueness
unique_id = i * num_providers_per_key + j + 1 # Ensure > 0
_peer_id_str = f"QmPeer{unique_id:06d}".replace("0", "A") + "1" * 38
peer_id = ID.from_base58(_peer_id_str)
provider = PeerInfo(peer_id, [])
store.add_provider(_key, provider)
# Verify total size
expected_size = num_keys * num_providers_per_key
assert store.size() == expected_size
# Verify individual keys
for i in range(num_keys):
_key = f"key_{i}".encode()
providers = store.get_providers(_key)
assert len(providers) == num_providers_per_key
def test_memory_efficiency_large_dataset(self):
"""Test memory behavior with large datasets."""
store = ProviderStore(host=mock_host)
# Add large number of providers
num_entries = 1000
for i in range(num_entries):
_key = f"key_{i:05d}".encode()
# Generate valid Base58 peer IDs (replace 0 with valid characters)
peer_str = f"QmPeer{i:05d}".replace("0", "1") + "1" * 35
peer_id = ID.from_base58(peer_str)
provider = PeerInfo(peer_id, [])
store.add_provider(_key, provider)
assert store.size() == num_entries
# Clean up all entries by making them expired
current_time = time.time()
for _key, providers in store.providers.items():
for _peer_id_str, record in providers.items():
record.timestamp = (
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
)
store.cleanup_expired()
assert store.size() == 0
assert len(store.providers) == 0
def test_unicode_key_handling(self):
"""Test handling of unicode content in keys."""
store = ProviderStore(host=mock_host)
# Test various unicode keys
unicode_keys = [
b"hello",
"héllo".encode(),
"🔑".encode(),
"ключ".encode(), # Russian
"".encode(), # Chinese
]
for i, key in enumerate(unicode_keys):
# Generate valid Base58 peer IDs
peer_id = ID.from_base58(f"QmPeer{i + 1}" + "1" * 42) # Valid base58
provider = PeerInfo(peer_id, [])
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
def test_multiple_addresses_per_provider(self):
"""Test providers with multiple addresses."""
store = ProviderStore(host=mock_host)
key = b"test_key"
peer_id = ID.from_base58("QmTest123")
addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
Multiaddr("/ip6/::1/tcp/8001"),
Multiaddr("/ip4/192.168.1.100/tcp/8002"),
]
provider = PeerInfo(peer_id, addresses)
store.add_provider(key, provider)
providers = store.get_providers(key)
assert len(providers) == 1
assert providers[0].peer_id == peer_id
assert len(providers[0].addrs) == len(addresses)
assert all(addr in providers[0].addrs for addr in addresses)
@pytest.mark.trio
async def test_republish_provider_records_no_keys(self):
"""Test _republish_provider_records with no providing keys."""
store = ProviderStore(host=mock_host)
# Should complete without error even with no providing keys
await store._republish_provider_records()
assert len(store.providing_keys) == 0
def test_expiration_boundary_conditions(self):
"""Test expiration around boundary conditions."""
store = ProviderStore(host=mock_host)
peer_id = ID.from_base58("QmTest123")
provider = PeerInfo(peer_id, [])
current_time = time.time()
# Test records at various timestamps
timestamps = [
current_time, # Fresh
current_time - PROVIDER_ADDRESS_TTL + 1, # Addresses valid
current_time - PROVIDER_ADDRESS_TTL - 1, # Addresses expired
current_time
- PROVIDER_RECORD_REPUBLISH_INTERVAL
+ 1, # No republish needed
current_time - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1, # Republish needed
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL + 1, # Not expired
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1, # Expired
]
for i, timestamp in enumerate(timestamps):
test_key = f"key_{i}".encode()
record = ProviderRecord(provider, timestamp)
store.providers[test_key] = {str(peer_id): record}
# Test various operations
for i, timestamp in enumerate(timestamps):
test_key = f"key_{i}".encode()
providers = store.get_providers(test_key)
if timestamp <= current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL:
# Should be expired and removed
assert len(providers) == 0
assert test_key not in store.providers
else:
# Should be present
assert len(providers) == 1
assert providers[0].peer_id == peer_id

View File

@ -0,0 +1,371 @@
"""
Unit tests for the RoutingTable and KBucket classes in Kademlia DHT.
This module tests the core functionality of the routing table including:
- KBucket operations (add, remove, split, ping)
- RoutingTable management (peer addition, closest peer finding)
- Distance calculations and peer ordering
- Bucket splitting and range management
"""
import time
from unittest.mock import (
AsyncMock,
Mock,
patch,
)
import pytest
from multiaddr import (
Multiaddr,
)
import trio
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
from libp2p.kad_dht.routing_table import (
BUCKET_SIZE,
KBucket,
RoutingTable,
)
from libp2p.kad_dht.utils import (
create_key_from_binary,
xor_distance,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
def create_valid_peer_id(name: str) -> ID:
"""Create a valid peer ID for testing."""
# Use crypto to generate valid peer IDs
key_pair = create_new_key_pair()
return ID.from_pubkey(key_pair.public_key)
class TestKBucket:
"""Test suite for KBucket class."""
@pytest.fixture
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()
return host
@pytest.fixture
def sample_peer_info(self):
"""Create sample peer info for testing."""
peer_id = create_valid_peer_id("test")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
return PeerInfo(peer_id, addresses)
def test_init_default_parameters(self, mock_host):
"""Test KBucket initialization with default parameters."""
bucket = KBucket(mock_host)
assert bucket.bucket_size == BUCKET_SIZE
assert bucket.host == mock_host
assert bucket.min_range == 0
assert bucket.max_range == 2**256
assert len(bucket.peers) == 0
def test_peer_operations(self, mock_host, sample_peer_info):
"""Test basic peer operations: add, check, and remove."""
bucket = KBucket(mock_host)
# Test empty bucket
assert bucket.peer_ids() == []
assert bucket.size() == 0
assert not bucket.has_peer(sample_peer_info.peer_id)
# Add peer manually
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
# Test with peer
assert len(bucket.peer_ids()) == 1
assert sample_peer_info.peer_id in bucket.peer_ids()
assert bucket.size() == 1
assert bucket.has_peer(sample_peer_info.peer_id)
assert bucket.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
# Remove peer
result = bucket.remove_peer(sample_peer_info.peer_id)
assert result is True
assert bucket.size() == 0
assert not bucket.has_peer(sample_peer_info.peer_id)
@pytest.mark.trio
async def test_add_peer_functionality(self, mock_host):
"""Test add_peer method with different scenarios."""
bucket = KBucket(mock_host, bucket_size=2) # Small bucket for testing
# Add first peer
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
result = await bucket.add_peer(peer1)
assert result is True
assert bucket.size() == 1
# Add second peer
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
result = await bucket.add_peer(peer2)
assert result is True
assert bucket.size() == 2
# Add same peer again (should update timestamp)
await trio.sleep(0.001)
result = await bucket.add_peer(peer1)
assert result is True
assert bucket.size() == 2 # Still 2 peers
# Try to add third peer when bucket is full
peer3 = PeerInfo(create_valid_peer_id("peer3"), [])
with patch.object(bucket, "_ping_peer", return_value=True):
result = await bucket.add_peer(peer3)
assert result is False # Should fail if oldest peer responds
def test_get_oldest_peer(self, mock_host):
"""Test get_oldest_peer method."""
bucket = KBucket(mock_host)
# Empty bucket
assert bucket.get_oldest_peer() is None
# Add peers with different timestamps
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
current_time = time.time()
bucket.peers[peer1.peer_id] = (peer1, current_time - 300) # Older
bucket.peers[peer2.peer_id] = (peer2, current_time) # Newer
oldest = bucket.get_oldest_peer()
assert oldest == peer1.peer_id
def test_stale_peers(self, mock_host):
"""Test stale peer identification."""
bucket = KBucket(mock_host)
current_time = time.time()
fresh_peer = PeerInfo(create_valid_peer_id("fresh"), [])
stale_peer = PeerInfo(create_valid_peer_id("stale"), [])
bucket.peers[fresh_peer.peer_id] = (fresh_peer, current_time)
bucket.peers[stale_peer.peer_id] = (
stale_peer,
current_time - 7200,
) # 2 hours ago
stale_peers = bucket.get_stale_peers(3600) # 1 hour threshold
assert len(stale_peers) == 1
assert stale_peer.peer_id in stale_peers
def test_key_in_range(self, mock_host):
"""Test key_in_range method."""
bucket = KBucket(mock_host, min_range=100, max_range=200)
# Test keys within range
key_in_range = (150).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_in_range) is True
# Test keys outside range
key_below = (50).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_below) is False
key_above = (250).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_above) is False
# Test boundary conditions
key_min = (100).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_min) is True
key_max = (200).to_bytes(32, byteorder="big")
assert bucket.key_in_range(key_max) is False
def test_split_bucket(self, mock_host):
"""Test bucket splitting functionality."""
bucket = KBucket(mock_host, min_range=0, max_range=256)
lower_bucket, upper_bucket = bucket.split()
# Check ranges
assert lower_bucket.min_range == 0
assert lower_bucket.max_range == 128
assert upper_bucket.min_range == 128
assert upper_bucket.max_range == 256
# Check properties
assert lower_bucket.bucket_size == bucket.bucket_size
assert upper_bucket.bucket_size == bucket.bucket_size
assert lower_bucket.host == mock_host
assert upper_bucket.host == mock_host
@pytest.mark.trio
async def test_ping_peer_scenarios(self, mock_host, sample_peer_info):
"""Test different ping scenarios."""
bucket = KBucket(mock_host)
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
# Test ping peer not in bucket
other_peer_id = create_valid_peer_id("other")
with pytest.raises(ValueError, match="Peer .* not in bucket"):
await bucket._ping_peer(other_peer_id)
# Test ping failure due to stream error
mock_host.new_stream.side_effect = Exception("Stream failed")
result = await bucket._ping_peer(sample_peer_info.peer_id)
assert result is False
class TestRoutingTable:
"""Test suite for RoutingTable class."""
@pytest.fixture
def mock_host(self):
"""Create a mock host for testing."""
host = Mock()
host.get_peerstore.return_value = Mock()
return host
@pytest.fixture
def local_peer_id(self):
"""Create a local peer ID for testing."""
return create_valid_peer_id("local")
@pytest.fixture
def sample_peer_info(self):
"""Create sample peer info for testing."""
peer_id = create_valid_peer_id("sample")
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
return PeerInfo(peer_id, addresses)
def test_init_routing_table(self, mock_host, local_peer_id):
"""Test RoutingTable initialization."""
routing_table = RoutingTable(local_peer_id, mock_host)
assert routing_table.local_id == local_peer_id
assert routing_table.host == mock_host
assert len(routing_table.buckets) == 1
assert isinstance(routing_table.buckets[0], KBucket)
@pytest.mark.trio
async def test_add_peer_operations(
self, mock_host, local_peer_id, sample_peer_info
):
"""Test adding peers to routing table."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Test adding peer with PeerInfo
result = await routing_table.add_peer(sample_peer_info)
assert result is True
assert routing_table.size() == 1
assert routing_table.peer_in_table(sample_peer_info.peer_id)
# Test adding peer with just ID
peer_id = create_valid_peer_id("test")
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
mock_host.get_peerstore().addrs.return_value = mock_addrs
result = await routing_table.add_peer(peer_id)
assert result is True
assert routing_table.size() == 2
# Test adding peer with no addresses
no_addr_peer_id = create_valid_peer_id("no_addr")
mock_host.get_peerstore().addrs.return_value = []
result = await routing_table.add_peer(no_addr_peer_id)
assert result is False
assert routing_table.size() == 2
# Test adding local peer (should be ignored)
result = await routing_table.add_peer(local_peer_id)
assert result is False
assert routing_table.size() == 2
def test_find_bucket(self, mock_host, local_peer_id):
"""Test finding appropriate bucket for peers."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Test with peer ID
peer_id = create_valid_peer_id("test")
bucket = routing_table.find_bucket(peer_id)
assert isinstance(bucket, KBucket)
def test_peer_management(self, mock_host, local_peer_id, sample_peer_info):
"""Test peer management operations."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Add peer manually
bucket = routing_table.find_bucket(sample_peer_info.peer_id)
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
# Test peer queries
assert routing_table.peer_in_table(sample_peer_info.peer_id)
assert routing_table.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
assert routing_table.size() == 1
assert len(routing_table.get_peer_ids()) == 1
# Test remove peer
result = routing_table.remove_peer(sample_peer_info.peer_id)
assert result is True
assert not routing_table.peer_in_table(sample_peer_info.peer_id)
assert routing_table.size() == 0
def test_find_closest_peers(self, mock_host, local_peer_id):
"""Test finding closest peers."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Empty table
target_key = create_key_from_binary(b"target_key")
closest_peers = routing_table.find_local_closest_peers(target_key, 5)
assert closest_peers == []
# Add some peers
bucket = routing_table.buckets[0]
test_peers = []
for i in range(5):
peer = PeerInfo(create_valid_peer_id(f"peer{i}"), [])
test_peers.append(peer)
bucket.peers[peer.peer_id] = (peer, time.time())
closest_peers = routing_table.find_local_closest_peers(target_key, 3)
assert len(closest_peers) <= 3
assert len(closest_peers) <= len(test_peers)
assert all(isinstance(peer_id, ID) for peer_id in closest_peers)
def test_distance_calculation(self, mock_host, local_peer_id):
"""Test XOR distance calculation."""
# Test same keys
key = b"\x42" * 32
distance = xor_distance(key, key)
assert distance == 0
# Test different keys
key1 = b"\x00" * 32
key2 = b"\xff" * 32
distance = xor_distance(key1, key2)
expected = int.from_bytes(b"\xff" * 32, byteorder="big")
assert distance == expected
def test_edge_cases(self, mock_host, local_peer_id):
"""Test various edge cases."""
routing_table = RoutingTable(local_peer_id, mock_host)
# Test with invalid peer ID
nonexistent_peer_id = create_valid_peer_id("nonexistent")
assert not routing_table.peer_in_table(nonexistent_peer_id)
assert routing_table.get_peer_info(nonexistent_peer_id) is None
assert routing_table.remove_peer(nonexistent_peer_id) is False
# Test bucket splitting scenario
assert len(routing_table.buckets) == 1
initial_bucket = routing_table.buckets[0]
assert initial_bucket.min_range == 0
assert initial_bucket.max_range == 2**256

View File

@ -0,0 +1,504 @@
"""
Unit tests for the ValueStore class in Kademlia DHT.
This module tests the core functionality of the ValueStore including:
- Basic storage and retrieval operations
- Expiration and TTL handling
- Edge cases and error conditions
- Store management operations
"""
import time
from unittest.mock import (
Mock,
)
import pytest
from libp2p.kad_dht.value_store import (
DEFAULT_TTL,
ValueStore,
)
from libp2p.peer.id import (
ID,
)
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
class TestValueStore:
"""Test suite for ValueStore class."""
def test_init_empty_store(self):
"""Test that a new ValueStore is initialized empty."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
assert len(store.store) == 0
def test_init_with_host_and_peer_id(self):
"""Test initialization with host and local peer ID."""
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
store = ValueStore(host=mock_host, local_peer_id=peer_id)
assert store.host == mock_host
assert store.local_peer_id == peer_id
assert len(store.store) == 0
def test_put_basic(self):
"""Test basic put operation."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
assert key in store.store
stored_value, validity = store.store[key]
assert stored_value == value
assert validity is not None
assert validity > time.time() # Should be in the future
def test_put_with_custom_validity(self):
"""Test put operation with custom validity time."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
custom_validity = time.time() + 3600 # 1 hour from now
store.put(key, value, validity=custom_validity)
stored_value, validity = store.store[key]
assert stored_value == value
assert validity == custom_validity
def test_put_overwrite_existing(self):
"""Test that put overwrites existing values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value1 = b"value1"
value2 = b"value2"
store.put(key, value1)
store.put(key, value2)
assert len(store.store) == 1
stored_value, _ = store.store[key]
assert stored_value == value2
def test_get_existing_valid_value(self):
"""Test retrieving an existing, non-expired value."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_get_nonexistent_key(self):
"""Test retrieving a non-existent key returns None."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"nonexistent_key"
retrieved_value = store.get(key)
assert retrieved_value is None
def test_get_expired_value(self):
"""Test that expired values are automatically removed and return None."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
expired_validity = time.time() - 1 # 1 second ago
# Manually insert expired value
store.store[key] = (value, expired_validity)
retrieved_value = store.get(key)
assert retrieved_value is None
assert key not in store.store # Should be removed
def test_remove_existing_key(self):
"""Test removing an existing key."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
result = store.remove(key)
assert result is True
assert key not in store.store
def test_remove_nonexistent_key(self):
"""Test removing a non-existent key returns False."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"nonexistent_key"
result = store.remove(key)
assert result is False
def test_has_existing_valid_key(self):
"""Test has() returns True for existing, valid keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
store.put(key, value)
result = store.has(key)
assert result is True
def test_has_nonexistent_key(self):
"""Test has() returns False for non-existent keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"nonexistent_key"
result = store.has(key)
assert result is False
def test_has_expired_key(self):
"""Test has() returns False for expired keys and removes them."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"test_key"
value = b"test_value"
expired_validity = time.time() - 1
# Manually insert expired value
store.store[key] = (value, expired_validity)
result = store.has(key)
assert result is False
assert key not in store.store # Should be removed
def test_cleanup_expired_no_expired_values(self):
"""Test cleanup when there are no expired values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
value = b"value"
store.put(key1, value)
store.put(key2, value)
expired_count = store.cleanup_expired()
assert expired_count == 0
assert len(store.store) == 2
def test_cleanup_expired_with_expired_values(self):
"""Test cleanup removes expired values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"valid_key"
key2 = b"expired_key1"
key3 = b"expired_key2"
value = b"value"
expired_validity = time.time() - 1
store.put(key1, value) # Valid
store.store[key2] = (value, expired_validity) # Expired
store.store[key3] = (value, expired_validity) # Expired
expired_count = store.cleanup_expired()
assert expired_count == 2
assert len(store.store) == 1
assert key1 in store.store
assert key2 not in store.store
assert key3 not in store.store
def test_cleanup_expired_mixed_validity_types(self):
"""Test cleanup with mix of values with and without expiration."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"no_expiry"
key2 = b"valid_expiry"
key3 = b"expired"
value = b"value"
# No expiration (None validity)
store.put(key1, value)
# Valid expiration
store.put(key2, value, validity=time.time() + 3600)
# Expired
store.store[key3] = (value, time.time() - 1)
expired_count = store.cleanup_expired()
assert expired_count == 1
assert len(store.store) == 2
assert key1 in store.store
assert key2 in store.store
assert key3 not in store.store
def test_get_keys_empty_store(self):
"""Test get_keys() returns empty list for empty store."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
keys = store.get_keys()
assert keys == []
def test_get_keys_with_valid_values(self):
"""Test get_keys() returns all non-expired keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
key3 = b"expired_key"
value = b"value"
store.put(key1, value)
store.put(key2, value)
store.store[key3] = (value, time.time() - 1) # Expired
keys = store.get_keys()
assert len(keys) == 2
assert key1 in keys
assert key2 in keys
assert key3 not in keys
def test_size_empty_store(self):
"""Test size() returns 0 for empty store."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
size = store.size()
assert size == 0
def test_size_with_valid_values(self):
"""Test size() returns correct count after cleaning expired values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
key3 = b"expired_key"
value = b"value"
store.put(key1, value)
store.put(key2, value)
store.store[key3] = (value, time.time() - 1) # Expired
size = store.size()
assert size == 2
def test_edge_case_empty_key(self):
"""Test handling of empty key."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b""
value = b"value"
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_edge_case_empty_value(self):
"""Test handling of empty value."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b""
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_edge_case_large_key_value(self):
"""Test handling of large keys and values."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"x" * 10000 # 10KB key
value = b"y" * 100000 # 100KB value
store.put(key, value)
retrieved_value = store.get(key)
assert retrieved_value == value
def test_edge_case_negative_validity(self):
"""Test handling of negative validity time."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
store.put(key, value, validity=-1)
# Should be expired
retrieved_value = store.get(key)
assert retrieved_value is None
def test_default_ttl_calculation(self):
"""Test that default TTL is correctly applied."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
start_time = time.time()
store.put(key, value)
_, validity = store.store[key]
expected_validity = start_time + DEFAULT_TTL
# Allow small time difference for execution
assert abs(validity - expected_validity) < 1
def test_concurrent_operations(self):
"""Test that multiple operations don't interfere with each other."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Add multiple key-value pairs
for i in range(100):
key = f"key_{i}".encode()
value = f"value_{i}".encode()
store.put(key, value)
# Verify all are stored
assert store.size() == 100
# Remove every other key
for i in range(0, 100, 2):
key = f"key_{i}".encode()
store.remove(key)
# Verify correct count
assert store.size() == 50
# Verify remaining keys are correct
for i in range(1, 100, 2):
key = f"key_{i}".encode()
assert store.has(key)
def test_expiration_boundary_conditions(self):
"""Test expiration around current time boundary."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key1 = b"key1"
key2 = b"key2"
key3 = b"key3"
value = b"value"
current_time = time.time()
# Just expired
store.store[key1] = (value, current_time - 0.001)
# Valid for a longer time to account for test execution time
store.store[key2] = (value, current_time + 1.0)
# Exactly current time (should be expired)
store.store[key3] = (value, current_time)
# Small delay to ensure time has passed
time.sleep(0.002)
assert not store.has(key1) # Should be expired
assert store.has(key2) # Should be valid
assert not store.has(key3) # Should be expired (exactly at current time)
def test_store_internal_structure(self):
"""Test that internal store structure is maintained correctly."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
validity = time.time() + 3600
store.put(key, value, validity=validity)
# Verify internal structure
assert isinstance(store.store, dict)
assert key in store.store
stored_tuple = store.store[key]
assert isinstance(stored_tuple, tuple)
assert len(stored_tuple) == 2
assert stored_tuple[0] == value
assert stored_tuple[1] == validity
@pytest.mark.trio
async def test_store_at_peer_local_peer(self):
"""Test _store_at_peer returns True when storing at local peer."""
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
value = b"value"
result = await store._store_at_peer(peer_id, key, value)
assert result is True
@pytest.mark.trio
async def test_get_from_peer_local_peer(self):
"""Test _get_from_peer returns None when querying local peer."""
mock_host = Mock()
peer_id = ID.from_base58("QmTest123")
store = ValueStore(host=mock_host, local_peer_id=peer_id)
key = b"key"
result = await store._get_from_peer(peer_id, key)
assert result is None
def test_memory_efficiency_large_dataset(self):
"""Test memory behavior with large datasets."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Add a large number of entries
num_entries = 10000
for i in range(num_entries):
key = f"key_{i:05d}".encode()
value = f"value_{i:05d}".encode()
store.put(key, value)
assert store.size() == num_entries
# Clean up all entries
for i in range(num_entries):
key = f"key_{i:05d}".encode()
store.remove(key)
assert store.size() == 0
assert len(store.store) == 0
def test_key_collision_resistance(self):
"""Test that similar keys don't collide."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Test keys that might cause collisions
keys = [
b"key",
b"key\x00",
b"key1",
b"Key", # Different case
b"key ", # With space
b" key", # Leading space
]
for i, key in enumerate(keys):
value = f"value_{i}".encode()
store.put(key, value)
# Verify all keys are stored separately
assert store.size() == len(keys)
for i, key in enumerate(keys):
expected_value = f"value_{i}".encode()
assert store.get(key) == expected_value
def test_unicode_key_handling(self):
"""Test handling of unicode content in keys."""
store = ValueStore(host=mock_host, local_peer_id=peer_id)
# Test various unicode keys
unicode_keys = [
b"hello",
"héllo".encode(),
"🔑".encode(),
"ключ".encode(), # Russian
"".encode(), # Chinese
]
for i, key in enumerate(unicode_keys):
value = f"value_{i}".encode()
store.put(key, value)
assert store.get(key) == value