Merge branch 'main' into noise-arch-change

This commit is contained in:
Manu Sheel Gupta
2025-08-25 16:21:43 +05:30
committed by GitHub
45 changed files with 2701 additions and 59 deletions

View File

@ -0,0 +1,82 @@
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import (
INetConn,
INetStream,
INetwork,
INotifee,
)
from libp2p.tools.utils import connect_swarm
from tests.utils.factories import SwarmFactory
class CountingNotifee(INotifee):
def __init__(self, event: trio.Event) -> None:
self._event = event
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def connected(self, network: INetwork, conn: INetConn) -> None:
self._event.set()
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
class SlowNotifee(INotifee):
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def connected(self, network: INetwork, conn: INetConn) -> None:
await trio.sleep(0.5)
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
pass
@pytest.mark.trio
async def test_many_notifees_receive_connected_quickly() -> None:
async with SwarmFactory.create_batch_and_listen(2) as swarms:
count = 200
events = [trio.Event() for _ in range(count)]
for ev in events:
swarms[0].register_notifee(CountingNotifee(ev))
await connect_swarm(swarms[0], swarms[1])
with trio.fail_after(1.5):
for ev in events:
await ev.wait()
@pytest.mark.trio
async def test_slow_notifee_does_not_block_others() -> None:
async with SwarmFactory.create_batch_and_listen(2) as swarms:
fast_events = [trio.Event() for _ in range(20)]
for ev in fast_events:
swarms[0].register_notifee(CountingNotifee(ev))
swarms[0].register_notifee(SlowNotifee())
await connect_swarm(swarms[0], swarms[1])
# Fast notifees should complete quickly despite one slow notifee
with trio.fail_after(0.3):
for ev in fast_events:
await ev.wait()

View File

