Merge branch 'main' into chore01

This commit is contained in:
Manu Sheel Gupta
2025-09-24 22:40:38 +05:30
committed by GitHub
119 changed files with 16551 additions and 520 deletions

View File

@ -1,6 +1,5 @@
import pytest
@pytest.fixture
def security_protocol():
return None
return None

View File

@ -250,10 +250,13 @@ def test_new_swarm_tcp_multiaddr_supported():
assert isinstance(swarm.transport, TCP)
def test_new_swarm_quic_multiaddr_raises():
def test_new_swarm_quic_multiaddr_supported():
from libp2p.transport.quic.transport import QUICTransport
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
with pytest.raises(ValueError, match="QUIC not yet supported"):
new_swarm(listen_addrs=[addr])
swarm = new_swarm(listen_addrs=[addr])
assert isinstance(swarm, Swarm)
assert isinstance(swarm.transport, QUICTransport)
@pytest.mark.trio

View File

@ -1,4 +1,8 @@
import random
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
import trio
@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import (
PROTOCOL_ID,
GossipSub,
)
from libp2p.pubsub.pb import (
rpc_pb2,
)
from libp2p.tools.utils import (
connect,
)
@ -754,3 +761,173 @@ async def test_single_host():
assert connected_peers == 0, (
f"Single host has {connected_peers} connections, expected 0"
)
@pytest.mark.trio
async def test_handle_ihave(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock emit_iwant to capture calls
mock_emit_iwant = AsyncMock()
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
# Create a test message ID as a string representation of a (seqno, from) tuple
test_seqno = b"1234"
test_from = id_bob.to_bytes()
test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')"
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id])
# Mock seen_messages.cache to avoid false positives
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
# Simulate Bob sending IHAVE to Alice
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
# Check if emit_iwant was called with the correct message ID
mock_emit_iwant.assert_called_once()
called_args = mock_emit_iwant.call_args[0]
assert called_args[0] == [test_msg_id] # Expected message IDs
assert called_args[1] == id_bob # Sender peer ID
@pytest.mark.trio
async def test_handle_iwant(monkeypatch):
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_alice = pubsubs_gsub[index_alice].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock mcache.get to return a message
test_message = rpc_pb2.Message(data=b"test_data")
test_seqno = b"1234"
test_from = id_alice.to_bytes()
# ✅ Correct: use raw tuple and str() to serialize, no hex()
test_msg_id = str((test_seqno, test_from))
mock_mcache_get = MagicMock(return_value=test_message)
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
# Mock write_msg to capture the sent packet
mock_write_msg = AsyncMock()
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
# Simulate Alice sending IWANT to Bob
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id])
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
# Check if write_msg was called with the correct packet
mock_write_msg.assert_called_once()
packet = mock_write_msg.call_args[0][1]
assert isinstance(packet, rpc_pb2.RPC)
assert len(packet.publish) == 1
assert packet.publish[0] == test_message
# Verify that mcache.get was called with the correct parsed message ID
mock_mcache_get.assert_called_once()
called_msg_id = mock_mcache_get.call_args[0][0]
assert isinstance(called_msg_id, tuple)
assert called_msg_id == (test_seqno, test_from)
@pytest.mark.trio
async def test_handle_iwant_invalid_msg_id(monkeypatch):
"""
Test that handle_iwant raises ValueError for malformed message IDs.
"""
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_alice = pubsubs_gsub[index_alice].my_id
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1)
# Malformed message ID (not a tuple string)
malformed_msg_id = "not_a_valid_msg_id"
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id])
# Mock mcache.get and write_msg to ensure they are not called
mock_mcache_get = MagicMock()
monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get)
mock_write_msg = AsyncMock()
monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg)
with pytest.raises(ValueError):
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
mock_mcache_get.assert_not_called()
mock_write_msg.assert_not_called()
# Message ID that's a tuple string but not (bytes, bytes)
invalid_tuple_msg_id = "('abc', 123)"
iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id])
with pytest.raises(ValueError):
await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice)
mock_mcache_get.assert_not_called()
mock_write_msg.assert_not_called()
@pytest.mark.trio
async def test_handle_ihave_empty_message_ids(monkeypatch):
"""
Test that handle_ihave with an empty messageIDs list does not call emit_iwant.
"""
async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub:
gossipsub_routers = []
for pubsub in pubsubs_gsub:
if isinstance(pubsub.router, GossipSub):
gossipsub_routers.append(pubsub.router)
gossipsubs = tuple(gossipsub_routers)
index_alice = 0
index_bob = 1
id_bob = pubsubs_gsub[index_bob].my_id
# Connect Alice and Bob
await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host)
await trio.sleep(0.1) # Allow connections to establish
# Mock emit_iwant to capture calls
mock_emit_iwant = AsyncMock()
monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant)
# Empty messageIDs list
ihave_msg = rpc_pb2.ControlIHave(messageIDs=[])
# Mock seen_messages.cache to avoid false positives
monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {})
# Simulate Bob sending IHAVE to Alice
await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob)
# emit_iwant should not be called since there are no message IDs
mock_emit_iwant.assert_not_called()

View File

