mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into noise-arch-change
This commit is contained in:
82
tests/core/network/test_notifee_performance.py
Normal file
82
tests/core/network/test_notifee_performance.py
Normal file
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
INetConn,
|
||||
INetStream,
|
||||
INetwork,
|
||||
INotifee,
|
||||
)
|
||||
from libp2p.tools.utils import connect_swarm
|
||||
from tests.utils.factories import SwarmFactory
|
||||
|
||||
|
||||
class CountingNotifee(INotifee):
|
||||
def __init__(self, event: trio.Event) -> None:
|
||||
self._event = event
|
||||
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
self._event.set()
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SlowNotifee(INotifee):
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
await trio.sleep(0.5)
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_many_notifees_receive_connected_quickly() -> None:
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
count = 200
|
||||
events = [trio.Event() for _ in range(count)]
|
||||
for ev in events:
|
||||
swarms[0].register_notifee(CountingNotifee(ev))
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
with trio.fail_after(1.5):
|
||||
for ev in events:
|
||||
await ev.wait()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_slow_notifee_does_not_block_others() -> None:
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
fast_events = [trio.Event() for _ in range(20)]
|
||||
for ev in fast_events:
|
||||
swarms[0].register_notifee(CountingNotifee(ev))
|
||||
swarms[0].register_notifee(SlowNotifee())
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# Fast notifees should complete quickly despite one slow notifee
|
||||
with trio.fail_after(0.3):
|
||||
for ev in fast_events:
|
||||
await ev.wait()
|
||||
@ -5,11 +5,12 @@ the stream passed into opened_stream is correct.
|
||||
Note: Listen event does not get hit because MyNotifee is passed
|
||||
into network after network has already started listening
|
||||
|
||||
TODO: Add tests for closed_stream, listen_close when those
|
||||
features are implemented in swarm
|
||||
Note: ClosedStream events are processed asynchronously and may not be
|
||||
immediately available due to the rapid nature of operations
|
||||
"""
|
||||
|
||||
import enum
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
@ -29,11 +30,11 @@ from tests.utils.factories import (
|
||||
|
||||
class Event(enum.Enum):
|
||||
OpenedStream = 0
|
||||
ClosedStream = 1 # Not implemented
|
||||
ClosedStream = 1
|
||||
Connected = 2
|
||||
Disconnected = 3
|
||||
Listen = 4
|
||||
ListenClose = 5 # Not implemented
|
||||
ListenClose = 5
|
||||
|
||||
|
||||
class MyNotifee(INotifee):
|
||||
@ -60,8 +61,11 @@ class MyNotifee(INotifee):
|
||||
self.events.append(Event.Listen)
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
# TODO: It is not implemented yet.
|
||||
pass
|
||||
if network is None:
|
||||
raise ValueError("network parameter cannot be None")
|
||||
if multiaddr is None:
|
||||
raise ValueError("multiaddr parameter cannot be None")
|
||||
self.events.append(Event.ListenClose)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -123,3 +127,171 @@ async def test_notify(security_protocol):
|
||||
assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0)
|
||||
assert await wait_for_event(events_1_1, Event.ClosedStream, 1.0)
|
||||
assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)
|
||||
|
||||
# Note: ListenClose events are triggered when swarm closes during cleanup
|
||||
# The test framework automatically closes listeners, triggering ListenClose
|
||||
# notifications
|
||||
|
||||
|
||||
async def wait_for_event(events_list, event, timeout=1.0):
|
||||
"""Helper to wait for a specific event to appear in the events list."""
|
||||
with trio.move_on_after(timeout):
|
||||
while event not in events_list:
|
||||
await trio.sleep(0.01)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_notify_with_closed_stream_and_listen_close():
|
||||
"""Test that closed_stream and listen_close events are properly triggered."""
|
||||
# Event lists for notifees
|
||||
events_0 = []
|
||||
events_1 = []
|
||||
|
||||
# Create two swarms
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
# Register notifees
|
||||
notifee_0 = MyNotifee(events_0)
|
||||
notifee_1 = MyNotifee(events_1)
|
||||
|
||||
swarms[0].register_notifee(notifee_0)
|
||||
swarms[1].register_notifee(notifee_1)
|
||||
|
||||
# Connect swarms
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
|
||||
# Create and close a stream to trigger closed_stream event
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
await stream.close()
|
||||
|
||||
# Note: Events are processed asynchronously and may not be immediately available
|
||||
# due to the rapid nature of operations
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_notify_edge_cases():
|
||||
"""Test edge cases for notify system."""
|
||||
events = []
|
||||
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
notifee = MyNotifee(events)
|
||||
swarms[0].register_notifee(notifee)
|
||||
|
||||
# Connect swarms first
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
|
||||
# Test 1: Multiple rapid stream operations
|
||||
streams = []
|
||||
for _ in range(5):
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
streams.append(stream)
|
||||
|
||||
# Close all streams rapidly
|
||||
for stream in streams:
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_my_notifee_error_handling():
|
||||
"""Test error handling for invalid parameters in MyNotifee methods."""
|
||||
events = []
|
||||
notifee = MyNotifee(events)
|
||||
|
||||
# Mock objects for testing
|
||||
mock_network = Mock(spec=INetwork)
|
||||
mock_stream = Mock(spec=INetStream)
|
||||
mock_multiaddr = Mock(spec=Multiaddr)
|
||||
|
||||
# Test closed_stream with None parameters
|
||||
with pytest.raises(ValueError, match="network parameter cannot be None"):
|
||||
await notifee.closed_stream(None, mock_stream) # type: ignore
|
||||
|
||||
with pytest.raises(ValueError, match="stream parameter cannot be None"):
|
||||
await notifee.closed_stream(mock_network, None) # type: ignore
|
||||
|
||||
# Test listen_close with None parameters
|
||||
with pytest.raises(ValueError, match="network parameter cannot be None"):
|
||||
await notifee.listen_close(None, mock_multiaddr) # type: ignore
|
||||
|
||||
with pytest.raises(ValueError, match="multiaddr parameter cannot be None"):
|
||||
await notifee.listen_close(mock_network, None) # type: ignore
|
||||
|
||||
# Verify no events were recorded due to errors
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rapid_stream_operations():
|
||||
"""Test rapid stream open/close operations."""
|
||||
events_0 = []
|
||||
events_1 = []
|
||||
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
notifee_0 = MyNotifee(events_0)
|
||||
notifee_1 = MyNotifee(events_1)
|
||||
|
||||
swarms[0].register_notifee(notifee_0)
|
||||
swarms[1].register_notifee(notifee_1)
|
||||
|
||||
# Connect swarms
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
|
||||
# Rapidly create and close multiple streams
|
||||
streams = []
|
||||
for _ in range(3):
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
streams.append(stream)
|
||||
|
||||
# Close all streams immediately
|
||||
for stream in streams:
|
||||
await stream.close()
|
||||
|
||||
# Verify OpenedStream events are recorded
|
||||
assert events_0.count(Event.OpenedStream) == 3
|
||||
assert events_1.count(Event.OpenedStream) == 3
|
||||
|
||||
# Close peer to trigger disconnection events
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_concurrent_stream_operations():
|
||||
"""Test concurrent stream operations using trio nursery."""
|
||||
events_0 = []
|
||||
events_1 = []
|
||||
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
notifee_0 = MyNotifee(events_0)
|
||||
notifee_1 = MyNotifee(events_1)
|
||||
|
||||
swarms[0].register_notifee(notifee_0)
|
||||
swarms[1].register_notifee(notifee_1)
|
||||
|
||||
# Connect swarms
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
|
||||
async def create_and_close_stream():
|
||||
"""Create and immediately close a stream."""
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
await stream.close()
|
||||
|
||||
# Run multiple stream operations concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
for _ in range(4):
|
||||
nursery.start_soon(create_and_close_stream)
|
||||
|
||||
# Verify some OpenedStream events are recorded
|
||||
# (concurrent operations may not all succeed)
|
||||
opened_count_0 = events_0.count(Event.OpenedStream)
|
||||
opened_count_1 = events_1.count(Event.OpenedStream)
|
||||
|
||||
assert opened_count_0 > 0, (
|
||||
f"Expected some OpenedStream events, got {opened_count_0}"
|
||||
)
|
||||
assert opened_count_1 > 0, (
|
||||
f"Expected some OpenedStream events, got {opened_count_1}"
|
||||
)
|
||||
|
||||
# Close peer to trigger disconnection events
|
||||
await swarms[0].close_peer(swarms[1].get_peer_id())
|
||||
|
||||
76
tests/core/network/test_notify_listen_lifecycle.py
Normal file
76
tests/core/network/test_notify_listen_lifecycle.py
Normal file
@ -0,0 +1,76 @@
|
||||
import enum
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
INetConn,
|
||||
INetStream,
|
||||
INetwork,
|
||||
INotifee,
|
||||
)
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
from tests.utils.factories import SwarmFactory
|
||||
|
||||
|
||||
class Event(enum.Enum):
|
||||
Listen = 0
|
||||
ListenClose = 1
|
||||
|
||||
|
||||
class MyNotifee(INotifee):
|
||||
def __init__(self, events: list[Event]):
|
||||
self.events = events
|
||||
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
self.events.append(Event.Listen)
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
self.events.append(Event.ListenClose)
|
||||
|
||||
|
||||
async def wait_for_event(
|
||||
events_list: list[Event], event: Event, timeout: float = 1.0
|
||||
) -> bool:
|
||||
with trio.move_on_after(timeout):
|
||||
while event not in events_list:
|
||||
await trio.sleep(0.01)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listen_emitted_when_registered_before_listen():
|
||||
events: list[Event] = []
|
||||
swarm = SwarmFactory.build()
|
||||
swarm.register_notifee(MyNotifee(events))
|
||||
async with background_trio_service(swarm):
|
||||
# Start listening now; notifee was registered beforehand
|
||||
assert await swarm.listen(LISTEN_MADDR)
|
||||
assert await wait_for_event(events, Event.Listen)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_single_listener_close_emits_listen_close():
|
||||
events: list[Event] = []
|
||||
swarm = SwarmFactory.build()
|
||||
swarm.register_notifee(MyNotifee(events))
|
||||
async with background_trio_service(swarm):
|
||||
assert await swarm.listen(LISTEN_MADDR)
|
||||
# Explicitly notify listen_close (close path via manager doesn't emit it)
|
||||
await swarm.notify_listen_close(LISTEN_MADDR)
|
||||
assert await wait_for_event(events, Event.ListenClose)
|
||||
@ -16,6 +16,9 @@ from libp2p.network.exceptions import (
|
||||
from libp2p.network.swarm import (
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect_swarm,
|
||||
)
|
||||
@ -184,3 +187,116 @@ def test_new_swarm_quic_multiaddr_raises():
|
||||
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
|
||||
with pytest.raises(ValueError, match="QUIC not yet supported"):
|
||||
new_swarm(listen_addrs=[addr])
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_listen_multiple_addresses(security_protocol):
|
||||
"""Test that swarm can listen on multiple addresses simultaneously."""
|
||||
from libp2p.utils.address_validation import get_available_interfaces
|
||||
|
||||
# Get multiple addresses to listen on
|
||||
listen_addrs = get_available_interfaces(0) # Let OS choose ports
|
||||
|
||||
# Create a swarm and listen on multiple addresses
|
||||
swarm = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm):
|
||||
# Listen on all addresses
|
||||
success = await swarm.listen(*listen_addrs)
|
||||
assert success, "Should successfully listen on at least one address"
|
||||
|
||||
# Check that we have listeners for the addresses
|
||||
actual_listeners = list(swarm.listeners.keys())
|
||||
assert len(actual_listeners) > 0, "Should have at least one listener"
|
||||
|
||||
# Verify that all successful listeners are in the listeners dict
|
||||
successful_count = 0
|
||||
for addr in listen_addrs:
|
||||
addr_str = str(addr)
|
||||
if addr_str in actual_listeners:
|
||||
successful_count += 1
|
||||
# This address successfully started listening
|
||||
listener = swarm.listeners[addr_str]
|
||||
listener_addrs = listener.get_addrs()
|
||||
assert len(listener_addrs) > 0, (
|
||||
f"Listener for {addr} should have addresses"
|
||||
)
|
||||
|
||||
# Check that the listener address matches the expected address
|
||||
# (port might be different if we used port 0)
|
||||
expected_ip = addr.value_for_protocol("ip4")
|
||||
expected_protocol = addr.value_for_protocol("tcp")
|
||||
if expected_ip and expected_protocol:
|
||||
found_matching = False
|
||||
for listener_addr in listener_addrs:
|
||||
if (
|
||||
listener_addr.value_for_protocol("ip4") == expected_ip
|
||||
and listener_addr.value_for_protocol("tcp") is not None
|
||||
):
|
||||
found_matching = True
|
||||
break
|
||||
assert found_matching, (
|
||||
f"Listener for {addr} should have matching IP"
|
||||
)
|
||||
|
||||
assert successful_count == len(listen_addrs), (
|
||||
f"All {len(listen_addrs)} addresses should be listening, "
|
||||
f"but only {successful_count} succeeded"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_listen_multiple_addresses_connectivity(security_protocol):
|
||||
"""Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.utils.address_validation import get_available_interfaces
|
||||
|
||||
# Get multiple addresses to listen on
|
||||
listen_addrs = get_available_interfaces(0) # Let OS choose ports
|
||||
|
||||
# Create a swarm and listen on multiple addresses
|
||||
swarm1 = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm1):
|
||||
# Listen on all addresses
|
||||
success = await swarm1.listen(*listen_addrs)
|
||||
assert success, "Should successfully listen on at least one address"
|
||||
|
||||
# Verify all available interfaces are listening
|
||||
assert len(swarm1.listeners) == len(listen_addrs), (
|
||||
f"All {len(listen_addrs)} interfaces should be listening, "
|
||||
f"but only {len(swarm1.listeners)} are"
|
||||
)
|
||||
|
||||
# Create a second swarm to test connections
|
||||
swarm2 = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm2):
|
||||
# Test connectivity to each listening address using real libp2p connections
|
||||
for addr_str, listener in swarm1.listeners.items():
|
||||
listener_addrs = listener.get_addrs()
|
||||
for listener_addr in listener_addrs:
|
||||
# Create a full multiaddr with peer ID for libp2p connection
|
||||
peer_id = swarm1.get_peer_id()
|
||||
full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}")
|
||||
|
||||
# Test real libp2p connection
|
||||
try:
|
||||
peer_info = info_from_p2p_addr(full_addr)
|
||||
|
||||
# Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501
|
||||
swarm2.peerstore.add_addrs(
|
||||
peer_info.peer_id, [listener_addr], 10000
|
||||
)
|
||||
|
||||
await swarm2.dial_peer(peer_info.peer_id)
|
||||
|
||||
# Verify connection was established
|
||||
assert peer_info.peer_id in swarm2.connections, (
|
||||
f"Connection to {full_addr} should be established"
|
||||
)
|
||||
assert swarm2.get_peer_id() in swarm1.connections, (
|
||||
f"Connection from {full_addr} should be established"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Failed to establish libp2p connection to {full_addr}: {e}"
|
||||
)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from collections import deque
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
from libp2p.abc import IMultiselectCommunicator, INetStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
@ -13,6 +13,10 @@ from libp2p.protocol_muxer.multiselect import Multiselect
|
||||
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||
|
||||
|
||||
async def dummy_handler(stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
"""
|
||||
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
||||
@ -31,7 +35,7 @@ class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_one_of_timeout():
|
||||
async def test_select_one_of_timeout() -> None:
|
||||
ECHO = TProtocol("/echo/1.0.0")
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
|
||||
@ -42,7 +46,7 @@ async def test_select_one_of_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_query_multistream_command_timeout():
|
||||
async def test_query_multistream_command_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
client = MultiselectClient()
|
||||
|
||||
@ -51,9 +55,95 @@ async def test_query_multistream_command_timeout():
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout():
|
||||
async def test_negotiate_timeout() -> None:
|
||||
communicator = DummyMultiselectCommunicator()
|
||||
server = Multiselect()
|
||||
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 2)
|
||||
|
||||
|
||||
class HandshakeThenHangCommunicator(IMultiselectCommunicator):
|
||||
handshaked: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handshaked = False
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
if msg_str == "/multistream/1.0.0":
|
||||
self.handshaked = True
|
||||
return
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.handshaked:
|
||||
return "/multistream/1.0.0"
|
||||
# After handshake, hang on read.
|
||||
await trio.sleep_forever()
|
||||
# Should not be reached.
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_timeout_post_handshake() -> None:
|
||||
communicator = HandshakeThenHangCommunicator()
|
||||
server = Multiselect()
|
||||
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||
await server.negotiate(communicator, 1)
|
||||
|
||||
|
||||
class MockCommunicator(IMultiselectCommunicator):
|
||||
def __init__(self, commands_to_read: list[str]):
|
||||
self.read_queue = deque(commands_to_read)
|
||||
self.written_data: list[str] = []
|
||||
|
||||
async def write(self, msg_str: str) -> None:
|
||||
self.written_data.append(msg_str)
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.read_queue:
|
||||
raise EOFError
|
||||
return self.read_queue.popleft()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_empty_string_command() -> None:
|
||||
# server receives an empty string, which means client wants `None` protocol.
|
||||
server = Multiselect({None: dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check that server sent back handshake and the protocol confirmation (empty string)
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler() -> None:
|
||||
# server has None handler, client sends "" to select it.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, then empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, protocol confirmation
|
||||
assert communicator.written_data == ["/multistream/1.0.0", ""]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_with_none_handler_ls() -> None:
|
||||
# server has None handler, client sends "ls" then empty string.
|
||||
server = Multiselect({None: dummy_handler, TProtocol("/proto1"): dummy_handler})
|
||||
# Handshake, ls, empty command
|
||||
communicator = MockCommunicator(["/multistream/1.0.0", "ls", ""])
|
||||
protocol, handler = await server.negotiate(communicator)
|
||||
assert protocol is None
|
||||
assert handler == dummy_handler
|
||||
# Check written data: handshake, ls response, protocol confirmation
|
||||
assert communicator.written_data[0] == "/multistream/1.0.0"
|
||||
assert "/proto1" in communicator.written_data[1]
|
||||
# Note: `ls` should not list the `None` protocol.
|
||||
assert "None" not in communicator.written_data[1]
|
||||
assert "\n\n" not in communicator.written_data[1]
|
||||
assert communicator.written_data[2] == ""
|
||||
|
||||
@ -159,3 +159,41 @@ async def test_get_protocols_returns_all_registered_protocols():
|
||||
protocols = ms.get_protocols()
|
||||
|
||||
assert set(protocols) == {p1, p2, p3}
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol(security_protocol):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(
|
||||
None,
|
||||
[None],
|
||||
[None],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_server_no_none(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [None], [PROTOCOL_ECHO], security_protocol)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_client_none_in_list(security_protocol):
|
||||
expected_selected_protocol = PROTOCOL_ECHO
|
||||
await perform_simple_test(
|
||||
expected_selected_protocol,
|
||||
[None, PROTOCOL_ECHO],
|
||||
[PROTOCOL_ECHO],
|
||||
security_protocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_negotiate_optional_tprotocol_server_none_client_other(
|
||||
security_protocol,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await perform_simple_test(None, [PROTOCOL_ECHO], [None], security_protocol)
|
||||
|
||||
90
tests/core/pubsub/test_pubsub_notifee_integration.py
Normal file
90
tests/core/pubsub/test_pubsub_notifee_integration.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.utils import connect
|
||||
from tests.utils.factories import PubsubFactory
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connected_enqueues_and_adds_peer():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Wait until peer is added via queue processing
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
assert p1.my_id in p0.peers
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_disconnected_enqueues_and_removes_peer():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Ensure present first
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
# Now disconnect and expect removal via dead peer queue
|
||||
await p0.host.get_network().close_peer(p1.host.get_id())
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
assert p1.my_id not in p0.peers
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_channel_closed_is_swallowed_in_notifee(monkeypatch) -> None:
|
||||
# Ensure PubsubNotifee catches BrokenResourceError from its send channel
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
# Find the PubsubNotifee registered on the network
|
||||
from libp2p.pubsub.pubsub_notifee import PubsubNotifee
|
||||
|
||||
network = p0.host.get_network()
|
||||
notifees = getattr(network, "notifees", [])
|
||||
target = None
|
||||
for nf in notifees:
|
||||
if isinstance(nf, cast(type, PubsubNotifee)):
|
||||
target = nf
|
||||
break
|
||||
assert target is not None, "PubsubNotifee not found on network"
|
||||
|
||||
async def failing_send(_peer_id): # type: ignore[no-redef]
|
||||
raise trio.BrokenResourceError
|
||||
|
||||
# Make initiator queue send fail; PubsubNotifee should swallow
|
||||
monkeypatch.setattr(target.initiator_peers_queue, "send", failing_send)
|
||||
|
||||
# Connect peers; if exceptions are swallowed, service stays running
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_duplicate_connection_does_not_duplicate_peer_state():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
# Connect again should not add duplicates
|
||||
await connect(p0.host, p1.host)
|
||||
await trio.sleep(0.1)
|
||||
assert list(p0.peers.keys()).count(p1.my_id) == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_blocks_peer_added_by_notifee():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
# Blacklist before connecting
|
||||
p0.add_to_blacklist(p1.my_id)
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Give handler a chance to run
|
||||
await trio.sleep(0.1)
|
||||
assert p1.my_id not in p0.peers
|
||||
99
tests/discovery/random_walk/test_random_walk.py
Normal file
99
tests/discovery/random_walk/test_random_walk.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""
|
||||
Unit tests for the RandomWalk module in libp2p.discovery.random_walk.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host():
|
||||
host = Mock()
|
||||
peerstore = Mock()
|
||||
peerstore.peers_with_addrs.return_value = []
|
||||
peerstore.addrs.return_value = [Mock()]
|
||||
host.get_peerstore.return_value = peerstore
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_query_function():
|
||||
async def query(key_bytes):
|
||||
return []
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_peer_id():
|
||||
return b"\x01" * 32
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_random_walk_initialization(
|
||||
mock_host, dummy_peer_id, dummy_query_function
|
||||
):
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
|
||||
assert rw.host == mock_host
|
||||
assert rw.local_peer_id == dummy_peer_id
|
||||
assert rw.query_function == dummy_query_function
|
||||
|
||||
|
||||
def test_generate_random_peer_id(mock_host, dummy_peer_id, dummy_query_function):
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, dummy_query_function)
|
||||
peer_id = rw.generate_random_peer_id()
|
||||
assert isinstance(peer_id, str)
|
||||
assert len(peer_id) == 64 # 32 bytes hex
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_run_concurrent_random_walks(mock_host, dummy_peer_id):
|
||||
# Dummy query function returns different peer IDs for each walk
|
||||
call_count = {"count": 0}
|
||||
|
||||
async def query(key_bytes):
|
||||
call_count["count"] += 1
|
||||
# Return a unique peer ID for each call
|
||||
return [ID(bytes([call_count["count"]] * 32))]
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.run_concurrent_random_walks(count=3)
|
||||
# Should get 3 unique peers
|
||||
assert len(peers) == 3
|
||||
peer_ids = [peer.peer_id for peer in peers]
|
||||
assert len(set(peer_ids)) == 3
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_perform_random_walk_running(mock_host, dummy_peer_id):
|
||||
# Query function returns a single peer ID
|
||||
async def query(key_bytes):
|
||||
return [ID(b"\x02" * 32)]
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.perform_random_walk()
|
||||
assert isinstance(peers, list)
|
||||
if peers:
|
||||
assert isinstance(peers[0], PeerInfo)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_perform_random_walk_no_peers_found(mock_host, dummy_peer_id):
|
||||
"""Test perform_random_walk when no peers are discovered."""
|
||||
|
||||
# Query function returns empty list (no peers found)
|
||||
async def query(key_bytes):
|
||||
return []
|
||||
|
||||
rw = RandomWalk(mock_host, dummy_peer_id, query)
|
||||
peers = await rw.perform_random_walk()
|
||||
|
||||
# Should return empty list when no peers are found
|
||||
assert isinstance(peers, list)
|
||||
assert len(peers) == 0
|
||||
451
tests/discovery/random_walk/test_rt_refresh_manager.py
Normal file
451
tests/discovery/random_walk/test_rt_refresh_manager.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""
|
||||
Unit tests for the RTRefreshManager and related random walk logic.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.discovery.random_walk.config import (
|
||||
MIN_RT_REFRESH_THRESHOLD,
|
||||
RANDOM_WALK_CONCURRENCY,
|
||||
REFRESH_INTERVAL,
|
||||
)
|
||||
from libp2p.discovery.random_walk.exceptions import (
|
||||
RandomWalkError,
|
||||
)
|
||||
from libp2p.discovery.random_walk.random_walk import RandomWalk
|
||||
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
|
||||
|
||||
class DummyRoutingTable:
|
||||
def __init__(self, size=0):
|
||||
self._size = size
|
||||
self.added_peers = []
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
async def add_peer(self, peer_obj):
|
||||
self.added_peers.append(peer_obj)
|
||||
self._size += 1
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_host():
|
||||
host = Mock()
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
return host
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_peer_id():
|
||||
return ID(b"\x01" * 32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_query_function():
|
||||
async def query(key_bytes):
|
||||
return [ID(b"\x02" * 32)]
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_initialization(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
rt = DummyRoutingTable(size=5)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
assert manager.host == mock_host
|
||||
assert manager.routing_table == rt
|
||||
assert manager.local_peer_id == local_peer_id
|
||||
assert manager.query_function == dummy_query_function
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_refresh_logic(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
rt = DummyRoutingTable(size=2)
|
||||
# Simulate refresh logic
|
||||
if rt.size() < MIN_RT_REFRESH_THRESHOLD:
|
||||
await rt.add_peer(Mock())
|
||||
assert rt.size() >= 3
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_random_walk_integration(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
# Simulate random walk usage
|
||||
rw = RandomWalk(mock_host, local_peer_id, dummy_query_function)
|
||||
random_peer_id = rw.generate_random_peer_id()
|
||||
assert isinstance(random_peer_id, str)
|
||||
assert len(random_peer_id) == 64
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_error_handling(mock_host, local_peer_id):
|
||||
rt = DummyRoutingTable(size=0)
|
||||
|
||||
async def failing_query(_):
|
||||
raise RandomWalkError("Query failed")
|
||||
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=failing_query,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
with pytest.raises(RandomWalkError):
|
||||
await manager.query_function(b"key")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_start_method(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the start method functionality."""
|
||||
rt = DummyRoutingTable(size=2)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=False, # Disable auto-refresh to control the test
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
):
|
||||
# Test starting the manager
|
||||
assert not manager._running
|
||||
|
||||
# Start the manager in a nursery that we can control
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager.start)
|
||||
await trio.sleep(0.01) # Let it start
|
||||
|
||||
# Verify it's running
|
||||
assert manager._running
|
||||
|
||||
# Stop the manager
|
||||
await manager.stop()
|
||||
assert not manager._running
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_with_auto_refresh(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the _main_loop method with auto-refresh enabled."""
|
||||
rt = DummyRoutingTable(size=1) # Small size to trigger refresh
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
) as mock_random_walk:
|
||||
manager._running = True
|
||||
|
||||
# Run the main loop for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._main_loop)
|
||||
await trio.sleep(0.05) # Let it run briefly
|
||||
manager._running = False # Stop the loop
|
||||
|
||||
# Verify that random walk was called (initial refresh)
|
||||
mock_random_walk.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_without_auto_refresh(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test the _main_loop method with auto-refresh disabled."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=False,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
manager._running = True
|
||||
|
||||
# Run the main loop for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._main_loop)
|
||||
await trio.sleep(0.05)
|
||||
manager._running = False
|
||||
|
||||
# Verify that random walk was not called since auto-refresh is disabled
|
||||
mock_random_walk.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_rt_refresh_manager_main_loop_initial_refresh_exception(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test that _main_loop propagates exceptions from initial refresh."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock _do_refresh to raise an exception on the initial call
|
||||
with patch.object(
|
||||
manager, "_do_refresh", side_effect=Exception("Initial refresh failed")
|
||||
):
|
||||
manager._running = True
|
||||
|
||||
# The initial refresh exception should propagate
|
||||
with pytest.raises(Exception, match="Initial refresh failed"):
|
||||
await manager._main_loop()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_force_refresh(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test _do_refresh method with force=True."""
|
||||
rt = DummyRoutingTable(size=10) # Large size, but force should override
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=REFRESH_INTERVAL,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock the random walk to return some peers
|
||||
mock_peer_info1 = Mock(spec=PeerInfo)
|
||||
mock_peer_info2 = Mock(spec=PeerInfo)
|
||||
discovered_peers = [mock_peer_info1, mock_peer_info2]
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=discovered_peers,
|
||||
) as mock_random_walk:
|
||||
# Force refresh should work regardless of RT size
|
||||
await manager._do_refresh(force=True)
|
||||
|
||||
# Verify random walk was called
|
||||
mock_random_walk.assert_called_once_with(
|
||||
count=RANDOM_WALK_CONCURRENCY, current_routing_table_size=10
|
||||
)
|
||||
|
||||
# Verify peers were added to routing table
|
||||
assert len(rt.added_peers) == 2
|
||||
assert manager._last_refresh_time > 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_skip_due_to_interval(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test _do_refresh skips refresh when interval hasn't elapsed."""
|
||||
rt = DummyRoutingTable(size=1) # Small size to trigger refresh normally
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=100.0, # Long interval
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Set last refresh time to recent
|
||||
manager._last_refresh_time = time.time()
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=False)
|
||||
|
||||
# Verify refresh was skipped
|
||||
mock_random_walk.assert_not_called()
|
||||
mock_logger.debug.assert_called_with(
|
||||
"Skipping refresh: interval not elapsed"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_do_refresh_skip_due_to_rt_size(
|
||||
mock_host, local_peer_id, dummy_query_function
|
||||
):
|
||||
"""Test _do_refresh skips refresh when RT size is above threshold."""
|
||||
rt = DummyRoutingTable(size=20) # Large size above threshold
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1, # Short interval
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Set last refresh time to old
|
||||
manager._last_refresh_time = 0.0
|
||||
|
||||
with patch.object(
|
||||
manager.random_walk, "run_concurrent_random_walks"
|
||||
) as mock_random_walk:
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=False)
|
||||
|
||||
# Verify refresh was skipped
|
||||
mock_random_walk.assert_not_called()
|
||||
mock_logger.debug.assert_called_with(
|
||||
"Skipping refresh: routing table size above threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_refresh_done_callbacks(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test refresh completion callbacks functionality."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Create mock callbacks
|
||||
callback1 = Mock()
|
||||
callback2 = Mock()
|
||||
failing_callback = Mock(side_effect=Exception("Callback failed"))
|
||||
|
||||
# Add callbacks
|
||||
manager.add_refresh_done_callback(callback1)
|
||||
manager.add_refresh_done_callback(callback2)
|
||||
manager.add_refresh_done_callback(failing_callback)
|
||||
|
||||
# Mock the random walk
|
||||
mock_peer_info = Mock(spec=PeerInfo)
|
||||
with patch.object(
|
||||
manager.random_walk,
|
||||
"run_concurrent_random_walks",
|
||||
return_value=[mock_peer_info],
|
||||
):
|
||||
with patch(
|
||||
"libp2p.discovery.random_walk.rt_refresh_manager.logger"
|
||||
) as mock_logger:
|
||||
await manager._do_refresh(force=True)
|
||||
|
||||
# Verify all callbacks were called
|
||||
callback1.assert_called_once()
|
||||
callback2.assert_called_once()
|
||||
failing_callback.assert_called_once()
|
||||
|
||||
# Verify warning was logged for failing callback
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stop_when_not_running(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test stop method when manager is not running."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.1,
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Manager is not running
|
||||
assert not manager._running
|
||||
|
||||
# Stop should return without doing anything
|
||||
await manager.stop()
|
||||
assert not manager._running
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_periodic_refresh_task(mock_host, local_peer_id, dummy_query_function):
|
||||
"""Test the _periodic_refresh_task method."""
|
||||
rt = DummyRoutingTable(size=1)
|
||||
manager = RTRefreshManager(
|
||||
host=mock_host,
|
||||
routing_table=rt,
|
||||
local_peer_id=local_peer_id,
|
||||
query_function=dummy_query_function,
|
||||
enable_auto_refresh=True,
|
||||
refresh_interval=0.05, # Very short interval for testing
|
||||
min_refresh_threshold=MIN_RT_REFRESH_THRESHOLD,
|
||||
)
|
||||
|
||||
# Mock _do_refresh to track calls
|
||||
with patch.object(manager, "_do_refresh") as mock_do_refresh:
|
||||
manager._running = True
|
||||
|
||||
# Run periodic refresh task for a short time
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(manager._periodic_refresh_task)
|
||||
await trio.sleep(0.12) # Let it run for ~2 intervals
|
||||
manager._running = False # Stop the task
|
||||
|
||||
# Verify _do_refresh was called at least once
|
||||
assert mock_do_refresh.call_count >= 1
|
||||
109
tests/examples/test_echo_thin_waist.py
Normal file
109
tests/examples/test_echo_thin_waist.py
Normal file
@ -0,0 +1,109 @@
|
||||
import contextlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP
|
||||
|
||||
# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging
|
||||
|
||||
# This test is intentionally lightweight and can be marked as 'integration'.
|
||||
# It ensures the echo example runs and prints the new Thin Waist lines using
|
||||
# Trio primitives.
|
||||
|
||||
current_file = Path(__file__)
|
||||
project_root = current_file.parent.parent.parent
|
||||
EXAMPLES_DIR: Path = project_root / "examples" / "echo"
|
||||
|
||||
|
||||
def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path):
|
||||
"""Run echo server and validate printed multiaddr and peer id."""
|
||||
# Run echo example as server
|
||||
cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"]
|
||||
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||
proc: subprocess.Popen[str] = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if proc.stdout is None:
|
||||
proc.terminate()
|
||||
raise RuntimeError("Process stdout is None")
|
||||
out_stream = proc.stdout
|
||||
|
||||
peer_id: str | None = None
|
||||
printed_multiaddr: str | None = None
|
||||
saw_waiting = False
|
||||
|
||||
start = time.time()
|
||||
timeout_s = 8.0
|
||||
try:
|
||||
while time.time() - start < timeout_s:
|
||||
line = out_stream.readline()
|
||||
if not line:
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
s = line.strip()
|
||||
if s.startswith("I am "):
|
||||
peer_id = s.partition("I am ")[2]
|
||||
if s.startswith("echo-demo -d "):
|
||||
printed_multiaddr = s.partition("echo-demo -d ")[2]
|
||||
if "Waiting for incoming connections..." in s:
|
||||
saw_waiting = True
|
||||
break
|
||||
finally:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.terminate()
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.kill()
|
||||
|
||||
assert peer_id, "Did not capture peer ID line"
|
||||
assert printed_multiaddr, "Did not capture multiaddr line"
|
||||
assert saw_waiting, "Did not capture waiting-for-connections line"
|
||||
|
||||
# Validate multiaddr structure using py-multiaddr protocol methods
|
||||
ma = Multiaddr(printed_multiaddr) # should parse without error
|
||||
|
||||
# Check that the multiaddr contains the p2p protocol
|
||||
try:
|
||||
peer_id_from_multiaddr = ma.value_for_protocol("p2p")
|
||||
assert peer_id_from_multiaddr is not None, (
|
||||
"Multiaddr missing p2p protocol value"
|
||||
)
|
||||
assert peer_id_from_multiaddr == peer_id, (
|
||||
f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Failed to extract p2p protocol value: {e}")
|
||||
|
||||
# Validate the multiaddr structure by checking protocols
|
||||
protocols = ma.protocols()
|
||||
|
||||
# Should have at least IP, TCP, and P2P protocols
|
||||
assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), (
|
||||
"Missing IP protocol"
|
||||
)
|
||||
assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol"
|
||||
assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol"
|
||||
|
||||
# Extract the p2p part and validate it matches the captured peer ID
|
||||
p2p_part = Multiaddr(f"/p2p/{peer_id}")
|
||||
try:
|
||||
# Decapsulate the p2p part to get the transport address
|
||||
transport_addr = ma.decapsulate(p2p_part)
|
||||
# Verify the decapsulated address doesn't contain p2p
|
||||
transport_protocols = transport_addr.protocols()
|
||||
assert not any(p.code == P_P2P for p in transport_protocols), (
|
||||
"Decapsulation failed - still contains p2p"
|
||||
)
|
||||
# Verify the original multiaddr can be reconstructed
|
||||
reconstructed = transport_addr.encapsulate(p2p_part)
|
||||
assert str(reconstructed) == str(ma), "Reconstruction failed"
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Multiaddr decapsulation failed: {e}")
|
||||
56
tests/utils/test_address_validation.py
Normal file
56
tests/utils/test_address_validation.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.utils.address_validation import (
|
||||
expand_wildcard_address,
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("proto", ["tcp"])
|
||||
def test_get_available_interfaces(proto: str) -> None:
|
||||
interfaces = get_available_interfaces(0, protocol=proto)
|
||||
assert len(interfaces) > 0
|
||||
for addr in interfaces:
|
||||
assert isinstance(addr, Multiaddr)
|
||||
assert f"/{proto}/" in str(addr)
|
||||
|
||||
|
||||
def test_get_optimal_binding_address() -> None:
|
||||
addr = get_optimal_binding_address(0)
|
||||
assert isinstance(addr, Multiaddr)
|
||||
# At least IPv4 or IPv6 prefix present
|
||||
s = str(addr)
|
||||
assert ("/ip4/" in s) or ("/ip6/" in s)
|
||||
|
||||
|
||||
def test_expand_wildcard_address_ipv4() -> None:
|
||||
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0")
|
||||
expanded = expand_wildcard_address(wildcard)
|
||||
assert len(expanded) > 0
|
||||
for e in expanded:
|
||||
assert isinstance(e, Multiaddr)
|
||||
assert "/tcp/" in str(e)
|
||||
|
||||
|
||||
def test_expand_wildcard_address_port_override() -> None:
|
||||
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000")
|
||||
overridden = expand_wildcard_address(wildcard, port=9001)
|
||||
assert len(overridden) > 0
|
||||
for e in overridden:
|
||||
assert str(e).endswith("/tcp/9001")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("NO_IPV6") == "1",
|
||||
reason="Environment disallows IPv6",
|
||||
)
|
||||
def test_expand_wildcard_address_ipv6() -> None:
|
||||
wildcard = Multiaddr("/ip6/::/tcp/0")
|
||||
expanded = expand_wildcard_address(wildcard)
|
||||
assert len(expanded) > 0
|
||||
for e in expanded:
|
||||
assert "/ip6/" in str(e)
|
||||
Reference in New Issue
Block a user