@ -5,11 +5,12 @@ the stream passed into opened_stream is correct.
Note: Listen event does not get hit because MyNotifee is passed
into network after network has already started listening
TODO: Add tests for closed_stream, listen_close when those
features are implemented in swarm
Note: ClosedStream events are processed asynchronously and may not be
immediately available due to the rapid nature of operations
"""
import enum
from unittest.mock import Mock
import pytest
from multiaddr import Multiaddr
@ -29,11 +30,11 @@ from tests.utils.factories import (
class Event(enum.Enum):
OpenedStream = 0
ClosedStream = 1 # Not implemented
ClosedStream = 1
Connected = 2
Disconnected = 3
Listen = 4
ListenClose = 5 # Not implemented
ListenClose = 5
class MyNotifee(INotifee):
@ -60,8 +61,11 @@ class MyNotifee(INotifee):
self.events.append(Event.Listen)
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
# TODO: It is not implemented yet.
pass
if network is None:
raise ValueError("network parameter cannot be None")
if multiaddr is None:
raise ValueError("multiaddr parameter cannot be None")
self.events.append(Event.ListenClose)
@pytest.mark.trio
@ -123,3 +127,171 @@ async def test_notify(security_protocol):
assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0)
assert await wait_for_event(events_1_1, Event.ClosedStream, 1.0)
assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)
# Note: ListenClose events are triggered when swarm closes during cleanup
# The test framework automatically closes listeners, triggering ListenClose
# notifications
async def wait_for_event(events_list, event, timeout=1.0):
"""Helper to wait for a specific event to appear in the events list."""
with trio.move_on_after(timeout):
while event not in events_list:
await trio.sleep(0.01)
return True
return False
@pytest.mark.trio
async def test_notify_with_closed_stream_and_listen_close():
"""Test that closed_stream and listen_close events are properly triggered."""
# Event lists for notifees
events_0 = []
events_1 = []
# Create two swarms
async with SwarmFactory.create_batch_and_listen(2) as swarms:
# Register notifees
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
# Create and close a stream to trigger closed_stream event
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
await stream.close()
# Note: Events are processed asynchronously and may not be immediately available
# due to the rapid nature of operations
@pytest.mark.trio
async def test_notify_edge_cases():
"""Test edge cases for notify system."""
events = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee = MyNotifee(events)
swarms[0].register_notifee(notifee)
# Connect swarms first
await connect_swarm(swarms[0], swarms[1])
# Test 1: Multiple rapid stream operations
streams = []
for _ in range(5):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Close all streams rapidly
for stream in streams:
await stream.close()
@pytest.mark.trio
async def test_my_notifee_error_handling():
"""Test error handling for invalid parameters in MyNotifee methods."""
events = []
notifee = MyNotifee(events)
# Mock objects for testing
mock_network = Mock(spec=INetwork)
mock_stream = Mock(spec=INetStream)
mock_multiaddr = Mock(spec=Multiaddr)
# Test closed_stream with None parameters
with pytest.raises(ValueError, match="network parameter cannot be None"):
await notifee.closed_stream(None, mock_stream) # type: ignore
with pytest.raises(ValueError, match="stream parameter cannot be None"):
await notifee.closed_stream(mock_network, None) # type: ignore
# Test listen_close with None parameters
with pytest.raises(ValueError, match="network parameter cannot be None"):
await notifee.listen_close(None, mock_multiaddr) # type: ignore
with pytest.raises(ValueError, match="multiaddr parameter cannot be None"):
await notifee.listen_close(mock_network, None) # type: ignore
# Verify no events were recorded due to errors
assert len(events) == 0
@pytest.mark.trio
async def test_rapid_stream_operations():
"""Test rapid stream open/close operations."""
events_0 = []
events_1 = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
# Rapidly create and close multiple streams
streams = []
for _ in range(3):
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
streams.append(stream)
# Close all streams immediately
for stream in streams:
await stream.close()
# Verify OpenedStream events are recorded
assert events_0.count(Event.OpenedStream) == 3
assert events_1.count(Event.OpenedStream) == 3
# Close peer to trigger disconnection events
await swarms[0].close_peer(swarms[1].get_peer_id())
@pytest.mark.trio
async def test_concurrent_stream_operations():
"""Test concurrent stream operations using trio nursery."""
events_0 = []
events_1 = []
async with SwarmFactory.create_batch_and_listen(2) as swarms:
notifee_0 = MyNotifee(events_0)
notifee_1 = MyNotifee(events_1)
swarms[0].register_notifee(notifee_0)
swarms[1].register_notifee(notifee_1)
# Connect swarms
await connect_swarm(swarms[0], swarms[1])
async def create_and_close_stream():
"""Create and immediately close a stream."""
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
await stream.close()
# Run multiple stream operations concurrently
async with trio.open_nursery() as nursery:
for _ in range(4):
nursery.start_soon(create_and_close_stream)
# Verify some OpenedStream events are recorded
# (concurrent operations may not all succeed)
opened_count_0 = events_0.count(Event.OpenedStream)
opened_count_1 = events_1.count(Event.OpenedStream)
assert opened_count_0 > 0, (
f"Expected some OpenedStream events, got {opened_count_0}"
)
assert opened_count_1 > 0, (
f"Expected some OpenedStream events, got {opened_count_1}"
)
# Close peer to trigger disconnection events
await swarms[0].close_peer(swarms[1].get_peer_id())

View File

@ -0,0 +1,76 @@
import enum
import pytest
from multiaddr import Multiaddr
import trio
from libp2p.abc import (
INetConn,
INetStream,
INetwork,
INotifee,
)
from libp2p.tools.async_service import background_trio_service
from libp2p.tools.constants import LISTEN_MADDR
from tests.utils.factories import SwarmFactory
class Event(enum.Enum):
Listen = 0
ListenClose = 1
class MyNotifee(INotifee):
def __init__(self, events: list[Event]):
self.events = events
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
pass
async def connected(self, network: INetwork, conn: INetConn) -> None:
pass
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
pass
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
self.events.append(Event.Listen)
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
self.events.append(Event.ListenClose)
async def wait_for_event(
events_list: list[Event], event: Event, timeout: float = 1.0
) -> bool:
with trio.move_on_after(timeout):
while event not in events_list:
await trio.sleep(0.01)
return True
return False
@pytest.mark.trio
async def test_listen_emitted_when_registered_before_listen():
events: list[Event] = []
swarm = SwarmFactory.build()
swarm.register_notifee(MyNotifee(events))
async with background_trio_service(swarm):
# Start listening now; notifee was registered beforehand
assert await swarm.listen(LISTEN_MADDR)
assert await wait_for_event(events, Event.Listen)
@pytest.mark.trio
async def test_single_listener_close_emits_listen_close():
events: list[Event] = []
swarm = SwarmFactory.build()
swarm.register_notifee(MyNotifee(events))
async with background_trio_service(swarm):
assert await swarm.listen(LISTEN_MADDR)
# Explicitly notify listen_close (close path via manager doesn't emit it)
await swarm.notify_listen_close(LISTEN_MADDR)
assert await wait_for_event(events, Event.ListenClose)

View File

@ -16,6 +16,9 @@ from libp2p.network.exceptions import (
from libp2p.network.swarm import (
Swarm,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.utils import (
connect_swarm,
)
@ -184,3 +187,116 @@ def test_new_swarm_quic_multiaddr_raises():
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
with pytest.raises(ValueError, match="QUIC not yet supported"):
new_swarm(listen_addrs=[addr])
@pytest.mark.trio
async def test_swarm_listen_multiple_addresses(security_protocol):
"""Test that swarm can listen on multiple addresses simultaneously."""
from libp2p.utils.address_validation import get_available_interfaces
# Get multiple addresses to listen on
listen_addrs = get_available_interfaces(0) # Let OS choose ports
# Create a swarm and listen on multiple addresses
swarm = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm):
# Listen on all addresses
success = await swarm.listen(*listen_addrs)
assert success, "Should successfully listen on at least one address"
# Check that we have listeners for the addresses
actual_listeners = list(swarm.listeners.keys())
assert len(actual_listeners) > 0, "Should have at least one listener"
# Verify that all successful listeners are in the listeners dict
successful_count = 0
for addr in listen_addrs:
addr_str = str(addr)
if addr_str in actual_listeners:
successful_count += 1
# This address successfully started listening
listener = swarm.listeners[addr_str]
listener_addrs = listener.get_addrs()
assert len(listener_addrs) > 0, (
f"Listener for {addr} should have addresses"
)
# Check that the listener address matches the expected address
# (port might be different if we used port 0)
expected_ip = addr.value_for_protocol("ip4")
expected_protocol = addr.value_for_protocol("tcp")
if expected_ip and expected_protocol:
found_matching = False
for listener_addr in listener_addrs:
if (
listener_addr.value_for_protocol("ip4") == expected_ip
and listener_addr.value_for_protocol("tcp") is not None
):
found_matching = True
break
assert found_matching, (
f"Listener for {addr} should have matching IP"
)
assert successful_count == len(listen_addrs), (
f"All {len(listen_addrs)} addresses should be listening, "
f"but only {successful_count} succeeded"
)
@pytest.mark.trio
async def test_swarm_listen_multiple_addresses_connectivity(security_protocol):
"""Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.utils.address_validation import get_available_interfaces
# Get multiple addresses to listen on
listen_addrs = get_available_interfaces(0) # Let OS choose ports
# Create a swarm and listen on multiple addresses
swarm1 = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm1):
# Listen on all addresses
success = await swarm1.listen(*listen_addrs)
assert success, "Should successfully listen on at least one address"
# Verify all available interfaces are listening
assert len(swarm1.listeners) == len(listen_addrs), (
f"All {len(listen_addrs)} interfaces should be listening, "
f"but only {len(swarm1.listeners)} are"
)
# Create a second swarm to test connections
swarm2 = SwarmFactory.build(security_protocol=security_protocol)
async with background_trio_service(swarm2):
# Test connectivity to each listening address using real libp2p connections
for addr_str, listener in swarm1.listeners.items():
listener_addrs = listener.get_addrs()
for listener_addr in listener_addrs:
# Create a full multiaddr with peer ID for libp2p connection
peer_id = swarm1.get_peer_id()
full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}")
# Test real libp2p connection
try:
peer_info = info_from_p2p_addr(full_addr)
# Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501
swarm2.peerstore.add_addrs(
peer_info.peer_id, [listener_addr], 10000
)
await swarm2.dial_peer(peer_info.peer_id)
# Verify connection was established
assert peer_info.peer_id in swarm2.connections, (
f"Connection to {full_addr} should be established"
)
assert swarm2.get_peer_id() in swarm1.connections, (
f"Connection from {full_addr} should be established"
)
except Exception as e:
pytest.fail(
f"Failed to establish libp2p connection to {full_addr}: {e}"
)