@ -65,7 +65,7 @@ async def test_prune_backoff():
@pytest.mark.trio
async def test_unsubscribe_backoff():
async with PubsubFactory.create_batch_with_gossipsub(
2, heartbeat_interval=1, prune_back_off=1, unsubscribe_back_off=2
2, heartbeat_interval=0.5, prune_back_off=2, unsubscribe_back_off=4
) as pubsubs:
gsub0 = pubsubs[0].router
gsub1 = pubsubs[1].router
@ -107,7 +107,8 @@ async def test_unsubscribe_backoff():
)
# try to graft again (should succeed after backoff)
await trio.sleep(1)
# Wait longer than unsubscribe_back_off (4 seconds) + some buffer
await trio.sleep(4.5)
await gsub0.emit_graft(topic, host_1.get_id())
await trio.sleep(1)
assert host_0.get_id() in gsub1.mesh[topic], (

View File

@ -0,0 +1,108 @@
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
from libp2p.custom_types import (
TMuxerClass,
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.stream_muxer.muxer_multistream import (
MuxerMultistream,
)
@pytest.mark.trio
async def test_muxer_timeout_configuration():
"""Test that muxer respects timeout configuration."""
muxer = MuxerMultistream({}, negotiate_timeout=1)
assert muxer.negotiate_timeout == 1
@pytest.mark.trio
async def test_select_transport_passes_timeout_to_multiselect():
"""Test that timeout is passed to multiselect client in select_transport."""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = False
# Mock MultiselectClient
muxer = MuxerMultistream({}, negotiate_timeout=10)
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
# Call select_transport
await muxer.select_transport(mock_conn)
# Verify that select_one_of was called with the correct timeout
args, _ = muxer.multiselect.negotiate.call_args
assert args[1] == 10
@pytest.mark.trio
async def test_new_conn_passes_timeout_to_multistream_client():
"""Test that timeout is passed to multistream client in new_conn."""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = True
mock_peer_id = ID(b"test_peer")
mock_communicator = MagicMock()
# Mock MultistreamClient and transports
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
# Call new_conn
await muxer.new_conn(mock_conn, mock_peer_id)
# Verify that select_one_of was called with the correct timeout
muxer.multistream_client.select_one_of(
tuple(muxer.transports.keys()), mock_communicator, 30
)
@pytest.mark.trio
async def test_select_transport_no_protocol_selected():
"""
Test that select_transport raises MultiselectError when no protocol is selected.
"""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = False
# Mock Multiselect to return None
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
# Expect MultiselectError to be raised
with pytest.raises(MultiselectError, match="no protocol selected"):
await muxer.select_transport(mock_conn)
@pytest.mark.trio
async def test_add_transport_updates_precedence():
"""Test that adding a transport updates protocol precedence."""
# Mock transport classes
mock_transport1 = MagicMock(spec=TMuxerClass)
mock_transport2 = MagicMock(spec=TMuxerClass)
# Initialize muxer and add transports
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.add_transport(TProtocol("proto1"), mock_transport1)
muxer.add_transport(TProtocol("proto2"), mock_transport2)
# Verify transport order
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
# Re-add proto1 to check if it moves to the end
muxer.add_transport(TProtocol("proto1"), mock_transport1)
assert list(muxer.transports.keys()) == ["proto2", "proto1"]

View File

@ -0,0 +1,553 @@
"""
Enhanced tests for QUIC connection functionality - Module 3.
Tests all new features including advanced stream management, resource management,
error handling, and concurrent operations.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from multiaddr.multiaddr import Multiaddr
import trio
from libp2p.crypto.ed25519 import create_new_key_pair
from libp2p.peer.id import ID
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.exceptions import (
QUICConnectionClosedError,
QUICConnectionError,
QUICConnectionTimeoutError,
QUICPeerVerificationError,
QUICStreamLimitError,
QUICStreamTimeoutError,
)
from libp2p.transport.quic.security import QUICTLSConfigManager
from libp2p.transport.quic.stream import QUICStream, StreamDirection
class MockResourceScope:
"""Mock resource scope for testing."""
def __init__(self):
self.memory_reserved = 0
def reserve_memory(self, size):
self.memory_reserved += size
def release_memory(self, size):
self.memory_reserved = max(0, self.memory_reserved - size)
class TestQUICConnection:
"""Test suite for QUIC connection functionality."""
@pytest.fixture
def mock_quic_connection(self):
"""Create mock aioquic QuicConnection."""
mock = Mock()
mock.next_event.return_value = None
mock.datagrams_to_send.return_value = []
mock.get_timer.return_value = None
mock.connect = Mock()
mock.close = Mock()
mock.send_stream_data = Mock()
mock.reset_stream = Mock()
return mock
@pytest.fixture
def mock_quic_transport(self):
mock = Mock()
mock._config = QUICTransportConfig()
return mock
@pytest.fixture
def mock_resource_scope(self):
"""Create mock resource scope."""
return MockResourceScope()
@pytest.fixture
def quic_connection(
self,
mock_quic_connection: Mock,
mock_quic_transport: Mock,
mock_resource_scope: MockResourceScope,
):
"""Create test QUIC connection with enhanced features."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
mock_security_manager = Mock()
return QUICConnection(
quic_connection=mock_quic_connection,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=None,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=mock_quic_transport,
resource_scope=mock_resource_scope,
security_manager=mock_security_manager,
)
@pytest.fixture
def server_connection(self, mock_quic_connection, mock_resource_scope):
"""Create server-side QUIC connection."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
return QUICConnection(
quic_connection=mock_quic_connection,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=False,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
resource_scope=mock_resource_scope,
)
# Basic functionality tests
def test_connection_initialization_enhanced(
self, quic_connection, mock_resource_scope
):
"""Test enhanced connection initialization."""
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
assert quic_connection.is_initiator is True
assert not quic_connection.is_closed
assert not quic_connection.is_established
assert len(quic_connection._streams) == 0
assert quic_connection._resource_scope == mock_resource_scope
assert quic_connection._outbound_stream_count == 0
assert quic_connection._inbound_stream_count == 0
assert len(quic_connection._stream_accept_queue) == 0
def test_stream_id_calculation_enhanced(self):
"""Test enhanced stream ID calculation for client/server."""
# Client connection (initiator)
client_conn = QUICConnection(
quic_connection=Mock(),
remote_addr=("127.0.0.1", 4001),
remote_peer_id=None,
local_peer_id=Mock(),
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
assert client_conn._next_stream_id == 0 # Client starts with 0
# Server connection (not initiator)
server_conn = QUICConnection(
quic_connection=Mock(),
remote_addr=("127.0.0.1", 4001),
remote_peer_id=None,
local_peer_id=Mock(),
is_initiator=False,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
assert server_conn._next_stream_id == 1 # Server starts with 1
def test_incoming_stream_detection_enhanced(self, quic_connection):
"""Test enhanced incoming stream detection logic."""
# For client (initiator), odd stream IDs are incoming
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
# Stream management tests
@pytest.mark.trio
async def test_open_stream_basic(self, quic_connection):
"""Test basic stream opening."""
quic_connection._started = True
stream = await quic_connection.open_stream()
assert isinstance(stream, QUICStream)
assert stream.stream_id == "0"
assert stream.direction == StreamDirection.OUTBOUND
assert 0 in quic_connection._streams
assert quic_connection._outbound_stream_count == 1
@pytest.mark.trio
async def test_open_stream_limit_reached(self, quic_connection):
"""Test stream limit enforcement."""
quic_connection._started = True
quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS
with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"):
await quic_connection.open_stream()
@pytest.mark.trio
async def test_open_stream_timeout(self, quic_connection: QUICConnection):
"""Test stream opening timeout."""
quic_connection._started = True
return
# Mock the stream ID lock to simulate slow operation
async def slow_acquire():
await trio.sleep(10) # Longer than timeout
with patch.object(
quic_connection._stream_lock, "acquire", side_effect=slow_acquire
):
with pytest.raises(
QUICStreamTimeoutError, match="Stream creation timed out"
):
await quic_connection.open_stream(timeout=0.1)
@pytest.mark.trio
async def test_accept_stream_basic(self, quic_connection):
"""Test basic stream acceptance."""
# Create a mock inbound stream
mock_stream = Mock(spec=QUICStream)
mock_stream.stream_id = "1"
# Add to accept queue
quic_connection._stream_accept_queue.append(mock_stream)
quic_connection._stream_accept_event.set()
accepted_stream = await quic_connection.accept_stream(timeout=0.1)
assert accepted_stream == mock_stream
assert len(quic_connection._stream_accept_queue) == 0
@pytest.mark.trio
async def test_accept_stream_timeout(self, quic_connection):
"""Test stream acceptance timeout."""
with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"):
await quic_connection.accept_stream(timeout=0.1)
@pytest.mark.trio
async def test_accept_stream_on_closed_connection(self, quic_connection):
"""Test stream acceptance on closed connection."""
await quic_connection.close()
with pytest.raises(QUICConnectionClosedError, match="Connection is closed"):
await quic_connection.accept_stream()
# Stream handler tests
@pytest.mark.trio
async def test_stream_handler_setting(self, quic_connection):
"""Test setting stream handler."""
async def mock_handler(stream):
pass
quic_connection.set_stream_handler(mock_handler)
assert quic_connection._stream_handler == mock_handler
# Connection lifecycle tests
@pytest.mark.trio
async def test_connection_start_client(self, quic_connection):
"""Test client connection start."""
with patch.object(
quic_connection, "_initiate_connection", new_callable=AsyncMock
) as mock_initiate:
await quic_connection.start()
assert quic_connection._started
mock_initiate.assert_called_once()
@pytest.mark.trio
async def test_connection_start_server(self, server_connection):
"""Test server connection start."""
await server_connection.start()
assert server_connection._started
assert server_connection._established
assert server_connection._connected_event.is_set()
@pytest.mark.trio
async def test_connection_start_already_started(self, quic_connection):
"""Test starting already started connection."""
quic_connection._started = True
# Should not raise error, just log warning
await quic_connection.start()
assert quic_connection._started
@pytest.mark.trio
async def test_connection_start_closed(self, quic_connection):
"""Test starting closed connection."""
quic_connection._closed = True
with pytest.raises(
QUICConnectionError, match="Cannot start a closed connection"
):
await quic_connection.start()
@pytest.mark.trio
async def test_connection_connect_with_nursery(
self, quic_connection: QUICConnection
):
"""Test connection establishment with nursery."""
quic_connection._started = True
quic_connection._established = True
quic_connection._connected_event.set()
with patch.object(
quic_connection, "_start_background_tasks", new_callable=AsyncMock
) as mock_start_tasks:
with patch.object(
quic_connection,
"_verify_peer_identity_with_security",
new_callable=AsyncMock,
) as mock_verify:
async with trio.open_nursery() as nursery:
await quic_connection.connect(nursery)
assert quic_connection._nursery == nursery
mock_start_tasks.assert_called_once()
mock_verify.assert_called_once()
@pytest.mark.trio
@pytest.mark.slow
async def test_connection_connect_timeout(
self, quic_connection: QUICConnection
) -> None:
"""Test connection establishment timeout."""
quic_connection._started = True
# Don't set connected event to simulate timeout
with patch.object(
quic_connection, "_start_background_tasks", new_callable=AsyncMock
):
async with trio.open_nursery() as nursery:
with pytest.raises(
QUICConnectionTimeoutError, match="Connection handshake timed out"
):
await quic_connection.connect(nursery)
# Resource management tests
@pytest.mark.trio
async def test_stream_removal_resource_cleanup(
self, quic_connection: QUICConnection, mock_resource_scope
):
"""Test stream removal and resource cleanup."""
quic_connection._started = True
# Create a stream
stream = await quic_connection.open_stream()
# Remove the stream
quic_connection._remove_stream(int(stream.stream_id))
assert int(stream.stream_id) not in quic_connection._streams
# Note: Count updates is async, so we can't test it directly here
# Error handling tests
@pytest.mark.trio
async def test_connection_error_handling(self, quic_connection) -> None:
"""Test connection error handling."""
error = Exception("Test error")
with patch.object(
quic_connection, "close", new_callable=AsyncMock
) as mock_close:
await quic_connection._handle_connection_error(error)
mock_close.assert_called_once()
# Statistics and monitoring tests
@pytest.mark.trio
async def test_connection_stats_enhanced(self, quic_connection) -> None:
"""Test enhanced connection statistics."""
quic_connection._started = True
# Create some streams
_stream1 = await quic_connection.open_stream()
_stream2 = await quic_connection.open_stream()
stats = quic_connection.get_stream_stats()
expected_keys = [
"total_streams",
"outbound_streams",
"inbound_streams",
"max_streams",
"stream_utilization",
"stats",
]
for key in expected_keys:
assert key in stats
assert stats["total_streams"] == 2
assert stats["outbound_streams"] == 2
assert stats["inbound_streams"] == 0
@pytest.mark.trio
async def test_get_active_streams(self, quic_connection) -> None:
"""Test getting active streams."""
quic_connection._started = True
# Create streams
stream1 = await quic_connection.open_stream()
stream2 = await quic_connection.open_stream()
active_streams = quic_connection.get_active_streams()
assert len(active_streams) == 2
assert stream1 in active_streams
assert stream2 in active_streams
@pytest.mark.trio
async def test_get_streams_by_protocol(self, quic_connection) -> None:
"""Test getting streams by protocol."""
quic_connection._started = True
# Create streams with different protocols
stream1 = await quic_connection.open_stream()
stream1.protocol = "/test/1.0.0"
stream2 = await quic_connection.open_stream()
stream2.protocol = "/other/1.0.0"
test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0")
other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0")
assert len(test_streams) == 1
assert len(other_streams) == 1
assert stream1 in test_streams
assert stream2 in other_streams
# Enhanced close tests
@pytest.mark.trio
async def test_connection_close_enhanced(
self, quic_connection: QUICConnection
) -> None:
"""Test enhanced connection close with stream cleanup."""
quic_connection._started = True
# Create some streams
_stream1 = await quic_connection.open_stream()
_stream2 = await quic_connection.open_stream()
await quic_connection.close()
assert quic_connection.is_closed
assert len(quic_connection._streams) == 0
# Concurrent operations tests
@pytest.mark.trio
async def test_concurrent_stream_operations(
self, quic_connection: QUICConnection
) -> None:
"""Test concurrent stream operations."""
quic_connection._started = True
async def create_stream():
return await quic_connection.open_stream()
# Create multiple streams concurrently
async with trio.open_nursery() as nursery:
for i in range(10):
nursery.start_soon(create_stream)
# Wait a bit for all to start
await trio.sleep(0.1)
# Should have created streams without conflicts
assert quic_connection._outbound_stream_count == 10
assert len(quic_connection._streams) == 10
# Connection properties tests
def test_connection_properties(self, quic_connection: QUICConnection) -> None:
"""Test connection property accessors."""
assert quic_connection.multiaddr() == quic_connection._maddr
assert quic_connection.local_peer_id() == quic_connection._local_peer_id
assert quic_connection.remote_peer_id() == quic_connection._remote_peer_id
# IRawConnection interface tests
@pytest.mark.trio
async def test_raw_connection_write(self, quic_connection: QUICConnection) -> None:
"""Test raw connection write interface."""
quic_connection._started = True
with patch.object(quic_connection, "open_stream") as mock_open:
mock_stream = AsyncMock()
mock_open.return_value = mock_stream
await quic_connection.write(b"test data")
mock_open.assert_called_once()
mock_stream.write.assert_called_once_with(b"test data")
mock_stream.close_write.assert_called_once()
@pytest.mark.trio
async def test_raw_connection_read_not_implemented(
self, quic_connection: QUICConnection
) -> None:
"""Test raw connection read raises NotImplementedError."""
with pytest.raises(NotImplementedError):
await quic_connection.read()
# Mock verification helpers
def test_mock_resource_scope_functionality(self, mock_resource_scope) -> None:
"""Test mock resource scope works correctly."""
assert mock_resource_scope.memory_reserved == 0
mock_resource_scope.reserve_memory(1000)
assert mock_resource_scope.memory_reserved == 1000
mock_resource_scope.reserve_memory(500)
assert mock_resource_scope.memory_reserved == 1500
mock_resource_scope.release_memory(600)
assert mock_resource_scope.memory_reserved == 900
mock_resource_scope.release_memory(2000) # Should not go negative
assert mock_resource_scope.memory_reserved == 0
@pytest.mark.trio
async def test_invalid_certificate_verification():
key_pair1 = create_new_key_pair()
key_pair2 = create_new_key_pair()
peer_id1 = ID.from_pubkey(key_pair1.public_key)
peer_id2 = ID.from_pubkey(key_pair2.public_key)
manager = QUICTLSConfigManager(
libp2p_private_key=key_pair1.private_key, peer_id=peer_id1
)
# Match the certificate against a different peer_id
with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"):
manager.verify_peer_identity(manager.tls_config.certificate, peer_id2)
from cryptography.hazmat.primitives.serialization import Encoding
# --- Corrupt the certificate by tampering the DER bytes ---
cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER)
corrupted_bytes = bytearray(cert_bytes)
# Flip some random bytes in the middle of the certificate
corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF
from cryptography import x509
from cryptography.hazmat.backends import default_backend
# This will still parse (structurally valid), but the signature
# or fingerprint will break
corrupted_cert = x509.load_der_x509_certificate(
bytes(corrupted_bytes), backend=default_backend()
)
with pytest.raises(
QUICPeerVerificationError, match="Certificate verification failed"
):
manager.verify_peer_identity(corrupted_cert, peer_id1)

View File

@ -0,0 +1,624 @@
"""
QUIC Connection ID Management Tests
This test module covers comprehensive testing of QUIC connection ID functionality
including generation, rotation, retirement, and validation according to RFC 9000.
Tests are organized into:
1. Basic Connection ID Management
2. Connection ID Rotation and Updates
3. Connection ID Retirement
4. Error Conditions and Edge Cases
5. Integration Tests with Real Connections
"""
import secrets
import time
from typing import Any
from unittest.mock import Mock
import pytest
from aioquic.buffer import Buffer
# Import aioquic components for low-level testing
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection, QuicConnectionId
from multiaddr import Multiaddr
from libp2p.crypto.ed25519 import create_new_key_pair
from libp2p.peer.id import ID
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.transport import QUICTransport
class ConnectionIdTestHelper:
"""Helper class for connection ID testing utilities."""
@staticmethod
def generate_connection_id(length: int = 8) -> bytes:
"""Generate a random connection ID of specified length."""
return secrets.token_bytes(length)
@staticmethod
def create_quic_connection_id(cid: bytes, sequence: int = 0) -> QuicConnectionId:
"""Create a QuicConnectionId object."""
return QuicConnectionId(
cid=cid,
sequence_number=sequence,
stateless_reset_token=secrets.token_bytes(16),
)
@staticmethod
def extract_connection_ids_from_connection(conn: QUICConnection) -> dict[str, Any]:
"""Extract connection ID information from a QUIC connection."""
quic = conn._quic
return {
"host_cids": [cid.cid.hex() for cid in getattr(quic, "_host_cids", [])],
"peer_cid": getattr(quic, "_peer_cid", None),
"peer_cid_available": [
cid.cid.hex() for cid in getattr(quic, "_peer_cid_available", [])
],
"retire_connection_ids": getattr(quic, "_retire_connection_ids", []),
"host_cid_seq": getattr(quic, "_host_cid_seq", 0),
}
class TestBasicConnectionIdManagement:
"""Test basic connection ID management functionality."""
@pytest.fixture
def mock_quic_connection(self):
"""Create a mock QUIC connection with connection ID support."""
mock_quic = Mock(spec=QuicConnection)
mock_quic._host_cids = []
mock_quic._host_cid_seq = 0
mock_quic._peer_cid = None
mock_quic._peer_cid_available = []
mock_quic._retire_connection_ids = []
mock_quic._configuration = Mock()
mock_quic._configuration.connection_id_length = 8
mock_quic._remote_active_connection_id_limit = 8
return mock_quic
@pytest.fixture
def quic_connection(self, mock_quic_connection):
"""Create a QUICConnection instance for testing."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
return QUICConnection(
quic_connection=mock_quic_connection,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
def test_connection_id_initialization(self, quic_connection):
"""Test that connection ID tracking is properly initialized."""
# Check that connection ID tracking structures are initialized
assert hasattr(quic_connection, "_available_connection_ids")
assert hasattr(quic_connection, "_current_connection_id")
assert hasattr(quic_connection, "_retired_connection_ids")
assert hasattr(quic_connection, "_connection_id_sequence_numbers")
# Initial state should be empty
assert len(quic_connection._available_connection_ids) == 0
assert quic_connection._current_connection_id is None
assert len(quic_connection._retired_connection_ids) == 0
assert len(quic_connection._connection_id_sequence_numbers) == 0
def test_connection_id_stats_tracking(self, quic_connection):
"""Test connection ID statistics are properly tracked."""
stats = quic_connection.get_connection_id_stats()
# Check that all expected stats are present
expected_keys = [
"available_connection_ids",
"current_connection_id",
"retired_connection_ids",
"connection_ids_issued",
"connection_ids_retired",
"connection_id_changes",
"available_cid_list",
]
for key in expected_keys:
assert key in stats
# Initial values should be zero/empty
assert stats["available_connection_ids"] == 0
assert stats["current_connection_id"] is None
assert stats["retired_connection_ids"] == 0
assert stats["connection_ids_issued"] == 0
assert stats["connection_ids_retired"] == 0
assert stats["connection_id_changes"] == 0
assert stats["available_cid_list"] == []
def test_current_connection_id_getter(self, quic_connection):
"""Test getting current connection ID."""
# Initially no connection ID
assert quic_connection.get_current_connection_id() is None
# Set a connection ID
test_cid = ConnectionIdTestHelper.generate_connection_id()
quic_connection._current_connection_id = test_cid
assert quic_connection.get_current_connection_id() == test_cid
def test_connection_id_generation(self):
"""Test connection ID generation utilities."""
# Test default length
cid1 = ConnectionIdTestHelper.generate_connection_id()
assert len(cid1) == 8
assert isinstance(cid1, bytes)
# Test custom length
cid2 = ConnectionIdTestHelper.generate_connection_id(16)
assert len(cid2) == 16
# Test uniqueness
cid3 = ConnectionIdTestHelper.generate_connection_id()
assert cid1 != cid3
class TestConnectionIdRotationAndUpdates:
"""Test connection ID rotation and update mechanisms."""
@pytest.fixture
def transport_config(self):
"""Create transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=100,
)
@pytest.fixture
def server_key(self):
"""Generate server private key."""
return create_new_key_pair().private_key
@pytest.fixture
def client_key(self):
"""Generate client private key."""
return create_new_key_pair().private_key
def test_connection_id_replenishment(self):
"""Test connection ID replenishment mechanism."""
# Create a real QuicConnection to test replenishment
config = QuicConfiguration(is_client=True)
config.connection_id_length = 8
quic_conn = QuicConnection(configuration=config)
# Initial state - should have some host connection IDs
initial_count = len(quic_conn._host_cids)
assert initial_count > 0
# Remove some connection IDs to trigger replenishment
while len(quic_conn._host_cids) > 2:
quic_conn._host_cids.pop()
# Trigger replenishment
quic_conn._replenish_connection_ids()
# Should have replenished up to the limit
assert len(quic_conn._host_cids) >= initial_count
# All connection IDs should have unique sequence numbers
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
assert len(sequences) == len(set(sequences))
def test_connection_id_sequence_numbers(self):
"""Test connection ID sequence number management."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Get initial sequence number
initial_seq = quic_conn._host_cid_seq
# Trigger replenishment to generate new connection IDs
quic_conn._replenish_connection_ids()
# Sequence numbers should increment
assert quic_conn._host_cid_seq > initial_seq
# All host connection IDs should have sequential numbers
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
sequences.sort()
# Check for proper sequence
for i in range(len(sequences) - 1):
assert sequences[i + 1] > sequences[i]
def test_connection_id_limits(self):
"""Test connection ID limit enforcement."""
config = QuicConfiguration(is_client=True)
config.connection_id_length = 8
quic_conn = QuicConnection(configuration=config)
# Set a reasonable limit
quic_conn._remote_active_connection_id_limit = 4
# Replenish connection IDs
quic_conn._replenish_connection_ids()
# Should not exceed the limit
assert len(quic_conn._host_cids) <= quic_conn._remote_active_connection_id_limit
class TestConnectionIdRetirement:
"""Test connection ID retirement functionality."""
def test_connection_id_retirement_basic(self):
"""Test basic connection ID retirement."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Create a test connection ID to retire
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=1
)
# Add it to peer connection IDs
quic_conn._peer_cid_available.append(test_cid)
quic_conn._peer_cid_sequence_numbers.add(1)
# Retire the connection ID
quic_conn._retire_peer_cid(test_cid)
# Should be added to retirement list
assert 1 in quic_conn._retire_connection_ids
def test_connection_id_retirement_limits(self):
"""Test connection ID retirement limits."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Fill up retirement list near the limit
max_retirements = 32 # Based on aioquic's default limit
for i in range(max_retirements):
quic_conn._retire_connection_ids.append(i)
# Should be at limit
assert len(quic_conn._retire_connection_ids) == max_retirements
def test_connection_id_retirement_events(self):
"""Test that retirement generates proper events."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Create and add a host connection ID
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=5
)
quic_conn._host_cids.append(test_cid)
# Create a retirement frame buffer
from aioquic.buffer import Buffer
buf = Buffer(capacity=16)
buf.push_uint_var(5) # sequence number to retire
buf.seek(0)
# Process retirement (this should generate an event)
try:
quic_conn._handle_retire_connection_id_frame(
Mock(), # context
0x19, # RETIRE_CONNECTION_ID frame type
buf,
)
# Check that connection ID was removed
remaining_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
assert 5 not in remaining_sequences
except Exception:
# May fail due to missing context, but that's okay for this test
pass
class TestConnectionIdErrorConditions:
"""Test error conditions and edge cases in connection ID handling."""
def test_invalid_connection_id_length(self):
"""Test handling of invalid connection ID lengths."""
# Connection IDs must be 1-20 bytes according to RFC 9000
# Test too short (0 bytes) - this should be handled gracefully
empty_cid = b""
assert len(empty_cid) == 0
# Test too long (>20 bytes)
long_cid = secrets.token_bytes(21)
assert len(long_cid) == 21
# Test valid lengths
for length in range(1, 21):
valid_cid = secrets.token_bytes(length)
assert len(valid_cid) == length
def test_duplicate_sequence_numbers(self):
"""Test handling of duplicate sequence numbers."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Create two connection IDs with same sequence number
cid1 = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=10
)
cid2 = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=10
)
# Add first connection ID
quic_conn._peer_cid_available.append(cid1)
quic_conn._peer_cid_sequence_numbers.add(10)
# Adding second with same sequence should be handled appropriately
# (The implementation should prevent duplicates)
if 10 not in quic_conn._peer_cid_sequence_numbers:
quic_conn._peer_cid_available.append(cid2)
quic_conn._peer_cid_sequence_numbers.add(10)
# Should only have one entry for sequence 10
sequences = [cid.sequence_number for cid in quic_conn._peer_cid_available]
assert sequences.count(10) <= 1
def test_retire_unknown_connection_id(self):
"""Test retiring an unknown connection ID."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Try to create a buffer to retire unknown sequence number
buf = Buffer(capacity=16)
buf.push_uint_var(999) # Unknown sequence number
buf.seek(0)
# This should raise an error when processed
# (Testing the error condition, not the full processing)
unknown_sequence = 999
known_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
assert unknown_sequence not in known_sequences
def test_retire_current_connection_id(self):
"""Test that retiring current connection ID is prevented."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Get current connection ID if available
if quic_conn._host_cids:
current_cid = quic_conn._host_cids[0]
current_sequence = current_cid.sequence_number
# Trying to retire current connection ID should be prevented
# This is tested by checking the sequence number logic
assert current_sequence >= 0
class TestConnectionIdIntegration:
"""Integration tests for connection ID functionality with real connections."""
@pytest.fixture
def server_config(self):
"""Server transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=100,
)
@pytest.fixture
def client_config(self):
"""Client transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
)
@pytest.fixture
def server_key(self):
"""Generate server private key."""
return create_new_key_pair().private_key
@pytest.fixture
def client_key(self):
"""Generate client private key."""
return create_new_key_pair().private_key
@pytest.mark.trio
async def test_connection_id_exchange_during_handshake(
self, server_key, client_key, server_config, client_config
):
"""Test connection ID exchange during connection handshake."""
# This test would require a full connection setup
# For now, we test the setup components
server_transport = QUICTransport(server_key, server_config)
client_transport = QUICTransport(client_key, client_config)
# Verify transports are created with proper configuration
assert server_transport._config == server_config
assert client_transport._config == client_config
# Test that connection ID tracking is available
# (Integration with actual networking would require more setup)
def test_connection_id_extraction_utilities(self):
"""Test connection ID extraction utilities."""
# Create a mock connection with some connection IDs
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
mock_quic = Mock()
mock_quic._host_cids = [
ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), i
)
for i in range(3)
]
mock_quic._peer_cid = None
mock_quic._peer_cid_available = []
mock_quic._retire_connection_ids = []
mock_quic._host_cid_seq = 3
quic_conn = QUICConnection(
quic_connection=mock_quic,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
# Extract connection ID information
cid_info = ConnectionIdTestHelper.extract_connection_ids_from_connection(
quic_conn
)
# Verify extraction works
assert "host_cids" in cid_info
assert "peer_cid" in cid_info
assert "peer_cid_available" in cid_info
assert "retire_connection_ids" in cid_info
assert "host_cid_seq" in cid_info
# Check values
assert len(cid_info["host_cids"]) == 3
assert cid_info["host_cid_seq"] == 3
assert cid_info["peer_cid"] is None
assert len(cid_info["peer_cid_available"]) == 0
assert len(cid_info["retire_connection_ids"]) == 0
class TestConnectionIdStatistics:
"""Test connection ID statistics and monitoring."""
@pytest.fixture
def connection_with_stats(self):
"""Create a connection with connection ID statistics."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
mock_quic = Mock()
mock_quic._host_cids = []
mock_quic._peer_cid = None
mock_quic._peer_cid_available = []
mock_quic._retire_connection_ids = []
return QUICConnection(
quic_connection=mock_quic,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
def test_connection_id_stats_initialization(self, connection_with_stats):
"""Test that connection ID statistics are properly initialized."""
stats = connection_with_stats._stats
# Check that connection ID stats are present
assert "connection_ids_issued" in stats
assert "connection_ids_retired" in stats
assert "connection_id_changes" in stats
# Initial values should be zero
assert stats["connection_ids_issued"] == 0
assert stats["connection_ids_retired"] == 0
assert stats["connection_id_changes"] == 0
def test_connection_id_stats_update(self, connection_with_stats):
"""Test updating connection ID statistics."""
conn = connection_with_stats
# Add some connection IDs to tracking
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(3)]
for cid in test_cids:
conn._available_connection_ids.add(cid)
# Update stats (this would normally be done by the implementation)
conn._stats["connection_ids_issued"] = len(test_cids)
# Verify stats
stats = conn.get_connection_id_stats()
assert stats["connection_ids_issued"] == 3
assert stats["available_connection_ids"] == 3
def test_connection_id_list_representation(self, connection_with_stats):
"""Test connection ID list representation in stats."""
conn = connection_with_stats
# Add some connection IDs
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(2)]
for cid in test_cids:
conn._available_connection_ids.add(cid)
# Get stats
stats = conn.get_connection_id_stats()
# Check that CID list is properly formatted
assert "available_cid_list" in stats
assert len(stats["available_cid_list"]) == 2
# All entries should be hex strings
for cid_hex in stats["available_cid_list"]:
assert isinstance(cid_hex, str)
assert len(cid_hex) == 16 # 8 bytes = 16 hex chars
# Performance and stress tests
class TestConnectionIdPerformance:
"""Test connection ID performance and stress scenarios."""
def test_connection_id_generation_performance(self):
"""Test connection ID generation performance."""
start_time = time.time()
# Generate many connection IDs
cids = []
for _ in range(1000):
cid = ConnectionIdTestHelper.generate_connection_id()
cids.append(cid)
end_time = time.time()
generation_time = end_time - start_time
# Should be reasonably fast (less than 1 second for 1000 IDs)
assert generation_time < 1.0
# All should be unique
assert len(set(cids)) == len(cids)
def test_connection_id_tracking_memory(self):
"""Test memory usage of connection ID tracking."""
conn_ids = set()
# Add many connection IDs
for _ in range(1000):
cid = ConnectionIdTestHelper.generate_connection_id()
conn_ids.add(cid)
# Verify they're all stored
assert len(conn_ids) == 1000
# Clean up
conn_ids.clear()
assert len(conn_ids) == 0
if __name__ == "__main__":
# Run tests if executed directly
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,418 @@
"""
Basic QUIC Echo Test
Simple test to verify the basic QUIC flow:
1. Client connects to server
2. Client sends data
3. Server receives data and echoes back
4. Client receives the echo
This test focuses on identifying where the accept_stream issue occurs.
"""
import logging
import pytest
import multiaddr
import trio
from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID
from libp2p import new_host
from libp2p.abc import INetStream
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.quic.utils import create_quic_multiaddr
# Set up logging to see what's happening
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestBasicQUICFlow:
"""Test basic QUIC client-server communication flow."""
@pytest.fixture
def server_key(self):
"""Generate server key pair."""
return create_new_key_pair()
@pytest.fixture
def client_key(self):
"""Generate client key pair."""
return create_new_key_pair()
@pytest.fixture
def server_config(self):
"""Simple server configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=10,
max_connections=5,
)
@pytest.fixture
def client_config(self):
"""Simple client configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=5,
)
@pytest.mark.trio
async def test_basic_echo_flow(
self, server_key, client_key, server_config, client_config
):
"""Test basic client-server echo flow with detailed logging."""
print("\n=== BASIC QUIC ECHO TEST ===")
# Create server components
server_transport = QUICTransport(server_key.private_key, server_config)
# Track test state
server_received_data = None
server_connection_established = False
echo_sent = False
async def echo_server_handler(connection: QUICConnection) -> None:
"""Simple echo server handler with detailed logging."""
nonlocal server_received_data, server_connection_established, echo_sent
print("🔗 SERVER: Connection handler called")
server_connection_established = True
try:
print("📡 SERVER: Waiting for incoming stream...")
# Accept stream with timeout and detailed logging
print("📡 SERVER: Calling accept_stream...")
stream = await connection.accept_stream(timeout=5.0)
if stream is None:
print("❌ SERVER: accept_stream returned None")
return
print(f"✅ SERVER: Stream accepted! Stream ID: {stream.stream_id}")
# Read data from the stream
print("📖 SERVER: Reading data from stream...")
server_data = await stream.read(1024)
if not server_data:
print("❌ SERVER: No data received from stream")
return
server_received_data = server_data.decode("utf-8", errors="ignore")
print(f"📨 SERVER: Received data: '{server_received_data}'")
# Echo the data back
echo_message = f"ECHO: {server_received_data}"
print(f"📤 SERVER: Sending echo: '{echo_message}'")
await stream.write(echo_message.encode())
echo_sent = True
print("✅ SERVER: Echo sent successfully")
# Close the stream
await stream.close()
print("🔒 SERVER: Stream closed")
except Exception as e:
print(f"❌ SERVER: Error in handler: {e}")
import traceback
traceback.print_exc()
# Create listener
listener = server_transport.create_listener(echo_server_handler)
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
# Variables to track client state
client_connected = False
client_sent_data = False
client_received_echo = None
try:
print("🚀 Starting server...")
async with trio.open_nursery() as nursery:
# Start server listener
success = await listener.listen(listen_addr, nursery)
assert success, "Failed to start server listener"
# Get server address
server_addrs = listener.get_addrs()
server_addr = multiaddr.Multiaddr(
f"{server_addrs[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
)
print(f"🔧 SERVER: Listening on {server_addr}")
# Give server a moment to be ready
await trio.sleep(0.1)
print("🚀 Starting client...")
# Create client transport
client_transport = QUICTransport(client_key.private_key, client_config)
client_transport.set_background_nursery(nursery)
try:
# Connect to server
print(f"📞 CLIENT: Connecting to {server_addr}")
connection = await client_transport.dial(server_addr)
client_connected = True
print("✅ CLIENT: Connected to server")
# Open a stream
print("📤 CLIENT: Opening stream...")
stream = await connection.open_stream()
print(f"✅ CLIENT: Stream opened with ID: {stream.stream_id}")
# Send test data
test_message = "Hello QUIC Server!"
print(f"📨 CLIENT: Sending message: '{test_message}'")
await stream.write(test_message.encode())
client_sent_data = True
print("✅ CLIENT: Message sent")
# Read echo response
print("📖 CLIENT: Waiting for echo response...")
response_data = await stream.read(1024)
if response_data:
client_received_echo = response_data.decode(
"utf-8", errors="ignore"
)
print(f"📬 CLIENT: Received echo: '{client_received_echo}'")
else:
print("❌ CLIENT: No echo response received")
print("🔒 CLIENT: Closing connection")
await connection.close()
print("🔒 CLIENT: Connection closed")
print("🔒 CLIENT: Closing transport")
await client_transport.close()
print("🔒 CLIENT: Transport closed")
except Exception as e:
print(f"❌ CLIENT: Error: {e}")
import traceback
traceback.print_exc()
finally:
await client_transport.close()
print("🔒 CLIENT: Transport closed")
# Give everything time to complete
await trio.sleep(0.5)
# Cancel nursery to stop server
nursery.cancel_scope.cancel()
finally:
# Cleanup
if not listener._closed:
await listener.close()
await server_transport.close()
# Verify the flow worked
print("\n📊 TEST RESULTS:")
print(f" Server connection established: {server_connection_established}")
print(f" Client connected: {client_connected}")
print(f" Client sent data: {client_sent_data}")
print(f" Server received data: '{server_received_data}'")
print(f" Echo sent by server: {echo_sent}")
print(f" Client received echo: '{client_received_echo}'")
# Test assertions
assert server_connection_established, "Server connection handler was not called"
assert client_connected, "Client failed to connect"
assert client_sent_data, "Client failed to send data"
assert server_received_data == "Hello QUIC Server!", (
f"Server received wrong data: '{server_received_data}'"
)
assert echo_sent, "Server failed to send echo"
assert client_received_echo == "ECHO: Hello QUIC Server!", (
f"Client received wrong echo: '{client_received_echo}'"
)
print("✅ BASIC ECHO TEST PASSED!")
@pytest.mark.trio
async def test_server_accept_stream_timeout(
self, server_key, client_key, server_config, client_config
):
"""Test what happens when server accept_stream times out."""
print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===")
server_transport = QUICTransport(server_key.private_key, server_config)
accept_stream_called = False
accept_stream_timeout = False
async def timeout_test_handler(connection: QUICConnection) -> None:
"""Handler that tests accept_stream timeout."""
nonlocal accept_stream_called, accept_stream_timeout
print("🔗 SERVER: Connection established, testing accept_stream timeout")
accept_stream_called = True
try:
print("📡 SERVER: Calling accept_stream with 2 second timeout...")
stream = await connection.accept_stream(timeout=2.0)
print(f"✅ SERVER: accept_stream returned: {stream}")
except Exception as e:
print(f"⏰ SERVER: accept_stream timed out or failed: {e}")
accept_stream_timeout = True
listener = server_transport.create_listener(timeout_test_handler)
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
client_connected = False
try:
async with trio.open_nursery() as nursery:
# Start server
server_transport.set_background_nursery(nursery)
success = await listener.listen(listen_addr, nursery)
assert success
server_addr = multiaddr.Multiaddr(
f"{listener.get_addrs()[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
)
print(f"🔧 SERVER: Listening on {server_addr}")
# Create client but DON'T open a stream
async with trio.open_nursery() as client_nursery:
client_transport = QUICTransport(
client_key.private_key, client_config
)
client_transport.set_background_nursery(client_nursery)
try:
print("📞 CLIENT: Connecting (but NOT opening stream)...")
connection = await client_transport.dial(server_addr)
client_connected = True
print("✅ CLIENT: Connected (no stream opened)")
# Wait for server timeout
await trio.sleep(3.0)
await connection.close()
print("🔒 CLIENT: Connection closed")
finally:
await client_transport.close()
nursery.cancel_scope.cancel()
finally:
await listener.close()
await server_transport.close()
print("\n📊 TIMEOUT TEST RESULTS:")
print(f" Client connected: {client_connected}")
print(f" accept_stream called: {accept_stream_called}")
print(f" accept_stream timeout: {accept_stream_timeout}")
assert client_connected, "Client should have connected"
assert accept_stream_called, "accept_stream should have been called"
assert accept_stream_timeout, (
"accept_stream should have timed out when no stream was opened"
)
print("✅ TIMEOUT TEST PASSED!")
@pytest.mark.trio
async def test_yamux_stress_ping():
STREAM_COUNT = 100
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
latencies = []
failures = []
# === Server Setup ===
server_host = new_host(listen_addrs=[listen_addr])
async def handle_ping(stream: INetStream) -> None:
try:
while True:
payload = await stream.read(PING_LENGTH)
if not payload:
break
await stream.write(payload)
except Exception:
await stream.reset()
server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
async with server_host.run(listen_addrs=[listen_addr]):
# Give server time to start
await trio.sleep(0.1)
# === Client Setup ===
destination = str(server_host.get_addrs()[0])
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
client_host = new_host(listen_addrs=[client_listen_addr])
async with client_host.run(listen_addrs=[client_listen_addr]):
await client_host.connect(info)
async def ping_stream(i: int):
stream = None
try:
start = trio.current_time()
stream = await client_host.new_stream(
info.peer_id, [PING_PROTOCOL_ID]
)
await stream.write(b"\x01" * PING_LENGTH)
with trio.fail_after(5):
response = await stream.read(PING_LENGTH)
if response == b"\x01" * PING_LENGTH:
latency_ms = int((trio.current_time() - start) * 1000)
latencies.append(latency_ms)
print(f"[Ping #{i}] Latency: {latency_ms} ms")
await stream.close()
except Exception as e:
print(f"[Ping #{i}] Failed: {e}")
failures.append(i)
if stream:
await stream.reset()
async with trio.open_nursery() as nursery:
for i in range(STREAM_COUNT):
nursery.start_soon(ping_stream, i)
# === Result Summary ===
print("\n📊 Ping Stress Test Summary")
print(f"Total Streams Launched: {STREAM_COUNT}")
print(f"Successful Pings: {len(latencies)}")
print(f"Failed Pings: {len(failures)}")
if failures:
print(f"❌ Failed stream indices: {failures}")
# === Assertions ===
assert len(latencies) == STREAM_COUNT, (
f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}"
)
assert all(isinstance(x, int) and x >= 0 for x in latencies), (
"Invalid latencies"
)
avg_latency = sum(latencies) / len(latencies)
print(f"✅ Average Latency: {avg_latency:.2f} ms")
assert avg_latency < 1000

View File

@ -0,0 +1,150 @@
from unittest.mock import AsyncMock
import pytest
from multiaddr.multiaddr import Multiaddr
import trio
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.transport.quic.exceptions import (
QUICListenError,
)
from libp2p.transport.quic.listener import QUICListener
from libp2p.transport.quic.transport import (
QUICTransport,
QUICTransportConfig,
)
from libp2p.transport.quic.utils import (
create_quic_multiaddr,
)
class TestQUICListener:
"""Test suite for QUIC listener functionality."""
@pytest.fixture
def private_key(self):
"""Generate test private key."""
return create_new_key_pair().private_key
@pytest.fixture
def transport_config(self):
"""Generate test transport configuration."""
return QUICTransportConfig(idle_timeout=10.0)
@pytest.fixture
def transport(self, private_key, transport_config):
"""Create test transport instance."""
return QUICTransport(private_key, transport_config)
@pytest.fixture
def connection_handler(self):
"""Mock connection handler."""
return AsyncMock()
@pytest.fixture
def listener(self, transport, connection_handler):
"""Create test listener."""
return transport.create_listener(connection_handler)
def test_listener_creation(self, transport, connection_handler):
"""Test listener creation."""
listener = transport.create_listener(connection_handler)
assert isinstance(listener, QUICListener)
assert listener._transport == transport
assert listener._handler == connection_handler
assert not listener._listening
assert not listener._closed
@pytest.mark.trio
async def test_listener_invalid_multiaddr(self, listener: QUICListener):
"""Test listener with invalid multiaddr."""
async with trio.open_nursery() as nursery:
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"):
await listener.listen(invalid_addr, nursery)
@pytest.mark.trio
async def test_listener_basic_lifecycle(self, listener: QUICListener):
"""Test basic listener lifecycle."""
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random
async with trio.open_nursery() as nursery:
# Start listening
success = await listener.listen(listen_addr, nursery)
assert success
assert listener.is_listening()
# Check bound addresses
addrs = listener.get_addrs()
assert len(addrs) == 1
# Check stats
stats = listener.get_stats()
assert stats["is_listening"] is True
assert stats["active_connections"] == 0
assert stats["pending_connections"] == 0
# Sender Cancel Signal
nursery.cancel_scope.cancel()
await listener.close()
assert not listener.is_listening()
@pytest.mark.trio
async def test_listener_double_listen(self, listener: QUICListener):
"""Test that double listen raises error."""
listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic")
try:
async with trio.open_nursery() as nursery:
success = await listener.listen(listen_addr, nursery)
assert success
await trio.sleep(0.01)
addrs = listener.get_addrs()
assert len(addrs) > 0
async with trio.open_nursery() as nursery2:
with pytest.raises(QUICListenError, match="Already listening"):
await listener.listen(listen_addr, nursery2)
nursery2.cancel_scope.cancel()
nursery.cancel_scope.cancel()
finally:
await listener.close()
@pytest.mark.trio
async def test_listener_port_binding(self, listener: QUICListener):
"""Test listener port binding and cleanup."""
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
try:
async with trio.open_nursery() as nursery:
success = await listener.listen(listen_addr, nursery)
assert success
await trio.sleep(0.5)
addrs = listener.get_addrs()
assert len(addrs) > 0
nursery.cancel_scope.cancel()
finally:
await listener.close()
# By the time we get here, the listener and its tasks have been fully
# shut down, allowing the nursery to exit without hanging.
print("TEST COMPLETED SUCCESSFULLY.")
@pytest.mark.trio
async def test_listener_stats_tracking(self, listener):
"""Test listener statistics tracking."""
initial_stats = listener.get_stats()
# All counters should start at 0
assert initial_stats["connections_accepted"] == 0
assert initial_stats["connections_rejected"] == 0
assert initial_stats["bytes_received"] == 0
assert initial_stats["packets_processed"] == 0

View File

@ -0,0 +1,123 @@
from unittest.mock import (
Mock,
)
import pytest
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.crypto.keys import PrivateKey
from libp2p.transport.quic.exceptions import (
QUICDialError,
QUICListenError,
)
from libp2p.transport.quic.transport import (
QUICTransport,
QUICTransportConfig,
)
class TestQUICTransport:
"""Test suite for QUIC transport using trio."""
@pytest.fixture
def private_key(self):
"""Generate test private key."""
return create_new_key_pair().private_key
@pytest.fixture
def transport_config(self):
"""Generate test transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0, enable_draft29=True, enable_v1=True
)
@pytest.fixture
def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig):
"""Create test transport instance."""
return QUICTransport(private_key, transport_config)
def test_transport_initialization(self, transport):
"""Test transport initialization."""
assert transport._private_key is not None
assert transport._peer_id is not None
assert not transport._closed
assert len(transport._quic_configs) >= 1
def test_supported_protocols(self, transport):
"""Test supported protocol identifiers."""
protocols = transport.protocols()
# TODO: Update when quic-v1 compatible
# assert "quic-v1" in protocols
assert "quic" in protocols # draft-29
def test_can_dial_quic_addresses(self, transport: QUICTransport):
"""Test multiaddr compatibility checking."""
import multiaddr
# Valid QUIC addresses
valid_addrs = [
# TODO: Update Multiaddr package to accept quic-v1
multiaddr.Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
multiaddr.Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
multiaddr.Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
]
for addr in valid_addrs:
assert transport.can_dial(addr)
# Invalid addresses
invalid_addrs = [
multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/4001"),
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001"),
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/ws"),
]
for addr in invalid_addrs:
assert not transport.can_dial(addr)
@pytest.mark.trio
async def test_transport_lifecycle(self, transport):
"""Test transport lifecycle management using trio."""
assert not transport._closed
await transport.close()
assert transport._closed
# Should be safe to close multiple times
await transport.close()
@pytest.mark.trio
async def test_dial_closed_transport(self, transport: QUICTransport) -> None:
"""Test dialing with closed transport raises error."""
import multiaddr
await transport.close()
with pytest.raises(QUICDialError, match="Transport is closed"):
await transport.dial(
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
)
def test_create_listener_closed_transport(self, transport: QUICTransport) -> None:
"""Test creating listener with closed transport raises error."""
transport._closed = True
with pytest.raises(QUICListenError, match="Transport is closed"):
transport.create_listener(Mock())

