Merge branch 'libp2p:main' into tests/notifee-coverage

This commit is contained in:
Mercy Boma Naps Nkari
2025-08-21 08:07:53 +01:00
committed by GitHub
26 changed files with 1685 additions and 22 deletions

View File

@ -1,9 +1,9 @@
from collections import deque
import pytest
import trio
from libp2p.abc import (
IMultiselectCommunicator,
)
from libp2p.abc import IMultiselectCommunicator, INetStream
from libp2p.custom_types import TProtocol
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
async def dummy_handler(stream: INetStream) -> None:
pass
class DummyMultiselectCommunicator(IMultiselectCommunicator):
"""
Dummy MultiSelectCommunicator to test out negotiate timmeout.
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
@pytest.mark.trio
async def test_select_one_of_timeout():
async def test_select_one_of_timeout() -> None:
ECHO = TProtocol("/echo/1.0.0")
communicator = DummyMultiselectCommunicator()
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
@pytest.mark.trio
async def test_query_multistream_command_timeout():
async def test_query_multistream_command_timeout() -> None:
communicator = DummyMultiselectCommunicator()
client = MultiselectClient()
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
@pytest.mark.trio
async def test_negotiate_timeout():
async def test_negotiate_timeout() -> None:
communicator = DummyMultiselectCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 2)
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
handshaked: bool
def __init__(self) -> None:
self.handshaked = False
async def write(self, msg_str: str) -> None:
if msg_str == "/multistream/1.0.0":
self.handshaked = True
return
async def read(self) -> str:
if not self.handshaked:
return "/multistream/1.0.0"
# After handshake, hang on read.
await trio.sleep_forever()
# Should not be reached.
return ""
@pytest.mark.trio
async def test_negotiate_timeout_post_handshake() -> None:
communicator = HandshakeThenHangCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 1)
class MockCommunicator(IMultiselectCommunicator):
def __init__(self, commands_to_read: list[str]):
self.read_queue = deque(commands_to_read)
self.written_data: list[str] = []
async def write(self, msg_str: str) -> None:
self.written_data.append(msg_str)
async def read(self) -> str:
if not self.read_queue:
raise EOFError
return self.read_queue.popleft()
@pytest.mark.trio
async def test_negotiate_empty_string_command() -> None:
# server receives an empty string, which means client wants `None` protocol.
server = Multiselect({None: dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check that server sent back handshake and the protocol confirmation (empty string)
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler() -> None:
# server has None handler, client sends "" to select it.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, protocol confirmation
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler_ls() -> None:
# server has None handler, client sends "ls" then empty string.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, ls, empty command
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, ls response, protocol confirmation
assert communicator.written_data[0] == "/multistream/1.0.0"
assert "/proto1" in communicator.written_data[1]
# Note: `ls` should not list the `None` protocol.
assert "None" not in communicator.written_data[1]
assert "\n\n" not in communicator.written_data[1]
assert communicator.written_data[2] == ""

View File

@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
protocols = ms.get_protocols()
assert set(protocols) == {p1, p2, p3}
@pytest.mark.trio
async def test_negotiate_optional_tprotocol(security_protocol):
with pytest.raises(Exception):
await perform_simple_test(
None,
[None],
[None],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test(
expected_selected_protocol,
[None, PROTOCOL_ECHO],
[PROTOCOL_ECHO],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_server_none_client_other(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)

View File

@ -0,0 +1,99 @@
"""
Unit tests for the RandomWalk module in libp2p.discovery.random_walk.
"""
from unittest.mock import AsyncMock, Mock
import pytest
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
@pytest.fixture
def mock_host():
host = Mock()
peerstore = Mock()
peerstore.peers_with_addrs.return_value = []
peerstore.addrs.return_value = [Mock()]
host.get_peerstore.return_value = peerstore
host.new_stream = AsyncMock()
return host
@pytest.fixture
def dummy_query_function():
async def query(key_bytes):
return []
return query
@pytest.fixture
def dummy_peer_id():
return b"\x01" * 32
@pytest.mark.trio
async def test_random_walk_initialization(
mock_host, dummy_peer_id, dummy_query_function
):
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
assert rw.host == mock_host
assert rw.local_peer_id == dummy_peer_id
assert rw.query_function == dummy_query_function
def test_generate_random_peer_id(mock_host, dummy_peer_id, dummy_query_function):
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
peer_id = rw.generate_random_peer_id()
assert isinstance(peer_id, str)
assert len(peer_id) == 64 # 32 bytes hex
@pytest.mark.trio
async def test_run_concurrent_random_walks(mock_host, dummy_peer_id):
# Dummy query function returns different peer IDs for each walk
call_count = {"count": 0}
async def query(key_bytes):
call_count["count"] += 1
# Return a unique peer ID for each call
return [ID(bytes([call_count["count"]] * 32))]
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.run_concurrent_random_walks(count=3)
# Should get 3 unique peers
assert len(peers) == 3
peer_ids = [peer.peer_id for peer in peers]
assert len(set(peer_ids)) == 3
@pytest.mark.trio
async def test_perform_random_walk_running(mock_host, dummy_peer_id):
# Query function returns a single peer ID
async def query(key_bytes):
return [ID(b"\x02" * 32)]
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.perform_random_walk()
assert isinstance(peers, list)
if peers:
assert isinstance(peers[0], PeerInfo)
@pytest.mark.trio
async def test_perform_random_walk_no_peers_found(mock_host, dummy_peer_id):
"""Test perform_random_walk when no peers are discovered."""
# Query function returns empty list (no peers found)
async def query(key_bytes):
return []
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.perform_random_walk()
# Should return empty list when no peers are found
assert isinstance(peers, list)
assert len(peers) == 0

View File

@ -0,0 +1,451 @@
"""
Unit tests for the RTRefreshManager and related random walk logic.
"""
import time
from unittest.mock import AsyncMock, Mock, patch
import pytest
import trio
from libp2p.discovery.random_walk.config import (
MIN_RT_REFRESH_THRESHOLD,
RANDOM_WALK_CONCURRENCY,
REFRESH_INTERVAL,
)
from libp2p.discovery.random_walk.exceptions import (
RandomWalkError,
)
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
class DummyRoutingTable:
def __init__(self, size=0):
self._size = size
self.added_peers = []
def size(self):
return self._size
async def add_peer(self, peer_obj):
self.added_peers.append(peer_obj)
self._size += 1
return True
@pytest.fixture
def mock_host():
host = Mock()
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()
return host
@pytest.fixture
def local_peer_id():
return ID(b"\x01" * 32)
@pytest.fixture
def dummy_query_function():
async def query(key_bytes):
return [ID(b"\x02" * 32)]
return query
@pytest.mark.trio
async def test_rt_refresh_manager_initialization(
mock_host, local_peer_id, dummy_query_function
):
rt = DummyRoutingTable(size=5)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
assert manager.host == mock_host
assert manager.routing_table == rt
assert manager.local_peer_id == local_peer_id
assert manager.query_function == dummy_query_function
@pytest.mark.trio
async def test_rt_refresh_manager_refresh_logic(
mock_host, local_peer_id, dummy_query_function
):
rt = DummyRoutingTable(size=2)
# Simulate refresh logic
if rt.size() < MIN_RT_REFRESH_THRESHOLD:
await rt.add_peer(Mock())
assert rt.size() >= 3
@pytest.mark.trio
async def test_rt_refresh_manager_random_walk_integration(
mock_host, local_peer_id, dummy_query_function
):
# Simulate random walk usage
rw = RandomWalk(mock_host, local_peer_id, dummy_query_function)
random_peer_id = rw.generate_random_peer_id()
assert isinstance(random_peer_id, str)
assert len(random_peer_id) == 64
@pytest.mark.trio
async def test_rt_refresh_manager_error_handling(mock_host, local_peer_id):
rt = DummyRoutingTable(size=0)
async def failing_query(_):
raise RandomWalkError("Query failed")
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=failing_query,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
with pytest.raises(RandomWalkError):
await manager.query_function(b"key")
@pytest.mark.trio
async def test_rt_refresh_manager_start_method(
mock_host, local_peer_id, dummy_query_function
):
"""Test the start method functionality."""
rt = DummyRoutingTable(size=2)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=False, # Disable auto-refresh to control the test
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
):
# Test starting the manager
assert not manager._running
# Start the manager in a nursery that we can control
async with trio.open_nursery() as nursery:
nursery.start_soon(manager.start)
await trio.sleep(0.01) # Let it start
# Verify it's running
assert manager._running
# Stop the manager
await manager.stop()
assert not manager._running
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_with_auto_refresh(
mock_host, local_peer_id, dummy_query_function
):
"""Test the _main_loop method with auto-refresh enabled."""
rt = DummyRoutingTable(size=1) # Small size to trigger refresh
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
) as mock_random_walk:
manager._running = True
# Run the main loop for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._main_loop)
await trio.sleep(0.05) # Let it run briefly
manager._running = False # Stop the loop
# Verify that random walk was called (initial refresh)
mock_random_walk.assert_called()
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_without_auto_refresh(
mock_host, local_peer_id, dummy_query_function
):
"""Test the _main_loop method with auto-refresh disabled."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=False,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
manager._running = True
# Run the main loop for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._main_loop)
await trio.sleep(0.05)
manager._running = False
# Verify that random walk was not called since auto-refresh is disabled
mock_random_walk.assert_not_called()
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_initial_refresh_exception(
mock_host, local_peer_id, dummy_query_function
):
"""Test that _main_loop propagates exceptions from initial refresh."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock _do_refresh to raise an exception on the initial call
with patch.object(
manager, "_do_refresh", side_effect=Exception("Initial refresh failed")
):
manager._running = True
# The initial refresh exception should propagate
with pytest.raises(Exception, match="Initial refresh failed"):
await manager._main_loop()
@pytest.mark.trio
async def test_do_refresh_force_refresh(mock_host, local_peer_id, dummy_query_function):
"""Test _do_refresh method with force=True."""
rt = DummyRoutingTable(size=10) # Large size, but force should override
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info1 = Mock(spec=PeerInfo)
mock_peer_info2 = Mock(spec=PeerInfo)
discovered_peers = [mock_peer_info1, mock_peer_info2]
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=discovered_peers,
) as mock_random_walk:
# Force refresh should work regardless of RT size
await manager._do_refresh(force=True)
# Verify random walk was called
mock_random_walk.assert_called_once_with(
count=RANDOM_WALK_CONCURRENCY, current_routing_table_size=10
)
# Verify peers were added to routing table
assert len(rt.added_peers) == 2
assert manager._last_refresh_time > 0
@pytest.mark.trio
async def test_do_refresh_skip_due_to_interval(
mock_host, local_peer_id, dummy_query_function
):
"""Test _do_refresh skips refresh when interval hasn't elapsed."""
rt = DummyRoutingTable(size=1) # Small size to trigger refresh normally
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=100.0, # Long interval
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Set last refresh time to recent
manager._last_refresh_time = time.time()
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=False)
# Verify refresh was skipped
mock_random_walk.assert_not_called()
mock_logger.debug.assert_called_with(
"Skipping refresh: interval not elapsed"
)
@pytest.mark.trio
async def test_do_refresh_skip_due_to_rt_size(
mock_host, local_peer_id, dummy_query_function
):
"""Test _do_refresh skips refresh when RT size is above threshold."""
rt = DummyRoutingTable(size=20) # Large size above threshold
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1, # Short interval
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Set last refresh time to old
manager._last_refresh_time = 0.0
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=False)
# Verify refresh was skipped
mock_random_walk.assert_not_called()
mock_logger.debug.assert_called_with(
"Skipping refresh: routing table size above threshold"
)
@pytest.mark.trio
async def test_refresh_done_callbacks(mock_host, local_peer_id, dummy_query_function):
"""Test refresh completion callbacks functionality."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Create mock callbacks
callback1 = Mock()
callback2 = Mock()
failing_callback = Mock(side_effect=Exception("Callback failed"))
# Add callbacks
manager.add_refresh_done_callback(callback1)
manager.add_refresh_done_callback(callback2)
manager.add_refresh_done_callback(failing_callback)
# Mock the random walk
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
):
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=True)
# Verify all callbacks were called
callback1.assert_called_once()
callback2.assert_called_once()
failing_callback.assert_called_once()
# Verify warning was logged for failing callback
mock_logger.warning.assert_called()
@pytest.mark.trio
async def test_stop_when_not_running(mock_host, local_peer_id, dummy_query_function):
"""Test stop method when manager is not running."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Manager is not running
assert not manager._running
# Stop should return without doing anything
await manager.stop()
assert not manager._running
@pytest.mark.trio
async def test_periodic_refresh_task(mock_host, local_peer_id, dummy_query_function):
"""Test the _periodic_refresh_task method."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.05, # Very short interval for testing
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock _do_refresh to track calls
with patch.object(manager, "_do_refresh") as mock_do_refresh:
manager._running = True
# Run periodic refresh task for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._periodic_refresh_task)
await trio.sleep(0.12) # Let it run for ~2 intervals
manager._running = False # Stop the task
# Verify _do_refresh was called at least once
assert mock_do_refresh.call_count >= 1