View File

@ -1,9 +1,9 @@
from collections import deque
import pytest
import trio
from libp2p.abc import (
IMultiselectCommunicator,
)
from libp2p.abc import IMultiselectCommunicator, INetStream
from libp2p.custom_types import TProtocol
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
async def dummy_handler(stream: INetStream) -> None:
pass
class DummyMultiselectCommunicator(IMultiselectCommunicator):
"""
Dummy MultiSelectCommunicator to test out negotiate timmeout.
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
@pytest.mark.trio
async def test_select_one_of_timeout():
async def test_select_one_of_timeout() -> None:
ECHO = TProtocol("/echo/1.0.0")
communicator = DummyMultiselectCommunicator()
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
@pytest.mark.trio
async def test_query_multistream_command_timeout():
async def test_query_multistream_command_timeout() -> None:
communicator = DummyMultiselectCommunicator()
client = MultiselectClient()
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
@pytest.mark.trio
async def test_negotiate_timeout():
async def test_negotiate_timeout() -> None:
communicator = DummyMultiselectCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 2)
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
handshaked: bool
def __init__(self) -> None:
self.handshaked = False
async def write(self, msg_str: str) -> None:
if msg_str == "/multistream/1.0.0":
self.handshaked = True
return
async def read(self) -> str:
if not self.handshaked:
return "/multistream/1.0.0"
# After handshake, hang on read.
await trio.sleep_forever()
# Should not be reached.
return ""
@pytest.mark.trio
async def test_negotiate_timeout_post_handshake() -> None:
communicator = HandshakeThenHangCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 1)
class MockCommunicator(IMultiselectCommunicator):
def __init__(self, commands_to_read: list[str]):
self.read_queue = deque(commands_to_read)
self.written_data: list[str] = []
async def write(self, msg_str: str) -> None:
self.written_data.append(msg_str)
async def read(self) -> str:
if not self.read_queue:
raise EOFError
return self.read_queue.popleft()
@pytest.mark.trio
async def test_negotiate_empty_string_command() -> None:
# server receives an empty string, which means client wants `None` protocol.
server = Multiselect({None: dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check that server sent back handshake and the protocol confirmation (empty string)
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler() -> None:
# server has None handler, client sends "" to select it.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, then empty command
communicator = MockCommunicator(["/multistream/1.0.0", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, protocol confirmation
assert communicator.written_data == ["/multistream/1.0.0", ""]
@pytest.mark.trio
async def test_negotiate_with_none_handler_ls() -> None:
# server has None handler, client sends "ls" then empty string.
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
# Handshake, ls, empty command
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
protocol, handler = await server.negotiate(communicator)
assert protocol is None
assert handler == dummy_handler
# Check written data: handshake, ls response, protocol confirmation
assert communicator.written_data[0] == "/multistream/1.0.0"
assert "/proto1" in communicator.written_data[1]
# Note: `ls` should not list the `None` protocol.
assert "None" not in communicator.written_data[1]
assert "\n\n" not in communicator.written_data[1]
assert communicator.written_data[2] == ""

View File

@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
protocols = ms.get_protocols()
assert set(protocols) == {p1, p2, p3}
@pytest.mark.trio
async def test_negotiate_optional_tprotocol(security_protocol):
with pytest.raises(Exception):
await perform_simple_test(
None,
[None],
[None],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
expected_selected_protocol = PROTOCOL_ECHO
await perform_simple_test(
expected_selected_protocol,
[None, PROTOCOL_ECHO],
[PROTOCOL_ECHO],
security_protocol,
)
@pytest.mark.trio
async def test_negotiate_optional_tprotocol_server_none_client_other(
security_protocol,
):
with pytest.raises(Exception):
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)

View File

@ -0,0 +1,90 @@
from typing import cast
import pytest
import trio
from libp2p.tools.utils import connect
from tests.utils.factories import PubsubFactory
@pytest.mark.trio
async def test_connected_enqueues_and_adds_peer():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
await connect(p0.host, p1.host)
await p0.wait_until_ready()
# Wait until peer is added via queue processing
with trio.fail_after(1.0):
while p1.my_id not in p0.peers:
await trio.sleep(0.01)
assert p1.my_id in p0.peers
@pytest.mark.trio
async def test_disconnected_enqueues_and_removes_peer():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
await connect(p0.host, p1.host)
await p0.wait_until_ready()
# Ensure present first
with trio.fail_after(1.0):
while p1.my_id not in p0.peers:
await trio.sleep(0.01)
# Now disconnect and expect removal via dead peer queue
await p0.host.get_network().close_peer(p1.host.get_id())
with trio.fail_after(1.0):
while p1.my_id in p0.peers:
await trio.sleep(0.01)
assert p1.my_id not in p0.peers
@pytest.mark.trio
async def test_channel_closed_is_swallowed_in_notifee(monkeypatch) -> None:
# Ensure PubsubNotifee catches BrokenResourceError from its send channel
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
# Find the PubsubNotifee registered on the network
from libp2p.pubsub.pubsub_notifee import PubsubNotifee
network = p0.host.get_network()
notifees = getattr(network, "notifees", [])
target = None
for nf in notifees:
if isinstance(nf, cast(type, PubsubNotifee)):
target = nf
break
assert target is not None, "PubsubNotifee not found on network"
async def failing_send(_peer_id): # type: ignore[no-redef]
raise trio.BrokenResourceError
# Make initiator queue send fail; PubsubNotifee should swallow
monkeypatch.setattr(target.initiator_peers_queue, "send", failing_send)
# Connect peers; if exceptions are swallowed, service stays running
await connect(p0.host, p1.host)
await p0.wait_until_ready()
assert True
@pytest.mark.trio
async def test_duplicate_connection_does_not_duplicate_peer_state():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
await connect(p0.host, p1.host)
await p0.wait_until_ready()
with trio.fail_after(1.0):
while p1.my_id not in p0.peers:
await trio.sleep(0.01)
# Connect again should not add duplicates
await connect(p0.host, p1.host)
await trio.sleep(0.1)
assert list(p0.peers.keys()).count(p1.my_id) == 1
@pytest.mark.trio
async def test_blacklist_blocks_peer_added_by_notifee():
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
# Blacklist before connecting
p0.add_to_blacklist(p1.my_id)
await connect(p0.host, p1.host)
await p0.wait_until_ready()
# Give handler a chance to run
await trio.sleep(0.1)
assert p1.my_id not in p0.peers

View File

@ -0,0 +1,99 @@
"""
Unit tests for the RandomWalk module in libp2p.discovery.random_walk.
"""
from unittest.mock import AsyncMock, Mock
import pytest
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
@pytest.fixture
def mock_host():
host = Mock()
peerstore = Mock()
peerstore.peers_with_addrs.return_value = []
peerstore.addrs.return_value = [Mock()]
host.get_peerstore.return_value = peerstore
host.new_stream = AsyncMock()
return host
@pytest.fixture
def dummy_query_function():
async def query(key_bytes):
return []
return query
@pytest.fixture
def dummy_peer_id():
return b"\x01" * 32
@pytest.mark.trio
async def test_random_walk_initialization(
mock_host, dummy_peer_id, dummy_query_function
):
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
assert rw.host == mock_host
assert rw.local_peer_id == dummy_peer_id
assert rw.query_function == dummy_query_function
def test_generate_random_peer_id(mock_host, dummy_peer_id, dummy_query_function):
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
peer_id = rw.generate_random_peer_id()
assert isinstance(peer_id, str)
assert len(peer_id) == 64 # 32 bytes hex
@pytest.mark.trio
async def test_run_concurrent_random_walks(mock_host, dummy_peer_id):
# Dummy query function returns different peer IDs for each walk
call_count = {"count": 0}
async def query(key_bytes):
call_count["count"] += 1
# Return a unique peer ID for each call
return [ID(bytes([call_count["count"]] * 32))]
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.run_concurrent_random_walks(count=3)
# Should get 3 unique peers
assert len(peers) == 3
peer_ids = [peer.peer_id for peer in peers]
assert len(set(peer_ids)) == 3
@pytest.mark.trio
async def test_perform_random_walk_running(mock_host, dummy_peer_id):
# Query function returns a single peer ID
async def query(key_bytes):
return [ID(b"\x02" * 32)]
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.perform_random_walk()
assert isinstance(peers, list)
if peers:
assert isinstance(peers[0], PeerInfo)
@pytest.mark.trio
async def test_perform_random_walk_no_peers_found(mock_host, dummy_peer_id):
"""Test perform_random_walk when no peers are discovered."""
# Query function returns empty list (no peers found)
async def query(key_bytes):
return []
rw = RandomWalk(mock_host, dummy_peer_id, query)
peers = await rw.perform_random_walk()
# Should return empty list when no peers are found
assert isinstance(peers, list)
assert len(peers) == 0

View File

@ -0,0 +1,451 @@
"""
Unit tests for the RTRefreshManager and related random walk logic.
"""
import time
from unittest.mock import AsyncMock, Mock, patch
import pytest
import trio
from libp2p.discovery.random_walk.config import (
MIN_RT_REFRESH_THRESHOLD,
RANDOM_WALK_CONCURRENCY,
REFRESH_INTERVAL,
)
from libp2p.discovery.random_walk.exceptions import (
RandomWalkError,
)
from libp2p.discovery.random_walk.random_walk import RandomWalk
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
class DummyRoutingTable:
def __init__(self, size=0):
self._size = size
self.added_peers = []
def size(self):
return self._size
async def add_peer(self, peer_obj):
self.added_peers.append(peer_obj)
self._size += 1
return True
@pytest.fixture
def mock_host():
host = Mock()
host.get_peerstore.return_value = Mock()
host.new_stream = AsyncMock()
return host
@pytest.fixture
def local_peer_id():
return ID(b"\x01" * 32)
@pytest.fixture
def dummy_query_function():
async def query(key_bytes):
return [ID(b"\x02" * 32)]
return query
@pytest.mark.trio
async def test_rt_refresh_manager_initialization(
mock_host, local_peer_id, dummy_query_function
):
rt = DummyRoutingTable(size=5)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
assert manager.host == mock_host
assert manager.routing_table == rt
assert manager.local_peer_id == local_peer_id
assert manager.query_function == dummy_query_function
@pytest.mark.trio
async def test_rt_refresh_manager_refresh_logic(
mock_host, local_peer_id, dummy_query_function
):
rt = DummyRoutingTable(size=2)
# Simulate refresh logic
if rt.size() < MIN_RT_REFRESH_THRESHOLD:
await rt.add_peer(Mock())
assert rt.size() >= 3
@pytest.mark.trio
async def test_rt_refresh_manager_random_walk_integration(
mock_host, local_peer_id, dummy_query_function
):
# Simulate random walk usage
rw = RandomWalk(mock_host, local_peer_id, dummy_query_function)
random_peer_id = rw.generate_random_peer_id()
assert isinstance(random_peer_id, str)
assert len(random_peer_id) == 64
@pytest.mark.trio
async def test_rt_refresh_manager_error_handling(mock_host, local_peer_id):
rt = DummyRoutingTable(size=0)
async def failing_query(_):
raise RandomWalkError("Query failed")
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=failing_query,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
with pytest.raises(RandomWalkError):
await manager.query_function(b"key")
@pytest.mark.trio
async def test_rt_refresh_manager_start_method(
mock_host, local_peer_id, dummy_query_function
):
"""Test the start method functionality."""
rt = DummyRoutingTable(size=2)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=False, # Disable auto-refresh to control the test
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
):
# Test starting the manager
assert not manager._running
# Start the manager in a nursery that we can control
async with trio.open_nursery() as nursery:
nursery.start_soon(manager.start)
await trio.sleep(0.01) # Let it start
# Verify it's running
assert manager._running
# Stop the manager
await manager.stop()
assert not manager._running
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_with_auto_refresh(
mock_host, local_peer_id, dummy_query_function
):
"""Test the _main_loop method with auto-refresh enabled."""
rt = DummyRoutingTable(size=1) # Small size to trigger refresh
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
) as mock_random_walk:
manager._running = True
# Run the main loop for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._main_loop)
await trio.sleep(0.05) # Let it run briefly
manager._running = False # Stop the loop
# Verify that random walk was called (initial refresh)
mock_random_walk.assert_called()
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_without_auto_refresh(
mock_host, local_peer_id, dummy_query_function
):
"""Test the _main_loop method with auto-refresh disabled."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=False,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
manager._running = True
# Run the main loop for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._main_loop)
await trio.sleep(0.05)
manager._running = False
# Verify that random walk was not called since auto-refresh is disabled
mock_random_walk.assert_not_called()
@pytest.mark.trio
async def test_rt_refresh_manager_main_loop_initial_refresh_exception(
mock_host, local_peer_id, dummy_query_function
):
"""Test that _main_loop propagates exceptions from initial refresh."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock _do_refresh to raise an exception on the initial call
with patch.object(
manager, "_do_refresh", side_effect=Exception("Initial refresh failed")
):
manager._running = True
# The initial refresh exception should propagate
with pytest.raises(Exception, match="Initial refresh failed"):
await manager._main_loop()
@pytest.mark.trio
async def test_do_refresh_force_refresh(mock_host, local_peer_id, dummy_query_function):
"""Test _do_refresh method with force=True."""
rt = DummyRoutingTable(size=10) # Large size, but force should override
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=REFRESH_INTERVAL,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock the random walk to return some peers
mock_peer_info1 = Mock(spec=PeerInfo)
mock_peer_info2 = Mock(spec=PeerInfo)
discovered_peers = [mock_peer_info1, mock_peer_info2]
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=discovered_peers,
) as mock_random_walk:
# Force refresh should work regardless of RT size
await manager._do_refresh(force=True)
# Verify random walk was called
mock_random_walk.assert_called_once_with(
count=RANDOM_WALK_CONCURRENCY, current_routing_table_size=10
)
# Verify peers were added to routing table
assert len(rt.added_peers) == 2
assert manager._last_refresh_time > 0
@pytest.mark.trio
async def test_do_refresh_skip_due_to_interval(
mock_host, local_peer_id, dummy_query_function
):
"""Test _do_refresh skips refresh when interval hasn't elapsed."""
rt = DummyRoutingTable(size=1) # Small size to trigger refresh normally
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=100.0, # Long interval
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Set last refresh time to recent
manager._last_refresh_time = time.time()
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=False)
# Verify refresh was skipped
mock_random_walk.assert_not_called()
mock_logger.debug.assert_called_with(
"Skipping refresh: interval not elapsed"
)
@pytest.mark.trio
async def test_do_refresh_skip_due_to_rt_size(
mock_host, local_peer_id, dummy_query_function
):
"""Test _do_refresh skips refresh when RT size is above threshold."""
rt = DummyRoutingTable(size=20) # Large size above threshold
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1, # Short interval
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Set last refresh time to old
manager._last_refresh_time = 0.0
with patch.object(
manager.random_walk, "run_concurrent_random_walks"
) as mock_random_walk:
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=False)
# Verify refresh was skipped
mock_random_walk.assert_not_called()
mock_logger.debug.assert_called_with(
"Skipping refresh: routing table size above threshold"
)
@pytest.mark.trio
async def test_refresh_done_callbacks(mock_host, local_peer_id, dummy_query_function):
"""Test refresh completion callbacks functionality."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Create mock callbacks
callback1 = Mock()
callback2 = Mock()
failing_callback = Mock(side_effect=Exception("Callback failed"))
# Add callbacks
manager.add_refresh_done_callback(callback1)
manager.add_refresh_done_callback(callback2)
manager.add_refresh_done_callback(failing_callback)
# Mock the random walk
mock_peer_info = Mock(spec=PeerInfo)
with patch.object(
manager.random_walk,
"run_concurrent_random_walks",
return_value=[mock_peer_info],
):
with patch(
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
) as mock_logger:
await manager._do_refresh(force=True)
# Verify all callbacks were called
callback1.assert_called_once()
callback2.assert_called_once()
failing_callback.assert_called_once()
# Verify warning was logged for failing callback
mock_logger.warning.assert_called()
@pytest.mark.trio
async def test_stop_when_not_running(mock_host, local_peer_id, dummy_query_function):
"""Test stop method when manager is not running."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.1,
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Manager is not running
assert not manager._running
# Stop should return without doing anything
await manager.stop()
assert not manager._running
@pytest.mark.trio
async def test_periodic_refresh_task(mock_host, local_peer_id, dummy_query_function):
"""Test the _periodic_refresh_task method."""
rt = DummyRoutingTable(size=1)
manager = RTRefreshManager(
host=mock_host,
routing_table=rt,
local_peer_id=local_peer_id,
query_function=dummy_query_function,
enable_auto_refresh=True,
refresh_interval=0.05, # Very short interval for testing
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
)
# Mock _do_refresh to track calls
with patch.object(manager, "_do_refresh") as mock_do_refresh:
manager._running = True
# Run periodic refresh task for a short time
async with trio.open_nursery() as nursery:
nursery.start_soon(manager._periodic_refresh_task)
await trio.sleep(0.12) # Let it run for ~2 intervals
manager._running = False # Stop the task
# Verify _do_refresh was called at least once
assert mock_do_refresh.call_count >= 1