View File

@ -0,0 +1,321 @@
"""
Test suite for QUIC multiaddr utilities.
Focused tests covering essential functionality required for QUIC transport.
"""
import pytest
from multiaddr import Multiaddr
from libp2p.custom_types import TProtocol
from libp2p.transport.quic.exceptions import (
QUICInvalidMultiaddrError,
QUICUnsupportedVersionError,
)
from libp2p.transport.quic.utils import (
create_quic_multiaddr,
get_alpn_protocols,
is_quic_multiaddr,
multiaddr_to_quic_version,
normalize_quic_multiaddr,
quic_multiaddr_to_endpoint,
quic_version_to_wire_format,
)
class TestIsQuicMultiaddr:
"""Test QUIC multiaddr detection."""
def test_valid_quic_v1_multiaddrs(self):
"""Test valid QUIC v1 multiaddrs are detected."""
valid_addrs = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip4/192.168.1.1/udp/8080/quic-v1",
"/ip6/::1/udp/4001/quic-v1",
"/ip6/2001:db8::1/udp/5000/quic-v1",
]
for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
def test_valid_quic_draft29_multiaddrs(self):
"""Test valid QUIC draft-29 multiaddrs are detected."""
valid_addrs = [
"/ip4/127.0.0.1/udp/4001/quic",
"/ip4/10.0.0.1/udp/9000/quic",
"/ip6/::1/udp/4001/quic",
"/ip6/fe80::1/udp/6000/quic",
]
for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
def test_invalid_multiaddrs(self):
"""Test non-QUIC multiaddrs are not detected."""
invalid_addrs = [
"/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
"/ip4/127.0.0.1/udp/4001", # UDP without QUIC
"/ip4/127.0.0.1/udp/4001/ws", # WebSocket
"/ip4/127.0.0.1/quic-v1", # Missing UDP
"/udp/4001/quic-v1", # Missing IP
"/dns4/example.com/tcp/443/tls", # Completely different
]
for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
class TestQuicMultiaddrToEndpoint:
"""Test endpoint extraction from QUIC multiaddrs."""
def test_ipv4_extraction(self):
"""Test IPv4 host/port extraction."""
test_cases = [
("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
]
for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
result = quic_multiaddr_to_endpoint(maddr)
assert result == expected, f"Failed for {addr_str}"
def test_ipv6_extraction(self):
"""Test IPv6 host/port extraction."""
test_cases = [
("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
]
for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
result = quic_multiaddr_to_endpoint(maddr)
assert result == expected, f"Failed for {addr_str}"
def test_invalid_multiaddr_raises_error(self):
"""Test invalid multiaddrs raise appropriate errors."""
invalid_addrs = [
"/ip4/127.0.0.1/tcp/4001", # Not QUIC
"/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
]
for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
with pytest.raises(QUICInvalidMultiaddrError):
quic_multiaddr_to_endpoint(maddr)
class TestMultiaddrToQuicVersion:
"""Test QUIC version extraction."""
def test_quic_v1_detection(self):
"""Test QUIC v1 version detection."""
addrs = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip6/::1/udp/5000/quic-v1",
]
for addr_str in addrs:
maddr = Multiaddr(addr_str)
version = multiaddr_to_quic_version(maddr)
assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
def test_quic_draft29_detection(self):
"""Test QUIC draft-29 version detection."""
addrs = [
"/ip4/127.0.0.1/udp/4001/quic",
"/ip6/::1/udp/5000/quic",
]
for addr_str in addrs:
maddr = Multiaddr(addr_str)
version = multiaddr_to_quic_version(maddr)
assert version == "quic", f"Should detect quic for {addr_str}"
def test_non_quic_raises_error(self):
"""Test non-QUIC multiaddrs raise error."""
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICInvalidMultiaddrError):
multiaddr_to_quic_version(maddr)
class TestCreateQuicMultiaddr:
"""Test QUIC multiaddr creation."""
def test_ipv4_creation(self):
"""Test IPv4 QUIC multiaddr creation."""
test_cases = [
("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
]
for host, port, version, expected in test_cases:
result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
def test_ipv6_creation(self):
"""Test IPv6 QUIC multiaddr creation."""
test_cases = [
("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
]
for host, port, version, expected in test_cases:
result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
def test_default_version(self):
"""Test default version is quic-v1."""
result = create_quic_multiaddr("127.0.0.1", 4001)
expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
assert str(result) == expected
def test_invalid_inputs_raise_errors(self):
"""Test invalid inputs raise appropriate errors."""
# Invalid IP
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("invalid-ip", 4001)
# Invalid port
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", 70000)
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", -1)
# Invalid version
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
class TestQuicVersionToWireFormat:
"""Test QUIC version to wire format conversion."""
def test_supported_versions(self):
"""Test supported version conversions."""
test_cases = [
("quic-v1", 0x00000001), # RFC 9000
("quic", 0xFF00001D), # draft-29
]
for version, expected_wire in test_cases:
result = quic_version_to_wire_format(TProtocol(version))
assert result == expected_wire, f"Failed for version {version}"
def test_unsupported_version_raises_error(self):
"""Test unsupported versions raise error."""
with pytest.raises(QUICUnsupportedVersionError):
quic_version_to_wire_format(TProtocol("unsupported-version"))
class TestGetAlpnProtocols:
"""Test ALPN protocol retrieval."""
def test_returns_libp2p_protocols(self):
"""Test returns expected libp2p ALPN protocols."""
protocols = get_alpn_protocols()
assert protocols == ["libp2p"]
assert isinstance(protocols, list)
def test_returns_copy(self):
"""Test returns a copy, not the original list."""
protocols1 = get_alpn_protocols()
protocols2 = get_alpn_protocols()
# Modify one list
protocols1.append("test")
# Other list should be unchanged
assert protocols2 == ["libp2p"]
class TestNormalizeQuicMultiaddr:
"""Test QUIC multiaddr normalization."""
def test_already_normalized(self):
"""Test already normalized multiaddrs pass through."""
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
maddr = Multiaddr(addr_str)
result = normalize_quic_multiaddr(maddr)
assert str(result) == addr_str
def test_normalize_different_versions(self):
"""Test normalization works for different QUIC versions."""
test_cases = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip4/127.0.0.1/udp/4001/quic",
"/ip6/::1/udp/5000/quic-v1",
]
for addr_str in test_cases:
maddr = Multiaddr(addr_str)
result = normalize_quic_multiaddr(maddr)
# Should be valid QUIC multiaddr
assert is_quic_multiaddr(result)
# Should be parseable
host, port = quic_multiaddr_to_endpoint(result)
version = multiaddr_to_quic_version(result)
# Should match original
orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
orig_version = multiaddr_to_quic_version(maddr)
assert host == orig_host
assert port == orig_port
assert version == orig_version
def test_non_quic_raises_error(self):
"""Test non-QUIC multiaddrs raise error."""
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICInvalidMultiaddrError):
normalize_quic_multiaddr(maddr)
class TestIntegration:
"""Integration tests for utility functions working together."""
def test_round_trip_conversion(self):
"""Test creating and parsing multiaddrs works correctly."""
test_cases = [
("127.0.0.1", 4001, "quic-v1"),
("::1", 5000, "quic"),
("192.168.1.100", 8080, "quic-v1"),
]
for host, port, version in test_cases:
# Create multiaddr
maddr = create_quic_multiaddr(host, port, version)
# Should be detected as QUIC
assert is_quic_multiaddr(maddr)
# Should extract original values
extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
extracted_version = multiaddr_to_quic_version(maddr)
assert extracted_host == host
assert extracted_port == port
assert extracted_version == version
# Should normalize to same value
normalized = normalize_quic_multiaddr(maddr)
assert str(normalized) == str(maddr)
def test_wire_format_integration(self):
"""Test wire format conversion works with version detection."""
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
maddr = Multiaddr(addr_str)
# Extract version and convert to wire format
version = multiaddr_to_quic_version(maddr)
wire_format = quic_version_to_wire_format(version)
# Should be QUIC v1 wire format
assert wire_format == 0x00000001

