diff --git a/libp2p/transport/quic/config.py b/libp2p/transport/quic/config.py index d1ccf335..c2fa90ae 100644 --- a/libp2p/transport/quic/config.py +++ b/libp2p/transport/quic/config.py @@ -7,10 +7,45 @@ from dataclasses import ( field, ) import ssl +from typing import TypedDict from libp2p.custom_types import TProtocol +class QUICTransportKwargs(TypedDict, total=False): + """Type definition for kwargs accepted by new_transport function.""" + + # Connection settings + idle_timeout: float + max_datagram_size: int + local_port: int | None + + # Protocol version support + enable_draft29: bool + enable_v1: bool + + # TLS settings + verify_mode: ssl.VerifyMode + alpn_protocols: list[str] + + # Performance settings + max_concurrent_streams: int + connection_window: int + stream_window: int + + # Logging and debugging + enable_qlog: bool + qlog_dir: str | None + + # Connection management + max_connections: int + connection_timeout: float + + # Protocol identifiers + PROTOCOL_QUIC_V1: TProtocol + PROTOCOL_QUIC_DRAFT29: TProtocol + + @dataclass class QUICTransportConfig: """Configuration for QUIC transport.""" @@ -47,7 +82,7 @@ class QUICTransportConfig: PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic") # RFC 9000 PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29 - def __post_init__(self): + def __post_init__(self) -> None: """Validate configuration after initialization.""" if not (self.enable_draft29 or self.enable_v1): raise ValueError("At least one QUIC version must be enabled") diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py index 9746d234..d93ccf31 100644 --- a/libp2p/transport/quic/connection.py +++ b/libp2p/transport/quic/connection.py @@ -50,7 +50,7 @@ class QUICConnection(IRawConnection, IMuxedConn): Uses aioquic's sans-IO core with trio for native async support. QUIC natively provides stream multiplexing, so this connection acts as both a raw connection (for transport layer) and muxed connection (for upper layers). - + Updated to work properly with the QUIC listener for server-side connections. """ @@ -92,18 +92,20 @@ class QUICConnection(IRawConnection, IMuxedConn): self._background_tasks_started = False self._nursery: trio.Nursery | None = None - logger.debug(f"Created QUIC connection to {peer_id} (initiator: {is_initiator})") + logger.debug( + f"Created QUIC connection to {peer_id} (initiator: {is_initiator})" + ) def _calculate_initial_stream_id(self) -> int: """ Calculate the initial stream ID based on QUIC specification. - + QUIC stream IDs: - Client-initiated bidirectional: 0, 4, 8, 12, ... - Server-initiated bidirectional: 1, 5, 9, 13, ... - Client-initiated unidirectional: 2, 6, 10, 14, ... - Server-initiated unidirectional: 3, 7, 11, 15, ... - + For libp2p, we primarily use bidirectional streams. """ if self.__is_initiator: @@ -118,7 +120,7 @@ class QUICConnection(IRawConnection, IMuxedConn): async def start(self) -> None: """ Start the connection and its background tasks. - + This method implements the IMuxedConn.start() interface. It should be called to begin processing connection events. """ @@ -165,7 +167,9 @@ class QUICConnection(IRawConnection, IMuxedConn): if not self._background_tasks_started: # We would need a nursery to start background tasks # This is a limitation of the current design - logger.warning("Background tasks need nursery - connection may not work properly") + logger.warning( + "Background tasks need nursery - connection may not work properly" + ) except Exception as e: logger.error(f"Failed to initiate connection: {e}") @@ -174,13 +178,15 @@ class QUICConnection(IRawConnection, IMuxedConn): async def connect(self, nursery: trio.Nursery) -> None: """ Establish the QUIC connection using trio. - + Args: nursery: Trio nursery for background tasks """ if not self.__is_initiator: - raise QUICConnectionError("connect() should only be called by client connections") + raise QUICConnectionError( + "connect() should only be called by client connections" + ) try: # Store nursery for background tasks @@ -321,7 +327,7 @@ class QUICConnection(IRawConnection, IMuxedConn): def _is_incoming_stream(self, stream_id: int) -> bool: """ Determine if a stream ID represents an incoming stream. - + For bidirectional streams: - Even IDs are client-initiated - Odd IDs are server-initiated @@ -463,11 +469,7 @@ class QUICConnection(IRawConnection, IMuxedConn): self._next_stream_id += 4 # Increment by 4 for bidirectional streams # Create stream - stream = QUICStream( - connection=self, - stream_id=stream_id, - is_initiator=True - ) + stream = QUICStream(connection=self, stream_id=stream_id, is_initiator=True) self._streams[stream_id] = stream @@ -530,9 +532,10 @@ class QUICConnection(IRawConnection, IMuxedConn): # The certificate should contain the peer ID in a specific extension raise NotImplementedError("Certificate peer ID extraction not implemented") - def get_stats(self) -> dict: + # TODO: Define type for stats + def get_stats(self) -> dict[str, object]: """Get connection statistics.""" - return { + stats: dict[str, object] = { "peer_id": str(self._peer_id), "remote_addr": self._remote_addr, "is_initiator": self.__is_initiator, @@ -542,10 +545,16 @@ class QUICConnection(IRawConnection, IMuxedConn): "active_streams": len(self._streams), "next_stream_id": self._next_stream_id, } + return stats - def get_remote_address(self): + def get_remote_address(self) -> tuple[str, int]: return self._remote_addr def __str__(self) -> str: """String representation of the connection.""" - return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)}, established={self._established}, started={self._started})" + id = self._peer_id + estb = self._established + stream_len = len(self._streams) + return f"QUICConnection(peer={id}, streams={stream_len}".__add__( + f"established={estb}, started={self._started})" + ) diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py index 8757427e..b02251f9 100644 --- a/libp2p/transport/quic/listener.py +++ b/libp2p/transport/quic/listener.py @@ -8,7 +8,7 @@ import copy import logging import socket import time -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from aioquic.quic import events from aioquic.quic.configuration import QuicConfiguration @@ -49,7 +49,7 @@ class QUICListener(IListener): self, transport: "QUICTransport", handler_function: THandler, - quic_configs: Dict[TProtocol, QuicConfiguration], + quic_configs: dict[TProtocol, QuicConfiguration], config: QUICTransportConfig, ): """ @@ -72,8 +72,8 @@ class QUICListener(IListener): self._bound_addresses: list[Multiaddr] = [] # Connection management - self._connections: Dict[tuple[str, int], QUICConnection] = {} - self._pending_connections: Dict[tuple[str, int], QuicConnection] = {} + self._connections: dict[tuple[str, int], QUICConnection] = {} + self._pending_connections: dict[tuple[str, int], QuicConnection] = {} self._connection_lock = trio.Lock() # Listener state @@ -104,6 +104,7 @@ class QUICListener(IListener): Raises: QUICListenError: If failed to start listening + """ if not is_quic_multiaddr(maddr): raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}") @@ -133,11 +134,11 @@ class QUICListener(IListener): self._listening = True # Start background tasks directly in the provided nursery - # This ensures proper cancellation when the nursery exits + # This e per cancellation when the nursery exits nursery.start_soon(self._handle_incoming_packets) nursery.start_soon(self._manage_connections) - print(f"QUIC listener started on {actual_maddr}") + logger.info(f"QUIC listener started on {actual_maddr}") return True except trio.Cancelled: @@ -190,7 +191,8 @@ class QUICListener(IListener): try: while self._listening and self._socket: try: - # Receive UDP packet (this blocks until packet arrives or socket closes) + # Receive UDP packet + # (this blocks until packet arrives or socket closes) data, addr = await self._socket.recvfrom(65536) self._stats["bytes_received"] += len(data) self._stats["packets_processed"] += 1 @@ -208,10 +210,9 @@ class QUICListener(IListener): # Continue processing other packets await trio.sleep(0.01) except trio.Cancelled: - print("PACKET HANDLER CANCELLED - FORCIBLY CLOSING SOCKET") + logger.info("Received Cancel, stopping handling incoming packets") raise finally: - print("PACKET HANDLER FINISHED") logger.debug("Packet handling loop terminated") async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: @@ -456,10 +457,7 @@ class QUICListener(IListener): except Exception as e: logger.error(f"Error in connection management: {e}") except trio.Cancelled: - print("CONNECTION MANAGER CANCELLED") raise - finally: - print("CONNECTION MANAGER FINISHED") async def _cleanup_closed_connections(self) -> None: """Remove closed connections from tracking.""" @@ -500,20 +498,20 @@ class QUICListener(IListener): self._closed = True self._listening = False - print("Closing QUIC listener") + logger.debug("Closing QUIC listener") # CRITICAL: Close socket FIRST to unblock recvfrom() await self._cleanup_socket() - print("SOCKET CLEANUP COMPLETE") + logger.debug("SOCKET CLEANUP COMPLETE") # Close all connections WITHOUT using the lock during shutdown # (avoid deadlock if background tasks are cancelled while holding lock) connections_to_close = list(self._connections.values()) pending_to_close = list(self._pending_connections.values()) - print( - f"CLOSING {len(connections_to_close)} connections and {len(pending_to_close)} pending" + logger.debug( + f"CLOSING {connections_to_close} connections and {pending_to_close} pending" ) # Close active connections @@ -533,10 +531,7 @@ class QUICListener(IListener): # Clear the dictionaries without lock (we're shutting down) self._connections.clear() self._pending_connections.clear() - if self._nursery: - print("TASKS", len(self._nursery.child_tasks)) - - print("QUIC listener closed") + logger.debug("QUIC listener closed") async def _cleanup_socket(self) -> None: """Clean up the UDP socket.""" @@ -562,7 +557,7 @@ class QUICListener(IListener): """Check if the listener is actively listening.""" return self._listening and not self._closed - def get_stats(self) -> dict: + def get_stats(self) -> dict[str, int]: """Get listener statistics.""" stats = self._stats.copy() stats.update( @@ -576,4 +571,6 @@ class QUICListener(IListener): def __str__(self) -> str: """String representation of the listener.""" - return f"QUICListener(addrs={self._bound_addresses}, connections={len(self._connections)})" + addr = self._bound_addresses + conn_count = len(self._connections) + return f"QUICListener(addrs={addr}, connections={conn_count})" diff --git a/libp2p/transport/quic/security.py b/libp2p/transport/quic/security.py index 1a49cf37..c1b947e1 100644 --- a/libp2p/transport/quic/security.py +++ b/libp2p/transport/quic/security.py @@ -7,7 +7,6 @@ Full implementation will be in Module 5. from dataclasses import dataclass import os import tempfile -from typing import Optional from libp2p.crypto.keys import PrivateKey from libp2p.peer.id import ID @@ -21,7 +20,7 @@ class TLSConfig: cert_file: str key_file: str - ca_file: Optional[str] = None + ca_file: str | None = None def generate_libp2p_tls_config(private_key: PrivateKey, peer_id: ID) -> TLSConfig: diff --git a/libp2p/transport/quic/stream.py b/libp2p/transport/quic/stream.py index 3bff6b4f..e43a00cb 100644 --- a/libp2p/transport/quic/stream.py +++ b/libp2p/transport/quic/stream.py @@ -116,7 +116,8 @@ class QUICStream(IMuxedStream): """ Reset the stream """ - self.handle_reset(0) + await self.handle_reset(0) + return def get_remote_address(self) -> tuple[str, int] | None: return self._connection._remote_addr diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 3f8c4004..ae361706 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -15,9 +15,9 @@ from aioquic.quic.connection import ( ) import multiaddr import trio +from typing_extensions import Unpack from libp2p.abc import ( - IListener, IRawConnection, ITransport, ) @@ -28,6 +28,7 @@ from libp2p.custom_types import THandler, TProtocol from libp2p.peer.id import ( ID, ) +from libp2p.transport.quic.config import QUICTransportKwargs from libp2p.transport.quic.utils import ( is_quic_multiaddr, multiaddr_to_quic_version, @@ -131,7 +132,10 @@ class QUICTransport(ITransport): # # This follows the libp2p TLS spec for peer identity verification # tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id) - # config.load_cert_chain(certfile=tls_config.cert_file, keyfile=tls_config.key_file) + # config.load_cert_chain( + # certfile=tls_config.cert_file, + # keyfile=tls_config.key_file + # ) # if tls_config.ca_file: # config.load_verify_locations(tls_config.ca_file) @@ -210,7 +214,7 @@ class QUICTransport(ITransport): logger.error(f"Failed to dial QUIC connection to {maddr}: {e}") raise QUICDialError(f"Dial failed: {e}") from e - def create_listener(self, handler_function: THandler) -> IListener: + def create_listener(self, handler_function: THandler) -> QUICListener: """ Create a QUIC listener. @@ -298,12 +302,18 @@ class QUICTransport(ITransport): logger.info("QUIC transport closed") - def get_stats(self) -> dict: + def get_stats(self) -> dict[str, int | list[str] | object]: """Get transport statistics.""" - stats = { + protocols = self.protocols() + str_protocols = [] + + for proto in protocols: + str_protocols.append(str(proto)) + + stats: dict[str, int | list[str] | object] = { "active_connections": len(self._connections), "active_listeners": len(self._listeners), - "supported_protocols": self.protocols(), + "supported_protocols": str_protocols, } # Aggregate listener stats @@ -324,7 +334,9 @@ class QUICTransport(ITransport): def new_transport( - private_key: PrivateKey, config: QUICTransportConfig | None = None, **kwargs + private_key: PrivateKey, + config: QUICTransportConfig | None = None, + **kwargs: Unpack[QUICTransportKwargs], ) -> QUICTransport: """ Factory function to create a new QUIC transport. diff --git a/libp2p/transport/quic/utils.py b/libp2p/transport/quic/utils.py index 97ad8fa8..20f85e8c 100644 --- a/libp2p/transport/quic/utils.py +++ b/libp2p/transport/quic/utils.py @@ -3,8 +3,6 @@ Multiaddr utilities for QUIC transport. Handles QUIC-specific multiaddr parsing and validation. """ -from typing import Tuple - import multiaddr from libp2p.custom_types import TProtocol @@ -54,7 +52,7 @@ def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool: return False -def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> Tuple[str, int]: +def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]: """ Extract host and port from a QUIC multiaddr. @@ -78,20 +76,21 @@ def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> Tuple[str, int]: # Try to get IPv4 address try: - host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore + host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore except ValueError: pass # Try to get IPv6 address if IPv4 not found if host is None: try: - host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore + host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore except ValueError: pass # Get UDP port try: - port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) + # The the package is exposed by types not availble + port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore port = int(port_str) except ValueError: pass diff --git a/tests/core/transport/quic/test_integration.py b/tests/core/transport/quic/test_integration.py new file mode 100644 index 00000000..5279de12 --- /dev/null +++ b/tests/core/transport/quic/test_integration.py @@ -0,0 +1,765 @@ +""" +Integration tests for QUIC transport that test actual networking. +These tests require network access and test real socket operations. +""" + +import logging +import random +import socket +import time + +import pytest +import trio + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.transport import QUICTransport +from libp2p.transport.quic.utils import create_quic_multiaddr + +logger = logging.getLogger(__name__) + + +class TestQUICNetworking: + """Integration tests that use actual networking.""" + + @pytest.fixture + def server_config(self): + """Server configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=100, + ) + + @pytest.fixture + def client_config(self): + """Client configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + ) + + @pytest.fixture + def server_key(self): + """Generate server key pair.""" + return create_new_key_pair().private_key + + @pytest.fixture + def client_key(self): + """Generate client key pair.""" + return create_new_key_pair().private_key + + @pytest.mark.trio + async def test_listener_binding_real_socket(self, server_key, server_config): + """Test that listener can bind to real socket.""" + transport = QUICTransport(server_key, server_config) + + async def connection_handler(connection): + logger.info(f"Received connection: {connection}") + + listener = transport.create_listener(connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + async with trio.open_nursery() as nursery: + try: + success = await listener.listen(listen_addr, nursery) + assert success + + # Verify we got a real port + addrs = listener.get_addrs() + assert len(addrs) == 1 + + # Port should be non-zero (was assigned) + from libp2p.transport.quic.utils import quic_multiaddr_to_endpoint + + host, port = quic_multiaddr_to_endpoint(addrs[0]) + assert host == "127.0.0.1" + assert port > 0 + + logger.info(f"Listener bound to {host}:{port}") + + # Listener should be active + assert listener.is_listening() + + # Test basic stats + stats = listener.get_stats() + assert stats["active_connections"] == 0 + assert stats["pending_connections"] == 0 + + # Close listener + await listener.close() + assert not listener.is_listening() + + finally: + await transport.close() + + @pytest.mark.trio + async def test_multiple_listeners_different_ports(self, server_key, server_config): + """Test multiple listeners on different ports.""" + transport = QUICTransport(server_key, server_config) + + async def connection_handler(connection): + pass + + listeners = [] + bound_ports = [] + + # Create multiple listeners + for i in range(3): + listener = transport.create_listener(connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + # Get bound port + addrs = listener.get_addrs() + from libp2p.transport.quic.utils import quic_multiaddr_to_endpoint + + host, port = quic_multiaddr_to_endpoint(addrs[0]) + + bound_ports.append(port) + listeners.append(listener) + + logger.info(f"Listener {i} bound to port {port}") + nursery.cancel_scope.cancel() + finally: + await listener.close() + + # All ports should be different + assert len(set(bound_ports)) == len(bound_ports) + + @pytest.mark.trio + async def test_port_already_in_use(self, server_key, server_config): + """Test handling of port already in use.""" + transport1 = QUICTransport(server_key, server_config) + transport2 = QUICTransport(server_key, server_config) + + async def connection_handler(connection): + pass + + listener1 = transport1.create_listener(connection_handler) + listener2 = transport2.create_listener(connection_handler) + + # Bind first listener to a specific port + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + async with trio.open_nursery() as nursery: + success1 = await listener1.listen(listen_addr, nursery) + assert success1 + + # Get the actual bound port + addrs = listener1.get_addrs() + from libp2p.transport.quic.utils import quic_multiaddr_to_endpoint + + host, port = quic_multiaddr_to_endpoint(addrs[0]) + + # Try to bind second listener to same port + # Should fail or get different port + same_port_addr = create_quic_multiaddr("127.0.0.1", port, "/quic") + + # This might either fail or succeed with SO_REUSEPORT + # The exact behavior depends on the system + try: + success2 = await listener2.listen(same_port_addr, nursery) + if success2: + # If it succeeds, verify different behavior + logger.info("Second listener bound successfully (SO_REUSEPORT)") + except Exception as e: + logger.info(f"Second listener failed as expected: {e}") + + await listener1.close() + await listener2.close() + await transport1.close() + await transport2.close() + + @pytest.mark.trio + async def test_listener_connection_tracking(self, server_key, server_config): + """Test that listener properly tracks connection state.""" + transport = QUICTransport(server_key, server_config) + + received_connections = [] + + async def connection_handler(connection): + received_connections.append(connection) + logger.info(f"Handler received connection: {connection}") + + # Keep connection alive briefly + await trio.sleep(0.1) + + listener = transport.create_listener(connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + # Initially no connections + stats = listener.get_stats() + assert stats["active_connections"] == 0 + assert stats["pending_connections"] == 0 + + # Simulate some packet processing + await trio.sleep(0.1) + + # Verify listener is still healthy + assert listener.is_listening() + + await listener.close() + await transport.close() + + @pytest.mark.trio + async def test_listener_error_recovery(self, server_key, server_config): + """Test listener error handling and recovery.""" + transport = QUICTransport(server_key, server_config) + + # Handler that raises an exception + async def failing_handler(connection): + raise ValueError("Simulated handler error") + + listener = transport.create_listener(failing_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + # Even with failing handler, listener should remain stable + await trio.sleep(0.1) + assert listener.is_listening() + + # Test complete, stop listening + nursery.cancel_scope.cancel() + finally: + await listener.close() + await transport.close() + + @pytest.mark.trio + async def test_transport_resource_cleanup_v1(self, server_key, server_config): + """Test with single parent nursery managing all listeners.""" + transport = QUICTransport(server_key, server_config) + + async def connection_handler(connection): + pass + + listeners = [] + + try: + async with trio.open_nursery() as parent_nursery: + # Start all listeners in parallel within the same nursery + for i in range(3): + listener = transport.create_listener(connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + listeners.append(listener) + + parent_nursery.start_soon( + listener.listen, listen_addr, parent_nursery + ) + + # Give listeners time to start + await trio.sleep(0.2) + + # Verify all listeners are active + for i, listener in enumerate(listeners): + assert listener.is_listening() + + # Close transport should close all listeners + await transport.close() + + # The nursery will exit cleanly because listeners are closed + + finally: + # Cleanup verification outside nursery + assert transport._closed + assert len(transport._listeners) == 0 + + # All listeners should be closed + for listener in listeners: + assert not listener.is_listening() + + @pytest.mark.trio + async def test_concurrent_listener_operations(self, server_key, server_config): + """Test concurrent listener operations.""" + transport = QUICTransport(server_key, server_config) + + async def connection_handler(connection): + await trio.sleep(0.01) # Simulate some work + + async def create_and_run_listener(listener_id): + """Create, run, and close a listener.""" + listener = transport.create_listener(connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + logger.info(f"Listener {listener_id} started") + + # Run for a short time + await trio.sleep(0.1) + + await listener.close() + logger.info(f"Listener {listener_id} closed") + + try: + # Run multiple listeners concurrently + async with trio.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(create_and_run_listener, i) + + finally: + await transport.close() + + +class TestQUICConcurrency: + """Fixed tests with proper nursery management.""" + + @pytest.fixture + def server_key(self): + """Generate server key pair.""" + return create_new_key_pair().private_key + + @pytest.fixture + def server_config(self): + """Server configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=100, + ) + + @pytest.mark.trio + async def test_concurrent_listener_operations(self, server_key, server_config): + """Test concurrent listener operations - FIXED VERSION.""" + transport = QUICTransport(server_key, server_config) + + async def connection_handler(connection): + await trio.sleep(0.01) # Simulate some work + + listeners = [] + + async def create_and_run_listener(listener_id): + """Create and run a listener - fixed to avoid deadlock.""" + listener = transport.create_listener(connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + listeners.append(listener) + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + logger.info(f"Listener {listener_id} started") + + # Run for a short time + await trio.sleep(0.1) + + # Close INSIDE the nursery scope to allow clean exit + await listener.close() + logger.info(f"Listener {listener_id} closed") + + except Exception as e: + logger.error(f"Listener {listener_id} error: {e}") + if not listener._closed: + await listener.close() + raise + + try: + # Run multiple listeners concurrently + async with trio.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(create_and_run_listener, i) + + # Verify all listeners were created and closed properly + assert len(listeners) == 5 + for listener in listeners: + assert not listener.is_listening() # Should all be closed + + finally: + await transport.close() + + @pytest.mark.trio + @pytest.mark.slow + async def test_listener_under_simulated_load(self, server_key, server_config): + """REAL load test with actual packet simulation.""" + print("=== REAL LOAD TEST ===") + + config = QUICTransportConfig( + idle_timeout=30.0, + connection_timeout=10.0, + max_concurrent_streams=1000, + max_connections=500, + ) + + transport = QUICTransport(server_key, config) + connection_count = 0 + + async def connection_handler(connection): + nonlocal connection_count + # TODO: Remove type ignore when pyrefly fixes nonlocal bug + connection_count += 1 # type: ignore + print(f"Real connection established: {connection_count}") + # Simulate connection work + await trio.sleep(0.01) + + listener = transport.create_listener(connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + async def generate_udp_traffic(target_host, target_port, num_packets=100): + """Generate fake UDP traffic to simulate load.""" + print( + f"Generating {num_packets} UDP packets to {target_host}:{target_port}" + ) + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + for i in range(num_packets): + # Send random UDP packets + # (Won't be valid QUIC, but will exercise packet handler) + fake_packet = ( + f"FAKE_PACKET_{i}_{random.randint(1000, 9999)}".encode() + ) + sock.sendto(fake_packet, (target_host, int(target_port))) + + # Small delay between packets + await trio.sleep(0.001) + + if i % 20 == 0: + print(f"Sent {i + 1}/{num_packets} packets") + + except Exception as e: + print(f"Error sending packets: {e}") + finally: + sock.close() + + print(f"Finished sending {num_packets} packets") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + # Get the actual bound port + bound_addrs = listener.get_addrs() + bound_addr = bound_addrs[0] + print(bound_addr) + host, port = ( + bound_addr.value_for_protocol("ip4"), + bound_addr.value_for_protocol("udp"), + ) + + print(f"Listener bound to {host}:{port}") + + # Start load generation + nursery.start_soon(generate_udp_traffic, host, port, 50) + + # Let the load test run + start_time = time.time() + await trio.sleep(2.0) # Let traffic flow for 2 seconds + end_time = time.time() + + # Check that listener handled the load + stats = listener.get_stats() + print(f"Final stats: {stats}") + + # Should have received packets (even if they're invalid QUIC) + assert stats["packets_processed"] > 0 + assert stats["bytes_received"] > 0 + + duration = end_time - start_time + print(f"Load test ran for {duration:.2f}s") + print(f"Processed {stats['packets_processed']} packets") + print(f"Received {stats['bytes_received']} bytes") + + await listener.close() + + finally: + if not listener._closed: + await listener.close() + await transport.close() + + +class TestQUICRealWorldScenarios: + """Test real-world usage scenarios - FIXED VERSIONS.""" + + @pytest.mark.trio + async def test_echo_server_pattern(self): + """Test a basic echo server pattern - FIXED VERSION.""" + server_key = create_new_key_pair().private_key + config = QUICTransportConfig(idle_timeout=5.0) + transport = QUICTransport(server_key, config) + + echo_data = [] + + async def echo_connection_handler(connection): + """Echo server that handles one connection.""" + logger.info(f"Echo server got connection: {connection}") + + async def stream_handler(stream): + try: + # Read data and echo it back + while True: + data = await stream.read(1024) + if not data: + break + + echo_data.append(data) + await stream.write(b"ECHO: " + data) + + except Exception as e: + logger.error(f"Stream error: {e}") + finally: + await stream.close() + + connection.set_stream_handler(stream_handler) + + # Keep connection alive until closed + while not connection.is_closed: + await trio.sleep(0.1) + + listener = transport.create_listener(echo_connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + # Let server initialize + await trio.sleep(0.1) + + # Verify server is ready + assert listener.is_listening() + + # Run server for a bit + await trio.sleep(0.5) + + # Close inside nursery for clean exit + await listener.close() + + finally: + # Ensure cleanup + if not listener._closed: + await listener.close() + await transport.close() + + @pytest.mark.trio + async def test_connection_lifecycle_monitoring(self): + """Test monitoring connection lifecycle events - FIXED VERSION.""" + server_key = create_new_key_pair().private_key + config = QUICTransportConfig(idle_timeout=5.0) + transport = QUICTransport(server_key, config) + + lifecycle_events = [] + + async def monitoring_handler(connection): + lifecycle_events.append(("connection_started", connection.get_stats())) + + try: + # Monitor connection + while not connection.is_closed: + stats = connection.get_stats() + lifecycle_events.append(("connection_stats", stats)) + await trio.sleep(0.1) + + except Exception as e: + lifecycle_events.append(("connection_error", str(e))) + finally: + lifecycle_events.append(("connection_ended", connection.get_stats())) + + listener = transport.create_listener(monitoring_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + # Run monitoring for a bit + await trio.sleep(0.5) + + # Check that monitoring infrastructure is working + assert listener.is_listening() + + # Close inside nursery + await listener.close() + + finally: + # Ensure cleanup + if not listener._closed: + await listener.close() + await transport.close() + + # Should have some lifecycle events from setup + logger.info(f"Recorded {len(lifecycle_events)} lifecycle events") + + @pytest.mark.trio + async def test_multi_listener_echo_servers(self): + """Test multiple echo servers running in parallel.""" + server_key = create_new_key_pair().private_key + config = QUICTransportConfig(idle_timeout=5.0) + transport = QUICTransport(server_key, config) + + all_echo_data = {} + listeners = [] + + async def create_echo_server(server_id): + """Create and run one echo server.""" + echo_data = [] + all_echo_data[server_id] = echo_data + + async def echo_handler(connection): + logger.info(f"Echo server {server_id} got connection") + + async def stream_handler(stream): + try: + while True: + data = await stream.read(1024) + if not data: + break + echo_data.append(data) + await stream.write(f"ECHO-{server_id}: ".encode() + data) + except Exception as e: + logger.error(f"Stream error in server {server_id}: {e}") + finally: + await stream.close() + + connection.set_stream_handler(stream_handler) + while not connection.is_closed: + await trio.sleep(0.1) + + listener = transport.create_listener(echo_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + listeners.append(listener) + + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + logger.info(f"Echo server {server_id} started") + + # Run for a bit + await trio.sleep(0.3) + + # Close this server + await listener.close() + logger.info(f"Echo server {server_id} closed") + + try: + # Run multiple echo servers in parallel + async with trio.open_nursery() as nursery: + for i in range(3): + nursery.start_soon(create_echo_server, i) + + # Verify all servers ran + assert len(listeners) == 3 + assert len(all_echo_data) == 3 + + for listener in listeners: + assert not listener.is_listening() # Should all be closed + + finally: + await transport.close() + + @pytest.mark.trio + async def test_graceful_shutdown_sequence(self): + """Test graceful shutdown of multiple components.""" + server_key = create_new_key_pair().private_key + config = QUICTransportConfig(idle_timeout=5.0) + transport = QUICTransport(server_key, config) + + shutdown_events = [] + listeners = [] + + async def tracked_connection_handler(connection): + """Connection handler that tracks shutdown.""" + try: + while not connection.is_closed: + await trio.sleep(0.1) + finally: + shutdown_events.append(f"connection_closed_{id(connection)}") + + async def create_tracked_listener(listener_id): + """Create a listener that tracks its lifecycle.""" + try: + listener = transport.create_listener(tracked_connection_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + listeners.append(listener) + + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + shutdown_events.append(f"listener_{listener_id}_started") + + # Run for a bit + await trio.sleep(0.2) + + # Graceful close + await listener.close() + shutdown_events.append(f"listener_{listener_id}_closed") + + except Exception as e: + shutdown_events.append(f"listener_{listener_id}_error_{e}") + raise + + try: + # Start multiple listeners + async with trio.open_nursery() as nursery: + for i in range(3): + nursery.start_soon(create_tracked_listener, i) + + # Verify shutdown sequence + start_events = [e for e in shutdown_events if "started" in e] + close_events = [e for e in shutdown_events if "closed" in e] + + assert len(start_events) == 3 + assert len(close_events) == 3 + + logger.info(f"Shutdown sequence: {shutdown_events}") + + finally: + shutdown_events.append("transport_closing") + await transport.close() + shutdown_events.append("transport_closed") + + +# HELPER FUNCTIONS FOR CLEANER TESTS + + +async def run_listener_for_duration(transport, handler, duration=0.5): + """Helper to run a single listener for a specific duration.""" + listener = transport.create_listener(handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + + # Run for specified duration + await trio.sleep(duration) + + # Clean close + await listener.close() + + return listener + + +async def run_multiple_listeners_parallel(transport, handler, count=3, duration=0.5): + """Helper to run multiple listeners in parallel.""" + listeners = [] + + async def single_listener_task(listener_id): + listener = await run_listener_for_duration(transport, handler, duration) + listeners.append(listener) + logger.info(f"Listener {listener_id} completed") + + async with trio.open_nursery() as nursery: + for i in range(count): + nursery.start_soon(single_listener_task, i) + + return listeners + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/core/transport/quic/test_listener.py b/tests/core/transport/quic/test_listener.py index c0874ec4..840f7218 100644 --- a/tests/core/transport/quic/test_listener.py +++ b/tests/core/transport/quic/test_listener.py @@ -17,7 +17,6 @@ from libp2p.transport.quic.transport import ( ) from libp2p.transport.quic.utils import ( create_quic_multiaddr, - quic_multiaddr_to_endpoint, ) @@ -89,71 +88,51 @@ class TestQUICListener: assert stats["active_connections"] == 0 assert stats["pending_connections"] == 0 - # Close listener - await listener.close() - assert not listener.is_listening() + # Sender Cancel Signal + nursery.cancel_scope.cancel() + + await listener.close() + assert not listener.is_listening() @pytest.mark.trio async def test_listener_double_listen(self, listener: QUICListener): """Test that double listen raises error.""" listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic") - # The nursery is the outer context - async with trio.open_nursery() as nursery: - # The try/finally is now INSIDE the nursery scope - try: - # The listen method creates the socket and starts background tasks + try: + async with trio.open_nursery() as nursery: success = await listener.listen(listen_addr, nursery) assert success await trio.sleep(0.01) addrs = listener.get_addrs() assert len(addrs) > 0 - print("ADDRS 1: ", len(addrs)) - print("TEST LOGIC FINISHED") - async with trio.open_nursery() as nursery2: with pytest.raises(QUICListenError, match="Already listening"): await listener.listen(listen_addr, nursery2) - finally: - # This block runs BEFORE the 'async with nursery' exits. - print("INNER FINALLY: Closing listener to release socket...") + nursery2.cancel_scope.cancel() - # This closes the socket and sets self._listening = False, - # which helps the background tasks terminate cleanly. - await listener.close() - print("INNER FINALLY: Listener closed.") - - # By the time we get here, the listener and its tasks have been fully - # shut down, allowing the nursery to exit without hanging. - print("TEST COMPLETED SUCCESSFULLY.") + nursery.cancel_scope.cancel() + finally: + await listener.close() @pytest.mark.trio async def test_listener_port_binding(self, listener: QUICListener): """Test listener port binding and cleanup.""" listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") - # The nursery is the outer context - async with trio.open_nursery() as nursery: - # The try/finally is now INSIDE the nursery scope - try: - # The listen method creates the socket and starts background tasks + try: + async with trio.open_nursery() as nursery: success = await listener.listen(listen_addr, nursery) assert success await trio.sleep(0.5) addrs = listener.get_addrs() assert len(addrs) > 0 - print("TEST LOGIC FINISHED") - finally: - # This block runs BEFORE the 'async with nursery' exits. - print("INNER FINALLY: Closing listener to release socket...") - - # This closes the socket and sets self._listening = False, - # which helps the background tasks terminate cleanly. - await listener.close() - print("INNER FINALLY: Listener closed.") + nursery.cancel_scope.cancel() + finally: + await listener.close() # By the time we get here, the listener and its tasks have been fully # shut down, allowing the nursery to exit without hanging. diff --git a/tests/core/transport/quic/test_utils.py b/tests/core/transport/quic/test_utils.py index d67317c7..d2dacdcf 100644 --- a/tests/core/transport/quic/test_utils.py +++ b/tests/core/transport/quic/test_utils.py @@ -24,18 +24,14 @@ class TestQUICUtils: Multiaddr( f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" ), - Multiaddr( - f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" - ), + Multiaddr(f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"), Multiaddr( f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" ), Multiaddr( f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}" ), - Multiaddr( - f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" - ), + Multiaddr(f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"), ] for addr in valid: