mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'libp2p:main' into main
This commit is contained in:
@ -56,16 +56,9 @@ async def test_identify_protocol(security_protocol):
|
||||
)
|
||||
|
||||
# Check observed address
|
||||
# TODO: use decapsulateCode(protocols('p2p').code)
|
||||
# when the Multiaddr class will implement it
|
||||
host_b_addr = host_b.get_addrs()[0]
|
||||
cleaned_addr = Multiaddr.join(
|
||||
*(
|
||||
host_b_addr.split()[:-1]
|
||||
if str(host_b_addr.split()[-1]).startswith("/p2p/")
|
||||
else host_b_addr.split()
|
||||
)
|
||||
)
|
||||
host_b_peer_id = host_b.get_id()
|
||||
cleaned_addr = host_b_addr.decapsulate(Multiaddr(f"/p2p/{host_b_peer_id}"))
|
||||
|
||||
logger.debug("observed_addr: %s", Multiaddr(identify_response.observed_addr))
|
||||
logger.debug("host_b.get_addrs()[0]: %s", host_b.get_addrs()[0])
|
||||
|
||||
168
tests/core/kad_dht/test_kad_dht.py
Normal file
168
tests/core/kad_dht/test_kad_dht.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""
|
||||
Tests for the Kademlia DHT implementation.
|
||||
|
||||
This module tests core functionality of the Kademlia DHT including:
|
||||
- Node discovery (find_node)
|
||||
- Value storage and retrieval (put_value, get_value)
|
||||
- Content provider advertisement and discovery (provide, find_providers)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.kad_dht.kad_dht import (
|
||||
DHTMode,
|
||||
KadDHT,
|
||||
)
|
||||
from libp2p.kad_dht.utils import (
|
||||
create_key_from_binary,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger("test.kad_dht")
|
||||
|
||||
# Constants for the tests
|
||||
TEST_TIMEOUT = 5 # Timeout in seconds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dht_pair(security_protocol):
|
||||
"""Create a pair of connected DHT nodes for testing."""
|
||||
async with host_pair_factory(security_protocol=security_protocol) as (
|
||||
host_a,
|
||||
host_b,
|
||||
):
|
||||
# Get peer info for bootstrapping
|
||||
peer_b_info = PeerInfo(host_b.get_id(), host_b.get_addrs())
|
||||
peer_a_info = PeerInfo(host_a.get_id(), host_a.get_addrs())
|
||||
|
||||
# Create DHT nodes from the hosts with bootstrap peers as multiaddr strings
|
||||
dht_a: KadDHT = KadDHT(host_a, mode=DHTMode.SERVER)
|
||||
dht_b: KadDHT = KadDHT(host_b, mode=DHTMode.SERVER)
|
||||
await dht_a.peer_routing.routing_table.add_peer(peer_b_info)
|
||||
await dht_b.peer_routing.routing_table.add_peer(peer_a_info)
|
||||
|
||||
# Start both DHT services
|
||||
async with background_trio_service(dht_a), background_trio_service(dht_b):
|
||||
# Allow time for bootstrap to complete and connections to establish
|
||||
await trio.sleep(0.1)
|
||||
|
||||
logger.debug(
|
||||
"After bootstrap: Node A peers: %s", dht_a.routing_table.get_peer_ids()
|
||||
)
|
||||
logger.debug(
|
||||
"After bootstrap: Node B peers: %s", dht_b.routing_table.get_peer_ids()
|
||||
)
|
||||
|
||||
# Return the DHT pair
|
||||
yield (dht_a, dht_b)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
"""Test that nodes can find each other in the DHT."""
|
||||
dht_a, dht_b = dht_pair
|
||||
|
||||
# Node A should be able to find Node B
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
found_info = await dht_a.find_peer(dht_b.host.get_id())
|
||||
|
||||
# Verify that the found peer has the correct peer ID
|
||||
assert found_info is not None, "Failed to find the target peer"
|
||||
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
"""Test storing and retrieving values in the DHT."""
|
||||
dht_a, dht_b = dht_pair
|
||||
# dht_a.peer_routing.routing_table.add_peer(dht_b.pe)
|
||||
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
|
||||
# Generate a random key and value
|
||||
key = create_key_from_binary(b"test-key")
|
||||
value = b"test-value"
|
||||
|
||||
# First add the value directly to node A's store to verify storage works
|
||||
dht_a.value_store.put(key, value)
|
||||
logger.debug("Local value store: %s", dht_a.value_store.store)
|
||||
local_value = dht_a.value_store.get(key)
|
||||
assert local_value == value, "Local value storage failed"
|
||||
print("number of nodes in peer store", dht_a.host.get_peerstore().peer_ids())
|
||||
await dht_a.routing_table.add_peer(peer_b_info)
|
||||
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
|
||||
|
||||
# Store the value using the first node (this will also store locally)
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
await dht_a.put_value(key, value)
|
||||
|
||||
# # Log debugging information
|
||||
logger.debug("Put value with key %s...", key.hex()[:10])
|
||||
logger.debug("Node A value store: %s", dht_a.value_store.store)
|
||||
print("hello test")
|
||||
|
||||
# # Allow more time for the value to propagate
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# # Try direct connection between nodes to ensure they're properly linked
|
||||
logger.debug("Node A peers: %s", dht_a.routing_table.get_peer_ids())
|
||||
logger.debug("Node B peers: %s", dht_b.routing_table.get_peer_ids())
|
||||
|
||||
# Retrieve the value using the second node
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
retrieved_value = await dht_b.get_value(key)
|
||||
print("the value stored in node b is", dht_b.get_value_store_size())
|
||||
logger.debug("Retrieved value: %s", retrieved_value)
|
||||
|
||||
# Verify that the retrieved value matches the original
|
||||
assert retrieved_value == value, "Retrieved value does not match the stored value"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
"""Test advertising and finding content providers."""
|
||||
dht_a, dht_b = dht_pair
|
||||
|
||||
# Generate a random content ID
|
||||
content = f"test-content-{uuid.uuid4()}".encode()
|
||||
content_id = hashlib.sha256(content).digest()
|
||||
|
||||
# Store content on the first node
|
||||
dht_a.value_store.put(content_id, content)
|
||||
|
||||
# Advertise the first node as a provider
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
success = await dht_a.provide(content_id)
|
||||
assert success, "Failed to advertise as provider"
|
||||
|
||||
# Allow time for the provider record to propagate
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Find providers using the second node
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
providers = await dht_b.find_providers(content_id)
|
||||
|
||||
# Verify that we found the first node as a provider
|
||||
assert providers, "No providers found"
|
||||
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
|
||||
"Expected provider not found"
|
||||
)
|
||||
|
||||
# Retrieve the content using the provider information
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
retrieved_value = await dht_b.get_value(content_id)
|
||||
assert retrieved_value == content, (
|
||||
"Retrieved content does not match the original"
|
||||
)
|
||||
459
tests/core/kad_dht/test_unit_peer_routing.py
Normal file
459
tests/core/kad_dht/test_unit_peer_routing.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""
|
||||
Unit tests for the PeerRouting class in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of peer routing including:
|
||||
- Peer discovery and lookup
|
||||
- Network queries for closest peers
|
||||
- Protocol message handling
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
Mock,
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import varint
|
||||
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.kad_dht.pb.kademlia_pb2 import (
|
||||
Message,
|
||||
)
|
||||
from libp2p.kad_dht.peer_routing import (
|
||||
ALPHA,
|
||||
MAX_PEER_LOOKUP_ROUNDS,
|
||||
PROTOCOL_ID,
|
||||
PeerRouting,
|
||||
)
|
||||
from libp2p.kad_dht.routing_table import (
|
||||
RoutingTable,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
|
||||
def create_valid_peer_id(name: str) -> ID:
|
||||
"""Create a valid peer ID for testing."""
|
||||
key_pair = create_new_key_pair()
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
class TestPeerRouting:
|
||||
"""Test suite for PeerRouting class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host(self):
|
||||
"""Create a mock host for testing."""
|
||||
host = Mock()
|
||||
host.get_id.return_value = create_valid_peer_id("local")
|
||||
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
host.connect = AsyncMock()
|
||||
return host
|
||||
|
||||
@pytest.fixture
|
||||
def mock_routing_table(self, mock_host):
|
||||
"""Create a mock routing table for testing."""
|
||||
local_id = create_valid_peer_id("local")
|
||||
routing_table = RoutingTable(local_id, mock_host)
|
||||
return routing_table
|
||||
|
||||
@pytest.fixture
|
||||
def peer_routing(self, mock_host, mock_routing_table):
|
||||
"""Create a PeerRouting instance for testing."""
|
||||
return PeerRouting(mock_host, mock_routing_table)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_peer_info(self):
|
||||
"""Create sample peer info for testing."""
|
||||
peer_id = create_valid_peer_id("sample")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
|
||||
return PeerInfo(peer_id, addresses)
|
||||
|
||||
def test_init_peer_routing(self, mock_host, mock_routing_table):
|
||||
"""Test PeerRouting initialization."""
|
||||
peer_routing = PeerRouting(mock_host, mock_routing_table)
|
||||
|
||||
assert peer_routing.host == mock_host
|
||||
assert peer_routing.routing_table == mock_routing_table
|
||||
assert peer_routing.protocol_id == PROTOCOL_ID
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_local_host(self, peer_routing, mock_host):
|
||||
"""Test finding our own peer."""
|
||||
local_id = mock_host.get_id()
|
||||
|
||||
result = await peer_routing.find_peer(local_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.peer_id == local_id
|
||||
assert result.addrs == mock_host.get_addrs()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_in_routing_table(self, peer_routing, sample_peer_info):
|
||||
"""Test finding peer that exists in routing table."""
|
||||
# Add peer to routing table
|
||||
await peer_routing.routing_table.add_peer(sample_peer_info)
|
||||
|
||||
result = await peer_routing.find_peer(sample_peer_info.peer_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.peer_id == sample_peer_info.peer_id
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_in_peerstore(self, peer_routing, mock_host):
|
||||
"""Test finding peer that exists in peerstore."""
|
||||
peer_id = create_valid_peer_id("peerstore")
|
||||
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8002")]
|
||||
|
||||
# Mock peerstore to return addresses
|
||||
mock_host.get_peerstore().addrs.return_value = mock_addrs
|
||||
|
||||
result = await peer_routing.find_peer(peer_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.peer_id == peer_id
|
||||
assert result.addrs == mock_addrs
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_peer_not_found(self, peer_routing, mock_host):
|
||||
"""Test finding peer that doesn't exist anywhere."""
|
||||
peer_id = create_valid_peer_id("nonexistent")
|
||||
|
||||
# Mock peerstore to return no addresses
|
||||
mock_host.get_peerstore().addrs.return_value = []
|
||||
|
||||
# Mock network search to return empty results
|
||||
with patch.object(peer_routing, "find_closest_peers_network", return_value=[]):
|
||||
result = await peer_routing.find_peer(peer_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_closest_peers_network_empty_start(self, peer_routing):
|
||||
"""Test network search with no local peers."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Mock routing table to return empty list
|
||||
with patch.object(
|
||||
peer_routing.routing_table, "find_local_closest_peers", return_value=[]
|
||||
):
|
||||
result = await peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_closest_peers_network_with_peers(self, peer_routing, mock_host):
|
||||
"""Test network search with some initial peers."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create some test peers
|
||||
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(3)]
|
||||
|
||||
# Mock routing table to return initial peers
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
# Mock _query_peer_for_closest to return empty results (no new peers found)
|
||||
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
|
||||
result = await peer_routing.find_closest_peers_network(
|
||||
target_key, count=5
|
||||
)
|
||||
|
||||
assert len(result) <= 5
|
||||
# Should return the initial peers since no new ones were discovered
|
||||
assert all(peer in initial_peers for peer in result)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_closest_peers_convergence(self, peer_routing):
|
||||
"""Test that network search converges properly."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create test peers
|
||||
initial_peers = [create_valid_peer_id(f"peer{i}") for i in range(2)]
|
||||
|
||||
# Mock to simulate convergence (no improvement in closest peers)
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
with patch.object(peer_routing, "_query_peer_for_closest", return_value=[]):
|
||||
with patch(
|
||||
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
result = await peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
assert result == initial_peers
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_peer_for_closest_success(
|
||||
self, peer_routing, mock_host, sample_peer_info
|
||||
):
|
||||
"""Test successful peer query for closest peers."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create mock stream
|
||||
mock_stream = AsyncMock()
|
||||
mock_host.new_stream.return_value = mock_stream
|
||||
|
||||
# Create mock response
|
||||
response_msg = Message()
|
||||
response_msg.type = Message.MessageType.FIND_NODE
|
||||
|
||||
# Add a peer to the response
|
||||
peer_proto = response_msg.closerPeers.add()
|
||||
response_peer_id = create_valid_peer_id("response_peer")
|
||||
peer_proto.id = response_peer_id.to_bytes()
|
||||
peer_proto.addrs.append(Multiaddr("/ip4/127.0.0.1/tcp/8003").to_bytes())
|
||||
|
||||
response_bytes = response_msg.SerializeToString()
|
||||
|
||||
# Mock stream reading
|
||||
varint_length = varint.encode(len(response_bytes))
|
||||
mock_stream.read.side_effect = [varint_length, response_bytes]
|
||||
|
||||
# Mock peerstore
|
||||
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
|
||||
mock_host.get_peerstore().add_addrs = Mock()
|
||||
|
||||
result = await peer_routing._query_peer_for_closest(
|
||||
sample_peer_info.peer_id, target_key
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] == response_peer_id
|
||||
mock_stream.write.assert_called()
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_peer_for_closest_stream_failure(self, peer_routing, mock_host):
|
||||
"""Test peer query when stream creation fails."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
|
||||
# Mock stream creation failure
|
||||
mock_host.new_stream.side_effect = Exception("Stream failed")
|
||||
mock_host.get_peerstore().addrs.return_value = []
|
||||
|
||||
result = await peer_routing._query_peer_for_closest(peer_id, target_key)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_peer_for_closest_read_failure(
|
||||
self, peer_routing, mock_host, sample_peer_info
|
||||
):
|
||||
"""Test peer query when reading response fails."""
|
||||
target_key = b"target_key"
|
||||
|
||||
# Create mock stream that fails to read
|
||||
mock_stream = AsyncMock()
|
||||
mock_stream.read.side_effect = [b""] # Empty read simulates connection close
|
||||
mock_host.new_stream.return_value = mock_stream
|
||||
mock_host.get_peerstore().addrs.return_value = [sample_peer_info.addrs[0]]
|
||||
|
||||
result = await peer_routing._query_peer_for_closest(
|
||||
sample_peer_info.peer_id, target_key
|
||||
)
|
||||
|
||||
assert result == []
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_refresh_routing_table(self, peer_routing, mock_host):
|
||||
"""Test routing table refresh."""
|
||||
local_id = mock_host.get_id()
|
||||
discovered_peers = [create_valid_peer_id(f"discovered{i}") for i in range(3)]
|
||||
|
||||
# Mock find_closest_peers_network to return discovered peers
|
||||
with patch.object(
|
||||
peer_routing, "find_closest_peers_network", return_value=discovered_peers
|
||||
):
|
||||
# Mock peerstore to return addresses for discovered peers
|
||||
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8003")]
|
||||
mock_host.get_peerstore().addrs.return_value = mock_addrs
|
||||
|
||||
await peer_routing.refresh_routing_table()
|
||||
|
||||
# Should perform lookup for local ID
|
||||
peer_routing.find_closest_peers_network.assert_called_once_with(
|
||||
local_id.to_bytes()
|
||||
)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_kad_stream_find_node(self, peer_routing, mock_host):
|
||||
"""Test handling incoming FIND_NODE requests."""
|
||||
# Create mock stream
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
# Create FIND_NODE request
|
||||
request_msg = Message()
|
||||
request_msg.type = Message.MessageType.FIND_NODE
|
||||
request_msg.key = b"target_key"
|
||||
|
||||
request_bytes = request_msg.SerializeToString()
|
||||
|
||||
# Mock stream reading
|
||||
mock_stream.read.side_effect = [
|
||||
len(request_bytes).to_bytes(4, byteorder="big"),
|
||||
request_bytes,
|
||||
]
|
||||
|
||||
# Mock routing table to return some peers
|
||||
closest_peers = [create_valid_peer_id(f"close{i}") for i in range(2)]
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=closest_peers,
|
||||
):
|
||||
mock_host.get_peerstore().addrs.return_value = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/8004")
|
||||
]
|
||||
|
||||
await peer_routing._handle_kad_stream(mock_stream)
|
||||
|
||||
# Should write response
|
||||
mock_stream.write.assert_called()
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_kad_stream_invalid_message(self, peer_routing):
|
||||
"""Test handling stream with invalid message."""
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
# Mock stream to return invalid data
|
||||
mock_stream.read.side_effect = [
|
||||
(10).to_bytes(4, byteorder="big"),
|
||||
b"invalid_proto_data",
|
||||
]
|
||||
|
||||
# Should handle gracefully without raising exception
|
||||
await peer_routing._handle_kad_stream(mock_stream)
|
||||
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_kad_stream_connection_closed(self, peer_routing):
|
||||
"""Test handling stream when connection is closed early."""
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
# Mock stream to return empty data (connection closed)
|
||||
mock_stream.read.return_value = b""
|
||||
|
||||
await peer_routing._handle_kad_stream(mock_stream)
|
||||
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_single_peer_for_closest_success(self, peer_routing):
|
||||
"""Test _query_single_peer_for_closest method."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
new_peers = []
|
||||
|
||||
# Mock successful query
|
||||
mock_result = [create_valid_peer_id("result1"), create_valid_peer_id("result2")]
|
||||
with patch.object(
|
||||
peer_routing, "_query_peer_for_closest", return_value=mock_result
|
||||
):
|
||||
await peer_routing._query_single_peer_for_closest(
|
||||
peer_id, target_key, new_peers
|
||||
)
|
||||
|
||||
assert len(new_peers) == 2
|
||||
assert all(peer in new_peers for peer in mock_result)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_single_peer_for_closest_failure(self, peer_routing):
|
||||
"""Test _query_single_peer_for_closest when query fails."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
new_peers = []
|
||||
|
||||
# Mock query failure
|
||||
with patch.object(
|
||||
peer_routing,
|
||||
"_query_peer_for_closest",
|
||||
side_effect=Exception("Query failed"),
|
||||
):
|
||||
await peer_routing._query_single_peer_for_closest(
|
||||
peer_id, target_key, new_peers
|
||||
)
|
||||
|
||||
# Should handle exception gracefully
|
||||
assert len(new_peers) == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_single_peer_deduplication(self, peer_routing):
|
||||
"""Test that _query_single_peer_for_closest deduplicates peers."""
|
||||
target_key = b"target_key"
|
||||
peer_id = create_valid_peer_id("test")
|
||||
duplicate_peer = create_valid_peer_id("duplicate")
|
||||
new_peers = [duplicate_peer] # Pre-existing peer
|
||||
|
||||
# Mock query to return the same peer
|
||||
mock_result = [duplicate_peer, create_valid_peer_id("new")]
|
||||
with patch.object(
|
||||
peer_routing, "_query_peer_for_closest", return_value=mock_result
|
||||
):
|
||||
await peer_routing._query_single_peer_for_closest(
|
||||
peer_id, target_key, new_peers
|
||||
)
|
||||
|
||||
# Should not add duplicate
|
||||
assert len(new_peers) == 2 # Original + 1 new peer
|
||||
assert new_peers.count(duplicate_peer) == 1
|
||||
|
||||
def test_constants(self):
|
||||
"""Test that important constants are properly defined."""
|
||||
assert ALPHA == 3
|
||||
assert MAX_PEER_LOOKUP_ROUNDS == 20
|
||||
assert PROTOCOL_ID == "/ipfs/kad/1.0.0"
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_edge_case_max_rounds_reached(self, peer_routing):
|
||||
"""Test that lookup stops after maximum rounds."""
|
||||
target_key = b"target_key"
|
||||
initial_peers = [create_valid_peer_id("peer1")]
|
||||
|
||||
# Mock to always return new peers to force max rounds
|
||||
def mock_query_side_effect(peer, key):
|
||||
return [create_valid_peer_id(f"new_peer_{time.time()}")]
|
||||
|
||||
with patch.object(
|
||||
peer_routing.routing_table,
|
||||
"find_local_closest_peers",
|
||||
return_value=initial_peers,
|
||||
):
|
||||
with patch.object(
|
||||
peer_routing,
|
||||
"_query_peer_for_closest",
|
||||
side_effect=mock_query_side_effect,
|
||||
):
|
||||
with patch(
|
||||
"libp2p.kad_dht.peer_routing.sort_peer_ids_by_distance"
|
||||
) as mock_sort:
|
||||
# Always return different peers to prevent convergence
|
||||
mock_sort.side_effect = lambda key, peers: peers[:20]
|
||||
|
||||
result = await peer_routing.find_closest_peers_network(target_key)
|
||||
|
||||
# Should stop after max rounds, not infinite loop
|
||||
assert isinstance(result, list)
|
||||
805
tests/core/kad_dht/test_unit_provider_store.py
Normal file
805
tests/core/kad_dht/test_unit_provider_store.py
Normal file
@ -0,0 +1,805 @@
|
||||
"""
|
||||
Unit tests for the ProviderStore and ProviderRecord classes in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of provider record management including:
|
||||
- ProviderRecord creation, expiration, and republish logic
|
||||
- ProviderStore operations (add, get, cleanup)
|
||||
- Expiration and TTL handling
|
||||
- Network operations (mocked)
|
||||
- Edge cases and error conditions
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
Mock,
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
|
||||
from libp2p.kad_dht.provider_store import (
|
||||
PROVIDER_ADDRESS_TTL,
|
||||
PROVIDER_RECORD_EXPIRATION_INTERVAL,
|
||||
PROVIDER_RECORD_REPUBLISH_INTERVAL,
|
||||
ProviderRecord,
|
||||
ProviderStore,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
mock_host = Mock()
|
||||
|
||||
|
||||
class TestProviderRecord:
|
||||
"""Test suite for ProviderRecord class."""
|
||||
|
||||
def test_init_with_default_timestamp(self):
|
||||
"""Test ProviderRecord initialization with default timestamp."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
peer_info = PeerInfo(peer_id, addresses)
|
||||
|
||||
start_time = time.time()
|
||||
record = ProviderRecord(peer_info)
|
||||
end_time = time.time()
|
||||
|
||||
assert record.provider_info == peer_info
|
||||
assert start_time <= record.timestamp <= end_time
|
||||
assert record.peer_id == peer_id
|
||||
assert record.addresses == addresses
|
||||
|
||||
def test_init_with_custom_timestamp(self):
|
||||
"""Test ProviderRecord initialization with custom timestamp."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
custom_timestamp = time.time() - 3600 # 1 hour ago
|
||||
|
||||
record = ProviderRecord(peer_info, timestamp=custom_timestamp)
|
||||
|
||||
assert record.timestamp == custom_timestamp
|
||||
|
||||
def test_is_expired_fresh_record(self):
|
||||
"""Test that fresh records are not expired."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert not record.is_expired()
|
||||
|
||||
def test_is_expired_old_record(self):
|
||||
"""Test that old records are expired."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
old_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
record = ProviderRecord(peer_info, timestamp=old_timestamp)
|
||||
|
||||
assert record.is_expired()
|
||||
|
||||
def test_is_expired_boundary_condition(self):
|
||||
"""Test expiration at exact boundary."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
boundary_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL
|
||||
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
|
||||
|
||||
# At the exact boundary, should be expired (implementation uses >)
|
||||
assert record.is_expired()
|
||||
|
||||
def test_should_republish_fresh_record(self):
|
||||
"""Test that fresh records don't need republishing."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert not record.should_republish()
|
||||
|
||||
def test_should_republish_old_record(self):
|
||||
"""Test that old records need republishing."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
old_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1
|
||||
record = ProviderRecord(peer_info, timestamp=old_timestamp)
|
||||
|
||||
assert record.should_republish()
|
||||
|
||||
def test_should_republish_boundary_condition(self):
|
||||
"""Test republish at exact boundary."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
boundary_timestamp = time.time() - PROVIDER_RECORD_REPUBLISH_INTERVAL
|
||||
record = ProviderRecord(peer_info, timestamp=boundary_timestamp)
|
||||
|
||||
# At the exact boundary, should need republishing (implementation uses >)
|
||||
assert record.should_republish()
|
||||
|
||||
def test_properties(self):
|
||||
"""Test peer_id and addresses properties."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
|
||||
Multiaddr("/ip6/::1/tcp/8001"),
|
||||
]
|
||||
peer_info = PeerInfo(peer_id, addresses)
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert record.peer_id == peer_id
|
||||
assert record.addresses == addresses
|
||||
|
||||
def test_empty_addresses(self):
|
||||
"""Test ProviderRecord with empty address list."""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
peer_info = PeerInfo(peer_id, [])
|
||||
record = ProviderRecord(peer_info)
|
||||
|
||||
assert record.addresses == []
|
||||
|
||||
|
||||
class TestProviderStore:
|
||||
"""Test suite for ProviderStore class."""
|
||||
|
||||
def test_init_empty_store(self):
|
||||
"""Test that a new ProviderStore is initialized empty."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
assert len(store.providers) == 0
|
||||
assert store.peer_routing is None
|
||||
assert len(store.providing_keys) == 0
|
||||
|
||||
def test_init_with_host(self):
|
||||
"""Test initialization with host."""
|
||||
mock_host = Mock()
|
||||
mock_peer_id = ID.from_base58("QmTest123")
|
||||
mock_host.get_id.return_value = mock_peer_id
|
||||
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
assert store.host == mock_host
|
||||
assert store.local_peer_id == mock_peer_id
|
||||
assert len(store.providers) == 0
|
||||
|
||||
def test_init_with_host_and_peer_routing(self):
|
||||
"""Test initialization with both host and peer routing."""
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = Mock()
|
||||
mock_peer_id = ID.from_base58("QmTest123")
|
||||
mock_host.get_id.return_value = mock_peer_id
|
||||
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
assert store.host == mock_host
|
||||
assert store.peer_routing == mock_peer_routing
|
||||
assert store.local_peer_id == mock_peer_id
|
||||
|
||||
def test_add_provider_new_key(self):
|
||||
"""Test adding a provider for a new key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
provider = PeerInfo(peer_id, addresses)
|
||||
|
||||
store.add_provider(key, provider)
|
||||
|
||||
assert key in store.providers
|
||||
assert str(peer_id) in store.providers[key]
|
||||
|
||||
record = store.providers[key][str(peer_id)]
|
||||
assert record.provider_info == provider
|
||||
assert isinstance(record.timestamp, float)
|
||||
|
||||
def test_add_provider_existing_key(self):
|
||||
"""Test adding multiple providers for the same key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add first provider
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider1)
|
||||
|
||||
# Add second provider
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
store.add_provider(key, provider2)
|
||||
|
||||
assert len(store.providers[key]) == 2
|
||||
assert str(peer_id1) in store.providers[key]
|
||||
assert str(peer_id2) in store.providers[key]
|
||||
|
||||
def test_add_provider_update_existing(self):
|
||||
"""Test updating an existing provider."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
# Add initial provider
|
||||
provider1 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
|
||||
store.add_provider(key, provider1)
|
||||
first_timestamp = store.providers[key][str(peer_id)].timestamp
|
||||
|
||||
# Small delay to ensure timestamp difference
|
||||
time.sleep(0.001)
|
||||
|
||||
# Update provider
|
||||
provider2 = PeerInfo(peer_id, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
|
||||
store.add_provider(key, provider2)
|
||||
|
||||
# Should have same peer but updated info
|
||||
assert len(store.providers[key]) == 1
|
||||
assert str(peer_id) in store.providers[key]
|
||||
|
||||
record = store.providers[key][str(peer_id)]
|
||||
assert record.provider_info == provider2
|
||||
assert record.timestamp > first_timestamp
|
||||
|
||||
def test_get_providers_empty_key(self):
|
||||
"""Test getting providers for non-existent key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert providers == []
|
||||
|
||||
def test_get_providers_valid_records(self):
|
||||
"""Test getting providers with valid records."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add multiple providers
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider1 = PeerInfo(peer_id1, [Multiaddr("/ip4/127.0.0.1/tcp/8000")])
|
||||
provider2 = PeerInfo(peer_id2, [Multiaddr("/ip4/127.0.0.1/tcp/8001")])
|
||||
|
||||
store.add_provider(key, provider1)
|
||||
store.add_provider(key, provider2)
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 2
|
||||
provider_ids = {p.peer_id for p in providers}
|
||||
assert peer_id1 in provider_ids
|
||||
assert peer_id2 in provider_ids
|
||||
|
||||
def test_get_providers_expired_records(self):
|
||||
"""Test that expired records are filtered out and cleaned up."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add valid provider
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider1)
|
||||
|
||||
# Add expired provider manually
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key][str(peer_id2)] = ProviderRecord(
|
||||
provider2, expired_timestamp
|
||||
)
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
# Should only return valid provider
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id1
|
||||
|
||||
# Expired provider should be cleaned up
|
||||
assert str(peer_id2) not in store.providers[key]
|
||||
|
||||
def test_get_providers_address_ttl(self):
|
||||
"""Test address TTL handling in get_providers."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
provider = PeerInfo(peer_id, addresses)
|
||||
|
||||
# Add provider with old timestamp (addresses expired but record valid)
|
||||
old_timestamp = time.time() - PROVIDER_ADDRESS_TTL - 1
|
||||
store.providers[key] = {str(peer_id): ProviderRecord(provider, old_timestamp)}
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
# Should return provider but with empty addresses
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
assert providers[0].addrs == []
|
||||
|
||||
def test_get_providers_cleanup_empty_key(self):
|
||||
"""Test that keys with no valid providers are removed."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add only expired providers
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key] = {
|
||||
str(peer_id): ProviderRecord(provider, expired_timestamp)
|
||||
}
|
||||
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert providers == []
|
||||
assert key not in store.providers # Key should be removed
|
||||
|
||||
def test_cleanup_expired_no_expired_records(self):
|
||||
"""Test cleanup when there are no expired records."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
|
||||
# Add valid providers
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
|
||||
store.add_provider(key1, provider1)
|
||||
store.add_provider(key2, provider2)
|
||||
|
||||
initial_size = store.size()
|
||||
store.cleanup_expired()
|
||||
|
||||
assert store.size() == initial_size
|
||||
assert key1 in store.providers
|
||||
assert key2 in store.providers
|
||||
|
||||
def test_cleanup_expired_with_expired_records(self):
|
||||
"""Test cleanup removes expired records."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
# Add valid provider
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider1)
|
||||
|
||||
# Add expired provider
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key][str(peer_id2)] = ProviderRecord(
|
||||
provider2, expired_timestamp
|
||||
)
|
||||
|
||||
assert store.size() == 2
|
||||
store.cleanup_expired()
|
||||
|
||||
assert store.size() == 1
|
||||
assert str(peer_id1) in store.providers[key]
|
||||
assert str(peer_id2) not in store.providers[key]
|
||||
|
||||
def test_cleanup_expired_remove_empty_keys(self):
|
||||
"""Test that keys with only expired providers are removed."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
|
||||
# Add valid provider to key1
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key1, provider1)
|
||||
|
||||
# Add only expired provider to key2
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
expired_timestamp = time.time() - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
store.providers[key2] = {
|
||||
str(peer_id2): ProviderRecord(provider2, expired_timestamp)
|
||||
}
|
||||
|
||||
store.cleanup_expired()
|
||||
|
||||
assert key1 in store.providers
|
||||
assert key2 not in store.providers
|
||||
|
||||
def test_get_provided_keys_empty_store(self):
|
||||
"""Test get_provided_keys with empty store."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
keys = store.get_provided_keys(peer_id)
|
||||
|
||||
assert keys == []
|
||||
|
||||
def test_get_provided_keys_single_peer(self):
|
||||
"""Test get_provided_keys for a specific peer."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"key3"
|
||||
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
|
||||
# peer_id1 provides key1 and key2
|
||||
store.add_provider(key1, provider1)
|
||||
store.add_provider(key2, provider1)
|
||||
|
||||
# peer_id2 provides key2 and key3
|
||||
store.add_provider(key2, provider2)
|
||||
store.add_provider(key3, provider2)
|
||||
|
||||
keys1 = store.get_provided_keys(peer_id1)
|
||||
keys2 = store.get_provided_keys(peer_id2)
|
||||
|
||||
assert len(keys1) == 2
|
||||
assert key1 in keys1
|
||||
assert key2 in keys1
|
||||
|
||||
assert len(keys2) == 2
|
||||
assert key2 in keys2
|
||||
assert key3 in keys2
|
||||
|
||||
def test_get_provided_keys_nonexistent_peer(self):
|
||||
"""Test get_provided_keys for peer that provides nothing."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
|
||||
# Add provider for peer_id1 only
|
||||
key = b"key"
|
||||
provider = PeerInfo(peer_id1, [])
|
||||
store.add_provider(key, provider)
|
||||
|
||||
# Query for peer_id2 (provides nothing)
|
||||
keys = store.get_provided_keys(peer_id2)
|
||||
|
||||
assert keys == []
|
||||
|
||||
def test_size_empty_store(self):
|
||||
"""Test size() with empty store."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
assert store.size() == 0
|
||||
|
||||
def test_size_with_providers(self):
|
||||
"""Test size() with multiple providers."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Add providers
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
peer_id1 = ID.from_base58("QmTest123")
|
||||
peer_id2 = ID.from_base58("QmTest456")
|
||||
peer_id3 = ID.from_base58("QmTest789")
|
||||
|
||||
provider1 = PeerInfo(peer_id1, [])
|
||||
provider2 = PeerInfo(peer_id2, [])
|
||||
provider3 = PeerInfo(peer_id3, [])
|
||||
|
||||
store.add_provider(key1, provider1)
|
||||
store.add_provider(key1, provider2) # 2 providers for key1
|
||||
store.add_provider(key2, provider3) # 1 provider for key2
|
||||
|
||||
assert store.size() == 3
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_no_host(self):
|
||||
"""Test provide() returns False when no host is configured."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_no_peer_routing(self):
|
||||
"""Test provide() returns False when no peer routing is configured."""
|
||||
mock_host = Mock()
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_success(self):
|
||||
"""Test successful provide operation."""
|
||||
# Setup mocks
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = AsyncMock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
mock_host.get_id.return_value = peer_id
|
||||
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
|
||||
# Mock finding closest peers
|
||||
closest_peers = [ID.from_base58("QmPeer1"), ID.from_base58("QmPeer2")]
|
||||
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
|
||||
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
# Mock _send_add_provider to return success
|
||||
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
|
||||
key = b"test_key"
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is True
|
||||
assert key in store.providing_keys
|
||||
assert key in store.providers
|
||||
|
||||
# Should have called _send_add_provider for each peer
|
||||
assert mock_send.call_count == len(closest_peers)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_provide_skip_local_peer(self):
|
||||
"""Test that provide() skips sending to local peer."""
|
||||
# Setup mocks
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = AsyncMock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
mock_host.get_id.return_value = peer_id
|
||||
mock_host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
|
||||
# Include local peer in closest peers
|
||||
closest_peers = [peer_id, ID.from_base58("QmPeer1")]
|
||||
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
|
||||
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
with patch.object(store, "_send_add_provider", return_value=True) as mock_send:
|
||||
key = b"test_key"
|
||||
result = await store.provide(key)
|
||||
|
||||
assert result is True
|
||||
# Should only call _send_add_provider once (skip local peer)
|
||||
assert mock_send.call_count == 1
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_providers_no_host(self):
|
||||
"""Test find_providers() returns empty list when no host."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
|
||||
result = await store.find_providers(key)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_providers_local_only(self):
|
||||
"""Test find_providers() returns local providers."""
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = Mock()
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
# Add local providers
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(key, provider)
|
||||
|
||||
result = await store.find_providers(key)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].peer_id == peer_id
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_find_providers_network_search(self):
|
||||
"""Test find_providers() searches network when no local providers."""
|
||||
mock_host = Mock()
|
||||
mock_peer_routing = AsyncMock()
|
||||
store = ProviderStore(host=mock_host, peer_routing=mock_peer_routing)
|
||||
|
||||
# Mock network search
|
||||
closest_peers = [ID.from_base58("QmPeer1")]
|
||||
mock_peer_routing.find_closest_peers_network.return_value = closest_peers
|
||||
|
||||
# Mock provider response
|
||||
remote_peer_id = ID.from_base58("QmRemote123")
|
||||
remote_providers = [PeerInfo(remote_peer_id, [])]
|
||||
|
||||
with patch.object(
|
||||
store, "_get_providers_from_peer", return_value=remote_providers
|
||||
):
|
||||
key = b"test_key"
|
||||
result = await store.find_providers(key)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].peer_id == remote_peer_id
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_providers_from_peer_no_host(self):
|
||||
"""Test _get_providers_from_peer without host."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
key = b"test_key"
|
||||
|
||||
# Should handle missing host gracefully
|
||||
result = await store._get_providers_from_peer(peer_id, key)
|
||||
assert result == []
|
||||
|
||||
def test_edge_case_empty_key(self):
|
||||
"""Test handling of empty key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b""
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
|
||||
store.add_provider(key, provider)
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
|
||||
def test_edge_case_large_key(self):
|
||||
"""Test handling of large key."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"x" * 10000 # 10KB key
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
|
||||
store.add_provider(key, provider)
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
|
||||
def test_concurrent_operations(self):
|
||||
"""Test multiple concurrent operations."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Add many providers
|
||||
num_keys = 100
|
||||
num_providers_per_key = 5
|
||||
|
||||
for i in range(num_keys):
|
||||
_key = f"key_{i}".encode()
|
||||
for j in range(num_providers_per_key):
|
||||
# Generate unique valid Base58 peer IDs
|
||||
# Use a different approach that ensures uniqueness
|
||||
unique_id = i * num_providers_per_key + j + 1 # Ensure > 0
|
||||
_peer_id_str = f"QmPeer{unique_id:06d}".replace("0", "A") + "1" * 38
|
||||
peer_id = ID.from_base58(_peer_id_str)
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(_key, provider)
|
||||
|
||||
# Verify total size
|
||||
expected_size = num_keys * num_providers_per_key
|
||||
assert store.size() == expected_size
|
||||
|
||||
# Verify individual keys
|
||||
for i in range(num_keys):
|
||||
_key = f"key_{i}".encode()
|
||||
providers = store.get_providers(_key)
|
||||
assert len(providers) == num_providers_per_key
|
||||
|
||||
def test_memory_efficiency_large_dataset(self):
|
||||
"""Test memory behavior with large datasets."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Add large number of providers
|
||||
num_entries = 1000
|
||||
for i in range(num_entries):
|
||||
_key = f"key_{i:05d}".encode()
|
||||
# Generate valid Base58 peer IDs (replace 0 with valid characters)
|
||||
peer_str = f"QmPeer{i:05d}".replace("0", "1") + "1" * 35
|
||||
peer_id = ID.from_base58(peer_str)
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(_key, provider)
|
||||
|
||||
assert store.size() == num_entries
|
||||
|
||||
# Clean up all entries by making them expired
|
||||
current_time = time.time()
|
||||
for _key, providers in store.providers.items():
|
||||
for _peer_id_str, record in providers.items():
|
||||
record.timestamp = (
|
||||
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1
|
||||
)
|
||||
|
||||
store.cleanup_expired()
|
||||
assert store.size() == 0
|
||||
assert len(store.providers) == 0
|
||||
|
||||
def test_unicode_key_handling(self):
|
||||
"""Test handling of unicode content in keys."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Test various unicode keys
|
||||
unicode_keys = [
|
||||
b"hello",
|
||||
"héllo".encode(),
|
||||
"🔑".encode(),
|
||||
"ключ".encode(), # Russian
|
||||
"键".encode(), # Chinese
|
||||
]
|
||||
|
||||
for i, key in enumerate(unicode_keys):
|
||||
# Generate valid Base58 peer IDs
|
||||
peer_id = ID.from_base58(f"QmPeer{i + 1}" + "1" * 42) # Valid base58
|
||||
provider = PeerInfo(peer_id, [])
|
||||
store.add_provider(key, provider)
|
||||
|
||||
providers = store.get_providers(key)
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
|
||||
def test_multiple_addresses_per_provider(self):
|
||||
"""Test providers with multiple addresses."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
key = b"test_key"
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
addresses = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/8000"),
|
||||
Multiaddr("/ip6/::1/tcp/8001"),
|
||||
Multiaddr("/ip4/192.168.1.100/tcp/8002"),
|
||||
]
|
||||
provider = PeerInfo(peer_id, addresses)
|
||||
|
||||
store.add_provider(key, provider)
|
||||
providers = store.get_providers(key)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
assert len(providers[0].addrs) == len(addresses)
|
||||
assert all(addr in providers[0].addrs for addr in addresses)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_republish_provider_records_no_keys(self):
|
||||
"""Test _republish_provider_records with no providing keys."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
|
||||
# Should complete without error even with no providing keys
|
||||
await store._republish_provider_records()
|
||||
|
||||
assert len(store.providing_keys) == 0
|
||||
|
||||
def test_expiration_boundary_conditions(self):
|
||||
"""Test expiration around boundary conditions."""
|
||||
store = ProviderStore(host=mock_host)
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
provider = PeerInfo(peer_id, [])
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# Test records at various timestamps
|
||||
timestamps = [
|
||||
current_time, # Fresh
|
||||
current_time - PROVIDER_ADDRESS_TTL + 1, # Addresses valid
|
||||
current_time - PROVIDER_ADDRESS_TTL - 1, # Addresses expired
|
||||
current_time
|
||||
- PROVIDER_RECORD_REPUBLISH_INTERVAL
|
||||
+ 1, # No republish needed
|
||||
current_time - PROVIDER_RECORD_REPUBLISH_INTERVAL - 1, # Republish needed
|
||||
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL + 1, # Not expired
|
||||
current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL - 1, # Expired
|
||||
]
|
||||
|
||||
for i, timestamp in enumerate(timestamps):
|
||||
test_key = f"key_{i}".encode()
|
||||
record = ProviderRecord(provider, timestamp)
|
||||
store.providers[test_key] = {str(peer_id): record}
|
||||
|
||||
# Test various operations
|
||||
for i, timestamp in enumerate(timestamps):
|
||||
test_key = f"key_{i}".encode()
|
||||
providers = store.get_providers(test_key)
|
||||
|
||||
if timestamp <= current_time - PROVIDER_RECORD_EXPIRATION_INTERVAL:
|
||||
# Should be expired and removed
|
||||
assert len(providers) == 0
|
||||
assert test_key not in store.providers
|
||||
else:
|
||||
# Should be present
|
||||
assert len(providers) == 1
|
||||
assert providers[0].peer_id == peer_id
|
||||
371
tests/core/kad_dht/test_unit_routing_table.py
Normal file
371
tests/core/kad_dht/test_unit_routing_table.py
Normal file
@ -0,0 +1,371 @@
|
||||
"""
|
||||
Unit tests for the RoutingTable and KBucket classes in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of the routing table including:
|
||||
- KBucket operations (add, remove, split, ping)
|
||||
- RoutingTable management (peer addition, closest peer finding)
|
||||
- Distance calculations and peer ordering
|
||||
- Bucket splitting and range management
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
Mock,
|
||||
patch,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.kad_dht.routing_table import (
|
||||
BUCKET_SIZE,
|
||||
KBucket,
|
||||
RoutingTable,
|
||||
)
|
||||
from libp2p.kad_dht.utils import (
|
||||
create_key_from_binary,
|
||||
xor_distance,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
|
||||
def create_valid_peer_id(name: str) -> ID:
|
||||
"""Create a valid peer ID for testing."""
|
||||
# Use crypto to generate valid peer IDs
|
||||
key_pair = create_new_key_pair()
|
||||
return ID.from_pubkey(key_pair.public_key)
|
||||
|
||||
|
||||
class TestKBucket:
|
||||
"""Test suite for KBucket class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host(self):
|
||||
"""Create a mock host for testing."""
|
||||
host = Mock()
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
@pytest.fixture
|
||||
def sample_peer_info(self):
|
||||
"""Create sample peer info for testing."""
|
||||
peer_id = create_valid_peer_id("test")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
return PeerInfo(peer_id, addresses)
|
||||
|
||||
def test_init_default_parameters(self, mock_host):
|
||||
"""Test KBucket initialization with default parameters."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
assert bucket.bucket_size == BUCKET_SIZE
|
||||
assert bucket.host == mock_host
|
||||
assert bucket.min_range == 0
|
||||
assert bucket.max_range == 2**256
|
||||
assert len(bucket.peers) == 0
|
||||
|
||||
def test_peer_operations(self, mock_host, sample_peer_info):
|
||||
"""Test basic peer operations: add, check, and remove."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
# Test empty bucket
|
||||
assert bucket.peer_ids() == []
|
||||
assert bucket.size() == 0
|
||||
assert not bucket.has_peer(sample_peer_info.peer_id)
|
||||
|
||||
# Add peer manually
|
||||
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
|
||||
|
||||
# Test with peer
|
||||
assert len(bucket.peer_ids()) == 1
|
||||
assert sample_peer_info.peer_id in bucket.peer_ids()
|
||||
assert bucket.size() == 1
|
||||
assert bucket.has_peer(sample_peer_info.peer_id)
|
||||
assert bucket.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
|
||||
|
||||
# Remove peer
|
||||
result = bucket.remove_peer(sample_peer_info.peer_id)
|
||||
assert result is True
|
||||
assert bucket.size() == 0
|
||||
assert not bucket.has_peer(sample_peer_info.peer_id)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_add_peer_functionality(self, mock_host):
|
||||
"""Test add_peer method with different scenarios."""
|
||||
bucket = KBucket(mock_host, bucket_size=2) # Small bucket for testing
|
||||
|
||||
# Add first peer
|
||||
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
|
||||
result = await bucket.add_peer(peer1)
|
||||
assert result is True
|
||||
assert bucket.size() == 1
|
||||
|
||||
# Add second peer
|
||||
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
|
||||
result = await bucket.add_peer(peer2)
|
||||
assert result is True
|
||||
assert bucket.size() == 2
|
||||
|
||||
# Add same peer again (should update timestamp)
|
||||
await trio.sleep(0.001)
|
||||
result = await bucket.add_peer(peer1)
|
||||
assert result is True
|
||||
assert bucket.size() == 2 # Still 2 peers
|
||||
|
||||
# Try to add third peer when bucket is full
|
||||
peer3 = PeerInfo(create_valid_peer_id("peer3"), [])
|
||||
with patch.object(bucket, "_ping_peer", return_value=True):
|
||||
result = await bucket.add_peer(peer3)
|
||||
assert result is False # Should fail if oldest peer responds
|
||||
|
||||
def test_get_oldest_peer(self, mock_host):
|
||||
"""Test get_oldest_peer method."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
# Empty bucket
|
||||
assert bucket.get_oldest_peer() is None
|
||||
|
||||
# Add peers with different timestamps
|
||||
peer1 = PeerInfo(create_valid_peer_id("peer1"), [])
|
||||
peer2 = PeerInfo(create_valid_peer_id("peer2"), [])
|
||||
|
||||
current_time = time.time()
|
||||
bucket.peers[peer1.peer_id] = (peer1, current_time - 300) # Older
|
||||
bucket.peers[peer2.peer_id] = (peer2, current_time) # Newer
|
||||
|
||||
oldest = bucket.get_oldest_peer()
|
||||
assert oldest == peer1.peer_id
|
||||
|
||||
def test_stale_peers(self, mock_host):
|
||||
"""Test stale peer identification."""
|
||||
bucket = KBucket(mock_host)
|
||||
|
||||
current_time = time.time()
|
||||
fresh_peer = PeerInfo(create_valid_peer_id("fresh"), [])
|
||||
stale_peer = PeerInfo(create_valid_peer_id("stale"), [])
|
||||
|
||||
bucket.peers[fresh_peer.peer_id] = (fresh_peer, current_time)
|
||||
bucket.peers[stale_peer.peer_id] = (
|
||||
stale_peer,
|
||||
current_time - 7200,
|
||||
) # 2 hours ago
|
||||
|
||||
stale_peers = bucket.get_stale_peers(3600) # 1 hour threshold
|
||||
assert len(stale_peers) == 1
|
||||
assert stale_peer.peer_id in stale_peers
|
||||
|
||||
def test_key_in_range(self, mock_host):
|
||||
"""Test key_in_range method."""
|
||||
bucket = KBucket(mock_host, min_range=100, max_range=200)
|
||||
|
||||
# Test keys within range
|
||||
key_in_range = (150).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_in_range) is True
|
||||
|
||||
# Test keys outside range
|
||||
key_below = (50).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_below) is False
|
||||
|
||||
key_above = (250).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_above) is False
|
||||
|
||||
# Test boundary conditions
|
||||
key_min = (100).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_min) is True
|
||||
|
||||
key_max = (200).to_bytes(32, byteorder="big")
|
||||
assert bucket.key_in_range(key_max) is False
|
||||
|
||||
def test_split_bucket(self, mock_host):
|
||||
"""Test bucket splitting functionality."""
|
||||
bucket = KBucket(mock_host, min_range=0, max_range=256)
|
||||
|
||||
lower_bucket, upper_bucket = bucket.split()
|
||||
|
||||
# Check ranges
|
||||
assert lower_bucket.min_range == 0
|
||||
assert lower_bucket.max_range == 128
|
||||
assert upper_bucket.min_range == 128
|
||||
assert upper_bucket.max_range == 256
|
||||
|
||||
# Check properties
|
||||
assert lower_bucket.bucket_size == bucket.bucket_size
|
||||
assert upper_bucket.bucket_size == bucket.bucket_size
|
||||
assert lower_bucket.host == mock_host
|
||||
assert upper_bucket.host == mock_host
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_peer_scenarios(self, mock_host, sample_peer_info):
|
||||
"""Test different ping scenarios."""
|
||||
bucket = KBucket(mock_host)
|
||||
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
|
||||
|
||||
# Test ping peer not in bucket
|
||||
other_peer_id = create_valid_peer_id("other")
|
||||
with pytest.raises(ValueError, match="Peer .* not in bucket"):
|
||||
await bucket._ping_peer(other_peer_id)
|
||||
|
||||
# Test ping failure due to stream error
|
||||
mock_host.new_stream.side_effect = Exception("Stream failed")
|
||||
result = await bucket._ping_peer(sample_peer_info.peer_id)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRoutingTable:
|
||||
"""Test suite for RoutingTable class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host(self):
|
||||
"""Create a mock host for testing."""
|
||||
host = Mock()
|
||||
host.get_peerstore.return_value = Mock()
|
||||
return host
|
||||
|
||||
@pytest.fixture
|
||||
def local_peer_id(self):
|
||||
"""Create a local peer ID for testing."""
|
||||
return create_valid_peer_id("local")
|
||||
|
||||
@pytest.fixture
|
||||
def sample_peer_info(self):
|
||||
"""Create sample peer info for testing."""
|
||||
peer_id = create_valid_peer_id("sample")
|
||||
addresses = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
return PeerInfo(peer_id, addresses)
|
||||
|
||||
def test_init_routing_table(self, mock_host, local_peer_id):
|
||||
"""Test RoutingTable initialization."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
assert routing_table.local_id == local_peer_id
|
||||
assert routing_table.host == mock_host
|
||||
assert len(routing_table.buckets) == 1
|
||||
assert isinstance(routing_table.buckets[0], KBucket)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_add_peer_operations(
|
||||
self, mock_host, local_peer_id, sample_peer_info
|
||||
):
|
||||
"""Test adding peers to routing table."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Test adding peer with PeerInfo
|
||||
result = await routing_table.add_peer(sample_peer_info)
|
||||
assert result is True
|
||||
assert routing_table.size() == 1
|
||||
assert routing_table.peer_in_table(sample_peer_info.peer_id)
|
||||
|
||||
# Test adding peer with just ID
|
||||
peer_id = create_valid_peer_id("test")
|
||||
mock_addrs = [Multiaddr("/ip4/127.0.0.1/tcp/8001")]
|
||||
mock_host.get_peerstore().addrs.return_value = mock_addrs
|
||||
|
||||
result = await routing_table.add_peer(peer_id)
|
||||
assert result is True
|
||||
assert routing_table.size() == 2
|
||||
|
||||
# Test adding peer with no addresses
|
||||
no_addr_peer_id = create_valid_peer_id("no_addr")
|
||||
mock_host.get_peerstore().addrs.return_value = []
|
||||
|
||||
result = await routing_table.add_peer(no_addr_peer_id)
|
||||
assert result is False
|
||||
assert routing_table.size() == 2
|
||||
|
||||
# Test adding local peer (should be ignored)
|
||||
result = await routing_table.add_peer(local_peer_id)
|
||||
assert result is False
|
||||
assert routing_table.size() == 2
|
||||
|
||||
def test_find_bucket(self, mock_host, local_peer_id):
|
||||
"""Test finding appropriate bucket for peers."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Test with peer ID
|
||||
peer_id = create_valid_peer_id("test")
|
||||
bucket = routing_table.find_bucket(peer_id)
|
||||
assert isinstance(bucket, KBucket)
|
||||
|
||||
def test_peer_management(self, mock_host, local_peer_id, sample_peer_info):
|
||||
"""Test peer management operations."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Add peer manually
|
||||
bucket = routing_table.find_bucket(sample_peer_info.peer_id)
|
||||
bucket.peers[sample_peer_info.peer_id] = (sample_peer_info, time.time())
|
||||
|
||||
# Test peer queries
|
||||
assert routing_table.peer_in_table(sample_peer_info.peer_id)
|
||||
assert routing_table.get_peer_info(sample_peer_info.peer_id) == sample_peer_info
|
||||
assert routing_table.size() == 1
|
||||
assert len(routing_table.get_peer_ids()) == 1
|
||||
|
||||
# Test remove peer
|
||||
result = routing_table.remove_peer(sample_peer_info.peer_id)
|
||||
assert result is True
|
||||
assert not routing_table.peer_in_table(sample_peer_info.peer_id)
|
||||
assert routing_table.size() == 0
|
||||
|
||||
def test_find_closest_peers(self, mock_host, local_peer_id):
|
||||
"""Test finding closest peers."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Empty table
|
||||
target_key = create_key_from_binary(b"target_key")
|
||||
closest_peers = routing_table.find_local_closest_peers(target_key, 5)
|
||||
assert closest_peers == []
|
||||
|
||||
# Add some peers
|
||||
bucket = routing_table.buckets[0]
|
||||
test_peers = []
|
||||
for i in range(5):
|
||||
peer = PeerInfo(create_valid_peer_id(f"peer{i}"), [])
|
||||
test_peers.append(peer)
|
||||
bucket.peers[peer.peer_id] = (peer, time.time())
|
||||
|
||||
closest_peers = routing_table.find_local_closest_peers(target_key, 3)
|
||||
assert len(closest_peers) <= 3
|
||||
assert len(closest_peers) <= len(test_peers)
|
||||
assert all(isinstance(peer_id, ID) for peer_id in closest_peers)
|
||||
|
||||
def test_distance_calculation(self, mock_host, local_peer_id):
|
||||
"""Test XOR distance calculation."""
|
||||
# Test same keys
|
||||
key = b"\x42" * 32
|
||||
distance = xor_distance(key, key)
|
||||
assert distance == 0
|
||||
|
||||
# Test different keys
|
||||
key1 = b"\x00" * 32
|
||||
key2 = b"\xff" * 32
|
||||
distance = xor_distance(key1, key2)
|
||||
expected = int.from_bytes(b"\xff" * 32, byteorder="big")
|
||||
assert distance == expected
|
||||
|
||||
def test_edge_cases(self, mock_host, local_peer_id):
|
||||
"""Test various edge cases."""
|
||||
routing_table = RoutingTable(local_peer_id, mock_host)
|
||||
|
||||
# Test with invalid peer ID
|
||||
nonexistent_peer_id = create_valid_peer_id("nonexistent")
|
||||
assert not routing_table.peer_in_table(nonexistent_peer_id)
|
||||
assert routing_table.get_peer_info(nonexistent_peer_id) is None
|
||||
assert routing_table.remove_peer(nonexistent_peer_id) is False
|
||||
|
||||
# Test bucket splitting scenario
|
||||
assert len(routing_table.buckets) == 1
|
||||
initial_bucket = routing_table.buckets[0]
|
||||
assert initial_bucket.min_range == 0
|
||||
assert initial_bucket.max_range == 2**256
|
||||
504
tests/core/kad_dht/test_unit_value_store.py
Normal file
504
tests/core/kad_dht/test_unit_value_store.py
Normal file
@ -0,0 +1,504 @@
|
||||
"""
|
||||
Unit tests for the ValueStore class in Kademlia DHT.
|
||||
|
||||
This module tests the core functionality of the ValueStore including:
|
||||
- Basic storage and retrieval operations
|
||||
- Expiration and TTL handling
|
||||
- Edge cases and error conditions
|
||||
- Store management operations
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import (
|
||||
Mock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.kad_dht.value_store import (
|
||||
DEFAULT_TTL,
|
||||
ValueStore,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
|
||||
class TestValueStore:
|
||||
"""Test suite for ValueStore class."""
|
||||
|
||||
def test_init_empty_store(self):
|
||||
"""Test that a new ValueStore is initialized empty."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
assert len(store.store) == 0
|
||||
|
||||
def test_init_with_host_and_peer_id(self):
|
||||
"""Test initialization with host and local peer ID."""
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
assert store.host == mock_host
|
||||
assert store.local_peer_id == peer_id
|
||||
assert len(store.store) == 0
|
||||
|
||||
def test_put_basic(self):
|
||||
"""Test basic put operation."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
|
||||
assert key in store.store
|
||||
stored_value, validity = store.store[key]
|
||||
assert stored_value == value
|
||||
assert validity is not None
|
||||
assert validity > time.time() # Should be in the future
|
||||
|
||||
def test_put_with_custom_validity(self):
|
||||
"""Test put operation with custom validity time."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
custom_validity = time.time() + 3600 # 1 hour from now
|
||||
|
||||
store.put(key, value, validity=custom_validity)
|
||||
|
||||
stored_value, validity = store.store[key]
|
||||
assert stored_value == value
|
||||
assert validity == custom_validity
|
||||
|
||||
def test_put_overwrite_existing(self):
|
||||
"""Test that put overwrites existing values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value1 = b"value1"
|
||||
value2 = b"value2"
|
||||
|
||||
store.put(key, value1)
|
||||
store.put(key, value2)
|
||||
|
||||
assert len(store.store) == 1
|
||||
stored_value, _ = store.store[key]
|
||||
assert stored_value == value2
|
||||
|
||||
def test_get_existing_valid_value(self):
|
||||
"""Test retrieving an existing, non-expired value."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_get_nonexistent_key(self):
|
||||
"""Test retrieving a non-existent key returns None."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value is None
|
||||
|
||||
def test_get_expired_value(self):
|
||||
"""Test that expired values are automatically removed and return None."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
expired_validity = time.time() - 1 # 1 second ago
|
||||
|
||||
# Manually insert expired value
|
||||
store.store[key] = (value, expired_validity)
|
||||
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value is None
|
||||
assert key not in store.store # Should be removed
|
||||
|
||||
def test_remove_existing_key(self):
|
||||
"""Test removing an existing key."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
result = store.remove(key)
|
||||
|
||||
assert result is True
|
||||
assert key not in store.store
|
||||
|
||||
def test_remove_nonexistent_key(self):
|
||||
"""Test removing a non-existent key returns False."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
result = store.remove(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_has_existing_valid_key(self):
|
||||
"""Test has() returns True for existing, valid keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
|
||||
store.put(key, value)
|
||||
result = store.has(key)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_has_nonexistent_key(self):
|
||||
"""Test has() returns False for non-existent keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"nonexistent_key"
|
||||
|
||||
result = store.has(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_has_expired_key(self):
|
||||
"""Test has() returns False for expired keys and removes them."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"test_key"
|
||||
value = b"test_value"
|
||||
expired_validity = time.time() - 1
|
||||
|
||||
# Manually insert expired value
|
||||
store.store[key] = (value, expired_validity)
|
||||
|
||||
result = store.has(key)
|
||||
|
||||
assert result is False
|
||||
assert key not in store.store # Should be removed
|
||||
|
||||
def test_cleanup_expired_no_expired_values(self):
|
||||
"""Test cleanup when there are no expired values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
value = b"value"
|
||||
|
||||
store.put(key1, value)
|
||||
store.put(key2, value)
|
||||
|
||||
expired_count = store.cleanup_expired()
|
||||
|
||||
assert expired_count == 0
|
||||
assert len(store.store) == 2
|
||||
|
||||
def test_cleanup_expired_with_expired_values(self):
|
||||
"""Test cleanup removes expired values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"valid_key"
|
||||
key2 = b"expired_key1"
|
||||
key3 = b"expired_key2"
|
||||
value = b"value"
|
||||
expired_validity = time.time() - 1
|
||||
|
||||
store.put(key1, value) # Valid
|
||||
store.store[key2] = (value, expired_validity) # Expired
|
||||
store.store[key3] = (value, expired_validity) # Expired
|
||||
|
||||
expired_count = store.cleanup_expired()
|
||||
|
||||
assert expired_count == 2
|
||||
assert len(store.store) == 1
|
||||
assert key1 in store.store
|
||||
assert key2 not in store.store
|
||||
assert key3 not in store.store
|
||||
|
||||
def test_cleanup_expired_mixed_validity_types(self):
|
||||
"""Test cleanup with mix of values with and without expiration."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"no_expiry"
|
||||
key2 = b"valid_expiry"
|
||||
key3 = b"expired"
|
||||
value = b"value"
|
||||
|
||||
# No expiration (None validity)
|
||||
store.put(key1, value)
|
||||
# Valid expiration
|
||||
store.put(key2, value, validity=time.time() + 3600)
|
||||
# Expired
|
||||
store.store[key3] = (value, time.time() - 1)
|
||||
|
||||
expired_count = store.cleanup_expired()
|
||||
|
||||
assert expired_count == 1
|
||||
assert len(store.store) == 2
|
||||
assert key1 in store.store
|
||||
assert key2 in store.store
|
||||
assert key3 not in store.store
|
||||
|
||||
def test_get_keys_empty_store(self):
|
||||
"""Test get_keys() returns empty list for empty store."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
keys = store.get_keys()
|
||||
|
||||
assert keys == []
|
||||
|
||||
def test_get_keys_with_valid_values(self):
|
||||
"""Test get_keys() returns all non-expired keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"expired_key"
|
||||
value = b"value"
|
||||
|
||||
store.put(key1, value)
|
||||
store.put(key2, value)
|
||||
store.store[key3] = (value, time.time() - 1) # Expired
|
||||
|
||||
keys = store.get_keys()
|
||||
|
||||
assert len(keys) == 2
|
||||
assert key1 in keys
|
||||
assert key2 in keys
|
||||
assert key3 not in keys
|
||||
|
||||
def test_size_empty_store(self):
|
||||
"""Test size() returns 0 for empty store."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
size = store.size()
|
||||
|
||||
assert size == 0
|
||||
|
||||
def test_size_with_valid_values(self):
|
||||
"""Test size() returns correct count after cleaning expired values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"expired_key"
|
||||
value = b"value"
|
||||
|
||||
store.put(key1, value)
|
||||
store.put(key2, value)
|
||||
store.store[key3] = (value, time.time() - 1) # Expired
|
||||
|
||||
size = store.size()
|
||||
|
||||
assert size == 2
|
||||
|
||||
def test_edge_case_empty_key(self):
|
||||
"""Test handling of empty key."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b""
|
||||
value = b"value"
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_edge_case_empty_value(self):
|
||||
"""Test handling of empty value."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b""
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_edge_case_large_key_value(self):
|
||||
"""Test handling of large keys and values."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"x" * 10000 # 10KB key
|
||||
value = b"y" * 100000 # 100KB value
|
||||
|
||||
store.put(key, value)
|
||||
retrieved_value = store.get(key)
|
||||
|
||||
assert retrieved_value == value
|
||||
|
||||
def test_edge_case_negative_validity(self):
|
||||
"""Test handling of negative validity time."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
|
||||
store.put(key, value, validity=-1)
|
||||
|
||||
# Should be expired
|
||||
retrieved_value = store.get(key)
|
||||
assert retrieved_value is None
|
||||
|
||||
def test_default_ttl_calculation(self):
|
||||
"""Test that default TTL is correctly applied."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
start_time = time.time()
|
||||
|
||||
store.put(key, value)
|
||||
|
||||
_, validity = store.store[key]
|
||||
expected_validity = start_time + DEFAULT_TTL
|
||||
|
||||
# Allow small time difference for execution
|
||||
assert abs(validity - expected_validity) < 1
|
||||
|
||||
def test_concurrent_operations(self):
|
||||
"""Test that multiple operations don't interfere with each other."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Add multiple key-value pairs
|
||||
for i in range(100):
|
||||
key = f"key_{i}".encode()
|
||||
value = f"value_{i}".encode()
|
||||
store.put(key, value)
|
||||
|
||||
# Verify all are stored
|
||||
assert store.size() == 100
|
||||
|
||||
# Remove every other key
|
||||
for i in range(0, 100, 2):
|
||||
key = f"key_{i}".encode()
|
||||
store.remove(key)
|
||||
|
||||
# Verify correct count
|
||||
assert store.size() == 50
|
||||
|
||||
# Verify remaining keys are correct
|
||||
for i in range(1, 100, 2):
|
||||
key = f"key_{i}".encode()
|
||||
assert store.has(key)
|
||||
|
||||
def test_expiration_boundary_conditions(self):
|
||||
"""Test expiration around current time boundary."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key1 = b"key1"
|
||||
key2 = b"key2"
|
||||
key3 = b"key3"
|
||||
value = b"value"
|
||||
current_time = time.time()
|
||||
|
||||
# Just expired
|
||||
store.store[key1] = (value, current_time - 0.001)
|
||||
# Valid for a longer time to account for test execution time
|
||||
store.store[key2] = (value, current_time + 1.0)
|
||||
# Exactly current time (should be expired)
|
||||
store.store[key3] = (value, current_time)
|
||||
|
||||
# Small delay to ensure time has passed
|
||||
time.sleep(0.002)
|
||||
|
||||
assert not store.has(key1) # Should be expired
|
||||
assert store.has(key2) # Should be valid
|
||||
assert not store.has(key3) # Should be expired (exactly at current time)
|
||||
|
||||
def test_store_internal_structure(self):
|
||||
"""Test that internal store structure is maintained correctly."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
validity = time.time() + 3600
|
||||
|
||||
store.put(key, value, validity=validity)
|
||||
|
||||
# Verify internal structure
|
||||
assert isinstance(store.store, dict)
|
||||
assert key in store.store
|
||||
stored_tuple = store.store[key]
|
||||
assert isinstance(stored_tuple, tuple)
|
||||
assert len(stored_tuple) == 2
|
||||
assert stored_tuple[0] == value
|
||||
assert stored_tuple[1] == validity
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_store_at_peer_local_peer(self):
|
||||
"""Test _store_at_peer returns True when storing at local peer."""
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
value = b"value"
|
||||
|
||||
result = await store._store_at_peer(peer_id, key, value)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_from_peer_local_peer(self):
|
||||
"""Test _get_from_peer returns None when querying local peer."""
|
||||
mock_host = Mock()
|
||||
peer_id = ID.from_base58("QmTest123")
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
key = b"key"
|
||||
|
||||
result = await store._get_from_peer(peer_id, key)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_memory_efficiency_large_dataset(self):
|
||||
"""Test memory behavior with large datasets."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Add a large number of entries
|
||||
num_entries = 10000
|
||||
for i in range(num_entries):
|
||||
key = f"key_{i:05d}".encode()
|
||||
value = f"value_{i:05d}".encode()
|
||||
store.put(key, value)
|
||||
|
||||
assert store.size() == num_entries
|
||||
|
||||
# Clean up all entries
|
||||
for i in range(num_entries):
|
||||
key = f"key_{i:05d}".encode()
|
||||
store.remove(key)
|
||||
|
||||
assert store.size() == 0
|
||||
assert len(store.store) == 0
|
||||
|
||||
def test_key_collision_resistance(self):
|
||||
"""Test that similar keys don't collide."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Test keys that might cause collisions
|
||||
keys = [
|
||||
b"key",
|
||||
b"key\x00",
|
||||
b"key1",
|
||||
b"Key", # Different case
|
||||
b"key ", # With space
|
||||
b" key", # Leading space
|
||||
]
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
value = f"value_{i}".encode()
|
||||
store.put(key, value)
|
||||
|
||||
# Verify all keys are stored separately
|
||||
assert store.size() == len(keys)
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
expected_value = f"value_{i}".encode()
|
||||
assert store.get(key) == expected_value
|
||||
|
||||
def test_unicode_key_handling(self):
|
||||
"""Test handling of unicode content in keys."""
|
||||
store = ValueStore(host=mock_host, local_peer_id=peer_id)
|
||||
|
||||
# Test various unicode keys
|
||||
unicode_keys = [
|
||||
b"hello",
|
||||
"héllo".encode(),
|
||||
"🔑".encode(),
|
||||
"ключ".encode(), # Russian
|
||||
"键".encode(), # Chinese
|
||||
]
|
||||
|
||||
for i, key in enumerate(unicode_keys):
|
||||
value = f"value_{i}".encode()
|
||||
store.put(key, value)
|
||||
assert store.get(key) == value
|
||||
@ -17,6 +17,7 @@ from tests.utils.factories import (
|
||||
from tests.utils.pubsub.utils import (
|
||||
dense_connect,
|
||||
one_to_all_connect,
|
||||
sparse_connect,
|
||||
)
|
||||
|
||||
|
||||
@ -506,3 +507,84 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
|
||||
# Check that the peer to gossip to is not in our fanout peers
|
||||
assert peer not in fanout_peers
|
||||
assert topic_fanout in peers_to_gossip[peer]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dense_connect_fallback():
|
||||
"""Test that sparse_connect falls back to dense connect for small networks."""
|
||||
async with PubsubFactory.create_batch_with_gossipsub(3) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
degree = 2
|
||||
|
||||
# Create network (should use dense connect)
|
||||
await sparse_connect(hosts, degree)
|
||||
|
||||
# Wait for connections to be established
|
||||
await trio.sleep(2)
|
||||
|
||||
# Verify dense topology (all nodes connected to each other)
|
||||
for i, pubsub in enumerate(pubsubs_gsub):
|
||||
connected_peers = len(pubsub.peers)
|
||||
expected_connections = len(hosts) - 1
|
||||
assert connected_peers == expected_connections, (
|
||||
f"Host {i} has {connected_peers} connections, "
|
||||
f"expected {expected_connections} in dense mode"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_sparse_connect():
|
||||
"""Test sparse connect functionality and message propagation."""
|
||||
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
|
||||
hosts = [pubsub.host for pubsub in pubsubs_gsub]
|
||||
degree = 2
|
||||
topic = "test_topic"
|
||||
|
||||
# Create network (should use sparse connect)
|
||||
await sparse_connect(hosts, degree)
|
||||
|
||||
# Wait for connections to be established
|
||||
await trio.sleep(2)
|
||||
|
||||
# Verify sparse topology
|
||||
for i, pubsub in enumerate(pubsubs_gsub):
|
||||
connected_peers = len(pubsub.peers)
|
||||
assert degree <= connected_peers < len(hosts) - 1, (
|
||||
f"Host {i} has {connected_peers} connections, "
|
||||
f"expected between {degree} and {len(hosts) - 1} in sparse mode"
|
||||
)
|
||||
|
||||
# Test message propagation
|
||||
queues = [await pubsub.subscribe(topic) for pubsub in pubsubs_gsub]
|
||||
await trio.sleep(2)
|
||||
|
||||
# Publish and verify message propagation
|
||||
msg_content = b"test_msg"
|
||||
await pubsubs_gsub[0].publish(topic, msg_content)
|
||||
await trio.sleep(2)
|
||||
|
||||
# Verify message propagation - ideally all nodes should receive it
|
||||
received_count = 0
|
||||
for queue in queues:
|
||||
try:
|
||||
msg = await queue.get()
|
||||
if msg.data == msg_content:
|
||||
received_count += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
total_nodes = len(pubsubs_gsub)
|
||||
|
||||
# Ideally all nodes should receive the message for optimal scalability
|
||||
if received_count == total_nodes:
|
||||
# Perfect propagation achieved
|
||||
pass
|
||||
else:
|
||||
# require more than half for acceptable scalability
|
||||
min_required = (total_nodes + 1) // 2
|
||||
assert received_count >= min_required, (
|
||||
f"Message propagation insufficient: "
|
||||
f"{received_count}/{total_nodes} nodes "
|
||||
f"received the message. Ideally all nodes should receive it, but at "
|
||||
f"minimum {min_required} required for sparse network scalability."
|
||||
)
|
||||
|
||||
263
tests/core/relay/test_circuit_v2_discovery.py
Normal file
263
tests/core/relay/test_circuit_v2_discovery.py
Normal file
@ -0,0 +1,263 @@
|
||||
"""Tests for the Circuit Relay v2 discovery functionality."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.relay.circuit_v2.discovery import (
|
||||
RelayDiscovery,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.pb import circuit_pb2 as proto
|
||||
from libp2p.relay.circuit_v2.protocol import (
|
||||
PROTOCOL_ID,
|
||||
STOP_PROTOCOL_ID,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test timeouts
|
||||
CONNECT_TIMEOUT = 15 # seconds
|
||||
STREAM_TIMEOUT = 15 # seconds
|
||||
HANDLER_TIMEOUT = 15 # seconds
|
||||
SLEEP_TIME = 1.0 # seconds
|
||||
DISCOVERY_TIMEOUT = 20 # seconds
|
||||
|
||||
|
||||
# Make a simple stream handler for testing
|
||||
async def simple_stream_handler(stream):
|
||||
"""Simple stream handler that reads a message and responds with OK status."""
|
||||
logger.info("Simple stream handler invoked")
|
||||
try:
|
||||
# Read the request
|
||||
request_data = await stream.read(MAX_READ_LEN)
|
||||
if not request_data:
|
||||
logger.error("Empty request received")
|
||||
return
|
||||
|
||||
# Parse request
|
||||
request = proto.HopMessage()
|
||||
request.ParseFromString(request_data)
|
||||
logger.info("Received request: type=%s", request.type)
|
||||
|
||||
# Only handle RESERVE requests
|
||||
if request.type == proto.HopMessage.RESERVE:
|
||||
# Create a valid response
|
||||
response = proto.HopMessage(
|
||||
type=proto.HopMessage.RESERVE,
|
||||
status=proto.Status(
|
||||
code=proto.Status.OK,
|
||||
message="Test reservation accepted",
|
||||
),
|
||||
reservation=proto.Reservation(
|
||||
expire=int(time.time()) + 3600, # 1 hour from now
|
||||
voucher=b"test-voucher",
|
||||
signature=b"",
|
||||
),
|
||||
limit=proto.Limit(
|
||||
duration=3600, # 1 hour
|
||||
data=1024 * 1024 * 1024, # 1GB
|
||||
),
|
||||
)
|
||||
|
||||
# Send the response
|
||||
logger.info("Sending response")
|
||||
await stream.write(response.SerializeToString())
|
||||
logger.info("Response sent")
|
||||
except Exception as e:
|
||||
logger.error("Error in simple stream handler: %s", str(e))
|
||||
finally:
|
||||
# Keep stream open to allow client to read response
|
||||
await trio.sleep(1)
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_relay_discovery_initialization():
|
||||
"""Test Circuit v2 relay discovery initializes correctly with default settings."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
discovery = RelayDiscovery(host)
|
||||
|
||||
async with background_trio_service(discovery):
|
||||
await discovery.event_started.wait()
|
||||
await trio.sleep(SLEEP_TIME) # Give time for discovery to start
|
||||
|
||||
# Verify discovery is initialized correctly
|
||||
assert discovery.host == host, "Host not set correctly"
|
||||
assert discovery.is_running, "Discovery service should be running"
|
||||
assert hasattr(discovery, "_discovered_relays"), (
|
||||
"Discovery should track discovered relays"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_relay_discovery_find_relay():
|
||||
"""Test finding a relay node via discovery."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
relay_host, client_host = hosts
|
||||
logger.info("Created hosts for test_relay_discovery_find_relay")
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Client host ID: %s", client_host.get_id())
|
||||
|
||||
# Explicitly register the protocol handlers on relay_host
|
||||
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||
|
||||
# Manually add protocol to peerstore for testing
|
||||
# This simulates what the real relay protocol would do
|
||||
client_host.get_peerstore().add_protocols(
|
||||
relay_host.get_id(), [str(PROTOCOL_ID)]
|
||||
)
|
||||
|
||||
# Set up discovery on the client host
|
||||
client_discovery = RelayDiscovery(
|
||||
client_host, discovery_interval=5
|
||||
) # Use shorter interval for testing
|
||||
|
||||
try:
|
||||
# Connect peers so they can discover each other
|
||||
with trio.fail_after(CONNECT_TIMEOUT):
|
||||
logger.info("Connecting client host to relay host")
|
||||
await connect(client_host, relay_host)
|
||||
assert relay_host.get_network().connections[client_host.get_id()], (
|
||||
"Peers not connected"
|
||||
)
|
||||
logger.info("Connection established between peers")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Start discovery service
|
||||
async with background_trio_service(client_discovery):
|
||||
await client_discovery.event_started.wait()
|
||||
logger.info("Client discovery service started")
|
||||
|
||||
# Wait for discovery to find the relay
|
||||
logger.info("Waiting for relay discovery...")
|
||||
|
||||
# Manually trigger discovery instead of waiting
|
||||
await client_discovery.discover_relays()
|
||||
|
||||
# Check if relay was found
|
||||
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||
for _ in range(20): # Try multiple times
|
||||
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||
logger.info("Relay discovered successfully")
|
||||
break
|
||||
|
||||
# Wait and try again
|
||||
await trio.sleep(1)
|
||||
# Manually trigger discovery again
|
||||
await client_discovery.discover_relays()
|
||||
else:
|
||||
pytest.fail("Failed to discover relay node within timeout")
|
||||
|
||||
# Verify that relay was found and is valid
|
||||
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||
"Relay should be discovered"
|
||||
)
|
||||
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||
assert relay_info.peer_id == relay_host.get_id(), "Peer ID should match"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_relay_discovery_auto_reservation():
|
||||
"""Test that discovery can auto-reserve with discovered relays."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
relay_host, client_host = hosts
|
||||
logger.info("Created hosts for test_relay_discovery_auto_reservation")
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Client host ID: %s", client_host.get_id())
|
||||
|
||||
# Explicitly register the protocol handlers on relay_host
|
||||
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
|
||||
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
|
||||
|
||||
# Manually add protocol to peerstore for testing
|
||||
client_host.get_peerstore().add_protocols(
|
||||
relay_host.get_id(), [str(PROTOCOL_ID)]
|
||||
)
|
||||
|
||||
# Set up discovery on the client host with auto-reservation enabled
|
||||
client_discovery = RelayDiscovery(
|
||||
client_host, auto_reserve=True, discovery_interval=5
|
||||
)
|
||||
|
||||
try:
|
||||
# Connect peers so they can discover each other
|
||||
with trio.fail_after(CONNECT_TIMEOUT):
|
||||
logger.info("Connecting client host to relay host")
|
||||
await connect(client_host, relay_host)
|
||||
assert relay_host.get_network().connections[client_host.get_id()], (
|
||||
"Peers not connected"
|
||||
)
|
||||
logger.info("Connection established between peers")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Start discovery service
|
||||
async with background_trio_service(client_discovery):
|
||||
await client_discovery.event_started.wait()
|
||||
logger.info("Client discovery service started")
|
||||
|
||||
# Wait for discovery to find the relay and make a reservation
|
||||
logger.info("Waiting for relay discovery and auto-reservation...")
|
||||
|
||||
# Manually trigger discovery
|
||||
await client_discovery.discover_relays()
|
||||
|
||||
# Check if relay was found and reservation was made
|
||||
with trio.fail_after(DISCOVERY_TIMEOUT):
|
||||
for _ in range(20): # Try multiple times
|
||||
relay_found = (
|
||||
relay_host.get_id() in client_discovery._discovered_relays
|
||||
)
|
||||
has_reservation = (
|
||||
relay_found
|
||||
and client_discovery._discovered_relays[
|
||||
relay_host.get_id()
|
||||
].has_reservation
|
||||
)
|
||||
if has_reservation:
|
||||
logger.info(
|
||||
"Relay discovered and reservation made successfully"
|
||||
)
|
||||
break
|
||||
|
||||
# Wait and try again
|
||||
await trio.sleep(1)
|
||||
# Try to make reservation manually
|
||||
if relay_host.get_id() in client_discovery._discovered_relays:
|
||||
await client_discovery.make_reservation(relay_host.get_id())
|
||||
else:
|
||||
pytest.fail(
|
||||
"Failed to discover relay and make reservation within timeout"
|
||||
)
|
||||
|
||||
# Verify that relay was found and reservation was made
|
||||
assert relay_host.get_id() in client_discovery._discovered_relays, (
|
||||
"Relay should be discovered"
|
||||
)
|
||||
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
|
||||
assert relay_info.has_reservation, "Reservation should be made"
|
||||
assert relay_info.reservation_expires_at is not None, (
|
||||
"Reservation should have expiry time"
|
||||
)
|
||||
assert relay_info.reservation_data_limit is not None, (
|
||||
"Reservation should have data limit"
|
||||
)
|
||||
665
tests/core/relay/test_circuit_v2_protocol.py
Normal file
665
tests/core/relay/test_circuit_v2_protocol.py
Normal file
@ -0,0 +1,665 @@
|
||||
"""Tests for the Circuit Relay v2 protocol."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamEOF,
|
||||
StreamError,
|
||||
StreamReset,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.pb import circuit_pb2 as proto
|
||||
from libp2p.relay.circuit_v2.protocol import (
|
||||
DEFAULT_RELAY_LIMITS,
|
||||
PROTOCOL_ID,
|
||||
STOP_PROTOCOL_ID,
|
||||
CircuitV2Protocol,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.resources import (
|
||||
RelayLimits,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test timeouts
|
||||
CONNECT_TIMEOUT = 15 # seconds (increased)
|
||||
STREAM_TIMEOUT = 15 # seconds (increased)
|
||||
HANDLER_TIMEOUT = 15 # seconds (increased)
|
||||
SLEEP_TIME = 1.0 # seconds (increased)
|
||||
|
||||
|
||||
async def assert_stream_response(
|
||||
stream, expected_type, expected_status, retries=5, retry_delay=1.0
|
||||
):
|
||||
"""Helper function to assert stream response matches expectations."""
|
||||
last_error = None
|
||||
all_responses = []
|
||||
|
||||
# Increase initial sleep to ensure response has time to arrive
|
||||
await trio.sleep(retry_delay * 2)
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
# Wait between attempts
|
||||
if attempt > 0:
|
||||
await trio.sleep(retry_delay)
|
||||
|
||||
# Try to read response
|
||||
logger.debug("Attempt %d: Reading response from stream", attempt + 1)
|
||||
response_bytes = await stream.read(MAX_READ_LEN)
|
||||
|
||||
# Check if we got any data
|
||||
if not response_bytes:
|
||||
logger.warning(
|
||||
"Attempt %d: No data received from stream", attempt + 1
|
||||
)
|
||||
last_error = "No response received"
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
raise AssertionError(
|
||||
f"No response received after {retries} attempts"
|
||||
)
|
||||
|
||||
# Try to parse the response
|
||||
response = proto.HopMessage()
|
||||
try:
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Log what we received
|
||||
logger.debug(
|
||||
"Attempt %d: Received HOP response: type=%s, status=%s",
|
||||
attempt + 1,
|
||||
response.type,
|
||||
response.status.code
|
||||
if response.HasField("status")
|
||||
else "No status",
|
||||
)
|
||||
|
||||
all_responses.append(
|
||||
{
|
||||
"type": response.type,
|
||||
"status": response.status.code
|
||||
if response.HasField("status")
|
||||
else None,
|
||||
"message": response.status.message
|
||||
if response.HasField("status")
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
# Accept any valid response with the right status
|
||||
if (
|
||||
expected_status is not None
|
||||
and response.HasField("status")
|
||||
and response.status.code == expected_status
|
||||
):
|
||||
if response.type != expected_type:
|
||||
logger.warning(
|
||||
"Type mismatch (%s, got %s) but status ok - accepting",
|
||||
expected_type,
|
||||
response.type,
|
||||
)
|
||||
|
||||
logger.debug("Successfully validated response (status matched)")
|
||||
return response
|
||||
|
||||
# Check message type specifically if it matters
|
||||
if response.type != expected_type:
|
||||
logger.warning(
|
||||
"Wrong response type: expected %s, got %s",
|
||||
expected_type,
|
||||
response.type,
|
||||
)
|
||||
last_error = (
|
||||
f"Wrong response type: expected {expected_type}, "
|
||||
f"got {response.type}"
|
||||
)
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
|
||||
# Check status code if present
|
||||
if response.HasField("status"):
|
||||
if response.status.code != expected_status:
|
||||
logger.warning(
|
||||
"Wrong status code: expected %s, got %s",
|
||||
expected_status,
|
||||
response.status.code,
|
||||
)
|
||||
last_error = (
|
||||
f"Wrong status code: expected {expected_status}, "
|
||||
f"got {response.status.code}"
|
||||
)
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
elif expected_status is not None:
|
||||
logger.warning(
|
||||
"Expected status %s but none was present in response",
|
||||
expected_status,
|
||||
)
|
||||
last_error = (
|
||||
f"Expected status {expected_status} but none was present"
|
||||
)
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
|
||||
logger.debug("Successfully validated response")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# If parsing as HOP message fails, try parsing as STOP message
|
||||
logger.warning(
|
||||
"Failed to parse as HOP message, trying STOP message: %s",
|
||||
str(e),
|
||||
)
|
||||
try:
|
||||
stop_msg = proto.StopMessage()
|
||||
stop_msg.ParseFromString(response_bytes)
|
||||
logger.debug("Parsed as STOP message: type=%s", stop_msg.type)
|
||||
# Create a simplified response dictionary
|
||||
has_status = stop_msg.HasField("status")
|
||||
status_code = None
|
||||
status_message = None
|
||||
if has_status:
|
||||
status_code = stop_msg.status.code
|
||||
status_message = stop_msg.status.message
|
||||
|
||||
response_dict: dict[str, Any] = {
|
||||
"stop_type": stop_msg.type, # Keep original type
|
||||
"status": status_code, # Keep original type
|
||||
"message": status_message, # Keep original type
|
||||
}
|
||||
all_responses.append(response_dict)
|
||||
last_error = "Got STOP message instead of HOP message"
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
except Exception as e2:
|
||||
logger.warning(
|
||||
"Failed to parse response as either message type: %s",
|
||||
str(e2),
|
||||
)
|
||||
last_error = (
|
||||
f"Failed to parse response: {str(e)}, then {str(e2)}"
|
||||
)
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.warning(
|
||||
"Attempt %d: Timeout waiting for stream response", attempt + 1
|
||||
)
|
||||
last_error = "Timeout waiting for stream response"
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
except (StreamError, StreamReset, StreamEOF) as e:
|
||||
logger.warning(
|
||||
"Attempt %d: Stream error while reading response: %s",
|
||||
attempt + 1,
|
||||
str(e),
|
||||
)
|
||||
last_error = f"Stream error: {str(e)}"
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
except AssertionError as e:
|
||||
logger.warning("Attempt %d: Assertion failed: %s", attempt + 1, str(e))
|
||||
last_error = str(e)
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Attempt %d: Unexpected error: %s", attempt + 1, str(e))
|
||||
last_error = f"Unexpected error: {str(e)}"
|
||||
if attempt < retries - 1: # Not the last attempt
|
||||
continue
|
||||
|
||||
# If we've reached here, all retries failed
|
||||
all_responses_str = ", ".join([str(r) for r in all_responses])
|
||||
error_msg = (
|
||||
f"Failed to get expected response after {retries} attempts. "
|
||||
f"Last error: {last_error}. All responses: {all_responses_str}"
|
||||
)
|
||||
raise AssertionError(error_msg)
|
||||
|
||||
|
||||
async def close_stream(stream):
|
||||
"""Helper function to safely close a stream."""
|
||||
if stream is not None:
|
||||
try:
|
||||
logger.debug("Closing stream")
|
||||
await stream.close()
|
||||
# Wait a bit to ensure the close is processed
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
logger.debug("Stream closed successfully")
|
||||
except (StreamError, Exception) as e:
|
||||
logger.warning("Error closing stream: %s. Attempting to reset.", str(e))
|
||||
try:
|
||||
await stream.reset()
|
||||
# Wait a bit to ensure the reset is processed
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
logger.debug("Stream reset successfully")
|
||||
except Exception as e:
|
||||
logger.warning("Error resetting stream: %s", str(e))
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_circuit_v2_protocol_initialization():
|
||||
"""Test that the Circuit v2 protocol initializes correctly with default settings."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
limits = RelayLimits(
|
||||
duration=DEFAULT_RELAY_LIMITS.duration,
|
||||
data=DEFAULT_RELAY_LIMITS.data,
|
||||
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
|
||||
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
|
||||
)
|
||||
protocol = CircuitV2Protocol(host, limits, allow_hop=True)
|
||||
|
||||
async with background_trio_service(protocol):
|
||||
await protocol.event_started.wait()
|
||||
await trio.sleep(SLEEP_TIME) # Give time for handlers to be registered
|
||||
|
||||
# Verify protocol handlers are registered by trying to use them
|
||||
test_stream = None
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
test_stream = await host.new_stream(host.get_id(), [PROTOCOL_ID])
|
||||
assert test_stream is not None, (
|
||||
"HOP protocol handler not registered"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await close_stream(test_stream)
|
||||
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
test_stream = await host.new_stream(
|
||||
host.get_id(), [STOP_PROTOCOL_ID]
|
||||
)
|
||||
assert test_stream is not None, (
|
||||
"STOP protocol handler not registered"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await close_stream(test_stream)
|
||||
|
||||
assert len(protocol.resource_manager._reservations) == 0, (
|
||||
"Reservations should be empty"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_circuit_v2_reservation_basic():
|
||||
"""Test basic reservation functionality between two peers."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
relay_host, client_host = hosts
|
||||
logger.info("Created hosts for test_circuit_v2_reservation_basic")
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Client host ID: %s", client_host.get_id())
|
||||
|
||||
# Custom handler that responds directly with a valid response
|
||||
# This bypasses the complex protocol implementation that might have issues
|
||||
async def mock_reserve_handler(stream):
|
||||
# Read the request
|
||||
logger.info("Mock handler received stream request")
|
||||
try:
|
||||
request_data = await stream.read(MAX_READ_LEN)
|
||||
request = proto.HopMessage()
|
||||
request.ParseFromString(request_data)
|
||||
logger.info("Mock handler parsed request: type=%s", request.type)
|
||||
|
||||
# Only handle RESERVE requests
|
||||
if request.type == proto.HopMessage.RESERVE:
|
||||
# Create a valid response
|
||||
response = proto.HopMessage(
|
||||
type=proto.HopMessage.RESERVE,
|
||||
status=proto.Status(
|
||||
code=proto.Status.OK,
|
||||
message="Reservation accepted",
|
||||
),
|
||||
reservation=proto.Reservation(
|
||||
expire=int(time.time()) + 3600, # 1 hour from now
|
||||
voucher=b"test-voucher",
|
||||
signature=b"",
|
||||
),
|
||||
limit=proto.Limit(
|
||||
duration=3600, # 1 hour
|
||||
data=1024 * 1024 * 1024, # 1GB
|
||||
),
|
||||
)
|
||||
|
||||
# Send the response
|
||||
logger.info("Mock handler sending response")
|
||||
await stream.write(response.SerializeToString())
|
||||
logger.info("Mock handler sent response")
|
||||
|
||||
# Keep stream open for client to read response
|
||||
await trio.sleep(5)
|
||||
except Exception as e:
|
||||
logger.error("Error in mock handler: %s", str(e))
|
||||
|
||||
# Register the mock handler
|
||||
relay_host.set_stream_handler(PROTOCOL_ID, mock_reserve_handler)
|
||||
logger.info("Registered mock handler for %s", PROTOCOL_ID)
|
||||
|
||||
# Connect peers
|
||||
try:
|
||||
with trio.fail_after(CONNECT_TIMEOUT):
|
||||
logger.info("Connecting client host to relay host")
|
||||
await connect(client_host, relay_host)
|
||||
assert relay_host.get_network().connections[client_host.get_id()], (
|
||||
"Peers not connected"
|
||||
)
|
||||
logger.info("Connection established between peers")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Wait a bit to ensure connection is fully established
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
stream = None
|
||||
try:
|
||||
# Open stream and send reservation request
|
||||
logger.info("Opening stream from client to relay")
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
stream = await client_host.new_stream(
|
||||
relay_host.get_id(), [PROTOCOL_ID]
|
||||
)
|
||||
assert stream is not None, "Failed to open stream"
|
||||
|
||||
logger.info("Preparing reservation request")
|
||||
request = proto.HopMessage(
|
||||
type=proto.HopMessage.RESERVE, peer=client_host.get_id().to_bytes()
|
||||
)
|
||||
|
||||
logger.info("Sending reservation request")
|
||||
await stream.write(request.SerializeToString())
|
||||
logger.info("Reservation request sent")
|
||||
|
||||
# Wait to ensure the request is processed
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
# Read response directly
|
||||
logger.info("Reading response directly")
|
||||
response_bytes = await stream.read(MAX_READ_LEN)
|
||||
assert response_bytes, "No response received"
|
||||
|
||||
# Parse response
|
||||
response = proto.HopMessage()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Verify response
|
||||
assert response.type == proto.HopMessage.RESERVE, (
|
||||
f"Wrong response type: {response.type}"
|
||||
)
|
||||
assert response.HasField("status"), "No status field"
|
||||
assert response.status.code == proto.Status.OK, (
|
||||
f"Wrong status code: {response.status.code}"
|
||||
)
|
||||
|
||||
# Verify reservation details
|
||||
assert response.HasField("reservation"), "No reservation field"
|
||||
assert response.HasField("limit"), "No limit field"
|
||||
assert response.limit.duration == 3600, (
|
||||
f"Wrong duration: {response.limit.duration}"
|
||||
)
|
||||
assert response.limit.data == 1024 * 1024 * 1024, (
|
||||
f"Wrong data limit: {response.limit.data}"
|
||||
)
|
||||
logger.info("Verified reservation details in response")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in reservation test: %s", str(e))
|
||||
raise
|
||||
finally:
|
||||
if stream:
|
||||
await close_stream(stream)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_circuit_v2_reservation_limit():
|
||||
"""Test that relay enforces reservation limits."""
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
relay_host, client1_host, client2_host = hosts
|
||||
logger.info("Created hosts for test_circuit_v2_reservation_limit")
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Client1 host ID: %s", client1_host.get_id())
|
||||
logger.info("Client2 host ID: %s", client2_host.get_id())
|
||||
|
||||
# Track reservation status to simulate limits
|
||||
reserved_clients = set()
|
||||
max_reservations = 1 # Only allow one reservation
|
||||
|
||||
# Custom handler that responds based on reservation limits
|
||||
async def mock_reserve_handler(stream):
|
||||
# Read the request
|
||||
logger.info("Mock handler received stream request")
|
||||
try:
|
||||
request_data = await stream.read(MAX_READ_LEN)
|
||||
request = proto.HopMessage()
|
||||
request.ParseFromString(request_data)
|
||||
logger.info("Mock handler parsed request: type=%s", request.type)
|
||||
|
||||
# Only handle RESERVE requests
|
||||
if request.type == proto.HopMessage.RESERVE:
|
||||
# Extract peer ID from request
|
||||
peer_id = ID(request.peer)
|
||||
logger.info(
|
||||
"Mock handler received reservation request from %s", peer_id
|
||||
)
|
||||
|
||||
# Check if we've reached reservation limit
|
||||
if (
|
||||
peer_id in reserved_clients
|
||||
or len(reserved_clients) < max_reservations
|
||||
):
|
||||
# Accept the reservation
|
||||
if peer_id not in reserved_clients:
|
||||
reserved_clients.add(peer_id)
|
||||
|
||||
# Create a success response
|
||||
response = proto.HopMessage(
|
||||
type=proto.HopMessage.RESERVE,
|
||||
status=proto.Status(
|
||||
code=proto.Status.OK,
|
||||
message="Reservation accepted",
|
||||
),
|
||||
reservation=proto.Reservation(
|
||||
expire=int(time.time()) + 3600, # 1 hour from now
|
||||
voucher=b"test-voucher",
|
||||
signature=b"",
|
||||
),
|
||||
limit=proto.Limit(
|
||||
duration=3600, # 1 hour
|
||||
data=1024 * 1024 * 1024, # 1GB
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
"Mock handler accepting reservation for %s", peer_id
|
||||
)
|
||||
else:
|
||||
# Reject the reservation due to limits
|
||||
response = proto.HopMessage(
|
||||
type=proto.HopMessage.RESERVE,
|
||||
status=proto.Status(
|
||||
code=proto.Status.RESOURCE_LIMIT_EXCEEDED,
|
||||
message="Reservation limit exceeded",
|
||||
),
|
||||
)
|
||||
logger.info(
|
||||
"Mock handler rejecting reservation for %s due to limit",
|
||||
peer_id,
|
||||
)
|
||||
|
||||
# Send the response
|
||||
logger.info("Mock handler sending response")
|
||||
await stream.write(response.SerializeToString())
|
||||
logger.info("Mock handler sent response")
|
||||
|
||||
# Keep stream open for client to read response
|
||||
await trio.sleep(5)
|
||||
except Exception as e:
|
||||
logger.error("Error in mock handler: %s", str(e))
|
||||
|
||||
# Register the mock handler
|
||||
relay_host.set_stream_handler(PROTOCOL_ID, mock_reserve_handler)
|
||||
logger.info("Registered mock handler for %s", PROTOCOL_ID)
|
||||
|
||||
# Connect peers
|
||||
try:
|
||||
with trio.fail_after(CONNECT_TIMEOUT):
|
||||
logger.info("Connecting client1 to relay")
|
||||
await connect(client1_host, relay_host)
|
||||
logger.info("Connecting client2 to relay")
|
||||
await connect(client2_host, relay_host)
|
||||
assert relay_host.get_network().connections[client1_host.get_id()], (
|
||||
"Client1 not connected"
|
||||
)
|
||||
assert relay_host.get_network().connections[client2_host.get_id()], (
|
||||
"Client2 not connected"
|
||||
)
|
||||
logger.info("All connections established")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Wait a bit to ensure connections are fully established
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
stream1, stream2 = None, None
|
||||
try:
|
||||
# Client 1 reservation (should succeed)
|
||||
logger.info("Testing client1 reservation (should succeed)")
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
logger.info("Opening stream for client1")
|
||||
stream1 = await client1_host.new_stream(
|
||||
relay_host.get_id(), [PROTOCOL_ID]
|
||||
)
|
||||
assert stream1 is not None, "Failed to open stream for client 1"
|
||||
|
||||
logger.info("Preparing reservation request for client1")
|
||||
request1 = proto.HopMessage(
|
||||
type=proto.HopMessage.RESERVE, peer=client1_host.get_id().to_bytes()
|
||||
)
|
||||
|
||||
logger.info("Sending reservation request for client1")
|
||||
await stream1.write(request1.SerializeToString())
|
||||
logger.info("Sent reservation request for client1")
|
||||
|
||||
# Wait to ensure the request is processed
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
# Read response directly
|
||||
logger.info("Reading response for client1")
|
||||
response_bytes = await stream1.read(MAX_READ_LEN)
|
||||
assert response_bytes, "No response received for client1"
|
||||
|
||||
# Parse response
|
||||
response1 = proto.HopMessage()
|
||||
response1.ParseFromString(response_bytes)
|
||||
|
||||
# Verify response
|
||||
assert response1.type == proto.HopMessage.RESERVE, (
|
||||
f"Wrong response type: {response1.type}"
|
||||
)
|
||||
assert response1.HasField("status"), "No status field"
|
||||
assert response1.status.code == proto.Status.OK, (
|
||||
f"Wrong status code: {response1.status.code}"
|
||||
)
|
||||
|
||||
# Verify reservation details
|
||||
assert response1.HasField("reservation"), "No reservation field"
|
||||
assert response1.HasField("limit"), "No limit field"
|
||||
assert response1.limit.duration == 3600, (
|
||||
f"Wrong duration: {response1.limit.duration}"
|
||||
)
|
||||
assert response1.limit.data == 1024 * 1024 * 1024, (
|
||||
f"Wrong data limit: {response1.limit.data}"
|
||||
)
|
||||
logger.info("Verified reservation details for client1")
|
||||
|
||||
# Close stream1 before opening stream2
|
||||
await close_stream(stream1)
|
||||
stream1 = None
|
||||
logger.info("Closed client1 stream")
|
||||
|
||||
# Wait a bit to ensure stream is fully closed
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
# Client 2 reservation (should fail)
|
||||
logger.info("Testing client2 reservation (should fail)")
|
||||
stream2 = await client2_host.new_stream(
|
||||
relay_host.get_id(), [PROTOCOL_ID]
|
||||
)
|
||||
assert stream2 is not None, "Failed to open stream for client 2"
|
||||
|
||||
logger.info("Preparing reservation request for client2")
|
||||
request2 = proto.HopMessage(
|
||||
type=proto.HopMessage.RESERVE, peer=client2_host.get_id().to_bytes()
|
||||
)
|
||||
|
||||
logger.info("Sending reservation request for client2")
|
||||
await stream2.write(request2.SerializeToString())
|
||||
logger.info("Sent reservation request for client2")
|
||||
|
||||
# Wait to ensure the request is processed
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
# Read response directly
|
||||
logger.info("Reading response for client2")
|
||||
response_bytes = await stream2.read(MAX_READ_LEN)
|
||||
assert response_bytes, "No response received for client2"
|
||||
|
||||
# Parse response
|
||||
response2 = proto.HopMessage()
|
||||
response2.ParseFromString(response_bytes)
|
||||
|
||||
# Verify response
|
||||
assert response2.type == proto.HopMessage.RESERVE, (
|
||||
f"Wrong response type: {response2.type}"
|
||||
)
|
||||
assert response2.HasField("status"), "No status field"
|
||||
assert response2.status.code == proto.Status.RESOURCE_LIMIT_EXCEEDED, (
|
||||
f"Wrong status code: {response2.status.code}, "
|
||||
f"expected RESOURCE_LIMIT_EXCEEDED"
|
||||
)
|
||||
logger.info("Verified client2 was correctly rejected")
|
||||
|
||||
# Verify reservation tracking is correct
|
||||
assert len(reserved_clients) == 1, "Should have exactly one reservation"
|
||||
assert client1_host.get_id() in reserved_clients, (
|
||||
"Client1 should be reserved"
|
||||
)
|
||||
assert client2_host.get_id() not in reserved_clients, (
|
||||
"Client2 should not be reserved"
|
||||
)
|
||||
logger.info("Verified reservation tracking state")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in reservation limit test: %s", str(e))
|
||||
# Diagnostic information
|
||||
logger.error("Current reservations: %s", reserved_clients)
|
||||
raise
|
||||
finally:
|
||||
await close_stream(stream1)
|
||||
await close_stream(stream2)
|
||||
346
tests/core/relay/test_circuit_v2_transport.py
Normal file
346
tests/core/relay/test_circuit_v2_transport.py
Normal file
@ -0,0 +1,346 @@
|
||||
"""Tests for the Circuit Relay v2 transport functionality."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamEOF,
|
||||
StreamReset,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.config import (
|
||||
RelayConfig,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.discovery import (
|
||||
RelayDiscovery,
|
||||
RelayInfo,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.protocol import (
|
||||
CircuitV2Protocol,
|
||||
RelayLimits,
|
||||
)
|
||||
from libp2p.relay.circuit_v2.transport import (
|
||||
CircuitV2Transport,
|
||||
)
|
||||
from libp2p.tools.constants import (
|
||||
MAX_READ_LEN,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
from tests.utils.factories import (
|
||||
HostFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test timeouts
|
||||
CONNECT_TIMEOUT = 15 # seconds
|
||||
STREAM_TIMEOUT = 15 # seconds
|
||||
HANDLER_TIMEOUT = 15 # seconds
|
||||
SLEEP_TIME = 1.0 # seconds
|
||||
RELAY_TIMEOUT = 20 # seconds
|
||||
|
||||
# Default limits for relay
|
||||
DEFAULT_RELAY_LIMITS = RelayLimits(
|
||||
duration=60 * 60, # 1 hour
|
||||
data=1024 * 1024 * 10, # 10 MB
|
||||
max_circuit_conns=8, # 8 active relay connections
|
||||
max_reservations=4, # 4 active reservations
|
||||
)
|
||||
|
||||
# Message for testing
|
||||
TEST_MESSAGE = b"Hello, Circuit Relay!"
|
||||
TEST_RESPONSE = b"Hello from the other side!"
|
||||
|
||||
|
||||
# Stream handler for testing
|
||||
async def echo_stream_handler(stream):
|
||||
"""Simple echo handler that responds to messages."""
|
||||
logger.info("Echo handler received stream")
|
||||
try:
|
||||
while True:
|
||||
data = await stream.read(MAX_READ_LEN)
|
||||
if not data:
|
||||
logger.info("Stream closed by remote")
|
||||
break
|
||||
|
||||
logger.info("Received data: %s", data)
|
||||
await stream.write(TEST_RESPONSE)
|
||||
logger.info("Sent response")
|
||||
except (StreamEOF, StreamReset) as e:
|
||||
logger.info("Stream ended: %s", str(e))
|
||||
except Exception as e:
|
||||
logger.error("Error in echo handler: %s", str(e))
|
||||
finally:
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_circuit_v2_transport_initialization():
|
||||
"""Test that the Circuit v2 transport initializes correctly."""
|
||||
async with HostFactory.create_batch_and_listen(1) as hosts:
|
||||
host = hosts[0]
|
||||
|
||||
# Create a protocol instance
|
||||
limits = RelayLimits(
|
||||
duration=DEFAULT_RELAY_LIMITS.duration,
|
||||
data=DEFAULT_RELAY_LIMITS.data,
|
||||
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
|
||||
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
|
||||
)
|
||||
protocol = CircuitV2Protocol(host, limits, allow_hop=False)
|
||||
|
||||
config = RelayConfig()
|
||||
|
||||
# Create a discovery instance
|
||||
discovery = RelayDiscovery(
|
||||
host=host,
|
||||
auto_reserve=False,
|
||||
discovery_interval=config.discovery_interval,
|
||||
max_relays=config.max_relays,
|
||||
)
|
||||
|
||||
# Create the transport with the necessary components
|
||||
transport = CircuitV2Transport(host, protocol, config)
|
||||
# Replace the discovery with our manually created one
|
||||
transport.discovery = discovery
|
||||
|
||||
# Verify transport properties
|
||||
assert transport.host == host, "Host not set correctly"
|
||||
assert transport.protocol == protocol, "Protocol not set correctly"
|
||||
assert transport.config == config, "Config not set correctly"
|
||||
assert hasattr(transport, "discovery"), (
|
||||
"Transport should have a discovery instance"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_circuit_v2_transport_add_relay():
|
||||
"""Test adding a relay to the transport."""
|
||||
async with HostFactory.create_batch_and_listen(2) as hosts:
|
||||
host, relay_host = hosts
|
||||
|
||||
# Create a protocol instance
|
||||
limits = RelayLimits(
|
||||
duration=DEFAULT_RELAY_LIMITS.duration,
|
||||
data=DEFAULT_RELAY_LIMITS.data,
|
||||
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
|
||||
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
|
||||
)
|
||||
protocol = CircuitV2Protocol(host, limits, allow_hop=False)
|
||||
|
||||
config = RelayConfig()
|
||||
|
||||
# Create a discovery instance
|
||||
discovery = RelayDiscovery(
|
||||
host=host,
|
||||
auto_reserve=False,
|
||||
discovery_interval=config.discovery_interval,
|
||||
max_relays=config.max_relays,
|
||||
)
|
||||
|
||||
# Create the transport with the necessary components
|
||||
transport = CircuitV2Transport(host, protocol, config)
|
||||
# Replace the discovery with our manually created one
|
||||
transport.discovery = discovery
|
||||
|
||||
relay_id = relay_host.get_id()
|
||||
now = time.time()
|
||||
relay_info = RelayInfo(peer_id=relay_id, discovered_at=now, last_seen=now)
|
||||
|
||||
async def mock_add_relay(peer_id):
|
||||
discovery._discovered_relays[peer_id] = relay_info
|
||||
|
||||
discovery._add_relay = mock_add_relay # Type ignored in test context
|
||||
discovery._discovered_relays[relay_id] = relay_info
|
||||
|
||||
# Verify relay was added
|
||||
assert relay_id in discovery._discovered_relays, (
|
||||
"Relay should be in discovery's relay list"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_circuit_v2_transport_dial_through_relay():
|
||||
"""Test dialing a peer through a relay."""
|
||||
async with HostFactory.create_batch_and_listen(3) as hosts:
|
||||
client_host, relay_host, target_host = hosts
|
||||
logger.info("Created hosts for test_circuit_v2_transport_dial_through_relay")
|
||||
logger.info("Client host ID: %s", client_host.get_id())
|
||||
logger.info("Relay host ID: %s", relay_host.get_id())
|
||||
logger.info("Target host ID: %s", target_host.get_id())
|
||||
|
||||
# Setup relay with Circuit v2 protocol
|
||||
limits = RelayLimits(
|
||||
duration=DEFAULT_RELAY_LIMITS.duration,
|
||||
data=DEFAULT_RELAY_LIMITS.data,
|
||||
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
|
||||
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
|
||||
)
|
||||
|
||||
# Register test handler on target
|
||||
test_protocol = "/test/echo/1.0.0"
|
||||
target_host.set_stream_handler(TProtocol(test_protocol), echo_stream_handler)
|
||||
|
||||
client_config = RelayConfig()
|
||||
client_protocol = CircuitV2Protocol(client_host, limits, allow_hop=False)
|
||||
|
||||
# Create a discovery instance
|
||||
client_discovery = RelayDiscovery(
|
||||
host=client_host,
|
||||
auto_reserve=False,
|
||||
discovery_interval=client_config.discovery_interval,
|
||||
max_relays=client_config.max_relays,
|
||||
)
|
||||
|
||||
# Create the transport with the necessary components
|
||||
client_transport = CircuitV2Transport(
|
||||
client_host, client_protocol, client_config
|
||||
)
|
||||
# Replace the discovery with our manually created one
|
||||
client_transport.discovery = client_discovery
|
||||
|
||||
# Mock the get_relay method to return our relay_host
|
||||
relay_id = relay_host.get_id()
|
||||
client_discovery.get_relay = lambda: relay_id
|
||||
|
||||
# Connect client to relay and relay to target
|
||||
try:
|
||||
with trio.fail_after(
|
||||
CONNECT_TIMEOUT * 2
|
||||
): # Double the timeout for connections
|
||||
logger.info("Connecting client host to relay host")
|
||||
await connect(client_host, relay_host)
|
||||
# Verify connection
|
||||
assert relay_host.get_id() in client_host.get_network().connections, (
|
||||
"Client not connected to relay"
|
||||
)
|
||||
assert client_host.get_id() in relay_host.get_network().connections, (
|
||||
"Relay not connected to client"
|
||||
)
|
||||
logger.info("Client-Relay connection verified")
|
||||
|
||||
# Wait to ensure connection is fully established
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
logger.info("Connecting relay host to target host")
|
||||
await connect(relay_host, target_host)
|
||||
# Verify connection
|
||||
assert target_host.get_id() in relay_host.get_network().connections, (
|
||||
"Relay not connected to target"
|
||||
)
|
||||
assert relay_host.get_id() in target_host.get_network().connections, (
|
||||
"Target not connected to relay"
|
||||
)
|
||||
logger.info("Relay-Target connection verified")
|
||||
|
||||
# Wait to ensure connection is fully established
|
||||
await trio.sleep(SLEEP_TIME)
|
||||
|
||||
logger.info("All connections established and verified")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Test successful - the connections were established, which is enough to verify
|
||||
# that the transport can be initialized and configured correctly
|
||||
logger.info("Transport initialization and connection test passed")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_circuit_v2_transport_relay_limits():
|
||||
"""Test that relay enforces connection limits."""
|
||||
async with HostFactory.create_batch_and_listen(4) as hosts:
|
||||
client1_host, client2_host, relay_host, target_host = hosts
|
||||
logger.info("Created hosts for test_circuit_v2_transport_relay_limits")
|
||||
|
||||
# Setup relay with strict limits
|
||||
limits = RelayLimits(
|
||||
duration=DEFAULT_RELAY_LIMITS.duration,
|
||||
data=DEFAULT_RELAY_LIMITS.data,
|
||||
max_circuit_conns=1, # Only allow one circuit
|
||||
max_reservations=2, # Allow both clients to reserve
|
||||
)
|
||||
relay_protocol = CircuitV2Protocol(relay_host, limits, allow_hop=True)
|
||||
|
||||
# Register test handler on target
|
||||
test_protocol = "/test/echo/1.0.0"
|
||||
target_host.set_stream_handler(TProtocol(test_protocol), echo_stream_handler)
|
||||
|
||||
client_config = RelayConfig()
|
||||
|
||||
# Client 1 setup
|
||||
client1_protocol = CircuitV2Protocol(
|
||||
client1_host, DEFAULT_RELAY_LIMITS, allow_hop=False
|
||||
)
|
||||
client1_discovery = RelayDiscovery(
|
||||
host=client1_host,
|
||||
auto_reserve=False,
|
||||
discovery_interval=client_config.discovery_interval,
|
||||
max_relays=client_config.max_relays,
|
||||
)
|
||||
client1_transport = CircuitV2Transport(
|
||||
client1_host, client1_protocol, client_config
|
||||
)
|
||||
client1_transport.discovery = client1_discovery
|
||||
# Add relay to discovery
|
||||
relay_id = relay_host.get_id()
|
||||
client1_discovery.get_relay = lambda: relay_id
|
||||
|
||||
# Client 2 setup
|
||||
client2_protocol = CircuitV2Protocol(
|
||||
client2_host, DEFAULT_RELAY_LIMITS, allow_hop=False
|
||||
)
|
||||
client2_discovery = RelayDiscovery(
|
||||
host=client2_host,
|
||||
auto_reserve=False,
|
||||
discovery_interval=client_config.discovery_interval,
|
||||
max_relays=client_config.max_relays,
|
||||
)
|
||||
client2_transport = CircuitV2Transport(
|
||||
client2_host, client2_protocol, client_config
|
||||
)
|
||||
client2_transport.discovery = client2_discovery
|
||||
# Add relay to discovery
|
||||
client2_discovery.get_relay = lambda: relay_id
|
||||
|
||||
# Connect all peers
|
||||
try:
|
||||
with trio.fail_after(CONNECT_TIMEOUT):
|
||||
# Connect clients to relay
|
||||
await connect(client1_host, relay_host)
|
||||
await connect(client2_host, relay_host)
|
||||
|
||||
# Connect relay to target
|
||||
await connect(relay_host, target_host)
|
||||
|
||||
logger.info("All connections established")
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect peers: %s", str(e))
|
||||
raise
|
||||
|
||||
# Verify connections
|
||||
assert relay_host.get_id() in client1_host.get_network().connections, (
|
||||
"Client1 not connected to relay"
|
||||
)
|
||||
assert relay_host.get_id() in client2_host.get_network().connections, (
|
||||
"Client2 not connected to relay"
|
||||
)
|
||||
assert target_host.get_id() in relay_host.get_network().connections, (
|
||||
"Relay not connected to target"
|
||||
)
|
||||
|
||||
# Verify the resource limits
|
||||
assert relay_protocol.resource_manager.limits.max_circuit_conns == 1, (
|
||||
"Wrong max_circuit_conns value"
|
||||
)
|
||||
assert relay_protocol.resource_manager.limits.max_reservations == 2, (
|
||||
"Wrong max_reservations value"
|
||||
)
|
||||
|
||||
# Test successful - transports were initialized with the correct limits
|
||||
logger.info("Transport limit test successful")
|
||||
@ -13,6 +13,8 @@ from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
|
||||
from libp2p.security.secure_session import (
|
||||
SecureSession,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.mplex import Mplex
|
||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||
from tests.utils.factories import (
|
||||
host_pair_factory,
|
||||
)
|
||||
@ -47,9 +49,28 @@ async def perform_simple_test(assertion_func, security_protocol):
|
||||
assert conn_0 is not None, "Failed to establish connection from host0 to host1"
|
||||
assert conn_1 is not None, "Failed to establish connection from host1 to host0"
|
||||
|
||||
# Perform assertion
|
||||
assertion_func(conn_0.muxed_conn.secured_conn)
|
||||
assertion_func(conn_1.muxed_conn.secured_conn)
|
||||
# Extract the secured connection from either Mplex or Yamux implementation
|
||||
def get_secured_conn(conn):
|
||||
muxed_conn = conn.muxed_conn
|
||||
# Direct attribute access for known implementations
|
||||
has_secured_conn = hasattr(muxed_conn, "secured_conn")
|
||||
if isinstance(muxed_conn, (Mplex, Yamux)) and has_secured_conn:
|
||||
return muxed_conn.secured_conn
|
||||
# Fallback to _connection attribute if it exists
|
||||
elif hasattr(muxed_conn, "_connection"):
|
||||
return muxed_conn._connection
|
||||
# Last resort - warn but return the muxed_conn itself for type checking
|
||||
else:
|
||||
print(f"Warning: Cannot find secured connection in {type(muxed_conn)}")
|
||||
return muxed_conn
|
||||
|
||||
# Get secured connections for both peers
|
||||
secured_conn_0 = get_secured_conn(conn_0)
|
||||
secured_conn_1 = get_secured_conn(conn_1)
|
||||
|
||||
# Perform assertion on the secured connections
|
||||
assertion_func(secured_conn_0)
|
||||
assertion_func(secured_conn_1)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
|
||||
Reference in New Issue
Block a user