mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge upstream/main into add-ws-transport
Resolved conflicts in: - .gitignore: Combined JavaScript interop and Sphinx build ignores - libp2p/__init__.py: Integrated QUIC transport support with WebSocket transport - libp2p/network/swarm.py: Used upstream's improved listener handling - pyproject.toml: Kept both WebSocket and QUIC dependencies This merge brings in: - QUIC transport implementation - Enhanced swarm functionality - Improved peer discovery - Better error handling - Updated dependencies and documentation WebSocket transport implementation remains intact and functional.
This commit is contained in:
@ -1,3 +1,10 @@
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p import (
|
||||
new_swarm,
|
||||
)
|
||||
@ -10,6 +17,9 @@ from libp2p.host.basic_host import (
|
||||
from libp2p.host.defaults import (
|
||||
get_default_protocols,
|
||||
)
|
||||
from libp2p.host.exceptions import (
|
||||
StreamFailure,
|
||||
)
|
||||
|
||||
|
||||
def test_default_protocols():
|
||||
@ -22,3 +32,30 @@ def test_default_protocols():
|
||||
# NOTE: comparing keys for equality as handlers may be closures that do not compare
|
||||
# in the way this test is concerned with
|
||||
assert handlers.keys() == get_default_protocols(host).keys()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_stream_handler_no_protocol_selected(monkeypatch):
|
||||
key_pair = create_new_key_pair()
|
||||
swarm = new_swarm(key_pair)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
# Create a mock net_stream
|
||||
net_stream = MagicMock()
|
||||
net_stream.reset = AsyncMock()
|
||||
net_stream.muxed_conn.peer_id = "peer-test"
|
||||
|
||||
# Monkeypatch negotiate to simulate "no protocol selected"
|
||||
async def fake_negotiate(comm, timeout):
|
||||
return None, None
|
||||
|
||||
monkeypatch.setattr(host.multiselect, "negotiate", fake_negotiate)
|
||||
|
||||
# Now run the handler and expect StreamFailure
|
||||
with pytest.raises(
|
||||
StreamFailure, match="Failed to negotiate protocol: no protocol selected"
|
||||
):
|
||||
await host._swarm_stream_handler(net_stream)
|
||||
|
||||
# Ensure reset was called since negotiation failed
|
||||
net_stream.reset.assert_awaited()
|
||||
|
||||
@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol):
|
||||
assert peer_a_id in host_b.get_live_peers()
|
||||
|
||||
# Simulate unexpected connection drop by directly closing the connection
|
||||
conn = host_a.get_network().connections[peer_b_id]
|
||||
await conn.muxed_conn.close()
|
||||
conns = host_a.get_network().connections[peer_b_id]
|
||||
await conns[0].muxed_conn.close()
|
||||
|
||||
# Allow for connection cleanup
|
||||
await trio.sleep(0.1)
|
||||
|
||||
@ -9,11 +9,15 @@ This module tests core functionality of the Kademlia DHT including:
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.rsa import create_new_key_pair
|
||||
from libp2p.kad_dht.kad_dht import (
|
||||
DHTMode,
|
||||
KadDHT,
|
||||
@ -21,9 +25,13 @@ from libp2p.kad_dht.kad_dht import (
|
||||
from libp2p.kad_dht.utils import (
|
||||
create_key_from_binary,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope, seal_record
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peer_record import PeerRecord
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.peer.peerstore import create_signed_peer_record
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
@ -76,10 +84,52 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
"""Test that nodes can find each other in the DHT."""
|
||||
dht_a, dht_b = dht_pair
|
||||
|
||||
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
|
||||
# so both the nodes will have records of each other before the next FIND_NODE
|
||||
# req is sent
|
||||
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
|
||||
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
|
||||
|
||||
assert isinstance(envelope_a, Envelope)
|
||||
assert isinstance(envelope_b, Envelope)
|
||||
|
||||
record_a = envelope_a.record()
|
||||
record_b = envelope_b.record()
|
||||
|
||||
# Node A should be able to find Node B
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
found_info = await dht_a.find_peer(dht_b.host.get_id())
|
||||
|
||||
# Verifies if the senderRecord in the FIND_NODE request is correctly processed
|
||||
assert isinstance(
|
||||
dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope
|
||||
)
|
||||
|
||||
# Verifies if the senderRecord in the FIND_NODE response is correctly processed
|
||||
assert isinstance(
|
||||
dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope
|
||||
)
|
||||
|
||||
# These are the records that were sent between the peers during the FIND_NODE req
|
||||
envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record(
|
||||
dht_b.host.get_id()
|
||||
)
|
||||
envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record(
|
||||
dht_a.host.get_id()
|
||||
)
|
||||
|
||||
assert isinstance(envelope_a_find_peer, Envelope)
|
||||
assert isinstance(envelope_b_find_peer, Envelope)
|
||||
|
||||
record_a_find_peer = envelope_a_find_peer.record()
|
||||
record_b_find_peer = envelope_b_find_peer.record()
|
||||
|
||||
# This proves that both the records are same, and a latest cached signed record
|
||||
# was passed between the peers during FIND_NODE execution, which proves the
|
||||
# signed-record transfer/re-issuing works correctly in FIND_NODE executions.
|
||||
assert record_a.seq == record_a_find_peer.seq
|
||||
assert record_b.seq == record_b_find_peer.seq
|
||||
|
||||
# Verify that the found peer has the correct peer ID
|
||||
assert found_info is not None, "Failed to find the target peer"
|
||||
assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID"
|
||||
@ -104,14 +154,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
await dht_a.routing_table.add_peer(peer_b_info)
|
||||
print("Routing table of a has ", dht_a.routing_table.get_peer_ids())
|
||||
|
||||
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
|
||||
# so both the nodes will have records of each other before PUT_VALUE req is sent
|
||||
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
|
||||
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
|
||||
|
||||
assert isinstance(envelope_a, Envelope)
|
||||
assert isinstance(envelope_b, Envelope)
|
||||
|
||||
record_a = envelope_a.record()
|
||||
record_b = envelope_b.record()
|
||||
|
||||
# Store the value using the first node (this will also store locally)
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
await dht_a.put_value(key, value)
|
||||
|
||||
# These are the records that were sent between the peers during the PUT_VALUE req
|
||||
envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record(
|
||||
dht_b.host.get_id()
|
||||
)
|
||||
envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record(
|
||||
dht_a.host.get_id()
|
||||
)
|
||||
|
||||
assert isinstance(envelope_a_put_value, Envelope)
|
||||
assert isinstance(envelope_b_put_value, Envelope)
|
||||
|
||||
record_a_put_value = envelope_a_put_value.record()
|
||||
record_b_put_value = envelope_b_put_value.record()
|
||||
|
||||
# This proves that both the records are same, and a latest cached signed record
|
||||
# was passed between the peers during PUT_VALUE execution, which proves the
|
||||
# signed-record transfer/re-issuing works correctly in PUT_VALUE executions.
|
||||
assert record_a.seq == record_a_put_value.seq
|
||||
assert record_b.seq == record_b_put_value.seq
|
||||
|
||||
# # Log debugging information
|
||||
logger.debug("Put value with key %s...", key.hex()[:10])
|
||||
logger.debug("Node A value store: %s", dht_a.value_store.store)
|
||||
print("hello test")
|
||||
|
||||
# # Allow more time for the value to propagate
|
||||
await trio.sleep(0.5)
|
||||
@ -126,6 +206,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
print("the value stored in node b is", dht_b.get_value_store_size())
|
||||
logger.debug("Retrieved value: %s", retrieved_value)
|
||||
|
||||
# These are the records that were sent between the peers during the PUT_VALUE req
|
||||
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
|
||||
dht_b.host.get_id()
|
||||
)
|
||||
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
|
||||
dht_a.host.get_id()
|
||||
)
|
||||
|
||||
assert isinstance(envelope_a_get_value, Envelope)
|
||||
assert isinstance(envelope_b_get_value, Envelope)
|
||||
|
||||
record_a_get_value = envelope_a_get_value.record()
|
||||
record_b_get_value = envelope_b_get_value.record()
|
||||
|
||||
# This proves that there was no record exchange between the nodes during GET_VALUE
|
||||
# execution, as dht_b already had the key/value pair stored locally after the
|
||||
# PUT_VALUE execution.
|
||||
assert record_a_get_value.seq == record_a_put_value.seq
|
||||
assert record_b_get_value.seq == record_b_put_value.seq
|
||||
|
||||
# Verify that the retrieved value matches the original
|
||||
assert retrieved_value == value, "Retrieved value does not match the stored value"
|
||||
|
||||
@ -142,11 +242,44 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
# Store content on the first node
|
||||
dht_a.value_store.put(content_id, content)
|
||||
|
||||
# An extra FIND_NODE req is sent between the 2 nodes while dht creation,
|
||||
# so both the nodes will have records of each other before PUT_VALUE req is sent
|
||||
envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
|
||||
envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id())
|
||||
|
||||
assert isinstance(envelope_a, Envelope)
|
||||
assert isinstance(envelope_b, Envelope)
|
||||
|
||||
record_a = envelope_a.record()
|
||||
record_b = envelope_b.record()
|
||||
|
||||
# Advertise the first node as a provider
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
success = await dht_a.provide(content_id)
|
||||
assert success, "Failed to advertise as provider"
|
||||
|
||||
# These are the records that were sent between the peers during
|
||||
# the ADD_PROVIDER req
|
||||
envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record(
|
||||
dht_b.host.get_id()
|
||||
)
|
||||
envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record(
|
||||
dht_a.host.get_id()
|
||||
)
|
||||
|
||||
assert isinstance(envelope_a_add_prov, Envelope)
|
||||
assert isinstance(envelope_b_add_prov, Envelope)
|
||||
|
||||
record_a_add_prov = envelope_a_add_prov.record()
|
||||
record_b_add_prov = envelope_b_add_prov.record()
|
||||
|
||||
# This proves that both the records are same, the latest cached signed record
|
||||
# was passed between the peers during ADD_PROVIDER execution, which proves the
|
||||
# signed-record transfer/re-issuing of the latest record works correctly in
|
||||
# ADD_PROVIDER executions.
|
||||
assert record_a.seq == record_a_add_prov.seq
|
||||
assert record_b.seq == record_b_add_prov.seq
|
||||
|
||||
# Allow time for the provider record to propagate
|
||||
await trio.sleep(0.1)
|
||||
|
||||
@ -154,6 +287,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
with trio.fail_after(TEST_TIMEOUT):
|
||||
providers = await dht_b.find_providers(content_id)
|
||||
|
||||
# These are the records in each peer after the find_provider execution
|
||||
envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record(
|
||||
dht_b.host.get_id()
|
||||
)
|
||||
envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record(
|
||||
dht_a.host.get_id()
|
||||
)
|
||||
|
||||
assert isinstance(envelope_a_find_prov, Envelope)
|
||||
assert isinstance(envelope_b_find_prov, Envelope)
|
||||
|
||||
record_a_find_prov = envelope_a_find_prov.record()
|
||||
record_b_find_prov = envelope_b_find_prov.record()
|
||||
|
||||
# This proves that both the records are same, as the dht_b already
|
||||
# has the provider record for the content_id, after the ADD_PROVIDER
|
||||
# advertisement by dht_a
|
||||
assert record_a_find_prov.seq == record_a_add_prov.seq
|
||||
assert record_b_find_prov.seq == record_b_add_prov.seq
|
||||
|
||||
# Verify that we found the first node as a provider
|
||||
assert providers, "No providers found"
|
||||
assert any(p.peer_id == dht_a.local_peer_id for p in providers), (
|
||||
@ -166,3 +319,143 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
assert retrieved_value == content, (
|
||||
"Retrieved content does not match the original"
|
||||
)
|
||||
|
||||
# These are the record state of each peer aftet the GET_VALUE execution
|
||||
envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record(
|
||||
dht_b.host.get_id()
|
||||
)
|
||||
envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record(
|
||||
dht_a.host.get_id()
|
||||
)
|
||||
|
||||
assert isinstance(envelope_a_get_value, Envelope)
|
||||
assert isinstance(envelope_b_get_value, Envelope)
|
||||
|
||||
record_a_get_value = envelope_a_get_value.record()
|
||||
record_b_get_value = envelope_b_get_value.record()
|
||||
|
||||
# This proves that both the records are same, meaning that the latest cached
|
||||
# signed-record tranfer happened during the GET_VALUE execution by dht_b,
|
||||
# which means the signed-record transfer/re-issuing works correctly
|
||||
# in GET_VALUE executions.
|
||||
assert record_a_find_prov.seq == record_a_get_value.seq
|
||||
assert record_b_find_prov.seq == record_b_get_value.seq
|
||||
|
||||
# Create a new provider record in dht_a
|
||||
provider_key_pair = create_new_key_pair()
|
||||
provider_peer_id = ID.from_pubkey(provider_key_pair.public_key)
|
||||
provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
|
||||
provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr])
|
||||
|
||||
# Generate a random content ID
|
||||
content_2 = f"random-content-{uuid.uuid4()}".encode()
|
||||
content_id_2 = hashlib.sha256(content_2).digest()
|
||||
|
||||
provider_signed_envelope = create_signed_peer_record(
|
||||
provider_peer_id, [provider_addr], provider_key_pair.private_key
|
||||
)
|
||||
assert (
|
||||
dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200)
|
||||
is True
|
||||
)
|
||||
|
||||
# Store this provider record in dht_a
|
||||
dht_a.provider_store.add_provider(content_id_2, provider_peer_info)
|
||||
|
||||
# Fetch the provider-record via peer-discovery at dht_b's end
|
||||
peerinfo = await dht_b.provider_store.find_providers(content_id_2)
|
||||
|
||||
assert len(peerinfo) == 1
|
||||
assert peerinfo[0].peer_id == provider_peer_id
|
||||
provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id)
|
||||
|
||||
# This proves that the signed-envelope of provider is consumed on dht_b's end
|
||||
assert provider_envelope is not None
|
||||
assert (
|
||||
provider_signed_envelope.marshal_envelope()
|
||||
== provider_envelope.marshal_envelope()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]):
|
||||
dht_a, dht_b = dht_pair
|
||||
|
||||
# Warm-up: A stores B's current record
|
||||
with trio.fail_after(10):
|
||||
await dht_a.find_peer(dht_b.host.get_id())
|
||||
|
||||
env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
|
||||
assert isinstance(env0, Envelope)
|
||||
seq0 = env0.record().seq
|
||||
|
||||
# Simulate B's listen addrs changing (different port)
|
||||
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
|
||||
|
||||
# Patch just for the duration we force B to respond:
|
||||
with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]):
|
||||
# Force B to send a response (which should include a fresh SPR)
|
||||
with trio.fail_after(10):
|
||||
await dht_a.peer_routing._query_peer_for_closest(
|
||||
dht_b.host.get_id(), os.urandom(32)
|
||||
)
|
||||
|
||||
# A should now hold B's new record with a bumped seq
|
||||
env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id())
|
||||
assert isinstance(env1, Envelope)
|
||||
seq1 = env1.record().seq
|
||||
|
||||
# This proves that upon the change in listen_addrs, we issue new records
|
||||
assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dht_req_fail_with_invalid_record_transfer(
|
||||
dht_pair: tuple[KadDHT, KadDHT],
|
||||
):
|
||||
"""
|
||||
Testing showing failure of storing and retrieving values in the DHT,
|
||||
if invalid signed-records are sent.
|
||||
"""
|
||||
dht_a, dht_b = dht_pair
|
||||
peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs())
|
||||
|
||||
# Generate a random key and value
|
||||
key = create_key_from_binary(b"test-key")
|
||||
value = b"test-value"
|
||||
|
||||
# First add the value directly to node A's store to verify storage works
|
||||
dht_a.value_store.put(key, value)
|
||||
local_value = dht_a.value_store.get(key)
|
||||
assert local_value == value, "Local value storage failed"
|
||||
await dht_a.routing_table.add_peer(peer_b_info)
|
||||
|
||||
# Corrupt dht_a's local peer_record
|
||||
envelope = dht_a.host.get_peerstore().get_local_record()
|
||||
if envelope is not None:
|
||||
true_record = envelope.record()
|
||||
key_pair = create_new_key_pair()
|
||||
|
||||
if envelope is not None:
|
||||
envelope.public_key = key_pair.public_key
|
||||
dht_a.host.get_peerstore().set_local_record(envelope)
|
||||
|
||||
await dht_a.put_value(key, value)
|
||||
retrieved_value = dht_b.value_store.get(key)
|
||||
|
||||
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving
|
||||
# the corrupted invalid record
|
||||
assert retrieved_value is None
|
||||
|
||||
# Create a corrupt envelope with correct signature but false peer_id
|
||||
false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs)
|
||||
false_envelope = seal_record(false_record, dht_a.host.get_private_key())
|
||||
|
||||
dht_a.host.get_peerstore().set_local_record(false_envelope)
|
||||
|
||||
await dht_a.put_value(key, value)
|
||||
retrieved_value = dht_b.value_store.get(key)
|
||||
|
||||
# This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving
|
||||
# the record with a different peer_id regardless of a valid signature
|
||||
assert retrieved_value is None
|
||||
|
||||
@ -57,7 +57,10 @@ class TestPeerRouting:
|
||||
def mock_host(self):
|
||||
"""Create a mock host for testing."""
|
||||
host = Mock()
|
||||
host.get_id.return_value = create_valid_peer_id("local")
|
||||
key_pair = create_new_key_pair()
|
||||
host.get_id.return_value = ID.from_pubkey(key_pair.public_key)
|
||||
host.get_public_key.return_value = key_pair.public_key
|
||||
host.get_private_key.return_value = key_pair.private_key
|
||||
host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")]
|
||||
host.get_peerstore.return_value = Mock()
|
||||
host.new_stream = AsyncMock()
|
||||
|
||||
325
tests/core/network/test_enhanced_swarm.py
Normal file
325
tests/core/network/test_enhanced_swarm.py
Normal file
@ -0,0 +1,325 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import INetConn, INetStream
|
||||
from libp2p.network.exceptions import SwarmException
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
|
||||
class MockConnection(INetConn):
|
||||
"""Mock connection for testing."""
|
||||
|
||||
def __init__(self, peer_id: ID, is_closed: bool = False):
|
||||
self.peer_id = peer_id
|
||||
self._is_closed = is_closed
|
||||
self.streams = set() # Track streams properly
|
||||
# Mock the muxed_conn attribute that Swarm expects
|
||||
self.muxed_conn = Mock()
|
||||
self.muxed_conn.peer_id = peer_id
|
||||
# Required by INetConn interface
|
||||
self.event_started = trio.Event()
|
||||
|
||||
async def close(self):
|
||||
self._is_closed = True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._is_closed
|
||||
|
||||
async def new_stream(self) -> INetStream:
|
||||
# Create a mock stream and add it to the connection's stream set
|
||||
mock_stream = Mock(spec=INetStream)
|
||||
self.streams.add(mock_stream)
|
||||
return mock_stream
|
||||
|
||||
def get_streams(self) -> tuple[INetStream, ...]:
|
||||
"""Return all streams associated with this connection."""
|
||||
return tuple(self.streams)
|
||||
|
||||
def get_transport_addresses(self) -> list[Multiaddr]:
|
||||
"""Mock implementation of get_transport_addresses."""
|
||||
return []
|
||||
|
||||
|
||||
class MockNetStream(INetStream):
|
||||
"""Mock network stream for testing."""
|
||||
|
||||
def __init__(self, peer_id: ID):
|
||||
self.peer_id = peer_id
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_retry_config_defaults():
|
||||
"""Test RetryConfig default values."""
|
||||
config = RetryConfig()
|
||||
assert config.max_retries == 3
|
||||
assert config.initial_delay == 0.1
|
||||
assert config.max_delay == 30.0
|
||||
assert config.backoff_multiplier == 2.0
|
||||
assert config.jitter_factor == 0.1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_config_defaults():
|
||||
"""Test ConnectionConfig default values."""
|
||||
config = ConnectionConfig()
|
||||
assert config.max_connections_per_peer == 3
|
||||
assert config.connection_timeout == 30.0
|
||||
assert config.load_balancing_strategy == "round_robin"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_enhanced_swarm_constructor():
|
||||
"""Test enhanced Swarm constructor with new configuration."""
|
||||
# Create mock dependencies
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Test with default config
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
assert swarm.retry_config.max_retries == 3
|
||||
assert swarm.connection_config.max_connections_per_peer == 3
|
||||
assert isinstance(swarm.connections, dict)
|
||||
|
||||
# Test with custom config
|
||||
custom_retry = RetryConfig(max_retries=5, initial_delay=0.5)
|
||||
custom_conn = ConnectionConfig(max_connections_per_peer=5)
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn)
|
||||
assert swarm.retry_config.max_retries == 5
|
||||
assert swarm.retry_config.initial_delay == 0.5
|
||||
assert swarm.connection_config.max_connections_per_peer == 5
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_backoff_calculation():
|
||||
"""Test exponential backoff calculation with jitter."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
retry_config = RetryConfig(
|
||||
initial_delay=0.1, max_delay=1.0, backoff_multiplier=2.0, jitter_factor=0.1
|
||||
)
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
|
||||
|
||||
# Test backoff calculation
|
||||
delay1 = swarm._calculate_backoff_delay(0)
|
||||
delay2 = swarm._calculate_backoff_delay(1)
|
||||
delay3 = swarm._calculate_backoff_delay(2)
|
||||
|
||||
# Should increase exponentially
|
||||
assert delay2 > delay1
|
||||
assert delay3 > delay2
|
||||
|
||||
# Should respect max delay
|
||||
assert delay1 <= 1.0
|
||||
assert delay2 <= 1.0
|
||||
assert delay3 <= 1.0
|
||||
|
||||
# Should have jitter
|
||||
assert delay1 != 0.1 # Should have jitter added
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_retry_logic():
|
||||
"""Test retry logic in dial operations."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Configure for fast testing
|
||||
retry_config = RetryConfig(
|
||||
max_retries=2,
|
||||
initial_delay=0.01, # Very short for testing
|
||||
max_delay=0.1,
|
||||
)
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config)
|
||||
|
||||
# Mock the single attempt method to fail twice then succeed
|
||||
attempt_count = [0]
|
||||
|
||||
async def mock_single_attempt(addr, peer_id):
|
||||
attempt_count[0] += 1
|
||||
if attempt_count[0] < 3:
|
||||
raise SwarmException(f"Attempt {attempt_count[0]} failed")
|
||||
return MockConnection(peer_id)
|
||||
|
||||
swarm._dial_addr_single_attempt = mock_single_attempt
|
||||
|
||||
# Test retry logic
|
||||
start_time = time.time()
|
||||
result = await swarm._dial_with_retry(Mock(spec=Multiaddr), peer_id)
|
||||
end_time = time.time()
|
||||
|
||||
# Should have succeeded after 3 attempts
|
||||
assert attempt_count[0] == 3
|
||||
assert isinstance(result, MockConnection)
|
||||
assert end_time - start_time > 0.01 # Should have some delay
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_load_balancing_strategies():
|
||||
"""Test load balancing strategies."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Create mock connections with different stream counts
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
conn3 = MockConnection(peer_id)
|
||||
|
||||
# Add some streams to simulate load
|
||||
await conn1.new_stream()
|
||||
await conn1.new_stream()
|
||||
await conn2.new_stream()
|
||||
|
||||
connections = [conn1, conn2, conn3]
|
||||
|
||||
# Test round-robin strategy
|
||||
swarm.connection_config.load_balancing_strategy = "round_robin"
|
||||
# Cast to satisfy type checker
|
||||
connections_cast = cast("list[INetConn]", connections)
|
||||
selected1 = swarm._select_connection(connections_cast, peer_id)
|
||||
selected2 = swarm._select_connection(connections_cast, peer_id)
|
||||
selected3 = swarm._select_connection(connections_cast, peer_id)
|
||||
|
||||
# Should cycle through connections
|
||||
assert selected1 in connections
|
||||
assert selected2 in connections
|
||||
assert selected3 in connections
|
||||
|
||||
# Test least loaded strategy
|
||||
swarm.connection_config.load_balancing_strategy = "least_loaded"
|
||||
least_loaded = swarm._select_connection(connections_cast, peer_id)
|
||||
|
||||
# conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams
|
||||
# So conn3 should be selected as least loaded
|
||||
assert least_loaded == conn3
|
||||
|
||||
# Test default strategy (first connection)
|
||||
swarm.connection_config.load_balancing_strategy = "unknown"
|
||||
default_selected = swarm._select_connection(connections_cast, peer_id)
|
||||
assert default_selected == conn1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiple_connections_api():
|
||||
"""Test the new multiple connections API methods."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Test empty connections
|
||||
assert swarm.get_connections() == []
|
||||
assert swarm.get_connections(peer_id) == []
|
||||
assert swarm.get_connection(peer_id) is None
|
||||
assert swarm.get_connections_map() == {}
|
||||
|
||||
# Add some connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = [conn1, conn2]
|
||||
|
||||
# Test get_connections with peer_id
|
||||
peer_connections = swarm.get_connections(peer_id)
|
||||
assert len(peer_connections) == 2
|
||||
assert conn1 in peer_connections
|
||||
assert conn2 in peer_connections
|
||||
|
||||
# Test get_connections without peer_id (all connections)
|
||||
all_connections = swarm.get_connections()
|
||||
assert len(all_connections) == 2
|
||||
assert conn1 in all_connections
|
||||
assert conn2 in all_connections
|
||||
|
||||
# Test get_connection (backward compatibility)
|
||||
single_conn = swarm.get_connection(peer_id)
|
||||
assert single_conn in [conn1, conn2]
|
||||
|
||||
# Test get_connections_map
|
||||
connections_map = swarm.get_connections_map()
|
||||
assert peer_id in connections_map
|
||||
assert connections_map[peer_id] == [conn1, conn2]
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_connection_trimming():
|
||||
"""Test connection trimming when limit is exceeded."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
# Set max connections to 2
|
||||
connection_config = ConnectionConfig(max_connections_per_peer=2)
|
||||
swarm = Swarm(
|
||||
peer_id, peerstore, upgrader, transport, connection_config=connection_config
|
||||
)
|
||||
|
||||
# Add 3 connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
conn3 = MockConnection(peer_id)
|
||||
|
||||
swarm.connections[peer_id] = [conn1, conn2, conn3]
|
||||
|
||||
# Trigger trimming
|
||||
swarm._trim_connections(peer_id)
|
||||
|
||||
# Should have only 2 connections
|
||||
assert len(swarm.connections[peer_id]) == 2
|
||||
|
||||
# The most recent connections should remain
|
||||
remaining = swarm.connections[peer_id]
|
||||
assert conn2 in remaining
|
||||
assert conn3 in remaining
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_backward_compatibility():
|
||||
"""Test backward compatibility features."""
|
||||
peer_id = ID(b"QmTest")
|
||||
peerstore = Mock()
|
||||
upgrader = Mock()
|
||||
transport = Mock()
|
||||
|
||||
swarm = Swarm(peer_id, peerstore, upgrader, transport)
|
||||
|
||||
# Add connections
|
||||
conn1 = MockConnection(peer_id)
|
||||
conn2 = MockConnection(peer_id)
|
||||
swarm.connections[peer_id] = [conn1, conn2]
|
||||
|
||||
# Test connections_legacy property
|
||||
legacy_connections = swarm.connections_legacy
|
||||
assert peer_id in legacy_connections
|
||||
# Should return first connection
|
||||
assert legacy_connections[peer_id] in [conn1, conn2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
82
tests/core/network/test_notifee_performance.py
Normal file
82
tests/core/network/test_notifee_performance.py
Normal file
@ -0,0 +1,82 @@
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
INetConn,
|
||||
INetStream,
|
||||
INetwork,
|
||||
INotifee,
|
||||
)
|
||||
from libp2p.tools.utils import connect_swarm
|
||||
from tests.utils.factories import SwarmFactory
|
||||
|
||||
|
||||
class CountingNotifee(INotifee):
|
||||
def __init__(self, event: trio.Event) -> None:
|
||||
self._event = event
|
||||
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
self._event.set()
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SlowNotifee(INotifee):
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
await trio.sleep(0.5)
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_many_notifees_receive_connected_quickly() -> None:
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
count = 200
|
||||
events = [trio.Event() for _ in range(count)]
|
||||
for ev in events:
|
||||
swarms[0].register_notifee(CountingNotifee(ev))
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
with trio.fail_after(1.5):
|
||||
for ev in events:
|
||||
await ev.wait()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_slow_notifee_does_not_block_others() -> None:
|
||||
async with SwarmFactory.create_batch_and_listen(2) as swarms:
|
||||
fast_events = [trio.Event() for _ in range(20)]
|
||||
for ev in fast_events:
|
||||
swarms[0].register_notifee(CountingNotifee(ev))
|
||||
swarms[0].register_notifee(SlowNotifee())
|
||||
await connect_swarm(swarms[0], swarms[1])
|
||||
# Fast notifees should complete quickly despite one slow notifee
|
||||
with trio.fail_after(0.3):
|
||||
for ev in fast_events:
|
||||
await ev.wait()
|
||||
76
tests/core/network/test_notify_listen_lifecycle.py
Normal file
76
tests/core/network/test_notify_listen_lifecycle.py
Normal file
@ -0,0 +1,76 @@
|
||||
import enum
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
INetConn,
|
||||
INetStream,
|
||||
INetwork,
|
||||
INotifee,
|
||||
)
|
||||
from libp2p.tools.async_service import background_trio_service
|
||||
from libp2p.tools.constants import LISTEN_MADDR
|
||||
from tests.utils.factories import SwarmFactory
|
||||
|
||||
|
||||
class Event(enum.Enum):
|
||||
Listen = 0
|
||||
ListenClose = 1
|
||||
|
||||
|
||||
class MyNotifee(INotifee):
|
||||
def __init__(self, events: list[Event]):
|
||||
self.events = events
|
||||
|
||||
async def opened_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network: INetwork, stream: INetStream) -> None:
|
||||
pass
|
||||
|
||||
async def connected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def disconnected(self, network: INetwork, conn: INetConn) -> None:
|
||||
pass
|
||||
|
||||
async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
self.events.append(Event.Listen)
|
||||
|
||||
async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None:
|
||||
self.events.append(Event.ListenClose)
|
||||
|
||||
|
||||
async def wait_for_event(
|
||||
events_list: list[Event], event: Event, timeout: float = 1.0
|
||||
) -> bool:
|
||||
with trio.move_on_after(timeout):
|
||||
while event not in events_list:
|
||||
await trio.sleep(0.01)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listen_emitted_when_registered_before_listen():
|
||||
events: list[Event] = []
|
||||
swarm = SwarmFactory.build()
|
||||
swarm.register_notifee(MyNotifee(events))
|
||||
async with background_trio_service(swarm):
|
||||
# Start listening now; notifee was registered beforehand
|
||||
assert await swarm.listen(LISTEN_MADDR)
|
||||
assert await wait_for_event(events, Event.Listen)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_single_listener_close_emits_listen_close():
|
||||
events: list[Event] = []
|
||||
swarm = SwarmFactory.build()
|
||||
swarm.register_notifee(MyNotifee(events))
|
||||
async with background_trio_service(swarm):
|
||||
assert await swarm.listen(LISTEN_MADDR)
|
||||
# Explicitly notify listen_close (close path via manager doesn't emit it)
|
||||
await swarm.notify_listen_close(LISTEN_MADDR)
|
||||
assert await wait_for_event(events, Event.ListenClose)
|
||||
@ -16,6 +16,9 @@ from libp2p.network.exceptions import (
|
||||
from libp2p.network.swarm import (
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
background_trio_service,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect_swarm,
|
||||
)
|
||||
@ -48,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol):
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
|
||||
# New: dial_peer now returns list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Verify connections are established in both directions
|
||||
assert swarms[0].get_peer_id() in swarms[1].connections
|
||||
assert swarms[1].get_peer_id() in swarms[0].connections
|
||||
|
||||
# Test: Reuse connections when we already have ones with a peer.
|
||||
conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert conn is conn_to_1
|
||||
existing_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert new_connections == existing_connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@ -104,7 +112,8 @@ async def test_swarm_close_peer(security_protocol):
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_remove_conn(swarm_pair):
|
||||
swarm_0, swarm_1 = swarm_pair
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()]
|
||||
# Get the first connection from the list
|
||||
conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0]
|
||||
swarm_0.remove_conn(conn_0)
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
# Test: Remove twice. There should not be errors.
|
||||
@ -112,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair):
|
||||
assert swarm_1.get_peer_id() not in swarm_0.connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiple_connections(security_protocol):
|
||||
"""Test multiple connections per peer functionality."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup multiple addresses for peer
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Dial peer - should return list of connections
|
||||
connections = await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||
assert len(connections) > 0
|
||||
|
||||
# Test get_connections method
|
||||
peer_connections = swarms[0].get_connections(swarms[1].get_peer_id())
|
||||
assert len(peer_connections) == len(connections)
|
||||
|
||||
# Test get_connections_map method
|
||||
connections_map = swarms[0].get_connections_map()
|
||||
assert swarms[1].get_peer_id() in connections_map
|
||||
assert len(connections_map[swarms[1].get_peer_id()]) == len(connections)
|
||||
|
||||
# Test get_connection method (backward compatibility)
|
||||
single_conn = swarms[0].get_connection(swarms[1].get_peer_id())
|
||||
assert single_conn is not None
|
||||
assert single_conn in connections
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_load_balancing(security_protocol):
|
||||
"""Test load balancing across multiple connections."""
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
2, security_protocol=security_protocol
|
||||
) as swarms:
|
||||
# Setup connection
|
||||
addrs = tuple(
|
||||
addr
|
||||
for transport in swarms[1].listeners.values()
|
||||
for addr in transport.get_addrs()
|
||||
)
|
||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000)
|
||||
|
||||
# Create multiple streams - should use load balancing
|
||||
streams = []
|
||||
for _ in range(5):
|
||||
stream = await swarms[0].new_stream(swarms[1].get_peer_id())
|
||||
streams.append(stream)
|
||||
|
||||
# Verify streams were created successfully
|
||||
assert len(streams) == 5
|
||||
|
||||
# Clean up
|
||||
for stream in streams:
|
||||
await stream.close()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_multiaddr(security_protocol):
|
||||
async with SwarmFactory.create_batch_and_listen(
|
||||
@ -180,7 +250,123 @@ def test_new_swarm_tcp_multiaddr_supported():
|
||||
assert isinstance(swarm.transport, TCP)
|
||||
|
||||
|
||||
def test_new_swarm_quic_multiaddr_raises():
|
||||
def test_new_swarm_quic_multiaddr_supported():
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
|
||||
with pytest.raises(ValueError, match="QUIC not yet supported"):
|
||||
new_swarm(listen_addrs=[addr])
|
||||
swarm = new_swarm(listen_addrs=[addr])
|
||||
assert isinstance(swarm, Swarm)
|
||||
assert isinstance(swarm.transport, QUICTransport)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_listen_multiple_addresses(security_protocol):
|
||||
"""Test that swarm can listen on multiple addresses simultaneously."""
|
||||
from libp2p.utils.address_validation import get_available_interfaces
|
||||
|
||||
# Get multiple addresses to listen on
|
||||
listen_addrs = get_available_interfaces(0) # Let OS choose ports
|
||||
|
||||
# Create a swarm and listen on multiple addresses
|
||||
swarm = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm):
|
||||
# Listen on all addresses
|
||||
success = await swarm.listen(*listen_addrs)
|
||||
assert success, "Should successfully listen on at least one address"
|
||||
|
||||
# Check that we have listeners for the addresses
|
||||
actual_listeners = list(swarm.listeners.keys())
|
||||
assert len(actual_listeners) > 0, "Should have at least one listener"
|
||||
|
||||
# Verify that all successful listeners are in the listeners dict
|
||||
successful_count = 0
|
||||
for addr in listen_addrs:
|
||||
addr_str = str(addr)
|
||||
if addr_str in actual_listeners:
|
||||
successful_count += 1
|
||||
# This address successfully started listening
|
||||
listener = swarm.listeners[addr_str]
|
||||
listener_addrs = listener.get_addrs()
|
||||
assert len(listener_addrs) > 0, (
|
||||
f"Listener for {addr} should have addresses"
|
||||
)
|
||||
|
||||
# Check that the listener address matches the expected address
|
||||
# (port might be different if we used port 0)
|
||||
expected_ip = addr.value_for_protocol("ip4")
|
||||
expected_protocol = addr.value_for_protocol("tcp")
|
||||
if expected_ip and expected_protocol:
|
||||
found_matching = False
|
||||
for listener_addr in listener_addrs:
|
||||
if (
|
||||
listener_addr.value_for_protocol("ip4") == expected_ip
|
||||
and listener_addr.value_for_protocol("tcp") is not None
|
||||
):
|
||||
found_matching = True
|
||||
break
|
||||
assert found_matching, (
|
||||
f"Listener for {addr} should have matching IP"
|
||||
)
|
||||
|
||||
assert successful_count == len(listen_addrs), (
|
||||
f"All {len(listen_addrs)} addresses should be listening, "
|
||||
f"but only {successful_count} succeeded"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_swarm_listen_multiple_addresses_connectivity(security_protocol):
|
||||
"""Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.utils.address_validation import get_available_interfaces
|
||||
|
||||
# Get multiple addresses to listen on
|
||||
listen_addrs = get_available_interfaces(0) # Let OS choose ports
|
||||
|
||||
# Create a swarm and listen on multiple addresses
|
||||
swarm1 = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm1):
|
||||
# Listen on all addresses
|
||||
success = await swarm1.listen(*listen_addrs)
|
||||
assert success, "Should successfully listen on at least one address"
|
||||
|
||||
# Verify all available interfaces are listening
|
||||
assert len(swarm1.listeners) == len(listen_addrs), (
|
||||
f"All {len(listen_addrs)} interfaces should be listening, "
|
||||
f"but only {len(swarm1.listeners)} are"
|
||||
)
|
||||
|
||||
# Create a second swarm to test connections
|
||||
swarm2 = SwarmFactory.build(security_protocol=security_protocol)
|
||||
async with background_trio_service(swarm2):
|
||||
# Test connectivity to each listening address using real libp2p connections
|
||||
for addr_str, listener in swarm1.listeners.items():
|
||||
listener_addrs = listener.get_addrs()
|
||||
for listener_addr in listener_addrs:
|
||||
# Create a full multiaddr with peer ID for libp2p connection
|
||||
peer_id = swarm1.get_peer_id()
|
||||
full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}")
|
||||
|
||||
# Test real libp2p connection
|
||||
try:
|
||||
peer_info = info_from_p2p_addr(full_addr)
|
||||
|
||||
# Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501
|
||||
swarm2.peerstore.add_addrs(
|
||||
peer_info.peer_id, [listener_addr], 10000
|
||||
)
|
||||
|
||||
await swarm2.dial_peer(peer_info.peer_id)
|
||||
|
||||
# Verify connection was established
|
||||
assert peer_info.peer_id in swarm2.connections, (
|
||||
f"Connection to {full_addr} should be established"
|
||||
)
|
||||
assert swarm2.get_peer_id() in swarm1.connections, (
|
||||
f"Connection from {full_addr} should be established"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"Failed to establish libp2p connection to {full_addr}: {e}"
|
||||
)
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
import random
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import (
|
||||
PROTOCOL_ID,
|
||||
GossipSub,
|
||||
)
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
from libp2p.tools.utils import (
|
||||
connect,
|
||||
)
|
||||
@ -754,3 +761,173 @@ async def test_single_host():
|
||||
assert connected_peers == 0, (
|
||||
f"Single host has {connected_peers} connections, expected 0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_ihave(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsub_routers = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
if isinstance(pubsub.router, GossipSub):
|
||||
gossipsub_routers.append(pubsub.router)
|
||||
gossipsubs = tuple(gossipsub_routers)
|
||||
|
||||
index_alice = 0
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
|
||||
# Connect Alice and Bob
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
await trio.sleep(0.1) # Allow connections to establish
|
||||
|
||||
# Mock emit_iwant to capture calls
|
||||
mock_emit_iwant = AsyncMock()
|
||||
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
|
||||
|
||||
# Create a test message ID as a string representation of a (seqno, from) tuple
|
||||
test_seqno = b"1234"
|
||||
test_from = id_bob.to_bytes()
|
||||
test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')"
|
||||
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id])
|
||||
|
||||
# Mock seen_messages.cache to avoid false positives
|
||||
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
|
||||
|
||||
# Simulate Bob sending IHAVE to Alice
|
||||
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
|
||||
|
||||
# Check if emit_iwant was called with the correct message ID
|
||||
mock_emit_iwant.assert_called_once()
|
||||
called_args = mock_emit_iwant.call_args[0]
|
||||
assert called_args[0] == [test_msg_id] # Expected message IDs
|
||||
assert called_args[1] == id_bob # Sender peer ID
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_iwant(monkeypatch):
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsub_routers = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
if isinstance(pubsub.router, GossipSub):
|
||||
gossipsub_routers.append(pubsub.router)
|
||||
gossipsubs = tuple(gossipsub_routers)
|
||||
|
||||
index_alice = 0
|
||||
index_bob = 1
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
|
||||
# Connect Alice and Bob
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
await trio.sleep(0.1) # Allow connections to establish
|
||||
|
||||
# Mock mcache.get to return a message
|
||||
test_message = rpc_pb2.Message(data=b"test_data")
|
||||
test_seqno = b"1234"
|
||||
test_from = id_alice.to_bytes()
|
||||
|
||||
# ✅ Correct: use raw tuple and str() to serialize, no hex()
|
||||
test_msg_id = str((test_seqno, test_from))
|
||||
|
||||
mock_mcache_get = MagicMock(return_value=test_message)
|
||||
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
|
||||
|
||||
# Mock write_msg to capture the sent packet
|
||||
mock_write_msg = AsyncMock()
|
||||
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
|
||||
|
||||
# Simulate Alice sending IWANT to Bob
|
||||
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id])
|
||||
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
||||
|
||||
# Check if write_msg was called with the correct packet
|
||||
mock_write_msg.assert_called_once()
|
||||
packet = mock_write_msg.call_args[0][1]
|
||||
assert isinstance(packet, rpc_pb2.RPC)
|
||||
assert len(packet.publish) == 1
|
||||
assert packet.publish[0] == test_message
|
||||
|
||||
# Verify that mcache.get was called with the correct parsed message ID
|
||||
mock_mcache_get.assert_called_once()
|
||||
called_msg_id = mock_mcache_get.call_args[0][0]
|
||||
assert isinstance(called_msg_id, tuple)
|
||||
assert called_msg_id == (test_seqno, test_from)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_iwant_invalid_msg_id(monkeypatch):
|
||||
"""
|
||||
Test that handle_iwant raises ValueError for malformed message IDs.
|
||||
"""
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsub_routers = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
if isinstance(pubsub.router, GossipSub):
|
||||
gossipsub_routers.append(pubsub.router)
|
||||
gossipsubs = tuple(gossipsub_routers)
|
||||
|
||||
index_alice = 0
|
||||
index_bob = 1
|
||||
id_alice = pubsubs_gsub[index_alice].my_id
|
||||
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Malformed message ID (not a tuple string)
|
||||
malformed_msg_id = "not_a_valid_msg_id"
|
||||
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id])
|
||||
|
||||
# Mock mcache.get and write_msg to ensure they are not called
|
||||
mock_mcache_get = MagicMock()
|
||||
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
|
||||
mock_write_msg = AsyncMock()
|
||||
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
||||
mock_mcache_get.assert_not_called()
|
||||
mock_write_msg.assert_not_called()
|
||||
|
||||
# Message ID that's a tuple string but not (bytes, bytes)
|
||||
invalid_tuple_msg_id = "('abc', 123)"
|
||||
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id])
|
||||
with pytest.raises(ValueError):
|
||||
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
|
||||
mock_mcache_get.assert_not_called()
|
||||
mock_write_msg.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_handle_ihave_empty_message_ids(monkeypatch):
|
||||
"""
|
||||
Test that handle_ihave with an empty messageIDs list does not call emit_iwant.
|
||||
"""
|
||||
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
|
||||
gossipsub_routers = []
|
||||
for pubsub in pubsubs_gsub:
|
||||
if isinstance(pubsub.router, GossipSub):
|
||||
gossipsub_routers.append(pubsub.router)
|
||||
gossipsubs = tuple(gossipsub_routers)
|
||||
|
||||
index_alice = 0
|
||||
index_bob = 1
|
||||
id_bob = pubsubs_gsub[index_bob].my_id
|
||||
|
||||
# Connect Alice and Bob
|
||||
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
|
||||
await trio.sleep(0.1) # Allow connections to establish
|
||||
|
||||
# Mock emit_iwant to capture calls
|
||||
mock_emit_iwant = AsyncMock()
|
||||
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
|
||||
|
||||
# Empty messageIDs list
|
||||
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[])
|
||||
|
||||
# Mock seen_messages.cache to avoid false positives
|
||||
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
|
||||
|
||||
# Simulate Bob sending IHAVE to Alice
|
||||
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
|
||||
|
||||
# emit_iwant should not be called since there are no message IDs
|
||||
mock_emit_iwant.assert_not_called()
|
||||
|
||||
@ -8,8 +8,10 @@ from typing import (
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.rsa import create_new_key_pair
|
||||
from libp2p.custom_types import AsyncValidatorFn
|
||||
from libp2p.exceptions import (
|
||||
ValidationError,
|
||||
@ -17,9 +19,11 @@ from libp2p.exceptions import (
|
||||
from libp2p.network.stream.exceptions import (
|
||||
StreamEOF,
|
||||
)
|
||||
from libp2p.peer.envelope import Envelope, seal_record
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peer_record import PeerRecord
|
||||
from libp2p.pubsub.pb import (
|
||||
rpc_pb2,
|
||||
)
|
||||
@ -87,6 +91,45 @@ async def test_re_unsubscribe():
|
||||
assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_reissue_when_listen_addrs_change():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
# Check whether signed-records were transfered properly in the subscribe call
|
||||
envelope_b_sub = (
|
||||
pubsubs_fsub[1]
|
||||
.host.get_peerstore()
|
||||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||||
)
|
||||
assert isinstance(envelope_b_sub, Envelope)
|
||||
|
||||
# Simulate pubsubs_fsub[1].host listen addrs changing (different port)
|
||||
new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123")
|
||||
|
||||
# Patch just for the duration we force A to unsubscribe
|
||||
with patch.object(pubsubs_fsub[0].host, "get_addrs", return_value=[new_addr]):
|
||||
# Unsubscribe from A's side so that a new_record is issued
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
await trio.sleep(1)
|
||||
|
||||
# B should be holding A's new record with bumped seq
|
||||
envelope_b_unsub = (
|
||||
pubsubs_fsub[1]
|
||||
.host.get_peerstore()
|
||||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||||
)
|
||||
assert isinstance(envelope_b_unsub, Envelope)
|
||||
|
||||
# This proves that a freshly signed record was issued rather than
|
||||
# the latest-cached-one creating one.
|
||||
assert envelope_b_sub.record().seq < envelope_b_unsub.record().seq
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_peers_subscribe():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
@ -95,11 +138,71 @@ async def test_peers_subscribe():
|
||||
# Yield to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
# Check whether signed-records were transfered properly in the subscribe call
|
||||
envelope_b_sub = (
|
||||
pubsubs_fsub[1]
|
||||
.host.get_peerstore()
|
||||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||||
)
|
||||
assert isinstance(envelope_b_sub, Envelope)
|
||||
|
||||
await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC)
|
||||
# Yield to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC]
|
||||
|
||||
envelope_b_unsub = (
|
||||
pubsubs_fsub[1]
|
||||
.host.get_peerstore()
|
||||
.get_peer_record(pubsubs_fsub[0].host.get_id())
|
||||
)
|
||||
assert isinstance(envelope_b_unsub, Envelope)
|
||||
|
||||
# This proves that the latest-cached-record was re-issued rather than
|
||||
# freshly creating one.
|
||||
assert envelope_b_sub.record().seq == envelope_b_unsub.record().seq
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_peer_subscribe_fail_upon_invald_record_transfer():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host)
|
||||
|
||||
# Corrupt host_a's local peer record
|
||||
envelope = pubsubs_fsub[0].host.get_peerstore().get_local_record()
|
||||
if envelope is not None:
|
||||
true_record = envelope.record()
|
||||
key_pair = create_new_key_pair()
|
||||
|
||||
if envelope is not None:
|
||||
envelope.public_key = key_pair.public_key
|
||||
pubsubs_fsub[0].host.get_peerstore().set_local_record(envelope)
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yeild to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
|
||||
TESTING_TOPIC, set()
|
||||
)
|
||||
|
||||
# Create a corrupt envelope with correct signature but false peer-id
|
||||
false_record = PeerRecord(
|
||||
ID.from_pubkey(key_pair.public_key), true_record.addrs
|
||||
)
|
||||
false_envelope = seal_record(
|
||||
false_record, pubsubs_fsub[0].host.get_private_key()
|
||||
)
|
||||
|
||||
pubsubs_fsub[0].host.get_peerstore().set_local_record(false_envelope)
|
||||
|
||||
await pubsubs_fsub[0].subscribe(TESTING_TOPIC)
|
||||
# Yeild to let 0 notify 1
|
||||
await trio.sleep(1)
|
||||
assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get(
|
||||
TESTING_TOPIC, set()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_hello_packet():
|
||||
|
||||
90
tests/core/pubsub/test_pubsub_notifee_integration.py
Normal file
90
tests/core/pubsub/test_pubsub_notifee_integration.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.tools.utils import connect
|
||||
from tests.utils.factories import PubsubFactory
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connected_enqueues_and_adds_peer():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Wait until peer is added via queue processing
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
assert p1.my_id in p0.peers
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_disconnected_enqueues_and_removes_peer():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Ensure present first
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
# Now disconnect and expect removal via dead peer queue
|
||||
await p0.host.get_network().close_peer(p1.host.get_id())
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
assert p1.my_id not in p0.peers
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_channel_closed_is_swallowed_in_notifee(monkeypatch) -> None:
|
||||
# Ensure PubsubNotifee catches BrokenResourceError from its send channel
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
# Find the PubsubNotifee registered on the network
|
||||
from libp2p.pubsub.pubsub_notifee import PubsubNotifee
|
||||
|
||||
network = p0.host.get_network()
|
||||
notifees = getattr(network, "notifees", [])
|
||||
target = None
|
||||
for nf in notifees:
|
||||
if isinstance(nf, cast(type, PubsubNotifee)):
|
||||
target = nf
|
||||
break
|
||||
assert target is not None, "PubsubNotifee not found on network"
|
||||
|
||||
async def failing_send(_peer_id): # type: ignore[no-redef]
|
||||
raise trio.BrokenResourceError
|
||||
|
||||
# Make initiator queue send fail; PubsubNotifee should swallow
|
||||
monkeypatch.setattr(target.initiator_peers_queue, "send", failing_send)
|
||||
|
||||
# Connect peers; if exceptions are swallowed, service stays running
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_duplicate_connection_does_not_duplicate_peer_state():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
with trio.fail_after(1.0):
|
||||
while p1.my_id not in p0.peers:
|
||||
await trio.sleep(0.01)
|
||||
# Connect again should not add duplicates
|
||||
await connect(p0.host, p1.host)
|
||||
await trio.sleep(0.1)
|
||||
assert list(p0.peers.keys()).count(p1.my_id) == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_blocks_peer_added_by_notifee():
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1):
|
||||
# Blacklist before connecting
|
||||
p0.add_to_blacklist(p1.my_id)
|
||||
await connect(p0.host, p1.host)
|
||||
await p0.wait_until_ready()
|
||||
# Give handler a chance to run
|
||||
await trio.sleep(0.1)
|
||||
assert p1.my_id not in p0.peers
|
||||
@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol):
|
||||
|
||||
# Extract the secured connection from either Mplex or Yamux implementation
|
||||
def get_secured_conn(conn):
|
||||
# conn is now a list, get the first connection
|
||||
if isinstance(conn, list):
|
||||
conn = conn[0]
|
||||
muxed_conn = conn.muxed_conn
|
||||
# Direct attribute access for known implementations
|
||||
has_secured_conn = hasattr(muxed_conn, "secured_conn")
|
||||
|
||||
@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default):
|
||||
assert len(connections) > 0, "Connection not established"
|
||||
|
||||
# Get the first connection
|
||||
conn = list(connections.values())[0]
|
||||
conns = list(connections.values())[0]
|
||||
conn = conns[0] # Get first connection from the list
|
||||
muxed_conn = conn.muxed_conn
|
||||
|
||||
# Define a simple echo protocol
|
||||
|
||||
0
tests/core/transport/quic/test_concurrency.py
Normal file
0
tests/core/transport/quic/test_concurrency.py
Normal file
553
tests/core/transport/quic/test_connection.py
Normal file
553
tests/core/transport/quic/test_connection.py
Normal file
@ -0,0 +1,553 @@
|
||||
"""
|
||||
Enhanced tests for QUIC connection functionality - Module 3.
|
||||
Tests all new features including advanced stream management, resource management,
|
||||
error handling, and concurrent operations.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICConnectionClosedError,
|
||||
QUICConnectionError,
|
||||
QUICConnectionTimeoutError,
|
||||
QUICPeerVerificationError,
|
||||
QUICStreamLimitError,
|
||||
QUICStreamTimeoutError,
|
||||
)
|
||||
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||
from libp2p.transport.quic.stream import QUICStream, StreamDirection
|
||||
|
||||
|
||||
class MockResourceScope:
|
||||
"""Mock resource scope for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_reserved = 0
|
||||
|
||||
def reserve_memory(self, size):
|
||||
self.memory_reserved += size
|
||||
|
||||
def release_memory(self, size):
|
||||
self.memory_reserved = max(0, self.memory_reserved - size)
|
||||
|
||||
|
||||
class TestQUICConnection:
|
||||
"""Test suite for QUIC connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_connection(self):
|
||||
"""Create mock aioquic QuicConnection."""
|
||||
mock = Mock()
|
||||
mock.next_event.return_value = None
|
||||
mock.datagrams_to_send.return_value = []
|
||||
mock.get_timer.return_value = None
|
||||
mock.connect = Mock()
|
||||
mock.close = Mock()
|
||||
mock.send_stream_data = Mock()
|
||||
mock.reset_stream = Mock()
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_transport(self):
|
||||
mock = Mock()
|
||||
mock._config = QUICTransportConfig()
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_resource_scope(self):
|
||||
"""Create mock resource scope."""
|
||||
return MockResourceScope()
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(
|
||||
self,
|
||||
mock_quic_connection: Mock,
|
||||
mock_quic_transport: Mock,
|
||||
mock_resource_scope: MockResourceScope,
|
||||
):
|
||||
"""Create test QUIC connection with enhanced features."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
mock_security_manager = Mock()
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=None,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=mock_quic_transport,
|
||||
resource_scope=mock_resource_scope,
|
||||
security_manager=mock_security_manager,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def server_connection(self, mock_quic_connection, mock_resource_scope):
|
||||
"""Create server-side QUIC connection."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=False,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
resource_scope=mock_resource_scope,
|
||||
)
|
||||
|
||||
# Basic functionality tests
|
||||
|
||||
def test_connection_initialization_enhanced(
|
||||
self, quic_connection, mock_resource_scope
|
||||
):
|
||||
"""Test enhanced connection initialization."""
|
||||
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
|
||||
assert quic_connection.is_initiator is True
|
||||
assert not quic_connection.is_closed
|
||||
assert not quic_connection.is_established
|
||||
assert len(quic_connection._streams) == 0
|
||||
assert quic_connection._resource_scope == mock_resource_scope
|
||||
assert quic_connection._outbound_stream_count == 0
|
||||
assert quic_connection._inbound_stream_count == 0
|
||||
assert len(quic_connection._stream_accept_queue) == 0
|
||||
|
||||
def test_stream_id_calculation_enhanced(self):
|
||||
"""Test enhanced stream ID calculation for client/server."""
|
||||
# Client connection (initiator)
|
||||
client_conn = QUICConnection(
|
||||
quic_connection=Mock(),
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=None,
|
||||
local_peer_id=Mock(),
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
assert client_conn._next_stream_id == 0 # Client starts with 0
|
||||
|
||||
# Server connection (not initiator)
|
||||
server_conn = QUICConnection(
|
||||
quic_connection=Mock(),
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=None,
|
||||
local_peer_id=Mock(),
|
||||
is_initiator=False,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
assert server_conn._next_stream_id == 1 # Server starts with 1
|
||||
|
||||
def test_incoming_stream_detection_enhanced(self, quic_connection):
|
||||
"""Test enhanced incoming stream detection logic."""
|
||||
# For client (initiator), odd stream IDs are incoming
|
||||
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
|
||||
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
|
||||
|
||||
# Stream management tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_basic(self, quic_connection):
|
||||
"""Test basic stream opening."""
|
||||
quic_connection._started = True
|
||||
|
||||
stream = await quic_connection.open_stream()
|
||||
|
||||
assert isinstance(stream, QUICStream)
|
||||
assert stream.stream_id == "0"
|
||||
assert stream.direction == StreamDirection.OUTBOUND
|
||||
assert 0 in quic_connection._streams
|
||||
assert quic_connection._outbound_stream_count == 1
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_limit_reached(self, quic_connection):
|
||||
"""Test stream limit enforcement."""
|
||||
quic_connection._started = True
|
||||
quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS
|
||||
|
||||
with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"):
|
||||
await quic_connection.open_stream()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_timeout(self, quic_connection: QUICConnection):
|
||||
"""Test stream opening timeout."""
|
||||
quic_connection._started = True
|
||||
return
|
||||
|
||||
# Mock the stream ID lock to simulate slow operation
|
||||
async def slow_acquire():
|
||||
await trio.sleep(10) # Longer than timeout
|
||||
|
||||
with patch.object(
|
||||
quic_connection._stream_lock, "acquire", side_effect=slow_acquire
|
||||
):
|
||||
with pytest.raises(
|
||||
QUICStreamTimeoutError, match="Stream creation timed out"
|
||||
):
|
||||
await quic_connection.open_stream(timeout=0.1)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_basic(self, quic_connection):
|
||||
"""Test basic stream acceptance."""
|
||||
# Create a mock inbound stream
|
||||
mock_stream = Mock(spec=QUICStream)
|
||||
mock_stream.stream_id = "1"
|
||||
|
||||
# Add to accept queue
|
||||
quic_connection._stream_accept_queue.append(mock_stream)
|
||||
quic_connection._stream_accept_event.set()
|
||||
|
||||
accepted_stream = await quic_connection.accept_stream(timeout=0.1)
|
||||
|
||||
assert accepted_stream == mock_stream
|
||||
assert len(quic_connection._stream_accept_queue) == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_timeout(self, quic_connection):
|
||||
"""Test stream acceptance timeout."""
|
||||
with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"):
|
||||
await quic_connection.accept_stream(timeout=0.1)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_on_closed_connection(self, quic_connection):
|
||||
"""Test stream acceptance on closed connection."""
|
||||
await quic_connection.close()
|
||||
|
||||
with pytest.raises(QUICConnectionClosedError, match="Connection is closed"):
|
||||
await quic_connection.accept_stream()
|
||||
|
||||
# Stream handler tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_handler_setting(self, quic_connection):
|
||||
"""Test setting stream handler."""
|
||||
|
||||
async def mock_handler(stream):
|
||||
pass
|
||||
|
||||
quic_connection.set_stream_handler(mock_handler)
|
||||
assert quic_connection._stream_handler == mock_handler
|
||||
|
||||
# Connection lifecycle tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_client(self, quic_connection):
|
||||
"""Test client connection start."""
|
||||
with patch.object(
|
||||
quic_connection, "_initiate_connection", new_callable=AsyncMock
|
||||
) as mock_initiate:
|
||||
await quic_connection.start()
|
||||
|
||||
assert quic_connection._started
|
||||
mock_initiate.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_server(self, server_connection):
|
||||
"""Test server connection start."""
|
||||
await server_connection.start()
|
||||
|
||||
assert server_connection._started
|
||||
assert server_connection._established
|
||||
assert server_connection._connected_event.is_set()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_already_started(self, quic_connection):
|
||||
"""Test starting already started connection."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Should not raise error, just log warning
|
||||
await quic_connection.start()
|
||||
assert quic_connection._started
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_closed(self, quic_connection):
|
||||
"""Test starting closed connection."""
|
||||
quic_connection._closed = True
|
||||
|
||||
with pytest.raises(
|
||||
QUICConnectionError, match="Cannot start a closed connection"
|
||||
):
|
||||
await quic_connection.start()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_connect_with_nursery(
|
||||
self, quic_connection: QUICConnection
|
||||
):
|
||||
"""Test connection establishment with nursery."""
|
||||
quic_connection._started = True
|
||||
quic_connection._established = True
|
||||
quic_connection._connected_event.set()
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||
) as mock_start_tasks:
|
||||
with patch.object(
|
||||
quic_connection,
|
||||
"_verify_peer_identity_with_security",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_verify:
|
||||
async with trio.open_nursery() as nursery:
|
||||
await quic_connection.connect(nursery)
|
||||
|
||||
assert quic_connection._nursery == nursery
|
||||
mock_start_tasks.assert_called_once()
|
||||
mock_verify.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_connection_connect_timeout(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test connection establishment timeout."""
|
||||
quic_connection._started = True
|
||||
# Don't set connected event to simulate timeout
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
with pytest.raises(
|
||||
QUICConnectionTimeoutError, match="Connection handshake timed out"
|
||||
):
|
||||
await quic_connection.connect(nursery)
|
||||
|
||||
# Resource management tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_removal_resource_cleanup(
|
||||
self, quic_connection: QUICConnection, mock_resource_scope
|
||||
):
|
||||
"""Test stream removal and resource cleanup."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create a stream
|
||||
stream = await quic_connection.open_stream()
|
||||
|
||||
# Remove the stream
|
||||
quic_connection._remove_stream(int(stream.stream_id))
|
||||
|
||||
assert int(stream.stream_id) not in quic_connection._streams
|
||||
# Note: Count updates is async, so we can't test it directly here
|
||||
|
||||
# Error handling tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_error_handling(self, quic_connection) -> None:
|
||||
"""Test connection error handling."""
|
||||
error = Exception("Test error")
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "close", new_callable=AsyncMock
|
||||
) as mock_close:
|
||||
await quic_connection._handle_connection_error(error)
|
||||
mock_close.assert_called_once()
|
||||
|
||||
# Statistics and monitoring tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_stats_enhanced(self, quic_connection) -> None:
|
||||
"""Test enhanced connection statistics."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create some streams
|
||||
_stream1 = await quic_connection.open_stream()
|
||||
_stream2 = await quic_connection.open_stream()
|
||||
|
||||
stats = quic_connection.get_stream_stats()
|
||||
|
||||
expected_keys = [
|
||||
"total_streams",
|
||||
"outbound_streams",
|
||||
"inbound_streams",
|
||||
"max_streams",
|
||||
"stream_utilization",
|
||||
"stats",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in stats
|
||||
|
||||
assert stats["total_streams"] == 2
|
||||
assert stats["outbound_streams"] == 2
|
||||
assert stats["inbound_streams"] == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_active_streams(self, quic_connection) -> None:
|
||||
"""Test getting active streams."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create streams
|
||||
stream1 = await quic_connection.open_stream()
|
||||
stream2 = await quic_connection.open_stream()
|
||||
|
||||
active_streams = quic_connection.get_active_streams()
|
||||
|
||||
assert len(active_streams) == 2
|
||||
assert stream1 in active_streams
|
||||
assert stream2 in active_streams
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_streams_by_protocol(self, quic_connection) -> None:
|
||||
"""Test getting streams by protocol."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create streams with different protocols
|
||||
stream1 = await quic_connection.open_stream()
|
||||
stream1.protocol = "/test/1.0.0"
|
||||
|
||||
stream2 = await quic_connection.open_stream()
|
||||
stream2.protocol = "/other/1.0.0"
|
||||
|
||||
test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0")
|
||||
other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0")
|
||||
|
||||
assert len(test_streams) == 1
|
||||
assert len(other_streams) == 1
|
||||
assert stream1 in test_streams
|
||||
assert stream2 in other_streams
|
||||
|
||||
# Enhanced close tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_close_enhanced(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test enhanced connection close with stream cleanup."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create some streams
|
||||
_stream1 = await quic_connection.open_stream()
|
||||
_stream2 = await quic_connection.open_stream()
|
||||
|
||||
await quic_connection.close()
|
||||
|
||||
assert quic_connection.is_closed
|
||||
assert len(quic_connection._streams) == 0
|
||||
|
||||
# Concurrent operations tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_concurrent_stream_operations(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test concurrent stream operations."""
|
||||
quic_connection._started = True
|
||||
|
||||
async def create_stream():
|
||||
return await quic_connection.open_stream()
|
||||
|
||||
# Create multiple streams concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(10):
|
||||
nursery.start_soon(create_stream)
|
||||
|
||||
# Wait a bit for all to start
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Should have created streams without conflicts
|
||||
assert quic_connection._outbound_stream_count == 10
|
||||
assert len(quic_connection._streams) == 10
|
||||
|
||||
# Connection properties tests
|
||||
|
||||
def test_connection_properties(self, quic_connection: QUICConnection) -> None:
|
||||
"""Test connection property accessors."""
|
||||
assert quic_connection.multiaddr() == quic_connection._maddr
|
||||
assert quic_connection.local_peer_id() == quic_connection._local_peer_id
|
||||
assert quic_connection.remote_peer_id() == quic_connection._remote_peer_id
|
||||
|
||||
# IRawConnection interface tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_raw_connection_write(self, quic_connection: QUICConnection) -> None:
|
||||
"""Test raw connection write interface."""
|
||||
quic_connection._started = True
|
||||
|
||||
with patch.object(quic_connection, "open_stream") as mock_open:
|
||||
mock_stream = AsyncMock()
|
||||
mock_open.return_value = mock_stream
|
||||
|
||||
await quic_connection.write(b"test data")
|
||||
|
||||
mock_open.assert_called_once()
|
||||
mock_stream.write.assert_called_once_with(b"test data")
|
||||
mock_stream.close_write.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_raw_connection_read_not_implemented(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test raw connection read raises NotImplementedError."""
|
||||
with pytest.raises(NotImplementedError):
|
||||
await quic_connection.read()
|
||||
|
||||
# Mock verification helpers
|
||||
|
||||
def test_mock_resource_scope_functionality(self, mock_resource_scope) -> None:
|
||||
"""Test mock resource scope works correctly."""
|
||||
assert mock_resource_scope.memory_reserved == 0
|
||||
|
||||
mock_resource_scope.reserve_memory(1000)
|
||||
assert mock_resource_scope.memory_reserved == 1000
|
||||
|
||||
mock_resource_scope.reserve_memory(500)
|
||||
assert mock_resource_scope.memory_reserved == 1500
|
||||
|
||||
mock_resource_scope.release_memory(600)
|
||||
assert mock_resource_scope.memory_reserved == 900
|
||||
|
||||
mock_resource_scope.release_memory(2000) # Should not go negative
|
||||
assert mock_resource_scope.memory_reserved == 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_invalid_certificate_verification():
|
||||
key_pair1 = create_new_key_pair()
|
||||
key_pair2 = create_new_key_pair()
|
||||
|
||||
peer_id1 = ID.from_pubkey(key_pair1.public_key)
|
||||
peer_id2 = ID.from_pubkey(key_pair2.public_key)
|
||||
|
||||
manager = QUICTLSConfigManager(
|
||||
libp2p_private_key=key_pair1.private_key, peer_id=peer_id1
|
||||
)
|
||||
|
||||
# Match the certificate against a different peer_id
|
||||
with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"):
|
||||
manager.verify_peer_identity(manager.tls_config.certificate, peer_id2)
|
||||
|
||||
from cryptography.hazmat.primitives.serialization import Encoding
|
||||
|
||||
# --- Corrupt the certificate by tampering the DER bytes ---
|
||||
cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER)
|
||||
corrupted_bytes = bytearray(cert_bytes)
|
||||
|
||||
# Flip some random bytes in the middle of the certificate
|
||||
corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# This will still parse (structurally valid), but the signature
|
||||
# or fingerprint will break
|
||||
corrupted_cert = x509.load_der_x509_certificate(
|
||||
bytes(corrupted_bytes), backend=default_backend()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
QUICPeerVerificationError, match="Certificate verification failed"
|
||||
):
|
||||
manager.verify_peer_identity(corrupted_cert, peer_id1)
|
||||
624
tests/core/transport/quic/test_connection_id.py
Normal file
624
tests/core/transport/quic/test_connection_id.py
Normal file
@ -0,0 +1,624 @@
|
||||
"""
|
||||
QUIC Connection ID Management Tests
|
||||
|
||||
This test module covers comprehensive testing of QUIC connection ID functionality
|
||||
including generation, rotation, retirement, and validation according to RFC 9000.
|
||||
|
||||
Tests are organized into:
|
||||
1. Basic Connection ID Management
|
||||
2. Connection ID Rotation and Updates
|
||||
3. Connection ID Retirement
|
||||
4. Error Conditions and Edge Cases
|
||||
5. Integration Tests with Real Connections
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from aioquic.buffer import Buffer
|
||||
|
||||
# Import aioquic components for low-level testing
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection, QuicConnectionId
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
|
||||
class ConnectionIdTestHelper:
|
||||
"""Helper class for connection ID testing utilities."""
|
||||
|
||||
@staticmethod
|
||||
def generate_connection_id(length: int = 8) -> bytes:
|
||||
"""Generate a random connection ID of specified length."""
|
||||
return secrets.token_bytes(length)
|
||||
|
||||
@staticmethod
|
||||
def create_quic_connection_id(cid: bytes, sequence: int = 0) -> QuicConnectionId:
|
||||
"""Create a QuicConnectionId object."""
|
||||
return QuicConnectionId(
|
||||
cid=cid,
|
||||
sequence_number=sequence,
|
||||
stateless_reset_token=secrets.token_bytes(16),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_connection_ids_from_connection(conn: QUICConnection) -> dict[str, Any]:
|
||||
"""Extract connection ID information from a QUIC connection."""
|
||||
quic = conn._quic
|
||||
return {
|
||||
"host_cids": [cid.cid.hex() for cid in getattr(quic, "_host_cids", [])],
|
||||
"peer_cid": getattr(quic, "_peer_cid", None),
|
||||
"peer_cid_available": [
|
||||
cid.cid.hex() for cid in getattr(quic, "_peer_cid_available", [])
|
||||
],
|
||||
"retire_connection_ids": getattr(quic, "_retire_connection_ids", []),
|
||||
"host_cid_seq": getattr(quic, "_host_cid_seq", 0),
|
||||
}
|
||||
|
||||
|
||||
class TestBasicConnectionIdManagement:
|
||||
"""Test basic connection ID management functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_connection(self):
|
||||
"""Create a mock QUIC connection with connection ID support."""
|
||||
mock_quic = Mock(spec=QuicConnection)
|
||||
mock_quic._host_cids = []
|
||||
mock_quic._host_cid_seq = 0
|
||||
mock_quic._peer_cid = None
|
||||
mock_quic._peer_cid_available = []
|
||||
mock_quic._retire_connection_ids = []
|
||||
mock_quic._configuration = Mock()
|
||||
mock_quic._configuration.connection_id_length = 8
|
||||
mock_quic._remote_active_connection_id_limit = 8
|
||||
return mock_quic
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(self, mock_quic_connection):
|
||||
"""Create a QUICConnection instance for testing."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
|
||||
def test_connection_id_initialization(self, quic_connection):
|
||||
"""Test that connection ID tracking is properly initialized."""
|
||||
# Check that connection ID tracking structures are initialized
|
||||
assert hasattr(quic_connection, "_available_connection_ids")
|
||||
assert hasattr(quic_connection, "_current_connection_id")
|
||||
assert hasattr(quic_connection, "_retired_connection_ids")
|
||||
assert hasattr(quic_connection, "_connection_id_sequence_numbers")
|
||||
|
||||
# Initial state should be empty
|
||||
assert len(quic_connection._available_connection_ids) == 0
|
||||
assert quic_connection._current_connection_id is None
|
||||
assert len(quic_connection._retired_connection_ids) == 0
|
||||
assert len(quic_connection._connection_id_sequence_numbers) == 0
|
||||
|
||||
def test_connection_id_stats_tracking(self, quic_connection):
|
||||
"""Test connection ID statistics are properly tracked."""
|
||||
stats = quic_connection.get_connection_id_stats()
|
||||
|
||||
# Check that all expected stats are present
|
||||
expected_keys = [
|
||||
"available_connection_ids",
|
||||
"current_connection_id",
|
||||
"retired_connection_ids",
|
||||
"connection_ids_issued",
|
||||
"connection_ids_retired",
|
||||
"connection_id_changes",
|
||||
"available_cid_list",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in stats
|
||||
|
||||
# Initial values should be zero/empty
|
||||
assert stats["available_connection_ids"] == 0
|
||||
assert stats["current_connection_id"] is None
|
||||
assert stats["retired_connection_ids"] == 0
|
||||
assert stats["connection_ids_issued"] == 0
|
||||
assert stats["connection_ids_retired"] == 0
|
||||
assert stats["connection_id_changes"] == 0
|
||||
assert stats["available_cid_list"] == []
|
||||
|
||||
def test_current_connection_id_getter(self, quic_connection):
|
||||
"""Test getting current connection ID."""
|
||||
# Initially no connection ID
|
||||
assert quic_connection.get_current_connection_id() is None
|
||||
|
||||
# Set a connection ID
|
||||
test_cid = ConnectionIdTestHelper.generate_connection_id()
|
||||
quic_connection._current_connection_id = test_cid
|
||||
|
||||
assert quic_connection.get_current_connection_id() == test_cid
|
||||
|
||||
def test_connection_id_generation(self):
|
||||
"""Test connection ID generation utilities."""
|
||||
# Test default length
|
||||
cid1 = ConnectionIdTestHelper.generate_connection_id()
|
||||
assert len(cid1) == 8
|
||||
assert isinstance(cid1, bytes)
|
||||
|
||||
# Test custom length
|
||||
cid2 = ConnectionIdTestHelper.generate_connection_id(16)
|
||||
assert len(cid2) == 16
|
||||
|
||||
# Test uniqueness
|
||||
cid3 = ConnectionIdTestHelper.generate_connection_id()
|
||||
assert cid1 != cid3
|
||||
|
||||
|
||||
class TestConnectionIdRotationAndUpdates:
|
||||
"""Test connection ID rotation and update mechanisms."""
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
"""Create transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def server_key(self):
|
||||
"""Generate server private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def client_key(self):
|
||||
"""Generate client private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
def test_connection_id_replenishment(self):
|
||||
"""Test connection ID replenishment mechanism."""
|
||||
# Create a real QuicConnection to test replenishment
|
||||
config = QuicConfiguration(is_client=True)
|
||||
config.connection_id_length = 8
|
||||
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Initial state - should have some host connection IDs
|
||||
initial_count = len(quic_conn._host_cids)
|
||||
assert initial_count > 0
|
||||
|
||||
# Remove some connection IDs to trigger replenishment
|
||||
while len(quic_conn._host_cids) > 2:
|
||||
quic_conn._host_cids.pop()
|
||||
|
||||
# Trigger replenishment
|
||||
quic_conn._replenish_connection_ids()
|
||||
|
||||
# Should have replenished up to the limit
|
||||
assert len(quic_conn._host_cids) >= initial_count
|
||||
|
||||
# All connection IDs should have unique sequence numbers
|
||||
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
assert len(sequences) == len(set(sequences))
|
||||
|
||||
def test_connection_id_sequence_numbers(self):
|
||||
"""Test connection ID sequence number management."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Get initial sequence number
|
||||
initial_seq = quic_conn._host_cid_seq
|
||||
|
||||
# Trigger replenishment to generate new connection IDs
|
||||
quic_conn._replenish_connection_ids()
|
||||
|
||||
# Sequence numbers should increment
|
||||
assert quic_conn._host_cid_seq > initial_seq
|
||||
|
||||
# All host connection IDs should have sequential numbers
|
||||
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
sequences.sort()
|
||||
|
||||
# Check for proper sequence
|
||||
for i in range(len(sequences) - 1):
|
||||
assert sequences[i + 1] > sequences[i]
|
||||
|
||||
def test_connection_id_limits(self):
|
||||
"""Test connection ID limit enforcement."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
config.connection_id_length = 8
|
||||
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Set a reasonable limit
|
||||
quic_conn._remote_active_connection_id_limit = 4
|
||||
|
||||
# Replenish connection IDs
|
||||
quic_conn._replenish_connection_ids()
|
||||
|
||||
# Should not exceed the limit
|
||||
assert len(quic_conn._host_cids) <= quic_conn._remote_active_connection_id_limit
|
||||
|
||||
|
||||
class TestConnectionIdRetirement:
|
||||
"""Test connection ID retirement functionality."""
|
||||
|
||||
def test_connection_id_retirement_basic(self):
|
||||
"""Test basic connection ID retirement."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Create a test connection ID to retire
|
||||
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=1
|
||||
)
|
||||
|
||||
# Add it to peer connection IDs
|
||||
quic_conn._peer_cid_available.append(test_cid)
|
||||
quic_conn._peer_cid_sequence_numbers.add(1)
|
||||
|
||||
# Retire the connection ID
|
||||
quic_conn._retire_peer_cid(test_cid)
|
||||
|
||||
# Should be added to retirement list
|
||||
assert 1 in quic_conn._retire_connection_ids
|
||||
|
||||
def test_connection_id_retirement_limits(self):
|
||||
"""Test connection ID retirement limits."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Fill up retirement list near the limit
|
||||
max_retirements = 32 # Based on aioquic's default limit
|
||||
|
||||
for i in range(max_retirements):
|
||||
quic_conn._retire_connection_ids.append(i)
|
||||
|
||||
# Should be at limit
|
||||
assert len(quic_conn._retire_connection_ids) == max_retirements
|
||||
|
||||
def test_connection_id_retirement_events(self):
|
||||
"""Test that retirement generates proper events."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Create and add a host connection ID
|
||||
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=5
|
||||
)
|
||||
quic_conn._host_cids.append(test_cid)
|
||||
|
||||
# Create a retirement frame buffer
|
||||
from aioquic.buffer import Buffer
|
||||
|
||||
buf = Buffer(capacity=16)
|
||||
buf.push_uint_var(5) # sequence number to retire
|
||||
buf.seek(0)
|
||||
|
||||
# Process retirement (this should generate an event)
|
||||
try:
|
||||
quic_conn._handle_retire_connection_id_frame(
|
||||
Mock(), # context
|
||||
0x19, # RETIRE_CONNECTION_ID frame type
|
||||
buf,
|
||||
)
|
||||
|
||||
# Check that connection ID was removed
|
||||
remaining_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
assert 5 not in remaining_sequences
|
||||
|
||||
except Exception:
|
||||
# May fail due to missing context, but that's okay for this test
|
||||
pass
|
||||
|
||||
|
||||
class TestConnectionIdErrorConditions:
|
||||
"""Test error conditions and edge cases in connection ID handling."""
|
||||
|
||||
def test_invalid_connection_id_length(self):
|
||||
"""Test handling of invalid connection ID lengths."""
|
||||
# Connection IDs must be 1-20 bytes according to RFC 9000
|
||||
|
||||
# Test too short (0 bytes) - this should be handled gracefully
|
||||
empty_cid = b""
|
||||
assert len(empty_cid) == 0
|
||||
|
||||
# Test too long (>20 bytes)
|
||||
long_cid = secrets.token_bytes(21)
|
||||
assert len(long_cid) == 21
|
||||
|
||||
# Test valid lengths
|
||||
for length in range(1, 21):
|
||||
valid_cid = secrets.token_bytes(length)
|
||||
assert len(valid_cid) == length
|
||||
|
||||
def test_duplicate_sequence_numbers(self):
|
||||
"""Test handling of duplicate sequence numbers."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Create two connection IDs with same sequence number
|
||||
cid1 = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=10
|
||||
)
|
||||
cid2 = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=10
|
||||
)
|
||||
|
||||
# Add first connection ID
|
||||
quic_conn._peer_cid_available.append(cid1)
|
||||
quic_conn._peer_cid_sequence_numbers.add(10)
|
||||
|
||||
# Adding second with same sequence should be handled appropriately
|
||||
# (The implementation should prevent duplicates)
|
||||
if 10 not in quic_conn._peer_cid_sequence_numbers:
|
||||
quic_conn._peer_cid_available.append(cid2)
|
||||
quic_conn._peer_cid_sequence_numbers.add(10)
|
||||
|
||||
# Should only have one entry for sequence 10
|
||||
sequences = [cid.sequence_number for cid in quic_conn._peer_cid_available]
|
||||
assert sequences.count(10) <= 1
|
||||
|
||||
def test_retire_unknown_connection_id(self):
|
||||
"""Test retiring an unknown connection ID."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Try to create a buffer to retire unknown sequence number
|
||||
buf = Buffer(capacity=16)
|
||||
buf.push_uint_var(999) # Unknown sequence number
|
||||
buf.seek(0)
|
||||
|
||||
# This should raise an error when processed
|
||||
# (Testing the error condition, not the full processing)
|
||||
unknown_sequence = 999
|
||||
known_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
|
||||
assert unknown_sequence not in known_sequences
|
||||
|
||||
def test_retire_current_connection_id(self):
|
||||
"""Test that retiring current connection ID is prevented."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Get current connection ID if available
|
||||
if quic_conn._host_cids:
|
||||
current_cid = quic_conn._host_cids[0]
|
||||
current_sequence = current_cid.sequence_number
|
||||
|
||||
# Trying to retire current connection ID should be prevented
|
||||
# This is tested by checking the sequence number logic
|
||||
assert current_sequence >= 0
|
||||
|
||||
|
||||
class TestConnectionIdIntegration:
|
||||
"""Integration tests for connection ID functionality with real connections."""
|
||||
|
||||
@pytest.fixture
|
||||
def server_config(self):
|
||||
"""Server transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def client_config(self):
|
||||
"""Client transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def server_key(self):
|
||||
"""Generate server private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def client_key(self):
|
||||
"""Generate client private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_id_exchange_during_handshake(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test connection ID exchange during connection handshake."""
|
||||
# This test would require a full connection setup
|
||||
# For now, we test the setup components
|
||||
|
||||
server_transport = QUICTransport(server_key, server_config)
|
||||
client_transport = QUICTransport(client_key, client_config)
|
||||
|
||||
# Verify transports are created with proper configuration
|
||||
assert server_transport._config == server_config
|
||||
assert client_transport._config == client_config
|
||||
|
||||
# Test that connection ID tracking is available
|
||||
# (Integration with actual networking would require more setup)
|
||||
|
||||
def test_connection_id_extraction_utilities(self):
|
||||
"""Test connection ID extraction utilities."""
|
||||
# Create a mock connection with some connection IDs
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
mock_quic = Mock()
|
||||
mock_quic._host_cids = [
|
||||
ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), i
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_quic._peer_cid = None
|
||||
mock_quic._peer_cid_available = []
|
||||
mock_quic._retire_connection_ids = []
|
||||
mock_quic._host_cid_seq = 3
|
||||
|
||||
quic_conn = QUICConnection(
|
||||
quic_connection=mock_quic,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
|
||||
# Extract connection ID information
|
||||
cid_info = ConnectionIdTestHelper.extract_connection_ids_from_connection(
|
||||
quic_conn
|
||||
)
|
||||
|
||||
# Verify extraction works
|
||||
assert "host_cids" in cid_info
|
||||
assert "peer_cid" in cid_info
|
||||
assert "peer_cid_available" in cid_info
|
||||
assert "retire_connection_ids" in cid_info
|
||||
assert "host_cid_seq" in cid_info
|
||||
|
||||
# Check values
|
||||
assert len(cid_info["host_cids"]) == 3
|
||||
assert cid_info["host_cid_seq"] == 3
|
||||
assert cid_info["peer_cid"] is None
|
||||
assert len(cid_info["peer_cid_available"]) == 0
|
||||
assert len(cid_info["retire_connection_ids"]) == 0
|
||||
|
||||
|
||||
class TestConnectionIdStatistics:
|
||||
"""Test connection ID statistics and monitoring."""
|
||||
|
||||
@pytest.fixture
|
||||
def connection_with_stats(self):
|
||||
"""Create a connection with connection ID statistics."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
mock_quic = Mock()
|
||||
mock_quic._host_cids = []
|
||||
mock_quic._peer_cid = None
|
||||
mock_quic._peer_cid_available = []
|
||||
mock_quic._retire_connection_ids = []
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
|
||||
def test_connection_id_stats_initialization(self, connection_with_stats):
|
||||
"""Test that connection ID statistics are properly initialized."""
|
||||
stats = connection_with_stats._stats
|
||||
|
||||
# Check that connection ID stats are present
|
||||
assert "connection_ids_issued" in stats
|
||||
assert "connection_ids_retired" in stats
|
||||
assert "connection_id_changes" in stats
|
||||
|
||||
# Initial values should be zero
|
||||
assert stats["connection_ids_issued"] == 0
|
||||
assert stats["connection_ids_retired"] == 0
|
||||
assert stats["connection_id_changes"] == 0
|
||||
|
||||
def test_connection_id_stats_update(self, connection_with_stats):
|
||||
"""Test updating connection ID statistics."""
|
||||
conn = connection_with_stats
|
||||
|
||||
# Add some connection IDs to tracking
|
||||
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(3)]
|
||||
|
||||
for cid in test_cids:
|
||||
conn._available_connection_ids.add(cid)
|
||||
|
||||
# Update stats (this would normally be done by the implementation)
|
||||
conn._stats["connection_ids_issued"] = len(test_cids)
|
||||
|
||||
# Verify stats
|
||||
stats = conn.get_connection_id_stats()
|
||||
assert stats["connection_ids_issued"] == 3
|
||||
assert stats["available_connection_ids"] == 3
|
||||
|
||||
def test_connection_id_list_representation(self, connection_with_stats):
|
||||
"""Test connection ID list representation in stats."""
|
||||
conn = connection_with_stats
|
||||
|
||||
# Add some connection IDs
|
||||
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(2)]
|
||||
|
||||
for cid in test_cids:
|
||||
conn._available_connection_ids.add(cid)
|
||||
|
||||
# Get stats
|
||||
stats = conn.get_connection_id_stats()
|
||||
|
||||
# Check that CID list is properly formatted
|
||||
assert "available_cid_list" in stats
|
||||
assert len(stats["available_cid_list"]) == 2
|
||||
|
||||
# All entries should be hex strings
|
||||
for cid_hex in stats["available_cid_list"]:
|
||||
assert isinstance(cid_hex, str)
|
||||
assert len(cid_hex) == 16 # 8 bytes = 16 hex chars
|
||||
|
||||
|
||||
# Performance and stress tests
|
||||
class TestConnectionIdPerformance:
|
||||
"""Test connection ID performance and stress scenarios."""
|
||||
|
||||
def test_connection_id_generation_performance(self):
|
||||
"""Test connection ID generation performance."""
|
||||
start_time = time.time()
|
||||
|
||||
# Generate many connection IDs
|
||||
cids = []
|
||||
for _ in range(1000):
|
||||
cid = ConnectionIdTestHelper.generate_connection_id()
|
||||
cids.append(cid)
|
||||
|
||||
end_time = time.time()
|
||||
generation_time = end_time - start_time
|
||||
|
||||
# Should be reasonably fast (less than 1 second for 1000 IDs)
|
||||
assert generation_time < 1.0
|
||||
|
||||
# All should be unique
|
||||
assert len(set(cids)) == len(cids)
|
||||
|
||||
def test_connection_id_tracking_memory(self):
|
||||
"""Test memory usage of connection ID tracking."""
|
||||
conn_ids = set()
|
||||
|
||||
# Add many connection IDs
|
||||
for _ in range(1000):
|
||||
cid = ConnectionIdTestHelper.generate_connection_id()
|
||||
conn_ids.add(cid)
|
||||
|
||||
# Verify they're all stored
|
||||
assert len(conn_ids) == 1000
|
||||
|
||||
# Clean up
|
||||
conn_ids.clear()
|
||||
assert len(conn_ids) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests if executed directly
|
||||
pytest.main([__file__, "-v"])
|
||||
418
tests/core/transport/quic/test_integration.py
Normal file
418
tests/core/transport/quic/test_integration.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""
|
||||
Basic QUIC Echo Test
|
||||
|
||||
Simple test to verify the basic QUIC flow:
|
||||
1. Client connects to server
|
||||
2. Client sends data
|
||||
3. Server receives data and echoes back
|
||||
4. Client receives the echo
|
||||
|
||||
This test focuses on identifying where the accept_stream issue occurs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID
|
||||
from libp2p import new_host
|
||||
from libp2p.abc import INetStream
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.quic.utils import create_quic_multiaddr
|
||||
|
||||
# Set up logging to see what's happening
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestBasicQUICFlow:
|
||||
"""Test basic QUIC client-server communication flow."""
|
||||
|
||||
@pytest.fixture
|
||||
def server_key(self):
|
||||
"""Generate server key pair."""
|
||||
return create_new_key_pair()
|
||||
|
||||
@pytest.fixture
|
||||
def client_key(self):
|
||||
"""Generate client key pair."""
|
||||
return create_new_key_pair()
|
||||
|
||||
@pytest.fixture
|
||||
def server_config(self):
|
||||
"""Simple server configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=10,
|
||||
max_connections=5,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def client_config(self):
|
||||
"""Simple client configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=5,
|
||||
)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_basic_echo_flow(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test basic client-server echo flow with detailed logging."""
|
||||
print("\n=== BASIC QUIC ECHO TEST ===")
|
||||
|
||||
# Create server components
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
|
||||
# Track test state
|
||||
server_received_data = None
|
||||
server_connection_established = False
|
||||
echo_sent = False
|
||||
|
||||
async def echo_server_handler(connection: QUICConnection) -> None:
|
||||
"""Simple echo server handler with detailed logging."""
|
||||
nonlocal server_received_data, server_connection_established, echo_sent
|
||||
|
||||
print("🔗 SERVER: Connection handler called")
|
||||
server_connection_established = True
|
||||
|
||||
try:
|
||||
print("📡 SERVER: Waiting for incoming stream...")
|
||||
|
||||
# Accept stream with timeout and detailed logging
|
||||
print("📡 SERVER: Calling accept_stream...")
|
||||
stream = await connection.accept_stream(timeout=5.0)
|
||||
|
||||
if stream is None:
|
||||
print("❌ SERVER: accept_stream returned None")
|
||||
return
|
||||
|
||||
print(f"✅ SERVER: Stream accepted! Stream ID: {stream.stream_id}")
|
||||
|
||||
# Read data from the stream
|
||||
print("📖 SERVER: Reading data from stream...")
|
||||
server_data = await stream.read(1024)
|
||||
|
||||
if not server_data:
|
||||
print("❌ SERVER: No data received from stream")
|
||||
return
|
||||
|
||||
server_received_data = server_data.decode("utf-8", errors="ignore")
|
||||
print(f"📨 SERVER: Received data: '{server_received_data}'")
|
||||
|
||||
# Echo the data back
|
||||
echo_message = f"ECHO: {server_received_data}"
|
||||
print(f"📤 SERVER: Sending echo: '{echo_message}'")
|
||||
|
||||
await stream.write(echo_message.encode())
|
||||
echo_sent = True
|
||||
print("✅ SERVER: Echo sent successfully")
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
print("🔒 SERVER: Stream closed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ SERVER: Error in handler: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Create listener
|
||||
listener = server_transport.create_listener(echo_server_handler)
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
# Variables to track client state
|
||||
client_connected = False
|
||||
client_sent_data = False
|
||||
client_received_echo = None
|
||||
|
||||
try:
|
||||
print("🚀 Starting server...")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start server listener
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success, "Failed to start server listener"
|
||||
|
||||
# Get server address
|
||||
server_addrs = listener.get_addrs()
|
||||
server_addr = multiaddr.Multiaddr(
|
||||
f"{server_addrs[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
|
||||
)
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Give server a moment to be ready
|
||||
await trio.sleep(0.1)
|
||||
|
||||
print("🚀 Starting client...")
|
||||
|
||||
# Create client transport
|
||||
client_transport = QUICTransport(client_key.private_key, client_config)
|
||||
client_transport.set_background_nursery(nursery)
|
||||
|
||||
try:
|
||||
# Connect to server
|
||||
print(f"📞 CLIENT: Connecting to {server_addr}")
|
||||
connection = await client_transport.dial(server_addr)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected to server")
|
||||
|
||||
# Open a stream
|
||||
print("📤 CLIENT: Opening stream...")
|
||||
stream = await connection.open_stream()
|
||||
print(f"✅ CLIENT: Stream opened with ID: {stream.stream_id}")
|
||||
|
||||
# Send test data
|
||||
test_message = "Hello QUIC Server!"
|
||||
print(f"📨 CLIENT: Sending message: '{test_message}'")
|
||||
await stream.write(test_message.encode())
|
||||
client_sent_data = True
|
||||
print("✅ CLIENT: Message sent")
|
||||
|
||||
# Read echo response
|
||||
print("📖 CLIENT: Waiting for echo response...")
|
||||
response_data = await stream.read(1024)
|
||||
|
||||
if response_data:
|
||||
client_received_echo = response_data.decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
print(f"📬 CLIENT: Received echo: '{client_received_echo}'")
|
||||
else:
|
||||
print("❌ CLIENT: No echo response received")
|
||||
|
||||
print("🔒 CLIENT: Closing connection")
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
|
||||
print("🔒 CLIENT: Closing transport")
|
||||
await client_transport.close()
|
||||
print("🔒 CLIENT: Transport closed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ CLIENT: Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
print("🔒 CLIENT: Transport closed")
|
||||
|
||||
# Give everything time to complete
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# Cancel nursery to stop server
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if not listener._closed:
|
||||
await listener.close()
|
||||
await server_transport.close()
|
||||
|
||||
# Verify the flow worked
|
||||
print("\n📊 TEST RESULTS:")
|
||||
print(f" Server connection established: {server_connection_established}")
|
||||
print(f" Client connected: {client_connected}")
|
||||
print(f" Client sent data: {client_sent_data}")
|
||||
print(f" Server received data: '{server_received_data}'")
|
||||
print(f" Echo sent by server: {echo_sent}")
|
||||
print(f" Client received echo: '{client_received_echo}'")
|
||||
|
||||
# Test assertions
|
||||
assert server_connection_established, "Server connection handler was not called"
|
||||
assert client_connected, "Client failed to connect"
|
||||
assert client_sent_data, "Client failed to send data"
|
||||
assert server_received_data == "Hello QUIC Server!", (
|
||||
f"Server received wrong data: '{server_received_data}'"
|
||||
)
|
||||
assert echo_sent, "Server failed to send echo"
|
||||
assert client_received_echo == "ECHO: Hello QUIC Server!", (
|
||||
f"Client received wrong echo: '{client_received_echo}'"
|
||||
)
|
||||
|
||||
print("✅ BASIC ECHO TEST PASSED!")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_server_accept_stream_timeout(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test what happens when server accept_stream times out."""
|
||||
print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===")
|
||||
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
|
||||
accept_stream_called = False
|
||||
accept_stream_timeout = False
|
||||
|
||||
async def timeout_test_handler(connection: QUICConnection) -> None:
|
||||
"""Handler that tests accept_stream timeout."""
|
||||
nonlocal accept_stream_called, accept_stream_timeout
|
||||
|
||||
print("🔗 SERVER: Connection established, testing accept_stream timeout")
|
||||
accept_stream_called = True
|
||||
|
||||
try:
|
||||
print("📡 SERVER: Calling accept_stream with 2 second timeout...")
|
||||
stream = await connection.accept_stream(timeout=2.0)
|
||||
print(f"✅ SERVER: accept_stream returned: {stream}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⏰ SERVER: accept_stream timed out or failed: {e}")
|
||||
accept_stream_timeout = True
|
||||
|
||||
listener = server_transport.create_listener(timeout_test_handler)
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
client_connected = False
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start server
|
||||
server_transport.set_background_nursery(nursery)
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
|
||||
server_addr = multiaddr.Multiaddr(
|
||||
f"{listener.get_addrs()[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
|
||||
)
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Create client but DON'T open a stream
|
||||
async with trio.open_nursery() as client_nursery:
|
||||
client_transport = QUICTransport(
|
||||
client_key.private_key, client_config
|
||||
)
|
||||
client_transport.set_background_nursery(client_nursery)
|
||||
|
||||
try:
|
||||
print("📞 CLIENT: Connecting (but NOT opening stream)...")
|
||||
connection = await client_transport.dial(server_addr)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected (no stream opened)")
|
||||
|
||||
# Wait for server timeout
|
||||
await trio.sleep(3.0)
|
||||
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
finally:
|
||||
await listener.close()
|
||||
await server_transport.close()
|
||||
|
||||
print("\n📊 TIMEOUT TEST RESULTS:")
|
||||
print(f" Client connected: {client_connected}")
|
||||
print(f" accept_stream called: {accept_stream_called}")
|
||||
print(f" accept_stream timeout: {accept_stream_timeout}")
|
||||
|
||||
assert client_connected, "Client should have connected"
|
||||
assert accept_stream_called, "accept_stream should have been called"
|
||||
assert accept_stream_timeout, (
|
||||
"accept_stream should have timed out when no stream was opened"
|
||||
)
|
||||
|
||||
print("✅ TIMEOUT TEST PASSED!")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_stress_ping():
|
||||
STREAM_COUNT = 100
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
latencies = []
|
||||
failures = []
|
||||
|
||||
# === Server Setup ===
|
||||
server_host = new_host(listen_addrs=[listen_addr])
|
||||
|
||||
async def handle_ping(stream: INetStream) -> None:
|
||||
try:
|
||||
while True:
|
||||
payload = await stream.read(PING_LENGTH)
|
||||
if not payload:
|
||||
break
|
||||
await stream.write(payload)
|
||||
except Exception:
|
||||
await stream.reset()
|
||||
|
||||
server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||
|
||||
async with server_host.run(listen_addrs=[listen_addr]):
|
||||
# Give server time to start
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# === Client Setup ===
|
||||
destination = str(server_host.get_addrs()[0])
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
|
||||
client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
client_host = new_host(listen_addrs=[client_listen_addr])
|
||||
|
||||
async with client_host.run(listen_addrs=[client_listen_addr]):
|
||||
await client_host.connect(info)
|
||||
|
||||
async def ping_stream(i: int):
|
||||
stream = None
|
||||
try:
|
||||
start = trio.current_time()
|
||||
stream = await client_host.new_stream(
|
||||
info.peer_id, [PING_PROTOCOL_ID]
|
||||
)
|
||||
|
||||
await stream.write(b"\x01" * PING_LENGTH)
|
||||
|
||||
with trio.fail_after(5):
|
||||
response = await stream.read(PING_LENGTH)
|
||||
|
||||
if response == b"\x01" * PING_LENGTH:
|
||||
latency_ms = int((trio.current_time() - start) * 1000)
|
||||
latencies.append(latency_ms)
|
||||
print(f"[Ping #{i}] Latency: {latency_ms} ms")
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
print(f"[Ping #{i}] Failed: {e}")
|
||||
failures.append(i)
|
||||
if stream:
|
||||
await stream.reset()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(STREAM_COUNT):
|
||||
nursery.start_soon(ping_stream, i)
|
||||
|
||||
# === Result Summary ===
|
||||
print("\n📊 Ping Stress Test Summary")
|
||||
print(f"Total Streams Launched: {STREAM_COUNT}")
|
||||
print(f"Successful Pings: {len(latencies)}")
|
||||
print(f"Failed Pings: {len(failures)}")
|
||||
if failures:
|
||||
print(f"❌ Failed stream indices: {failures}")
|
||||
|
||||
# === Assertions ===
|
||||
assert len(latencies) == STREAM_COUNT, (
|
||||
f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}"
|
||||
)
|
||||
assert all(isinstance(x, int) and x >= 0 for x in latencies), (
|
||||
"Invalid latencies"
|
||||
)
|
||||
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
print(f"✅ Average Latency: {avg_latency:.2f} ms")
|
||||
assert avg_latency < 1000
|
||||
150
tests/core/transport/quic/test_listener.py
Normal file
150
tests/core/transport/quic/test_listener.py
Normal file
@ -0,0 +1,150 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICListenError,
|
||||
)
|
||||
from libp2p.transport.quic.listener import QUICListener
|
||||
from libp2p.transport.quic.transport import (
|
||||
QUICTransport,
|
||||
QUICTransportConfig,
|
||||
)
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_quic_multiaddr,
|
||||
)
|
||||
|
||||
|
||||
class TestQUICListener:
|
||||
"""Test suite for QUIC listener functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def private_key(self):
|
||||
"""Generate test private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
"""Generate test transport configuration."""
|
||||
return QUICTransportConfig(idle_timeout=10.0)
|
||||
|
||||
@pytest.fixture
|
||||
def transport(self, private_key, transport_config):
|
||||
"""Create test transport instance."""
|
||||
return QUICTransport(private_key, transport_config)
|
||||
|
||||
@pytest.fixture
|
||||
def connection_handler(self):
|
||||
"""Mock connection handler."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def listener(self, transport, connection_handler):
|
||||
"""Create test listener."""
|
||||
return transport.create_listener(connection_handler)
|
||||
|
||||
def test_listener_creation(self, transport, connection_handler):
|
||||
"""Test listener creation."""
|
||||
listener = transport.create_listener(connection_handler)
|
||||
|
||||
assert isinstance(listener, QUICListener)
|
||||
assert listener._transport == transport
|
||||
assert listener._handler == connection_handler
|
||||
assert not listener._listening
|
||||
assert not listener._closed
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_invalid_multiaddr(self, listener: QUICListener):
|
||||
"""Test listener with invalid multiaddr."""
|
||||
async with trio.open_nursery() as nursery:
|
||||
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
|
||||
with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"):
|
||||
await listener.listen(invalid_addr, nursery)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_basic_lifecycle(self, listener: QUICListener):
|
||||
"""Test basic listener lifecycle."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start listening
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
assert listener.is_listening()
|
||||
|
||||
# Check bound addresses
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) == 1
|
||||
|
||||
# Check stats
|
||||
stats = listener.get_stats()
|
||||
assert stats["is_listening"] is True
|
||||
assert stats["active_connections"] == 0
|
||||
assert stats["pending_connections"] == 0
|
||||
|
||||
# Sender Cancel Signal
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
await listener.close()
|
||||
assert not listener.is_listening()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_double_listen(self, listener: QUICListener):
|
||||
"""Test that double listen raises error."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic")
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
await trio.sleep(0.01)
|
||||
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) > 0
|
||||
async with trio.open_nursery() as nursery2:
|
||||
with pytest.raises(QUICListenError, match="Already listening"):
|
||||
await listener.listen(listen_addr, nursery2)
|
||||
nursery2.cancel_scope.cancel()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
finally:
|
||||
await listener.close()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_port_binding(self, listener: QUICListener):
|
||||
"""Test listener port binding and cleanup."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
await trio.sleep(0.5)
|
||||
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) > 0
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
finally:
|
||||
await listener.close()
|
||||
|
||||
# By the time we get here, the listener and its tasks have been fully
|
||||
# shut down, allowing the nursery to exit without hanging.
|
||||
print("TEST COMPLETED SUCCESSFULLY.")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_stats_tracking(self, listener):
|
||||
"""Test listener statistics tracking."""
|
||||
initial_stats = listener.get_stats()
|
||||
|
||||
# All counters should start at 0
|
||||
assert initial_stats["connections_accepted"] == 0
|
||||
assert initial_stats["connections_rejected"] == 0
|
||||
assert initial_stats["bytes_received"] == 0
|
||||
assert initial_stats["packets_processed"] == 0
|
||||
123
tests/core/transport/quic/test_transport.py
Normal file
123
tests/core/transport/quic/test_transport.py
Normal file
@ -0,0 +1,123 @@
|
||||
from unittest.mock import (
|
||||
Mock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
)
|
||||
from libp2p.transport.quic.transport import (
|
||||
QUICTransport,
|
||||
QUICTransportConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestQUICTransport:
|
||||
"""Test suite for QUIC transport using trio."""
|
||||
|
||||
@pytest.fixture
|
||||
def private_key(self):
|
||||
"""Generate test private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
"""Generate test transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0, enable_draft29=True, enable_v1=True
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig):
|
||||
"""Create test transport instance."""
|
||||
return QUICTransport(private_key, transport_config)
|
||||
|
||||
def test_transport_initialization(self, transport):
|
||||
"""Test transport initialization."""
|
||||
assert transport._private_key is not None
|
||||
assert transport._peer_id is not None
|
||||
assert not transport._closed
|
||||
assert len(transport._quic_configs) >= 1
|
||||
|
||||
def test_supported_protocols(self, transport):
|
||||
"""Test supported protocol identifiers."""
|
||||
protocols = transport.protocols()
|
||||
# TODO: Update when quic-v1 compatible
|
||||
# assert "quic-v1" in protocols
|
||||
assert "quic" in protocols # draft-29
|
||||
|
||||
def test_can_dial_quic_addresses(self, transport: QUICTransport):
|
||||
"""Test multiaddr compatibility checking."""
|
||||
import multiaddr
|
||||
|
||||
# Valid QUIC addresses
|
||||
valid_addrs = [
|
||||
# TODO: Update Multiaddr package to accept quic-v1
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
]
|
||||
|
||||
for addr in valid_addrs:
|
||||
assert transport.can_dial(addr)
|
||||
|
||||
# Invalid addresses
|
||||
invalid_addrs = [
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/4001"),
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001"),
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/ws"),
|
||||
]
|
||||
|
||||
for addr in invalid_addrs:
|
||||
assert not transport.can_dial(addr)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_transport_lifecycle(self, transport):
|
||||
"""Test transport lifecycle management using trio."""
|
||||
assert not transport._closed
|
||||
|
||||
await transport.close()
|
||||
assert transport._closed
|
||||
|
||||
# Should be safe to close multiple times
|
||||
await transport.close()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dial_closed_transport(self, transport: QUICTransport) -> None:
|
||||
"""Test dialing with closed transport raises error."""
|
||||
import multiaddr
|
||||
|
||||
await transport.close()
|
||||
|
||||
with pytest.raises(QUICDialError, match="Transport is closed"):
|
||||
await transport.dial(
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
)
|
||||
|
||||
def test_create_listener_closed_transport(self, transport: QUICTransport) -> None:
|
||||
"""Test creating listener with closed transport raises error."""
|
||||
transport._closed = True
|
||||
|
||||
with pytest.raises(QUICListenError, match="Transport is closed"):
|
||||
transport.create_listener(Mock())
|
||||
321
tests/core/transport/quic/test_utils.py
Normal file
321
tests/core/transport/quic/test_utils.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""
|
||||
Test suite for QUIC multiaddr utilities.
|
||||
Focused tests covering essential functionality required for QUIC transport.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICInvalidMultiaddrError,
|
||||
QUICUnsupportedVersionError,
|
||||
)
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_quic_multiaddr,
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
normalize_quic_multiaddr,
|
||||
quic_multiaddr_to_endpoint,
|
||||
quic_version_to_wire_format,
|
||||
)
|
||||
|
||||
|
||||
class TestIsQuicMultiaddr:
|
||||
"""Test QUIC multiaddr detection."""
|
||||
|
||||
def test_valid_quic_v1_multiaddrs(self):
|
||||
"""Test valid QUIC v1 multiaddrs are detected."""
|
||||
valid_addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip4/192.168.1.1/udp/8080/quic-v1",
|
||||
"/ip6/::1/udp/4001/quic-v1",
|
||||
"/ip6/2001:db8::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in valid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
def test_valid_quic_draft29_multiaddrs(self):
|
||||
"""Test valid QUIC draft-29 multiaddrs are detected."""
|
||||
valid_addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip4/10.0.0.1/udp/9000/quic",
|
||||
"/ip6/::1/udp/4001/quic",
|
||||
"/ip6/fe80::1/udp/6000/quic",
|
||||
]
|
||||
|
||||
for addr_str in valid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
def test_invalid_multiaddrs(self):
|
||||
"""Test non-QUIC multiaddrs are not detected."""
|
||||
invalid_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
|
||||
"/ip4/127.0.0.1/udp/4001", # UDP without QUIC
|
||||
"/ip4/127.0.0.1/udp/4001/ws", # WebSocket
|
||||
"/ip4/127.0.0.1/quic-v1", # Missing UDP
|
||||
"/udp/4001/quic-v1", # Missing IP
|
||||
"/dns4/example.com/tcp/443/tls", # Completely different
|
||||
]
|
||||
|
||||
for addr_str in invalid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
|
||||
|
||||
|
||||
class TestQuicMultiaddrToEndpoint:
|
||||
"""Test endpoint extraction from QUIC multiaddrs."""
|
||||
|
||||
def test_ipv4_extraction(self):
|
||||
"""Test IPv4 host/port extraction."""
|
||||
test_cases = [
|
||||
("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
|
||||
("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
|
||||
("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
|
||||
]
|
||||
|
||||
for addr_str, expected in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = quic_multiaddr_to_endpoint(maddr)
|
||||
assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
def test_ipv6_extraction(self):
|
||||
"""Test IPv6 host/port extraction."""
|
||||
test_cases = [
|
||||
("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
|
||||
("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
|
||||
]
|
||||
|
||||
for addr_str, expected in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = quic_multiaddr_to_endpoint(maddr)
|
||||
assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
def test_invalid_multiaddr_raises_error(self):
|
||||
"""Test invalid multiaddrs raise appropriate errors."""
|
||||
invalid_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/4001", # Not QUIC
|
||||
"/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
|
||||
]
|
||||
|
||||
for addr_str in invalid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
quic_multiaddr_to_endpoint(maddr)
|
||||
|
||||
|
||||
class TestMultiaddrToQuicVersion:
|
||||
"""Test QUIC version extraction."""
|
||||
|
||||
def test_quic_v1_detection(self):
|
||||
"""Test QUIC v1 version detection."""
|
||||
addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip6/::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
|
||||
|
||||
def test_quic_draft29_detection(self):
|
||||
"""Test QUIC draft-29 version detection."""
|
||||
addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip6/::1/udp/5000/quic",
|
||||
]
|
||||
|
||||
for addr_str in addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
assert version == "quic", f"Should detect quic for {addr_str}"
|
||||
|
||||
def test_non_quic_raises_error(self):
|
||||
"""Test non-QUIC multiaddrs raise error."""
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
multiaddr_to_quic_version(maddr)
|
||||
|
||||
|
||||
class TestCreateQuicMultiaddr:
|
||||
"""Test QUIC multiaddr creation."""
|
||||
|
||||
def test_ipv4_creation(self):
|
||||
"""Test IPv4 QUIC multiaddr creation."""
|
||||
test_cases = [
|
||||
("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
|
||||
("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
|
||||
("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
|
||||
]
|
||||
|
||||
for host, port, version, expected in test_cases:
|
||||
result = create_quic_multiaddr(host, port, version)
|
||||
assert str(result) == expected
|
||||
|
||||
def test_ipv6_creation(self):
|
||||
"""Test IPv6 QUIC multiaddr creation."""
|
||||
test_cases = [
|
||||
("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
|
||||
("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
|
||||
]
|
||||
|
||||
for host, port, version, expected in test_cases:
|
||||
result = create_quic_multiaddr(host, port, version)
|
||||
assert str(result) == expected
|
||||
|
||||
def test_default_version(self):
|
||||
"""Test default version is quic-v1."""
|
||||
result = create_quic_multiaddr("127.0.0.1", 4001)
|
||||
expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
assert str(result) == expected
|
||||
|
||||
def test_invalid_inputs_raise_errors(self):
|
||||
"""Test invalid inputs raise appropriate errors."""
|
||||
# Invalid IP
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("invalid-ip", 4001)
|
||||
|
||||
# Invalid port
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", 70000)
|
||||
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", -1)
|
||||
|
||||
# Invalid version
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
|
||||
|
||||
|
||||
class TestQuicVersionToWireFormat:
|
||||
"""Test QUIC version to wire format conversion."""
|
||||
|
||||
def test_supported_versions(self):
|
||||
"""Test supported version conversions."""
|
||||
test_cases = [
|
||||
("quic-v1", 0x00000001), # RFC 9000
|
||||
("quic", 0xFF00001D), # draft-29
|
||||
]
|
||||
|
||||
for version, expected_wire in test_cases:
|
||||
result = quic_version_to_wire_format(TProtocol(version))
|
||||
assert result == expected_wire, f"Failed for version {version}"
|
||||
|
||||
def test_unsupported_version_raises_error(self):
|
||||
"""Test unsupported versions raise error."""
|
||||
with pytest.raises(QUICUnsupportedVersionError):
|
||||
quic_version_to_wire_format(TProtocol("unsupported-version"))
|
||||
|
||||
|
||||
class TestGetAlpnProtocols:
|
||||
"""Test ALPN protocol retrieval."""
|
||||
|
||||
def test_returns_libp2p_protocols(self):
|
||||
"""Test returns expected libp2p ALPN protocols."""
|
||||
protocols = get_alpn_protocols()
|
||||
assert protocols == ["libp2p"]
|
||||
assert isinstance(protocols, list)
|
||||
|
||||
def test_returns_copy(self):
|
||||
"""Test returns a copy, not the original list."""
|
||||
protocols1 = get_alpn_protocols()
|
||||
protocols2 = get_alpn_protocols()
|
||||
|
||||
# Modify one list
|
||||
protocols1.append("test")
|
||||
|
||||
# Other list should be unchanged
|
||||
assert protocols2 == ["libp2p"]
|
||||
|
||||
|
||||
class TestNormalizeQuicMultiaddr:
|
||||
"""Test QUIC multiaddr normalization."""
|
||||
|
||||
def test_already_normalized(self):
|
||||
"""Test already normalized multiaddrs pass through."""
|
||||
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
maddr = Multiaddr(addr_str)
|
||||
|
||||
result = normalize_quic_multiaddr(maddr)
|
||||
assert str(result) == addr_str
|
||||
|
||||
def test_normalize_different_versions(self):
|
||||
"""Test normalization works for different QUIC versions."""
|
||||
test_cases = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip6/::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = normalize_quic_multiaddr(maddr)
|
||||
|
||||
# Should be valid QUIC multiaddr
|
||||
assert is_quic_multiaddr(result)
|
||||
|
||||
# Should be parseable
|
||||
host, port = quic_multiaddr_to_endpoint(result)
|
||||
version = multiaddr_to_quic_version(result)
|
||||
|
||||
# Should match original
|
||||
orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
|
||||
orig_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
assert host == orig_host
|
||||
assert port == orig_port
|
||||
assert version == orig_version
|
||||
|
||||
def test_non_quic_raises_error(self):
|
||||
"""Test non-QUIC multiaddrs raise error."""
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
normalize_quic_multiaddr(maddr)
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for utility functions working together."""
|
||||
|
||||
def test_round_trip_conversion(self):
|
||||
"""Test creating and parsing multiaddrs works correctly."""
|
||||
test_cases = [
|
||||
("127.0.0.1", 4001, "quic-v1"),
|
||||
("::1", 5000, "quic"),
|
||||
("192.168.1.100", 8080, "quic-v1"),
|
||||
]
|
||||
|
||||
for host, port, version in test_cases:
|
||||
# Create multiaddr
|
||||
maddr = create_quic_multiaddr(host, port, version)
|
||||
|
||||
# Should be detected as QUIC
|
||||
assert is_quic_multiaddr(maddr)
|
||||
|
||||
# Should extract original values
|
||||
extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
|
||||
extracted_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
assert extracted_host == host
|
||||
assert extracted_port == port
|
||||
assert extracted_version == version
|
||||
|
||||
# Should normalize to same value
|
||||
normalized = normalize_quic_multiaddr(maddr)
|
||||
assert str(normalized) == str(maddr)
|
||||
|
||||
def test_wire_format_integration(self):
|
||||
"""Test wire format conversion works with version detection."""
|
||||
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
maddr = Multiaddr(addr_str)
|
||||
|
||||
# Extract version and convert to wire format
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
wire_format = quic_version_to_wire_format(version)
|
||||
|
||||
# Should be QUIC v1 wire format
|
||||
assert wire_format == 0x00000001
|
||||
109
tests/examples/test_echo_thin_waist.py
Normal file
109
tests/examples/test_echo_thin_waist.py
Normal file
@ -0,0 +1,109 @@
|
||||
import contextlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP
|
||||
|
||||
# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging
|
||||
|
||||
# This test is intentionally lightweight and can be marked as 'integration'.
|
||||
# It ensures the echo example runs and prints the new Thin Waist lines using
|
||||
# Trio primitives.
|
||||
|
||||
current_file = Path(__file__)
|
||||
project_root = current_file.parent.parent.parent
|
||||
EXAMPLES_DIR: Path = project_root / "examples" / "echo"
|
||||
|
||||
|
||||
def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path):
|
||||
"""Run echo server and validate printed multiaddr and peer id."""
|
||||
# Run echo example as server
|
||||
cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"]
|
||||
env = {**os.environ, "PYTHONUNBUFFERED": "1"}
|
||||
proc: subprocess.Popen[str] = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if proc.stdout is None:
|
||||
proc.terminate()
|
||||
raise RuntimeError("Process stdout is None")
|
||||
out_stream = proc.stdout
|
||||
|
||||
peer_id: str | None = None
|
||||
printed_multiaddr: str | None = None
|
||||
saw_waiting = False
|
||||
|
||||
start = time.time()
|
||||
timeout_s = 8.0
|
||||
try:
|
||||
while time.time() - start < timeout_s:
|
||||
line = out_stream.readline()
|
||||
if not line:
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
s = line.strip()
|
||||
if s.startswith("I am "):
|
||||
peer_id = s.partition("I am ")[2]
|
||||
if s.startswith("echo-demo -d "):
|
||||
printed_multiaddr = s.partition("echo-demo -d ")[2]
|
||||
if "Waiting for incoming connections..." in s:
|
||||
saw_waiting = True
|
||||
break
|
||||
finally:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.terminate()
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.kill()
|
||||
|
||||
assert peer_id, "Did not capture peer ID line"
|
||||
assert printed_multiaddr, "Did not capture multiaddr line"
|
||||
assert saw_waiting, "Did not capture waiting-for-connections line"
|
||||
|
||||
# Validate multiaddr structure using py-multiaddr protocol methods
|
||||
ma = Multiaddr(printed_multiaddr) # should parse without error
|
||||
|
||||
# Check that the multiaddr contains the p2p protocol
|
||||
try:
|
||||
peer_id_from_multiaddr = ma.value_for_protocol("p2p")
|
||||
assert peer_id_from_multiaddr is not None, (
|
||||
"Multiaddr missing p2p protocol value"
|
||||
)
|
||||
assert peer_id_from_multiaddr == peer_id, (
|
||||
f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Failed to extract p2p protocol value: {e}")
|
||||
|
||||
# Validate the multiaddr structure by checking protocols
|
||||
protocols = ma.protocols()
|
||||
|
||||
# Should have at least IP, TCP, and P2P protocols
|
||||
assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), (
|
||||
"Missing IP protocol"
|
||||
)
|
||||
assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol"
|
||||
assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol"
|
||||
|
||||
# Extract the p2p part and validate it matches the captured peer ID
|
||||
p2p_part = Multiaddr(f"/p2p/{peer_id}")
|
||||
try:
|
||||
# Decapsulate the p2p part to get the transport address
|
||||
transport_addr = ma.decapsulate(p2p_part)
|
||||
# Verify the decapsulated address doesn't contain p2p
|
||||
transport_protocols = transport_addr.protocols()
|
||||
assert not any(p.code == P_P2P for p in transport_protocols), (
|
||||
"Decapsulation failed - still contains p2p"
|
||||
)
|
||||
# Verify the original multiaddr can be reconstructed
|
||||
reconstructed = transport_addr.encapsulate(p2p_part)
|
||||
assert str(reconstructed) == str(ma), "Reconstruction failed"
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Multiaddr decapsulation failed: {e}")
|
||||
6
tests/examples/test_quic_echo_example.py
Normal file
6
tests/examples/test_quic_echo_example.py
Normal file
@ -0,0 +1,6 @@
|
||||
def test_echo_quic_example():
|
||||
"""Test that the QUIC echo example can be imported and has required functions."""
|
||||
from examples.echo import echo_quic
|
||||
|
||||
assert hasattr(echo_quic, "main")
|
||||
assert hasattr(echo_quic, "run")
|
||||
8
tests/interop/nim_libp2p/.gitignore
vendored
Normal file
8
tests/interop/nim_libp2p/.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
nimble.develop
|
||||
nimble.paths
|
||||
|
||||
*.nimble
|
||||
nim-libp2p/
|
||||
|
||||
nim_echo_server
|
||||
config.nims
|
||||
119
tests/interop/nim_libp2p/conftest.py
Normal file
119
tests/interop/nim_libp2p/conftest.py
Normal file
@ -0,0 +1,119 @@
|
||||
import fcntl
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_nim_available():
|
||||
"""Check if nim compiler is available."""
|
||||
return shutil.which("nim") is not None and shutil.which("nimble") is not None
|
||||
|
||||
|
||||
def check_nim_binary_built():
|
||||
"""Check if nim echo server binary is built."""
|
||||
current_dir = Path(__file__).parent
|
||||
binary_path = current_dir / "nim_echo_server"
|
||||
return binary_path.exists() and binary_path.stat().st_size > 0
|
||||
|
||||
|
||||
def run_nim_setup_with_lock():
|
||||
"""Run nim setup with file locking to prevent parallel execution."""
|
||||
current_dir = Path(__file__).parent
|
||||
lock_file = current_dir / ".setup_lock"
|
||||
setup_script = current_dir / "scripts" / "setup_nim_echo.sh"
|
||||
|
||||
if not setup_script.exists():
|
||||
raise RuntimeError(f"Setup script not found: {setup_script}")
|
||||
|
||||
# Try to acquire lock
|
||||
try:
|
||||
with open(lock_file, "w") as f:
|
||||
# Non-blocking lock attempt
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
|
||||
# Double-check binary doesn't exist (another worker might have built it)
|
||||
if check_nim_binary_built():
|
||||
logger.info("Binary already exists, skipping setup")
|
||||
return
|
||||
|
||||
logger.info("Acquired setup lock, running nim-libp2p setup...")
|
||||
|
||||
# Make setup script executable and run it
|
||||
setup_script.chmod(0o755)
|
||||
result = subprocess.run(
|
||||
[str(setup_script)],
|
||||
cwd=current_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minute timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Setup failed (exit {result.returncode}):\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Verify binary was built
|
||||
if not check_nim_binary_built():
|
||||
raise RuntimeError("nim_echo_server binary not found after setup")
|
||||
|
||||
logger.info("nim-libp2p setup completed successfully")
|
||||
|
||||
except BlockingIOError:
|
||||
# Another worker is running setup, wait for it to complete
|
||||
logger.info("Another worker is running setup, waiting...")
|
||||
|
||||
# Wait for setup to complete (check every 2 seconds, max 5 minutes)
|
||||
for _ in range(150): # 150 * 2 = 300 seconds = 5 minutes
|
||||
if check_nim_binary_built():
|
||||
logger.info("Setup completed by another worker")
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
raise TimeoutError("Timed out waiting for setup to complete")
|
||||
|
||||
finally:
|
||||
# Clean up lock file
|
||||
try:
|
||||
lock_file.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="function") # Changed to function scope
|
||||
def nim_echo_binary():
|
||||
"""Get nim echo server binary path."""
|
||||
current_dir = Path(__file__).parent
|
||||
binary_path = current_dir / "nim_echo_server"
|
||||
|
||||
if not binary_path.exists():
|
||||
pytest.skip(
|
||||
"nim_echo_server binary not found. "
|
||||
"Run setup script: ./scripts/setup_nim_echo.sh"
|
||||
)
|
||||
|
||||
return binary_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def nim_server(nim_echo_binary):
|
||||
"""Start and stop nim echo server for tests."""
|
||||
# Import here to avoid circular imports
|
||||
# pyrefly: ignore
|
||||
from test_echo_interop import NimEchoServer
|
||||
|
||||
server = NimEchoServer(nim_echo_binary)
|
||||
|
||||
try:
|
||||
peer_id, listen_addr = await server.start()
|
||||
yield server, peer_id, listen_addr
|
||||
finally:
|
||||
await server.stop()
|
||||
108
tests/interop/nim_libp2p/nim_echo_server.nim
Normal file
108
tests/interop/nim_libp2p/nim_echo_server.nim
Normal file
@ -0,0 +1,108 @@
|
||||
{.used.}
|
||||
|
||||
import chronos
|
||||
import stew/byteutils
|
||||
import libp2p
|
||||
|
||||
##
|
||||
# Simple Echo Protocol Implementation for py-libp2p Interop Testing
|
||||
##
|
||||
const EchoCodec = "/echo/1.0.0"
|
||||
|
||||
type EchoProto = ref object of LPProtocol
|
||||
|
||||
proc new(T: typedesc[EchoProto]): T =
|
||||
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
echo "Echo server: Received connection from ", conn.peerId
|
||||
|
||||
# Read and echo messages in a loop
|
||||
while not conn.atEof:
|
||||
try:
|
||||
# Read length-prefixed message using nim-libp2p's readLp
|
||||
let message = await conn.readLp(1024 * 1024) # Max 1MB
|
||||
if message.len == 0:
|
||||
echo "Echo server: Empty message, closing connection"
|
||||
break
|
||||
|
||||
let messageStr = string.fromBytes(message)
|
||||
echo "Echo server: Received (", message.len, " bytes): ", messageStr
|
||||
|
||||
# Echo back using writeLp
|
||||
await conn.writeLp(message)
|
||||
echo "Echo server: Echoed message back"
|
||||
|
||||
except CatchableError as e:
|
||||
echo "Echo server: Error processing message: ", e.msg
|
||||
break
|
||||
|
||||
except CancelledError as e:
|
||||
echo "Echo server: Connection cancelled"
|
||||
raise e
|
||||
except CatchableError as e:
|
||||
echo "Echo server: Exception in handler: ", e.msg
|
||||
finally:
|
||||
echo "Echo server: Connection closed"
|
||||
await conn.close()
|
||||
|
||||
return T.new(codecs = @[EchoCodec], handler = handle)
|
||||
|
||||
##
|
||||
# Create QUIC-enabled switch
|
||||
##
|
||||
proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
|
||||
var switch = SwitchBuilder
|
||||
.new()
|
||||
.withRng(rng)
|
||||
.withAddress(ma)
|
||||
.withQuicTransport()
|
||||
.build()
|
||||
result = switch
|
||||
|
||||
##
|
||||
# Main server
|
||||
##
|
||||
proc main() {.async.} =
|
||||
let
|
||||
rng = newRng()
|
||||
localAddr = MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet()
|
||||
echoProto = EchoProto.new()
|
||||
|
||||
echo "=== Nim Echo Server for py-libp2p Interop ==="
|
||||
|
||||
# Create switch
|
||||
let switch = createSwitch(localAddr, rng)
|
||||
switch.mount(echoProto)
|
||||
|
||||
# Start server
|
||||
await switch.start()
|
||||
|
||||
# Print connection info
|
||||
echo "Peer ID: ", $switch.peerInfo.peerId
|
||||
echo "Listening on:"
|
||||
for addr in switch.peerInfo.addrs:
|
||||
echo " ", $addr, "/p2p/", $switch.peerInfo.peerId
|
||||
echo "Protocol: ", EchoCodec
|
||||
echo "Ready for py-libp2p connections!"
|
||||
echo ""
|
||||
|
||||
# Keep running
|
||||
try:
|
||||
await sleepAsync(100.hours)
|
||||
except CancelledError:
|
||||
echo "Shutting down..."
|
||||
finally:
|
||||
await switch.stop()
|
||||
|
||||
# Graceful shutdown handler
|
||||
proc signalHandler() {.noconv.} =
|
||||
echo "\nShutdown signal received"
|
||||
quit(0)
|
||||
|
||||
when isMainModule:
|
||||
setControlCHook(signalHandler)
|
||||
try:
|
||||
waitFor(main())
|
||||
except CatchableError as e:
|
||||
echo "Error: ", e.msg
|
||||
quit(1)
|
||||
74
tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
Executable file
74
tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
Executable file
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env bash
|
||||
# tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
|
||||
# Cache-aware setup that skips installation if packages exist
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="${SCRIPT_DIR}/.."
|
||||
|
||||
# Colors
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||
|
||||
main() {
|
||||
log_info "Setting up nim echo server for interop testing..."
|
||||
|
||||
# Check if nim is available
|
||||
if ! command -v nim &> /dev/null || ! command -v nimble &> /dev/null; then
|
||||
log_error "Nim not found. Please install nim first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd "${PROJECT_DIR}"
|
||||
|
||||
# Create logs directory
|
||||
mkdir -p logs
|
||||
|
||||
# Check if binary already exists
|
||||
if [[ -f "nim_echo_server" ]]; then
|
||||
log_info "nim_echo_server already exists, skipping build"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Check if libp2p is already installed (cache-aware)
|
||||
if nimble list -i | grep -q "libp2p"; then
|
||||
log_info "libp2p already installed, skipping installation"
|
||||
else
|
||||
log_info "Installing nim-libp2p globally..."
|
||||
nimble install -y libp2p
|
||||
fi
|
||||
|
||||
log_info "Building nim echo server..."
|
||||
# Compile the echo server
|
||||
nim c \
|
||||
-d:release \
|
||||
-d:chronicles_log_level=INFO \
|
||||
-d:libp2p_quic_support \
|
||||
-d:chronos_event_loop=iocp \
|
||||
-d:ssl \
|
||||
--opt:speed \
|
||||
--mm:orc \
|
||||
--verbosity:1 \
|
||||
-o:nim_echo_server \
|
||||
nim_echo_server.nim
|
||||
|
||||
# Verify binary was created
|
||||
if [[ -f "nim_echo_server" ]]; then
|
||||
log_info "✅ nim_echo_server built successfully"
|
||||
log_info "Binary size: $(ls -lh nim_echo_server | awk '{print $5}')"
|
||||
else
|
||||
log_error "❌ Failed to build nim_echo_server"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "🎉 Setup complete!"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
195
tests/interop/nim_libp2p/test_echo_interop.py
Normal file
195
tests/interop/nim_libp2p/test_echo_interop.py
Normal file
@ -0,0 +1,195 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.utils.varint import encode_varint_prefixed, read_varint_prefixed_bytes
|
||||
|
||||
# Configuration
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
TEST_TIMEOUT = 30
|
||||
SERVER_START_TIMEOUT = 10.0
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NimEchoServer:
|
||||
"""Simple nim echo server manager."""
|
||||
|
||||
def __init__(self, binary_path: Path):
|
||||
self.binary_path = binary_path
|
||||
self.process: None | subprocess.Popen = None
|
||||
self.peer_id = None
|
||||
self.listen_addr = None
|
||||
|
||||
async def start(self):
|
||||
"""Start nim echo server and get connection info."""
|
||||
logger.info(f"Starting nim echo server: {self.binary_path}")
|
||||
|
||||
self.process = subprocess.Popen(
|
||||
[str(self.binary_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Parse output for connection info
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < SERVER_START_TIMEOUT:
|
||||
if self.process and self.process.poll() and self.process.stdout:
|
||||
output = self.process.stdout.read()
|
||||
raise RuntimeError(f"Server exited early: {output}")
|
||||
|
||||
reader = self.process.stdout if self.process else None
|
||||
if reader:
|
||||
line = reader.readline().strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
logger.info(f"Server: {line}")
|
||||
|
||||
if line.startswith("Peer ID:"):
|
||||
self.peer_id = line.split(":", 1)[1].strip()
|
||||
|
||||
elif "/quic-v1/p2p/" in line and self.peer_id:
|
||||
if line.strip().startswith("/"):
|
||||
self.listen_addr = line.strip()
|
||||
logger.info(f"Server ready: {self.listen_addr}")
|
||||
return self.peer_id, self.listen_addr
|
||||
|
||||
await self.stop()
|
||||
raise TimeoutError(f"Server failed to start within {SERVER_START_TIMEOUT}s")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the server."""
|
||||
if self.process:
|
||||
logger.info("Stopping nim echo server...")
|
||||
try:
|
||||
self.process.terminate()
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
self.process = None
|
||||
|
||||
|
||||
async def run_echo_test(server_addr: str, messages: list[str]):
|
||||
"""Test echo protocol against nim server with proper timeout handling."""
|
||||
# Create py-libp2p QUIC client with shorter timeouts
|
||||
|
||||
host = new_host(
|
||||
enable_quic=True,
|
||||
key_pair=create_new_key_pair(),
|
||||
)
|
||||
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/udp/0/quic-v1")
|
||||
responses = []
|
||||
|
||||
try:
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
logger.info(f"Connecting to nim server: {server_addr}")
|
||||
|
||||
# Connect to nim server
|
||||
maddr = multiaddr.Multiaddr(server_addr)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
await host.connect(info)
|
||||
|
||||
# Create stream
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
logger.info("Stream created")
|
||||
|
||||
# Test each message
|
||||
for i, message in enumerate(messages, 1):
|
||||
logger.info(f"Testing message {i}: {message}")
|
||||
|
||||
# Send with varint length prefix
|
||||
data = message.encode("utf-8")
|
||||
prefixed_data = encode_varint_prefixed(data)
|
||||
await stream.write(prefixed_data)
|
||||
|
||||
# Read response
|
||||
response_data = await read_varint_prefixed_bytes(stream)
|
||||
response = response_data.decode("utf-8")
|
||||
|
||||
logger.info(f"Got echo: {response}")
|
||||
responses.append(response)
|
||||
|
||||
# Verify echo
|
||||
assert message == response, (
|
||||
f"Echo failed: sent {message!r}, got {response!r}"
|
||||
)
|
||||
|
||||
await stream.close()
|
||||
logger.info("✅ All messages echoed correctly")
|
||||
|
||||
finally:
|
||||
await host.close()
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.timeout(TEST_TIMEOUT)
|
||||
async def test_basic_echo_interop(nim_server):
|
||||
"""Test basic echo functionality between py-libp2p and nim-libp2p."""
|
||||
server, peer_id, listen_addr = nim_server
|
||||
|
||||
test_messages = [
|
||||
"Hello from py-libp2p!",
|
||||
"QUIC transport working",
|
||||
"Echo test successful!",
|
||||
"Unicode: Ñoël, 测试, Ψυχή",
|
||||
]
|
||||
|
||||
logger.info(f"Testing against nim server: {peer_id}")
|
||||
|
||||
# Run test with timeout
|
||||
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
|
||||
responses = await run_echo_test(listen_addr, test_messages)
|
||||
|
||||
# Verify all messages echoed correctly
|
||||
assert len(responses) == len(test_messages)
|
||||
for sent, received in zip(test_messages, responses):
|
||||
assert sent == received
|
||||
|
||||
logger.info("✅ Basic echo interop test passed!")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.timeout(TEST_TIMEOUT)
|
||||
async def test_large_message_echo(nim_server):
|
||||
"""Test echo with larger messages."""
|
||||
server, peer_id, listen_addr = nim_server
|
||||
|
||||
large_messages = [
|
||||
"x" * 1024,
|
||||
"y" * 5000,
|
||||
]
|
||||
|
||||
logger.info("Testing large message echo...")
|
||||
|
||||
# Run test with timeout
|
||||
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
|
||||
responses = await run_echo_test(listen_addr, large_messages)
|
||||
|
||||
assert len(responses) == len(large_messages)
|
||||
for sent, received in zip(large_messages, responses):
|
||||
assert sent == received
|
||||
|
||||
logger.info("✅ Large message echo test passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@ -669,8 +669,8 @@ async def swarm_conn_pair_factory(
|
||||
async with swarm_pair_factory(
|
||||
security_protocol=security_protocol, muxer_opt=muxer_opt
|
||||
) as swarms:
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()]
|
||||
conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0]
|
||||
conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0]
|
||||
yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1)
|
||||
|
||||
|
||||
|
||||
56
tests/utils/test_address_validation.py
Normal file
56
tests/utils/test_address_validation.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.utils.address_validation import (
|
||||
expand_wildcard_address,
|
||||
get_available_interfaces,
|
||||
get_optimal_binding_address,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("proto", ["tcp"])
|
||||
def test_get_available_interfaces(proto: str) -> None:
|
||||
interfaces = get_available_interfaces(0, protocol=proto)
|
||||
assert len(interfaces) > 0
|
||||
for addr in interfaces:
|
||||
assert isinstance(addr, Multiaddr)
|
||||
assert f"/{proto}/" in str(addr)
|
||||
|
||||
|
||||
def test_get_optimal_binding_address() -> None:
|
||||
addr = get_optimal_binding_address(0)
|
||||
assert isinstance(addr, Multiaddr)
|
||||
# At least IPv4 or IPv6 prefix present
|
||||
s = str(addr)
|
||||
assert ("/ip4/" in s) or ("/ip6/" in s)
|
||||
|
||||
|
||||
def test_expand_wildcard_address_ipv4() -> None:
|
||||
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0")
|
||||
expanded = expand_wildcard_address(wildcard)
|
||||
assert len(expanded) > 0
|
||||
for e in expanded:
|
||||
assert isinstance(e, Multiaddr)
|
||||
assert "/tcp/" in str(e)
|
||||
|
||||
|
||||
def test_expand_wildcard_address_port_override() -> None:
|
||||
wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000")
|
||||
overridden = expand_wildcard_address(wildcard, port=9001)
|
||||
assert len(overridden) > 0
|
||||
for e in overridden:
|
||||
assert str(e).endswith("/tcp/9001")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("NO_IPV6") == "1",
|
||||
reason="Environment disallows IPv6",
|
||||
)
|
||||
def test_expand_wildcard_address_ipv6() -> None:
|
||||
wildcard = Multiaddr("/ip6/::/tcp/0")
|
||||
expanded = expand_wildcard_address(wildcard)
|
||||
assert len(expanded) > 0
|
||||
for e in expanded:
|
||||
assert "/ip6/" in str(e)
|
||||
@ -15,6 +15,7 @@ import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.utils.logging import (
|
||||
_current_handlers,
|
||||
_current_listener,
|
||||
_listener_ready,
|
||||
log_queue,
|
||||
@ -24,13 +25,19 @@ from libp2p.utils.logging import (
|
||||
|
||||
def _reset_logging():
|
||||
"""Reset all logging state."""
|
||||
global _current_listener, _listener_ready
|
||||
global _current_listener, _listener_ready, _current_handlers
|
||||
|
||||
# Stop existing listener if any
|
||||
if _current_listener is not None:
|
||||
_current_listener.stop()
|
||||
_current_listener = None
|
||||
|
||||
# Close all file handlers to ensure proper cleanup on Windows
|
||||
for handler in _current_handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
handler.close()
|
||||
_current_handlers.clear()
|
||||
|
||||
# Reset the event
|
||||
_listener_ready = threading.Event()
|
||||
|
||||
@ -174,6 +181,15 @@ async def test_custom_log_file(clean_env):
|
||||
if _current_listener is not None:
|
||||
_current_listener.stop()
|
||||
|
||||
# Give a moment for the listener to fully stop
|
||||
await trio.sleep(0.05)
|
||||
|
||||
# Close all file handlers to release the file
|
||||
for handler in _current_handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
handler.flush() # Ensure all writes are flushed
|
||||
handler.close()
|
||||
|
||||
# Check if the file exists and contains our message
|
||||
assert log_file.exists()
|
||||
content = log_file.read_text()
|
||||
@ -185,16 +201,15 @@ async def test_default_log_file(clean_env):
|
||||
"""Test logging to the default file path."""
|
||||
os.environ["LIBP2P_DEBUG"] = "INFO"
|
||||
|
||||
with patch("libp2p.utils.logging.datetime") as mock_datetime:
|
||||
# Mock the timestamp to have a predictable filename
|
||||
mock_datetime.now.return_value.strftime.return_value = "20240101_120000"
|
||||
with patch("libp2p.utils.paths.create_temp_file") as mock_create_temp:
|
||||
# Mock the temp file creation to return a predictable path
|
||||
mock_temp_file = (
|
||||
Path(tempfile.gettempdir()) / "test_py-libp2p_20240101_120000.log"
|
||||
)
|
||||
mock_create_temp.return_value = mock_temp_file
|
||||
|
||||
# Remove the log file if it exists
|
||||
if os.name == "nt": # Windows
|
||||
log_file = Path("C:/Windows/Temp/20240101_120000_py-libp2p.log")
|
||||
else: # Unix-like
|
||||
log_file = Path("/tmp/20240101_120000_py-libp2p.log")
|
||||
log_file.unlink(missing_ok=True)
|
||||
mock_temp_file.unlink(missing_ok=True)
|
||||
|
||||
setup_logging()
|
||||
|
||||
@ -211,9 +226,18 @@ async def test_default_log_file(clean_env):
|
||||
if _current_listener is not None:
|
||||
_current_listener.stop()
|
||||
|
||||
# Check the default log file
|
||||
if log_file.exists(): # Only check content if we have write permission
|
||||
content = log_file.read_text()
|
||||
# Give a moment for the listener to fully stop
|
||||
await trio.sleep(0.05)
|
||||
|
||||
# Close all file handlers to release the file
|
||||
for handler in _current_handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
handler.flush() # Ensure all writes are flushed
|
||||
handler.close()
|
||||
|
||||
# Check the mocked temp file
|
||||
if mock_temp_file.exists():
|
||||
content = mock_temp_file.read_text()
|
||||
assert "Test message" in content
|
||||
|
||||
|
||||
|
||||
290
tests/utils/test_paths.py
Normal file
290
tests/utils/test_paths.py
Normal file
@ -0,0 +1,290 @@
|
||||
"""
|
||||
Tests for cross-platform path utilities.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.utils.paths import (
|
||||
create_temp_file,
|
||||
ensure_dir_exists,
|
||||
find_executable,
|
||||
get_binary_path,
|
||||
get_config_dir,
|
||||
get_project_root,
|
||||
get_python_executable,
|
||||
get_script_binary_path,
|
||||
get_script_dir,
|
||||
get_temp_dir,
|
||||
get_venv_path,
|
||||
join_paths,
|
||||
normalize_path,
|
||||
resolve_relative_path,
|
||||
)
|
||||
|
||||
|
||||
class TestPathUtilities:
|
||||
"""Test cross-platform path utilities."""
|
||||
|
||||
def test_get_temp_dir(self):
|
||||
"""Test that temp directory is accessible and exists."""
|
||||
temp_dir = get_temp_dir()
|
||||
assert isinstance(temp_dir, Path)
|
||||
assert temp_dir.exists()
|
||||
assert temp_dir.is_dir()
|
||||
# Should match system temp directory
|
||||
assert temp_dir == Path(tempfile.gettempdir())
|
||||
|
||||
def test_get_project_root(self):
|
||||
"""Test that project root is correctly determined."""
|
||||
project_root = get_project_root()
|
||||
assert isinstance(project_root, Path)
|
||||
assert project_root.exists()
|
||||
# Should contain pyproject.toml
|
||||
assert (project_root / "pyproject.toml").exists()
|
||||
# Should contain libp2p directory
|
||||
assert (project_root / "libp2p").exists()
|
||||
|
||||
def test_join_paths(self):
|
||||
"""Test cross-platform path joining."""
|
||||
# Test with strings
|
||||
result = join_paths("a", "b", "c")
|
||||
expected = Path("a") / "b" / "c"
|
||||
assert result == expected
|
||||
|
||||
# Test with mixed types
|
||||
result = join_paths("a", Path("b"), "c")
|
||||
expected = Path("a") / "b" / "c"
|
||||
assert result == expected
|
||||
|
||||
# Test with absolute path
|
||||
result = join_paths("/absolute", "path")
|
||||
expected = Path("/absolute") / "path"
|
||||
assert result == expected
|
||||
|
||||
def test_ensure_dir_exists(self, tmp_path):
|
||||
"""Test directory creation and existence checking."""
|
||||
# Test creating new directory
|
||||
new_dir = tmp_path / "new_dir"
|
||||
result = ensure_dir_exists(new_dir)
|
||||
assert result == new_dir
|
||||
assert new_dir.exists()
|
||||
assert new_dir.is_dir()
|
||||
|
||||
# Test creating nested directory
|
||||
nested_dir = tmp_path / "parent" / "child" / "grandchild"
|
||||
result = ensure_dir_exists(nested_dir)
|
||||
assert result == nested_dir
|
||||
assert nested_dir.exists()
|
||||
assert nested_dir.is_dir()
|
||||
|
||||
# Test with existing directory
|
||||
result = ensure_dir_exists(new_dir)
|
||||
assert result == new_dir
|
||||
assert new_dir.exists()
|
||||
|
||||
def test_get_config_dir(self):
|
||||
"""Test platform-specific config directory."""
|
||||
config_dir = get_config_dir()
|
||||
assert isinstance(config_dir, Path)
|
||||
|
||||
if os.name == "nt": # Windows
|
||||
# Should be in AppData/Roaming or user home
|
||||
assert "AppData" in str(config_dir) or "py-libp2p" in str(config_dir)
|
||||
else: # Unix-like
|
||||
# Should be in ~/.config
|
||||
assert ".config" in str(config_dir)
|
||||
assert "py-libp2p" in str(config_dir)
|
||||
|
||||
def test_get_script_dir(self):
|
||||
"""Test script directory detection."""
|
||||
# Test with current file
|
||||
script_dir = get_script_dir(__file__)
|
||||
assert isinstance(script_dir, Path)
|
||||
assert script_dir.exists()
|
||||
assert script_dir.is_dir()
|
||||
# Should contain this test file
|
||||
assert (script_dir / "test_paths.py").exists()
|
||||
|
||||
def test_create_temp_file(self):
|
||||
"""Test temporary file creation."""
|
||||
temp_file = create_temp_file()
|
||||
assert isinstance(temp_file, Path)
|
||||
assert temp_file.parent == get_temp_dir()
|
||||
assert temp_file.name.startswith("py-libp2p_")
|
||||
assert temp_file.name.endswith(".log")
|
||||
|
||||
# Test with custom prefix and suffix
|
||||
temp_file = create_temp_file(prefix="test_", suffix=".txt")
|
||||
assert temp_file.name.startswith("test_")
|
||||
assert temp_file.name.endswith(".txt")
|
||||
|
||||
def test_resolve_relative_path(self, tmp_path):
|
||||
"""Test relative path resolution."""
|
||||
base_path = tmp_path / "base"
|
||||
base_path.mkdir()
|
||||
|
||||
# Test relative path
|
||||
relative_path = "subdir/file.txt"
|
||||
result = resolve_relative_path(base_path, relative_path)
|
||||
expected = (base_path / "subdir" / "file.txt").resolve()
|
||||
assert result == expected
|
||||
|
||||
# Test absolute path (platform-agnostic)
|
||||
if os.name == "nt": # Windows
|
||||
absolute_path = "C:\\absolute\\path"
|
||||
else: # Unix-like
|
||||
absolute_path = "/absolute/path"
|
||||
result = resolve_relative_path(base_path, absolute_path)
|
||||
assert result == Path(absolute_path)
|
||||
|
||||
def test_normalize_path(self, tmp_path):
|
||||
"""Test path normalization."""
|
||||
# Test with relative path
|
||||
relative_path = tmp_path / ".." / "normalize_test"
|
||||
result = normalize_path(relative_path)
|
||||
assert result.is_absolute()
|
||||
assert "normalize_test" in str(result)
|
||||
|
||||
# Test with absolute path
|
||||
absolute_path = tmp_path / "test_file"
|
||||
result = normalize_path(absolute_path)
|
||||
assert result.is_absolute()
|
||||
assert result == absolute_path.resolve()
|
||||
|
||||
def test_get_venv_path(self, monkeypatch):
|
||||
"""Test virtual environment path detection."""
|
||||
# Test when no virtual environment is active
|
||||
# Temporarily clear VIRTUAL_ENV to test the "no venv" case
|
||||
monkeypatch.delenv("VIRTUAL_ENV", raising=False)
|
||||
result = get_venv_path()
|
||||
assert result is None
|
||||
|
||||
# Test when virtual environment is active
|
||||
test_venv_path = "/path/to/venv"
|
||||
monkeypatch.setenv("VIRTUAL_ENV", test_venv_path)
|
||||
result = get_venv_path()
|
||||
assert result == Path(test_venv_path)
|
||||
|
||||
def test_get_python_executable(self):
|
||||
"""Test Python executable path detection."""
|
||||
result = get_python_executable()
|
||||
assert isinstance(result, Path)
|
||||
assert result.exists()
|
||||
assert result.name.startswith("python")
|
||||
|
||||
def test_find_executable(self):
|
||||
"""Test executable finding in PATH."""
|
||||
# Test with non-existent executable
|
||||
result = find_executable("nonexistent_executable")
|
||||
assert result is None
|
||||
|
||||
# Test with existing executable (python should be available)
|
||||
result = find_executable("python")
|
||||
if result:
|
||||
assert isinstance(result, Path)
|
||||
assert result.exists()
|
||||
|
||||
def test_get_script_binary_path(self):
|
||||
"""Test script binary path detection."""
|
||||
result = get_script_binary_path()
|
||||
assert isinstance(result, Path)
|
||||
assert result.exists()
|
||||
assert result.is_dir()
|
||||
|
||||
def test_get_binary_path(self, monkeypatch):
|
||||
"""Test binary path resolution with virtual environment."""
|
||||
# Test when no virtual environment is active
|
||||
result = get_binary_path("python")
|
||||
if result:
|
||||
assert isinstance(result, Path)
|
||||
assert result.exists()
|
||||
|
||||
# Test when virtual environment is active
|
||||
test_venv_path = "/path/to/venv"
|
||||
monkeypatch.setenv("VIRTUAL_ENV", test_venv_path)
|
||||
# This test is more complex as it depends on the actual venv structure
|
||||
# We'll just verify the function doesn't crash
|
||||
result = get_binary_path("python")
|
||||
# Result can be None if binary not found in venv
|
||||
if result:
|
||||
assert isinstance(result, Path)
|
||||
|
||||
|
||||
class TestCrossPlatformCompatibility:
|
||||
"""Test cross-platform compatibility."""
|
||||
|
||||
def test_config_dir_platform_specific_windows(self, monkeypatch):
|
||||
"""Test config directory respects Windows conventions."""
|
||||
import platform
|
||||
|
||||
# Only run this test on Windows systems
|
||||
if platform.system() != "Windows":
|
||||
pytest.skip("This test only runs on Windows systems")
|
||||
|
||||
monkeypatch.setattr("os.name", "nt")
|
||||
monkeypatch.setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming")
|
||||
config_dir = get_config_dir()
|
||||
assert "AppData" in str(config_dir)
|
||||
assert "py-libp2p" in str(config_dir)
|
||||
|
||||
def test_path_separators_consistent(self):
|
||||
"""Test that path separators are handled consistently."""
|
||||
# Test that join_paths uses platform-appropriate separators
|
||||
result = join_paths("dir1", "dir2", "file.txt")
|
||||
expected = Path("dir1") / "dir2" / "file.txt"
|
||||
assert result == expected
|
||||
|
||||
# Test that the result uses correct separators for the platform
|
||||
if os.name == "nt": # Windows
|
||||
assert "\\" in str(result) or "/" in str(result)
|
||||
else: # Unix-like
|
||||
assert "/" in str(result)
|
||||
|
||||
def test_temp_file_uniqueness(self):
|
||||
"""Test that temporary files have unique names."""
|
||||
files = set()
|
||||
for _ in range(10):
|
||||
temp_file = create_temp_file()
|
||||
assert temp_file not in files
|
||||
files.add(temp_file)
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Test backward compatibility with existing code patterns."""
|
||||
|
||||
def test_path_operations_equivalent(self):
|
||||
"""Test that new path operations are equivalent to old os.path operations."""
|
||||
# Test join_paths vs os.path.join
|
||||
parts = ["a", "b", "c"]
|
||||
new_result = join_paths(*parts)
|
||||
old_result = Path(os.path.join(*parts))
|
||||
assert new_result == old_result
|
||||
|
||||
# Test get_script_dir vs os.path.dirname(os.path.abspath(__file__))
|
||||
new_script_dir = get_script_dir(__file__)
|
||||
old_script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
|
||||
assert new_script_dir == old_script_dir
|
||||
|
||||
def test_existing_functionality_preserved(self):
|
||||
"""Ensure no existing functionality is broken."""
|
||||
# Test that all functions return Path objects
|
||||
assert isinstance(get_temp_dir(), Path)
|
||||
assert isinstance(get_project_root(), Path)
|
||||
assert isinstance(join_paths("a", "b"), Path)
|
||||
assert isinstance(ensure_dir_exists(tempfile.gettempdir()), Path)
|
||||
assert isinstance(get_config_dir(), Path)
|
||||
assert isinstance(get_script_dir(__file__), Path)
|
||||
assert isinstance(create_temp_file(), Path)
|
||||
assert isinstance(resolve_relative_path(".", "test"), Path)
|
||||
assert isinstance(normalize_path("."), Path)
|
||||
assert isinstance(get_python_executable(), Path)
|
||||
assert isinstance(get_script_binary_path(), Path)
|
||||
|
||||
# Test optional return types
|
||||
venv_path = get_venv_path()
|
||||
if venv_path is not None:
|
||||
assert isinstance(venv_path, Path)
|
||||
Reference in New Issue
Block a user