View File

@ -0,0 +1,324 @@
"""
Tests for the transport registry functionality.
"""
from multiaddr import Multiaddr
from libp2p.abc import IListener, IRawConnection, ITransport
from libp2p.custom_types import THandler
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.transport_registry import (
TransportRegistry,
create_transport_for_multiaddr,
get_supported_transport_protocols,
get_transport_registry,
register_transport,
)
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
class TestTransportRegistry:
"""Test the TransportRegistry class."""
def test_init(self):
"""Test registry initialization."""
registry = TransportRegistry()
assert isinstance(registry, TransportRegistry)
# Check that default transports are registered
supported = registry.get_supported_protocols()
assert "tcp" in supported
assert "ws" in supported
def test_register_transport(self):
"""Test transport registration."""
registry = TransportRegistry()
# Register a custom transport
class CustomTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("CustomTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"CustomTransport create_listener not implemented"
)
registry.register_transport("custom", CustomTransport)
assert registry.get_transport("custom") == CustomTransport
def test_get_transport(self):
"""Test getting registered transports."""
registry = TransportRegistry()
# Test existing transports
assert registry.get_transport("tcp") == TCP
assert registry.get_transport("ws") == WebsocketTransport
# Test non-existent transport
assert registry.get_transport("nonexistent") is None
def test_get_supported_protocols(self):
"""Test getting supported protocols."""
registry = TransportRegistry()
protocols = registry.get_supported_protocols()
assert isinstance(protocols, list)
assert "tcp" in protocols
assert "ws" in protocols
def test_create_transport_tcp(self):
"""Test creating TCP transport."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("tcp", upgrader)
assert isinstance(transport, TCP)
def test_create_transport_websocket(self):
"""Test creating WebSocket transport."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("ws", upgrader)
assert isinstance(transport, WebsocketTransport)
def test_create_transport_invalid_protocol(self):
"""Test creating transport with invalid protocol."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("invalid", upgrader)
assert transport is None
def test_create_transport_websocket_no_upgrader(self):
"""Test that WebSocket transport requires upgrader."""
registry = TransportRegistry()
# This should fail gracefully and return None
transport = registry.create_transport("ws", None)
assert transport is None
class TestGlobalRegistry:
"""Test the global registry functions."""
def test_get_transport_registry(self):
"""Test getting the global registry."""
registry = get_transport_registry()
assert isinstance(registry, TransportRegistry)
def test_register_transport_global(self):
"""Test registering transport globally."""
class GlobalCustomTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("GlobalCustomTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"GlobalCustomTransport create_listener not implemented"
)
# Register globally
register_transport("global_custom", GlobalCustomTransport)
# Check that it's available
registry = get_transport_registry()
assert registry.get_transport("global_custom") == GlobalCustomTransport
def test_get_supported_transport_protocols_global(self):
"""Test getting supported protocols from global registry."""
protocols = get_supported_transport_protocols()
assert isinstance(protocols, list)
assert "tcp" in protocols
assert "ws" in protocols
class TestTransportFactory:
"""Test the transport factory functions."""
def test_create_transport_for_multiaddr_tcp(self):
"""Test creating transport for TCP multiaddr."""
upgrader = TransportUpgrader({}, {})
# TCP multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, TCP)
def test_create_transport_for_multiaddr_websocket(self):
"""Test creating transport for WebSocket multiaddr."""
upgrader = TransportUpgrader({}, {})
# WebSocket multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_websocket_secure(self):
"""Test creating transport for WebSocket multiaddr."""
upgrader = TransportUpgrader({}, {})
# WebSocket multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_ipv6(self):
"""Test creating transport for IPv6 multiaddr."""
upgrader = TransportUpgrader({}, {})
# IPv6 WebSocket multiaddr
maddr = Multiaddr("/ip6/::1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_dns(self):
"""Test creating transport for DNS multiaddr."""
upgrader = TransportUpgrader({}, {})
# DNS WebSocket multiaddr
maddr = Multiaddr("/dns4/example.com/tcp/443/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_unknown(self):
"""Test creating transport for unknown multiaddr."""
upgrader = TransportUpgrader({}, {})
# Unknown multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/udp/8080")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is None
def test_create_transport_for_multiaddr_with_upgrader(self):
"""Test creating transport with upgrader."""
upgrader = TransportUpgrader({}, {})
# This should work for both TCP and WebSocket with upgrader
maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080")
transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader)
assert transport_tcp is not None
maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader)
assert transport_ws is not None
class TestTransportInterfaceCompliance:
"""Test that all transports implement the required interface."""
def test_tcp_implements_itransport(self):
"""Test that TCP transport implements ITransport."""
transport = TCP()
assert isinstance(transport, ITransport)
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
assert callable(transport.dial)
assert callable(transport.create_listener)
def test_websocket_implements_itransport(self):
"""Test that WebSocket transport implements ITransport."""
upgrader = TransportUpgrader({}, {})
transport = WebsocketTransport(upgrader)
assert isinstance(transport, ITransport)
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
assert callable(transport.dial)
assert callable(transport.create_listener)
class TestErrorHandling:
"""Test error handling in the transport registry."""
def test_create_transport_with_exception(self):
"""Test handling of transport creation exceptions."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
# Register a transport that raises an exception
class ExceptionTransport(ITransport):
def __init__(self, *args, **kwargs):
raise RuntimeError("Transport creation failed")
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("ExceptionTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"ExceptionTransport create_listener not implemented"
)
registry.register_transport("exception", ExceptionTransport)
# Should handle exception gracefully and return None
transport = registry.create_transport("exception", upgrader)
assert transport is None
def test_invalid_multiaddr_handling(self):
"""Test handling of invalid multiaddrs."""
upgrader = TransportUpgrader({}, {})
# Test with a multiaddr that has an unsupported transport protocol
# This should be handled gracefully by our transport registry
# udp is not a supported transport
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is None
class TestIntegration:
"""Test integration scenarios."""
def test_multiple_transport_types(self):
"""Test using multiple transport types in the same registry."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
# Create different transport types
tcp_transport = registry.create_transport("tcp", upgrader)
ws_transport = registry.create_transport("ws", upgrader)
# All should be different types
assert isinstance(tcp_transport, TCP)
assert isinstance(ws_transport, WebsocketTransport)
# All should be different instances
assert tcp_transport is not ws_transport
def test_transport_registry_persistence(self):
"""Test that transport registry persists across calls."""
registry1 = get_transport_registry()
registry2 = get_transport_registry()
# Should be the same instance
assert registry1 is registry2
# Register a transport in one
class PersistentTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("PersistentTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"PersistentTransport create_listener not implemented"
)
registry1.register_transport("persistent", PersistentTransport)
# Should be available in the other
assert registry2.get_transport("persistent") == PersistentTransport

View File

@ -0,0 +1,27 @@
import pytest
from libp2p.custom_types import (
TMuxerOptions,
TSecurityOptions,
)
from libp2p.transport.upgrader import (
TransportUpgrader,
)
@pytest.mark.trio
async def test_transport_upgrader_security_and_muxer_initialization():
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
secure_transports: TSecurityOptions = {}
muxer_transports: TMuxerOptions = {}
negotiate_timeout = 15
upgrader = TransportUpgrader(
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
)
# Verify security multistream initialization
assert upgrader.security_multistream.transports == secure_transports
# Verify muxer multistream initialization and timeout
assert upgrader.muxer_multistream.transports == muxer_transports
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,532 @@
#!/usr/bin/env python3
"""
Python-to-Python WebSocket peer-to-peer tests.
This module tests real WebSocket communication between two Python libp2p hosts,
including both WS and WSS (WebSocket Secure) scenarios.
"""
import pytest
from multiaddr import Multiaddr
from libp2p import create_yamux_muxer_option, new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.crypto.x25519 import create_new_key_pair as create_new_x25519_key_pair
from libp2p.custom_types import TProtocol
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
from libp2p.transport.websocket.multiaddr_utils import (
is_valid_websocket_multiaddr,
parse_websocket_multiaddr,
)
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
PING_LENGTH = 32
@pytest.mark.trio
async def test_websocket_p2p_plaintext():
"""Test Python-to-Python WebSocket communication with plaintext security."""
# Create two hosts with plaintext security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
# Host A (listener) - use only plaintext security
security_options_a = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
)
# Host B (dialer) - use only plaintext security
security_options_b = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
# transport
)
# Test data
test_data = b"Hello WebSocket P2P!"
received_data = None
# Set up ping handler on host A
async def ping_handler(stream):
nonlocal received_data
received_data = await stream.read(len(test_data))
await stream.write(received_data) # Echo back
await stream.close()
host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler)
# Start both hosts
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
host_b.run(listen_addrs=[]),
):
# Get host A's listen address
listen_addrs = host_a.get_addrs()
assert len(listen_addrs) > 0
# Find the WebSocket address
ws_addr = None
for addr in listen_addrs:
if "/ws" in str(addr):
ws_addr = addr
break
assert ws_addr is not None, "No WebSocket listen address found"
assert is_valid_websocket_multiaddr(ws_addr), "Invalid WebSocket multiaddr"
# Parse the WebSocket multiaddr
parsed = parse_websocket_multiaddr(ws_addr)
assert not parsed.is_wss, "Should be plain WebSocket, not WSS"
assert parsed.sni is None, "SNI should be None for plain WebSocket"
# Connect host B to host A
from libp2p.peer.peerinfo import info_from_p2p_addr
peer_info = info_from_p2p_addr(ws_addr)
await host_b.connect(peer_info)
# Create stream and test communication
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
await stream.write(test_data)
response = await stream.read(len(test_data))
await stream.close()
# Verify communication
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
assert response == test_data, f"Expected echo {test_data}, got {response}"
@pytest.mark.trio
async def test_websocket_p2p_noise():
"""Test Python-to-Python WebSocket communication with Noise security."""
# Create two hosts with Noise security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
noise_key_pair_a = create_new_x25519_key_pair()
noise_key_pair_b = create_new_x25519_key_pair()
# Host A (listener)
security_options_a = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_a,
noise_privkey=noise_key_pair_a.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
)
# Host B (dialer)
security_options_b = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_b,
noise_privkey=noise_key_pair_b.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
# transport
)
# Test data
test_data = b"Hello WebSocket P2P with Noise!"
received_data = None
# Set up ping handler on host A
async def ping_handler(stream):
nonlocal received_data
received_data = await stream.read(len(test_data))
await stream.write(received_data) # Echo back
await stream.close()
host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler)
# Start both hosts
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
host_b.run(listen_addrs=[]),
):
# Get host A's listen address
listen_addrs = host_a.get_addrs()
assert len(listen_addrs) > 0
# Find the WebSocket address
ws_addr = None
for addr in listen_addrs:
if "/ws" in str(addr):
ws_addr = addr
break
assert ws_addr is not None, "No WebSocket listen address found"
assert is_valid_websocket_multiaddr(ws_addr), "Invalid WebSocket multiaddr"
# Parse the WebSocket multiaddr
parsed = parse_websocket_multiaddr(ws_addr)
assert not parsed.is_wss, "Should be plain WebSocket, not WSS"
assert parsed.sni is None, "SNI should be None for plain WebSocket"
# Connect host B to host A
from libp2p.peer.peerinfo import info_from_p2p_addr
peer_info = info_from_p2p_addr(ws_addr)
await host_b.connect(peer_info)
# Create stream and test communication
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
await stream.write(test_data)
response = await stream.read(len(test_data))
await stream.close()
# Verify communication
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
assert response == test_data, f"Expected echo {test_data}, got {response}"
@pytest.mark.trio
async def test_websocket_p2p_libp2p_ping():
"""Test Python-to-Python WebSocket communication using libp2p ping protocol."""
# Create two hosts with Noise security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
noise_key_pair_a = create_new_x25519_key_pair()
noise_key_pair_b = create_new_x25519_key_pair()
# Host A (listener)
security_options_a = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_a,
noise_privkey=noise_key_pair_a.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
)
# Host B (dialer)
security_options_b = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_b,
noise_privkey=noise_key_pair_b.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
# transport
)
# Set up ping handler on host A (standard libp2p ping protocol)
async def ping_handler(stream):
# Read ping data (32 bytes)
ping_data = await stream.read(PING_LENGTH)
# Echo back the same data (pong)
await stream.write(ping_data)
await stream.close()
host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler)
# Start both hosts
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
host_b.run(listen_addrs=[]),
):
# Get host A's listen address
listen_addrs = host_a.get_addrs()
assert len(listen_addrs) > 0
# Find the WebSocket address
ws_addr = None
for addr in listen_addrs:
if "/ws" in str(addr):
ws_addr = addr
break
assert ws_addr is not None, "No WebSocket listen address found"
# Connect host B to host A
from libp2p.peer.peerinfo import info_from_p2p_addr
peer_info = info_from_p2p_addr(ws_addr)
await host_b.connect(peer_info)
# Create stream and test libp2p ping protocol
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
# Send ping (32 bytes as per libp2p ping protocol)
ping_data = b"\x01" * PING_LENGTH
await stream.write(ping_data)
# Receive pong (should be same 32 bytes)
pong_data = await stream.read(PING_LENGTH)
await stream.close()
# Verify ping-pong
assert pong_data == ping_data, (
f"Expected ping {ping_data}, got pong {pong_data}"
)
@pytest.mark.trio
async def test_websocket_p2p_multiple_streams():
"""
Test Python-to-Python WebSocket communication with multiple concurrent
streams.
"""
# Create two hosts with Noise security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
noise_key_pair_a = create_new_x25519_key_pair()
noise_key_pair_b = create_new_x25519_key_pair()
# Host A (listener)
security_options_a = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_a,
noise_privkey=noise_key_pair_a.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
)
# Host B (dialer)
security_options_b = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_b,
noise_privkey=noise_key_pair_b.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
# transport
)
# Test protocol
test_protocol = TProtocol("/test/multiple/streams/1.0.0")
received_data = []
# Set up handler on host A
async def test_handler(stream):
data = await stream.read(1024)
received_data.append(data)
await stream.write(data) # Echo back
await stream.close()
host_a.set_stream_handler(test_protocol, test_handler)
# Start both hosts
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
host_b.run(listen_addrs=[]),
):
# Get host A's listen address
listen_addrs = host_a.get_addrs()
ws_addr = None
for addr in listen_addrs:
if "/ws" in str(addr):
ws_addr = addr
break
assert ws_addr is not None, "No WebSocket listen address found"
# Connect host B to host A
from libp2p.peer.peerinfo import info_from_p2p_addr
peer_info = info_from_p2p_addr(ws_addr)
await host_b.connect(peer_info)
# Create multiple concurrent streams
num_streams = 5
test_data_list = [f"Stream {i} data".encode() for i in range(num_streams)]
async def create_stream_and_test(stream_id: int, data: bytes):
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
await stream.write(data)
response = await stream.read(len(data))
await stream.close()
return response
# Run all streams concurrently
tasks = [
create_stream_and_test(i, test_data_list[i]) for i in range(num_streams)
]
responses = []
for task in tasks:
responses.append(await task)
# Verify all communications
assert len(received_data) == num_streams, (
f"Expected {num_streams} received messages, got {len(received_data)}"
)
for i, (sent, received, response) in enumerate(
zip(test_data_list, received_data, responses)
):
assert received == sent, f"Stream {i}: Expected {sent}, got {received}"
assert response == sent, f"Stream {i}: Expected echo {sent}, got {response}"
@pytest.mark.trio
async def test_websocket_p2p_connection_state():
"""Test WebSocket connection state tracking and metadata."""
# Create two hosts with Noise security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
noise_key_pair_a = create_new_x25519_key_pair()
noise_key_pair_b = create_new_x25519_key_pair()
# Host A (listener)
security_options_a = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_a,
noise_privkey=noise_key_pair_a.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
)
# Host B (dialer)
security_options_b = {
NOISE_PROTOCOL_ID: NoiseTransport(
libp2p_keypair=key_pair_b,
noise_privkey=noise_key_pair_b.private_key,
early_data=None,
with_noise_pipes=False,
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
# transport
)
# Set up handler on host A
async def test_handler(stream):
# Read some data
await stream.read(1024)
# Write some data back
await stream.write(b"Response data")
await stream.close()
host_a.set_stream_handler(PING_PROTOCOL_ID, test_handler)
# Start both hosts
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
host_b.run(listen_addrs=[]),
):
# Get host A's listen address
listen_addrs = host_a.get_addrs()
ws_addr = None
for addr in listen_addrs:
if "/ws" in str(addr):
ws_addr = addr
break
assert ws_addr is not None, "No WebSocket listen address found"
# Connect host B to host A
from libp2p.peer.peerinfo import info_from_p2p_addr
peer_info = info_from_p2p_addr(ws_addr)
await host_b.connect(peer_info)
# Create stream and test communication
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
await stream.write(b"Test data for connection state")
response = await stream.read(1024)
await stream.close()
# Verify response
assert response == b"Response data", f"Expected 'Response data', got {response}"
# Test connection state (if available)
# Note: This tests the connection state tracking we implemented
connections = host_b.get_network().connections
assert len(connections) > 0, "Should have at least one connection"
# Get the connection to host A
conn_to_a = None
for peer_id, conn_list in connections.items():
if peer_id == host_a.get_id():
# connections maps peer_id to list of connections, get the first one
conn_to_a = conn_list[0] if conn_list else None
break
assert conn_to_a is not None, "Should have connection to host A"
# Test that the connection has the expected properties
assert hasattr(conn_to_a, "muxed_conn"), "Connection should have muxed_conn"
assert hasattr(conn_to_a.muxed_conn, "secured_conn"), (
"Muxed connection should have underlying secured_conn"
)
# If the underlying connection is our WebSocket connection, test its state
# Type assertion to access private attribute for testing
underlying_conn = getattr(conn_to_a.muxed_conn, "secured_conn")
if hasattr(underlying_conn, "conn_state"):
state = underlying_conn.conn_state()
assert "connection_start_time" in state, (
"Connection state should include start time"
)
assert "bytes_read" in state, "Connection state should include bytes read"
assert "bytes_written" in state, (
"Connection state should include bytes written"
)
assert state["bytes_read"] > 0, "Should have read some bytes"
assert state["bytes_written"] > 0, "Should have written some bytes"

View File

@ -0,0 +1,117 @@
"""
Tests to verify that all examples use the new address paradigm consistently
"""
from pathlib import Path
class TestExamplesAddressParadigm:
"""Test suite to verify all examples use the new address paradigm consistently"""
def get_example_files(self):
"""Get all Python files in the examples directory"""
examples_dir = Path("examples")
return list(examples_dir.rglob("*.py"))
def check_file_for_wildcard_binding(self, filepath):
"""Check if a file contains 0.0.0.0 binding"""
with open(filepath, encoding="utf-8") as f:
content = f.read()
# Check for various forms of wildcard binding
wildcard_patterns = [
"0.0.0.0",
"/ip4/0.0.0.0/",
]
found_wildcards = []
for line_num, line in enumerate(content.splitlines(), 1):
for pattern in wildcard_patterns:
if pattern in line and not line.strip().startswith("#"):
found_wildcards.append((line_num, line.strip()))
return found_wildcards
def test_examples_use_address_paradigm(self):
"""Test that examples use the new address paradigm functions"""
example_files = self.get_example_files()
# Files that should use the new paradigm
networking_examples = [
"echo/echo.py",
"chat/chat.py",
"ping/ping.py",
"bootstrap/bootstrap.py",
"pubsub/pubsub.py",
"identify/identify.py",
]
paradigm_functions = [
"get_available_interfaces",
"get_optimal_binding_address",
]
for filename in networking_examples:
filepath = None
for example_file in example_files:
if filename in str(example_file):
filepath = example_file
break
if filepath is None:
continue
with open(filepath, encoding="utf-8") as f:
content = f.read()
# Check that the file uses the new paradigm functions
for func in paradigm_functions:
assert func in content, (
f"{filepath} should use {func} from the new address paradigm"
)
def test_wildcard_available_as_feature(self):
"""Test that wildcard is available as a feature when needed"""
example_files = self.get_example_files()
# Check that network_discover.py demonstrates wildcard usage
network_discover_file = None
for example_file in example_files:
if "network_discover.py" in str(example_file):
network_discover_file = example_file
break
if network_discover_file:
with open(network_discover_file, encoding="utf-8") as f:
content = f.read()
# Should demonstrate wildcard expansion
assert "0.0.0.0" in content, (
f"{network_discover_file} should demonstrate wildcard usage"
)
assert "expand_wildcard_address" in content, (
f"{network_discover_file} should use expand_wildcard_address"
)
def test_doc_examples_use_paradigm(self):
"""Test that documentation examples use the new address paradigm"""
doc_examples_dir = Path("examples/doc-examples")
if not doc_examples_dir.exists():
return
doc_example_files = list(doc_examples_dir.glob("*.py"))
paradigm_functions = [
"get_available_interfaces",
"get_optimal_binding_address",
]
for filepath in doc_example_files:
with open(filepath, encoding="utf-8") as f:
content = f.read()
# Check that doc examples use the new paradigm
for func in paradigm_functions:
assert func in content, (
f"Documentation example {filepath} should use {func}"
)

View File

@ -0,0 +1,6 @@
def test_echo_quic_example():
"""Test that the QUIC echo example can be imported and has required functions."""
from examples.echo import echo_quic
assert hasattr(echo_quic, "main")
assert hasattr(echo_quic, "run")

View File

View File

@ -0,0 +1,21 @@
{
"name": "src",
"version": "1.0.0",
"main": "ping.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"keywords": [],
"author": "",
"license": "ISC",
"description": "",
"dependencies": {
"@chainsafe/libp2p-noise": "^9.0.0",
"@chainsafe/libp2p-yamux": "^5.0.1",
"@libp2p/ping": "^2.0.36",
"@libp2p/plaintext": "^2.0.29",
"@libp2p/websockets": "^9.2.18",
"libp2p": "^2.9.0",
"multiaddr": "^10.0.1"
}
}

View File

@ -0,0 +1,122 @@
import { createLibp2p } from 'libp2p'
import { webSockets } from '@libp2p/websockets'
import { ping } from '@libp2p/ping'
import { noise } from '@chainsafe/libp2p-noise'
import { plaintext } from '@libp2p/plaintext'
import { yamux } from '@chainsafe/libp2p-yamux'
// import { identify } from '@libp2p/identify' // Commented out for compatibility
// Configuration from environment (with defaults for compatibility)
const TRANSPORT = process.env.transport || 'ws'
const SECURITY = process.env.security || 'noise'
const MUXER = process.env.muxer || 'yamux'
const IP = process.env.ip || '0.0.0.0'
async function main() {
console.log(`🔧 Configuration: transport=${TRANSPORT}, security=${SECURITY}, muxer=${MUXER}`)
// Build options following the proven pattern from test-plans-fork
const options = {
start: true,
connectionGater: {
denyDialMultiaddr: async () => false
},
connectionMonitor: {
enabled: false
},
services: {
ping: ping()
}
}
// Transport configuration (following get-libp2p.ts pattern)
switch (TRANSPORT) {
case 'ws':
options.transports = [webSockets()]
options.addresses = {
listen: [`/ip4/${IP}/tcp/0/ws`]
}
break
case 'wss':
process.env.NODE_TLS_REJECT_UNAUTHORIZED = '0'
options.transports = [webSockets()]
options.addresses = {
listen: [`/ip4/${IP}/tcp/0/wss`]
}
break
default:
throw new Error(`Unknown transport: ${TRANSPORT}`)
}
// Security configuration
switch (SECURITY) {
case 'noise':
options.connectionEncryption = [noise()]
break
case 'plaintext':
options.connectionEncryption = [plaintext()]
break
default:
throw new Error(`Unknown security: ${SECURITY}`)
}
// Muxer configuration
switch (MUXER) {
case 'yamux':
options.streamMuxers = [yamux()]
break
default:
throw new Error(`Unknown muxer: ${MUXER}`)
}
console.log('🔧 Creating libp2p node with proven interop configuration...')
const node = await createLibp2p(options)
await node.start()
console.log(node.peerId.toString())
for (const addr of node.getMultiaddrs()) {
console.log(addr.toString())
}
// Debug: Print supported protocols
console.log('DEBUG: Supported protocols:')
if (node.services && node.services.registrar) {
const protocols = node.services.registrar.getProtocols()
for (const protocol of protocols) {
console.log('DEBUG: Protocol:', protocol)
}
}
// Debug: Print connection encryption protocols
console.log('DEBUG: Connection encryption protocols:')
try {
if (node.components && node.components.connectionEncryption) {
for (const encrypter of node.components.connectionEncryption) {
console.log('DEBUG: Encrypter:', encrypter.protocol)
}
}
} catch (e) {
console.log('DEBUG: Could not access connectionEncryption:', e.message)
}
// Debug: Print stream muxer protocols
console.log('DEBUG: Stream muxer protocols:')
try {
if (node.components && node.components.streamMuxers) {
for (const muxer of node.components.streamMuxers) {
console.log('DEBUG: Muxer:', muxer.protocol)
}
}
} catch (e) {
console.log('DEBUG: Could not access streamMuxers:', e.message)
}
// Keep the process alive
await new Promise(() => {})
}
main().catch(err => {
console.error(err)
process.exit(1)
})

8
tests/interop/nim_libp2p/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
nimble.develop
nimble.paths
*.nimble
nim-libp2p/
nim_echo_server
config.nims

View File

@ -0,0 +1,119 @@
import fcntl
import logging
from pathlib import Path
import shutil
import subprocess
import time
import pytest
logger = logging.getLogger(__name__)
def check_nim_available():
"""Check if nim compiler is available."""
return shutil.which("nim") is not None and shutil.which("nimble") is not None
def check_nim_binary_built():
"""Check if nim echo server binary is built."""
current_dir = Path(__file__).parent
binary_path = current_dir / "nim_echo_server"
return binary_path.exists() and binary_path.stat().st_size > 0
def run_nim_setup_with_lock():
"""Run nim setup with file locking to prevent parallel execution."""
current_dir = Path(__file__).parent
lock_file = current_dir / ".setup_lock"
setup_script = current_dir / "scripts" / "setup_nim_echo.sh"
if not setup_script.exists():
raise RuntimeError(f"Setup script not found: {setup_script}")
# Try to acquire lock
try:
with open(lock_file, "w") as f:
# Non-blocking lock attempt
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
# Double-check binary doesn't exist (another worker might have built it)
if check_nim_binary_built():
logger.info("Binary already exists, skipping setup")
return
logger.info("Acquired setup lock, running nim-libp2p setup...")
# Make setup script executable and run it
setup_script.chmod(0o755)
result = subprocess.run(
[str(setup_script)],
cwd=current_dir,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
)
if result.returncode != 0:
raise RuntimeError(
f"Setup failed (exit {result.returncode}):\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Verify binary was built
if not check_nim_binary_built():
raise RuntimeError("nim_echo_server binary not found after setup")
logger.info("nim-libp2p setup completed successfully")
except BlockingIOError:
# Another worker is running setup, wait for it to complete
logger.info("Another worker is running setup, waiting...")
# Wait for setup to complete (check every 2 seconds, max 5 minutes)
for _ in range(150): # 150 * 2 = 300 seconds = 5 minutes
if check_nim_binary_built():
logger.info("Setup completed by another worker")
return
time.sleep(2)
raise TimeoutError("Timed out waiting for setup to complete")
finally:
# Clean up lock file
try:
lock_file.unlink(missing_ok=True)
except Exception:
pass
@pytest.fixture(scope="function") # Changed to function scope
def nim_echo_binary():
"""Get nim echo server binary path."""
current_dir = Path(__file__).parent
binary_path = current_dir / "nim_echo_server"
if not binary_path.exists():
pytest.skip(
"nim_echo_server binary not found. "
"Run setup script: ./scripts/setup_nim_echo.sh"
)
return binary_path
@pytest.fixture
async def nim_server(nim_echo_binary):
"""Start and stop nim echo server for tests."""
# Import here to avoid circular imports
# pyrefly: ignore
from test_echo_interop import NimEchoServer
server = NimEchoServer(nim_echo_binary)
try:
peer_id, listen_addr = await server.start()
yield server, peer_id, listen_addr
finally:
await server.stop()

View File

@ -0,0 +1,108 @@
{.used.}
import chronos
import stew/byteutils
import libp2p
##
# Simple Echo Protocol Implementation for py-libp2p Interop Testing
##
const EchoCodec = "/echo/1.0.0"
type EchoProto = ref object of LPProtocol
proc new(T: typedesc[EchoProto]): T =
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
try:
echo "Echo server: Received connection from ", conn.peerId
# Read and echo messages in a loop
while not conn.atEof:
try:
# Read length-prefixed message using nim-libp2p's readLp
let message = await conn.readLp(1024 * 1024) # Max 1MB
if message.len == 0:
echo "Echo server: Empty message, closing connection"
break
let messageStr = string.fromBytes(message)
echo "Echo server: Received (", message.len, " bytes): ", messageStr
# Echo back using writeLp
await conn.writeLp(message)
echo "Echo server: Echoed message back"
except CatchableError as e:
echo "Echo server: Error processing message: ", e.msg
break
except CancelledError as e:
echo "Echo server: Connection cancelled"
raise e
except CatchableError as e:
echo "Echo server: Exception in handler: ", e.msg
finally:
echo "Echo server: Connection closed"
await conn.close()
return T.new(codecs = @[EchoCodec], handler = handle)
##
# Create QUIC-enabled switch
##
proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
var switch = SwitchBuilder
.new()
.withRng(rng)
.withAddress(ma)
.withQuicTransport()
.build()
result = switch
##
# Main server
##
proc main() {.async.} =
let
rng = newRng()
localAddr = MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet()
echoProto = EchoProto.new()
echo "=== Nim Echo Server for py-libp2p Interop ==="
# Create switch
let switch = createSwitch(localAddr, rng)
switch.mount(echoProto)
# Start server
await switch.start()
# Print connection info
echo "Peer ID: ", $switch.peerInfo.peerId
echo "Listening on:"
for addr in switch.peerInfo.addrs:
echo " ", $addr, "/p2p/", $switch.peerInfo.peerId
echo "Protocol: ", EchoCodec
echo "Ready for py-libp2p connections!"
echo ""
# Keep running
try:
await sleepAsync(100.hours)
except CancelledError:
echo "Shutting down..."
finally:
await switch.stop()
# Graceful shutdown handler
proc signalHandler() {.noconv.} =
echo "\nShutdown signal received"
quit(0)
when isMainModule:
setControlCHook(signalHandler)
try:
waitFor(main())
except CatchableError as e:
echo "Error: ", e.msg
quit(1)

View File

@ -0,0 +1,74 @@
#!/usr/bin/env bash
# tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
# Cache-aware setup that skips installation if packages exist
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="${SCRIPT_DIR}/.."
# Colors
GREEN='\033[0;32m'
RED='\033[0;31m'
YELLOW='\033[1;33m'
NC='\033[0m'
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
main() {
log_info "Setting up nim echo server for interop testing..."
# Check if nim is available
if ! command -v nim &> /dev/null || ! command -v nimble &> /dev/null; then
log_error "Nim not found. Please install nim first."
exit 1
fi
cd "${PROJECT_DIR}"
# Create logs directory
mkdir -p logs
# Check if binary already exists
if [[ -f "nim_echo_server" ]]; then
log_info "nim_echo_server already exists, skipping build"
return 0
fi
# Check if libp2p is already installed (cache-aware)
if nimble list -i | grep -q "libp2p"; then
log_info "libp2p already installed, skipping installation"
else
log_info "Installing nim-libp2p globally..."
nimble install -y libp2p
fi
log_info "Building nim echo server..."
# Compile the echo server
nim c \
-d:release \
-d:chronicles_log_level=INFO \
-d:libp2p_quic_support \
-d:chronos_event_loop=iocp \
-d:ssl \
--opt:speed \
--mm:orc \
--verbosity:1 \
-o:nim_echo_server \
nim_echo_server.nim
# Verify binary was created
if [[ -f "nim_echo_server" ]]; then
log_info "✅ nim_echo_server built successfully"
log_info "Binary size: $(ls -lh nim_echo_server | awk '{print $5}')"
else
log_error "❌ Failed to build nim_echo_server"
exit 1
fi
log_info "🎉 Setup complete!"
}
main "$@"

View File

@ -0,0 +1,195 @@
import logging
from pathlib import Path
import subprocess
import time
import pytest
import multiaddr
import trio
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.utils.varint import encode_varint_prefixed, read_varint_prefixed_bytes
# Configuration
PROTOCOL_ID = TProtocol("/echo/1.0.0")
TEST_TIMEOUT = 30
SERVER_START_TIMEOUT = 10.0
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NimEchoServer:
"""Simple nim echo server manager."""
def __init__(self, binary_path: Path):
self.binary_path = binary_path
self.process: None | subprocess.Popen = None
self.peer_id = None
self.listen_addr = None
async def start(self):
"""Start nim echo server and get connection info."""
logger.info(f"Starting nim echo server: {self.binary_path}")
self.process = subprocess.Popen(
[str(self.binary_path)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
)
# Parse output for connection info
start_time = time.time()
while time.time() - start_time < SERVER_START_TIMEOUT:
if self.process and self.process.poll() and self.process.stdout:
output = self.process.stdout.read()
raise RuntimeError(f"Server exited early: {output}")
reader = self.process.stdout if self.process else None
if reader:
line = reader.readline().strip()
if not line:
continue
logger.info(f"Server: {line}")
if line.startswith("Peer ID:"):
self.peer_id = line.split(":", 1)[1].strip()
elif "/quic-v1/p2p/" in line and self.peer_id:
if line.strip().startswith("/"):
self.listen_addr = line.strip()
logger.info(f"Server ready: {self.listen_addr}")
return self.peer_id, self.listen_addr
await self.stop()
raise TimeoutError(f"Server failed to start within {SERVER_START_TIMEOUT}s")
async def stop(self):
"""Stop the server."""
if self.process:
logger.info("Stopping nim echo server...")
try:
self.process.terminate()
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()
self.process = None
async def run_echo_test(server_addr: str, messages: list[str]):
"""Test echo protocol against nim server with proper timeout handling."""
# Create py-libp2p QUIC client with shorter timeouts
host = new_host(
enable_quic=True,
key_pair=create_new_key_pair(),
)
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/udp/0/quic-v1")
responses = []
try:
async with host.run(listen_addrs=[listen_addr]):
logger.info(f"Connecting to nim server: {server_addr}")
# Connect to nim server
maddr = multiaddr.Multiaddr(server_addr)
info = info_from_p2p_addr(maddr)
await host.connect(info)
# Create stream
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
logger.info("Stream created")
# Test each message
for i, message in enumerate(messages, 1):
logger.info(f"Testing message {i}: {message}")
# Send with varint length prefix
data = message.encode("utf-8")
prefixed_data = encode_varint_prefixed(data)
await stream.write(prefixed_data)
# Read response
response_data = await read_varint_prefixed_bytes(stream)
response = response_data.decode("utf-8")
logger.info(f"Got echo: {response}")
responses.append(response)
# Verify echo
assert message == response, (
f"Echo failed: sent {message!r}, got {response!r}"
)
await stream.close()
logger.info("✅ All messages echoed correctly")
finally:
await host.close()
return responses
@pytest.mark.trio
@pytest.mark.timeout(TEST_TIMEOUT)
async def test_basic_echo_interop(nim_server):
"""Test basic echo functionality between py-libp2p and nim-libp2p."""
server, peer_id, listen_addr = nim_server
test_messages = [
"Hello from py-libp2p!",
"QUIC transport working",
"Echo test successful!",
"Unicode: Ñoël, 测试, Ψυχή",
]
logger.info(f"Testing against nim server: {peer_id}")
# Run test with timeout
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
responses = await run_echo_test(listen_addr, test_messages)
# Verify all messages echoed correctly
assert len(responses) == len(test_messages)
for sent, received in zip(test_messages, responses):
assert sent == received
logger.info("✅ Basic echo interop test passed!")
@pytest.mark.trio
@pytest.mark.timeout(TEST_TIMEOUT)
async def test_large_message_echo(nim_server):
"""Test echo with larger messages."""
server, peer_id, listen_addr = nim_server
large_messages = [
"x" * 1024,
"y" * 5000,
]
logger.info("Testing large message echo...")
# Run test with timeout
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
responses = await run_echo_test(listen_addr, large_messages)
assert len(responses) == len(large_messages)
for sent, received in zip(large_messages, responses):
assert sent == received
logger.info("✅ Large message echo test passed!")
if __name__ == "__main__":
# Run tests directly
pytest.main([__file__, "-v", "--tb=short"])

View File

@ -0,0 +1,127 @@
import os
import signal
import subprocess
import pytest
from multiaddr import Multiaddr
import trio
from trio.lowlevel import open_process
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost
from libp2p.network.exceptions import SwarmException
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0"
@pytest.mark.trio
async def test_ping_with_js_node():
# Skip this test due to JavaScript dependency issues
pytest.skip("Skipping JS interop test due to dependency issues")
js_node_dir = os.path.join(os.path.dirname(__file__), "js_libp2p", "js_node", "src")
script_name = "./ws_ping_node.mjs"
try:
subprocess.run(
["npm", "install"],
cwd=js_node_dir,
check=True,
capture_output=True,
text=True,
)
except (subprocess.CalledProcessError, FileNotFoundError) as e:
pytest.fail(f"Failed to run 'npm install': {e}")
# Launch the JS libp2p node (long-running)
proc = await open_process(
["node", script_name],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=js_node_dir,
)
assert proc.stdout is not None, "stdout pipe missing"
assert proc.stderr is not None, "stderr pipe missing"
stdout = proc.stdout
stderr = proc.stderr
try:
# Read first two lines (PeerID and multiaddr)
buffer = b""
with trio.fail_after(30):
while buffer.count(b"\n") < 2:
chunk = await stdout.receive_some(1024)
if not chunk:
break
buffer += chunk
lines = [line for line in buffer.decode().splitlines() if line.strip()]
if len(lines) < 2:
stderr_output = await stderr.receive_some(2048)
stderr_output = stderr_output.decode()
pytest.fail(
"JS node did not produce expected PeerID and multiaddr.\n"
f"Stdout: {buffer.decode()!r}\n"
f"Stderr: {stderr_output!r}"
)
peer_id_line, addr_line = lines[0], lines[1]
peer_id = ID.from_base58(peer_id_line)
maddr = Multiaddr(addr_line)
# Debug: Print what we're trying to connect to
print(f"JS Node Peer ID: {peer_id_line}")
print(f"JS Node Address: {addr_line}")
print(f"All JS Node lines: {lines}")
# Set up Python host
key_pair = create_new_key_pair()
py_peer_id = ID.from_pubkey(key_pair.public_key)
peer_store = PeerStore()
peer_store.add_key_pair(py_peer_id, key_pair)
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
swarm = Swarm(py_peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm)
# Connect to JS node
peer_info = PeerInfo(peer_id, [maddr])
print(f"Python trying to connect to: {peer_info}")
# Use the host as a context manager
async with host.run(listen_addrs=[]):
await trio.sleep(1)
try:
await host.connect(peer_info)
except SwarmException as e:
underlying_error = e.__cause__
pytest.fail(
"Connection failed with SwarmException.\n"
f"THE REAL ERROR IS: {underlying_error!r}\n"
)
assert host.get_network().connections.get(peer_id) is not None
# Ping protocol
stream = await host.new_stream(peer_id, [TProtocol("/ipfs/ping/1.0.0")])
await stream.write(b"ping")
data = await stream.read(4)
assert data == b"pong"
finally:
proc.send_signal(signal.SIGTERM)
await trio.sleep(0)

View File

@ -0,0 +1,206 @@
"""
Tests for the new address paradigm with wildcard support as a feature
"""
import pytest
from multiaddr import Multiaddr
from libp2p import new_host
from libp2p.utils.address_validation import (
get_available_interfaces,
get_optimal_binding_address,
get_wildcard_address,
)
class TestAddressParadigm:
"""
Test suite for verifying the new address paradigm:
- get_available_interfaces() returns all available interfaces
- get_optimal_binding_address() returns optimal address for examples
- get_wildcard_address() provides wildcard as a feature when needed
"""
def test_wildcard_address_function(self):
"""Test that get_wildcard_address() provides wildcard as a feature"""
port = 8000
addr = get_wildcard_address(port)
# Should return wildcard address when explicitly requested
assert "0.0.0.0" in str(addr)
addr_str = str(addr)
assert "/ip4/" in addr_str
assert f"/tcp/{port}" in addr_str
def test_optimal_binding_address_selection(self):
"""Test that optimal binding address uses good heuristics"""
port = 8000
addr = get_optimal_binding_address(port)
# Should return a valid IP address (could be loopback or local network)
addr_str = str(addr)
assert "/ip4/" in addr_str
assert f"/tcp/{port}" in addr_str
# Should be from available interfaces
available_interfaces = get_available_interfaces(port)
assert addr in available_interfaces
def test_available_interfaces_includes_loopback(self):
"""Test that available interfaces always includes loopback address"""
port = 8000
interfaces = get_available_interfaces(port)
# Should have at least one interface
assert len(interfaces) > 0
# Should include loopback address
loopback_found = any("127.0.0.1" in str(addr) for addr in interfaces)
assert loopback_found, "Loopback address not found in available interfaces"
# Available interfaces should not include wildcard by default
# (wildcard is available as a feature through get_wildcard_address())
wildcard_found = any("0.0.0.0" in str(addr) for addr in interfaces)
assert not wildcard_found, (
"Wildcard should not be in default available interfaces"
)
def test_host_default_listen_address(self):
"""Test that new hosts use secure default addresses"""
# Create a host with a specific port
port = 8000
listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}")
host = new_host(listen_addrs=[listen_addr])
# Verify the host configuration
assert host is not None
# Note: We can't test actual binding without running the host,
# but we've verified the address format is correct
def test_paradigm_consistency(self):
"""Test that the address paradigm is consistent"""
port = 8000
# get_optimal_binding_address should return a valid address
optimal_addr = get_optimal_binding_address(port)
assert "/ip4/" in str(optimal_addr)
assert f"/tcp/{port}" in str(optimal_addr)
# get_wildcard_address should return wildcard when explicitly needed
wildcard_addr = get_wildcard_address(port)
assert "0.0.0.0" in str(wildcard_addr)
assert f"/tcp/{port}" in str(wildcard_addr)
# Both should be valid Multiaddr objects
assert isinstance(optimal_addr, Multiaddr)
assert isinstance(wildcard_addr, Multiaddr)
@pytest.mark.parametrize("protocol", ["tcp", "udp"])
def test_different_protocols_support(self, protocol):
"""Test that different protocols are supported by the paradigm"""
port = 8000
# Test optimal address with different protocols
optimal_addr = get_optimal_binding_address(port, protocol=protocol)
assert protocol in str(optimal_addr)
assert f"/{protocol}/{port}" in str(optimal_addr)
# Test wildcard address with different protocols
wildcard_addr = get_wildcard_address(port, protocol=protocol)
assert "0.0.0.0" in str(wildcard_addr)
assert protocol in str(wildcard_addr)
assert f"/{protocol}/{port}" in str(wildcard_addr)
# Test available interfaces with different protocols
interfaces = get_available_interfaces(port, protocol=protocol)
assert len(interfaces) > 0
for addr in interfaces:
assert protocol in str(addr)
def test_wildcard_available_as_feature(self):
"""Test that wildcard binding is available as a feature when needed"""
port = 8000
# Wildcard should be available through get_wildcard_address()
wildcard_addr = get_wildcard_address(port)
assert "0.0.0.0" in str(wildcard_addr)
# But should not be in default available interfaces
interfaces = get_available_interfaces(port)
wildcard_in_interfaces = any("0.0.0.0" in str(addr) for addr in interfaces)
assert not wildcard_in_interfaces, (
"Wildcard should not be in default interfaces"
)
# Optimal address should not be wildcard by default
optimal = get_optimal_binding_address(port)
assert "0.0.0.0" not in str(optimal), (
"Optimal address should not be wildcard by default"
)
def test_loopback_is_always_available(self):
"""Test that loopback address is always available as an option"""
port = 8000
interfaces = get_available_interfaces(port)
# Loopback should always be available
loopback_addrs = [addr for addr in interfaces if "127.0.0.1" in str(addr)]
assert len(loopback_addrs) > 0, "Loopback address should always be available"
# At least one loopback address should have the correct port
loopback_with_port = [
addr for addr in loopback_addrs if f"/tcp/{port}" in str(addr)
]
assert len(loopback_with_port) > 0, (
f"Loopback address with port {port} should be available"
)
def test_optimal_address_selection_behavior(self):
"""Test that optimal address selection works correctly"""
port = 8000
interfaces = get_available_interfaces(port)
optimal = get_optimal_binding_address(port)
# Should return one of the available interfaces
optimal_str = str(optimal)
interface_strs = [str(addr) for addr in interfaces]
assert optimal_str in interface_strs, (
f"Optimal address {optimal_str} should be in available interfaces"
)
# Should prefer non-loopback when available, fallback to loopback
non_loopback_interfaces = [
addr for addr in interfaces if "127.0.0.1" not in str(addr)
]
if non_loopback_interfaces:
# Should prefer non-loopback when available
assert "127.0.0.1" not in str(optimal), (
"Should prefer non-loopback when available"
)
else:
# Should use loopback when no other interfaces available
assert "127.0.0.1" in str(optimal), (
"Should use loopback when no other interfaces available"
)
def test_address_paradigm_completeness(self):
"""Test that the address paradigm provides all necessary functionality"""
port = 8000
# Test that we get interface options
interfaces = get_available_interfaces(port)
assert len(interfaces) >= 1, "Should have at least one interface"
# Test that loopback is always included
has_loopback = any("127.0.0.1" in str(addr) for addr in interfaces)
assert has_loopback, "Loopback should always be available"
# Test that wildcard is available as a feature
wildcard_addr = get_wildcard_address(port)
assert "0.0.0.0" in str(wildcard_addr)
# Test optimal selection
optimal = get_optimal_binding_address(port)
assert optimal in interfaces, (
"Optimal address should be from available interfaces"
)