View File

@ -0,0 +1,109 @@
import contextlib
import os
from pathlib import Path
import subprocess
import sys
import time
from multiaddr import Multiaddr
from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP
# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging
# This test is intentionally lightweight and can be marked as 'integration'.
# It ensures the echo example runs and prints the new Thin Waist lines using
# Trio primitives.
current_file = Path(__file__)
project_root = current_file.parent.parent.parent
EXAMPLES_DIR: Path = project_root / "examples" / "echo"
def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path):
"""Run echo server and validate printed multiaddr and peer id."""
# Run echo example as server
cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"]
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
proc: subprocess.Popen[str] = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env=env,
)
if proc.stdout is None:
proc.terminate()
raise RuntimeError("Process stdout is None")
out_stream = proc.stdout
peer_id: str | None = None
printed_multiaddr: str | None = None
saw_waiting = False
start = time.time()
timeout_s = 8.0
try:
while time.time() - start < timeout_s:
line = out_stream.readline()
if not line:
time.sleep(0.05)
continue
s = line.strip()
if s.startswith("I am "):
peer_id = s.partition("I am ")[2]
if s.startswith("echo-demo -d "):
printed_multiaddr = s.partition("echo-demo -d ")[2]
if "Waiting for incoming connections..." in s:
saw_waiting = True
break
finally:
with contextlib.suppress(ProcessLookupError):
proc.terminate()
with contextlib.suppress(ProcessLookupError):
proc.kill()
assert peer_id, "Did not capture peer ID line"
assert printed_multiaddr, "Did not capture multiaddr line"
assert saw_waiting, "Did not capture waiting-for-connections line"
# Validate multiaddr structure using py-multiaddr protocol methods
ma = Multiaddr(printed_multiaddr) # should parse without error
# Check that the multiaddr contains the p2p protocol
try:
peer_id_from_multiaddr = ma.value_for_protocol("p2p")
assert peer_id_from_multiaddr is not None, (
"Multiaddr missing p2p protocol value"
)
assert peer_id_from_multiaddr == peer_id, (
f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}"
)
except Exception as e:
raise AssertionError(f"Failed to extract p2p protocol value: {e}")
# Validate the multiaddr structure by checking protocols
protocols = ma.protocols()
# Should have at least IP, TCP, and P2P protocols
assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), (
"Missing IP protocol"
)
assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol"
assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol"
# Extract the p2p part and validate it matches the captured peer ID
p2p_part = Multiaddr(f"/p2p/{peer_id}")
try:
# Decapsulate the p2p part to get the transport address
transport_addr = ma.decapsulate(p2p_part)
# Verify the decapsulated address doesn't contain p2p
transport_protocols = transport_addr.protocols()
assert not any(p.code == P_P2P for p in transport_protocols), (
"Decapsulation failed - still contains p2p"
)
# Verify the original multiaddr can be reconstructed
reconstructed = transport_addr.encapsulate(p2p_part)
assert str(reconstructed) == str(ma), "Reconstruction failed"
except Exception as e:
raise AssertionError(f"Multiaddr decapsulation failed: {e}")

View File

@ -0,0 +1,56 @@
import os
import pytest
from multiaddr import Multiaddr
from libp2p.utils.address_validation import (
expand_wildcard_address,
get_available_interfaces,
get_optimal_binding_address,
)
@pytest.mark.parametrize("proto", ["tcp"])
def test_get_available_interfaces(proto: str) -> None:
interfaces = get_available_interfaces(0, protocol=proto)
assert len(interfaces) > 0
for addr in interfaces:
assert isinstance(addr, Multiaddr)
assert f"/{proto}/" in str(addr)
def test_get_optimal_binding_address() -> None:
addr = get_optimal_binding_address(0)
assert isinstance(addr, Multiaddr)
# At least IPv4 or IPv6 prefix present
s = str(addr)
assert ("/ip4/" in s) or ("/ip6/" in s)
def test_expand_wildcard_address_ipv4() -> None:
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0")
expanded = expand_wildcard_address(wildcard)
assert len(expanded) > 0
for e in expanded:
assert isinstance(e, Multiaddr)
assert "/tcp/" in str(e)
def test_expand_wildcard_address_port_override() -> None:
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000")
overridden = expand_wildcard_address(wildcard, port=9001)
assert len(overridden) > 0
for e in overridden:
assert str(e).endswith("/tcp/9001")
@pytest.mark.skipif(
os.environ.get("NO_IPV6") == "1",
reason="Environment disallows IPv6",
)
def test_expand_wildcard_address_ipv6() -> None:
wildcard = Multiaddr("/ip6/::/tcp/0")
expanded = expand_wildcard_address(wildcard)
assert len(expanded) > 0
for e in expanded:
assert "/ip6/" in str(e)