From 54b3055eaaddc03263b6c2da9544560bbe2d4e29 Mon Sep 17 00:00:00 2001 From: Akash Mondal Date: Tue, 10 Jun 2025 21:40:21 +0000 Subject: [PATCH] fix: impl quic listener --- libp2p/custom_types.py | 11 +- libp2p/transport/quic/config.py | 8 + libp2p/transport/quic/connection.py | 335 ++++++++--- libp2p/transport/quic/listener.py | 579 +++++++++++++++++++ libp2p/transport/quic/security.py | 123 ++++ libp2p/transport/quic/stream.py | 15 +- libp2p/transport/quic/transport.py | 122 ++-- libp2p/transport/quic/utils.py | 223 +++++++ pyproject.toml | 1 + tests/core/transport/quic/test_connection.py | 119 ++++ tests/core/transport/quic/test_listener.py | 171 ++++++ tests/core/transport/quic/test_transport.py | 36 +- tests/core/transport/quic/test_utils.py | 94 +++ 13 files changed, 1687 insertions(+), 150 deletions(-) create mode 100644 libp2p/transport/quic/listener.py create mode 100644 libp2p/transport/quic/security.py create mode 100644 libp2p/transport/quic/utils.py create mode 100644 tests/core/transport/quic/test_connection.py create mode 100644 tests/core/transport/quic/test_listener.py create mode 100644 tests/core/transport/quic/test_utils.py diff --git a/libp2p/custom_types.py b/libp2p/custom_types.py index 0b844133..73a65c39 100644 --- a/libp2p/custom_types.py +++ b/libp2p/custom_types.py @@ -5,17 +5,15 @@ from collections.abc import ( ) from typing import TYPE_CHECKING, NewType, Union, cast +from libp2p.transport.quic.stream import QUICStream + if TYPE_CHECKING: - from libp2p.abc import ( - IMuxedConn, - INetStream, - ISecureTransport, - ) + from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport else: IMuxedConn = cast(type, object) INetStream = cast(type, object) ISecureTransport = cast(type, object) - + IMuxedStream = cast(type, object) from libp2p.io.abc import ( ReadWriteCloser, @@ -37,3 +35,4 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] UnsubscribeFn = Callable[[], Awaitable[None]] +TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]] diff --git a/libp2p/transport/quic/config.py b/libp2p/transport/quic/config.py index 75402626..d1ccf335 100644 --- a/libp2p/transport/quic/config.py +++ b/libp2p/transport/quic/config.py @@ -8,6 +8,8 @@ from dataclasses import ( ) import ssl +from libp2p.custom_types import TProtocol + @dataclass class QUICTransportConfig: @@ -39,6 +41,12 @@ class QUICTransportConfig: max_connections: int = 1000 # Maximum number of connections connection_timeout: float = 10.0 # Connection establishment timeout + # Protocol identifiers matching go-libp2p + # TODO: UNTIL MUITIADDR REPO IS UPDATED + # PROTOCOL_QUIC_V1: TProtocol = TProtocol("/quic-v1") # RFC 9000 + PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic") # RFC 9000 + PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29 + def __post_init__(self): """Validate configuration after initialization.""" if not (self.enable_draft29 or self.enable_v1): diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py index fceb9d87..9746d234 100644 --- a/libp2p/transport/quic/connection.py +++ b/libp2p/transport/quic/connection.py @@ -6,6 +6,7 @@ Uses aioquic's sans-IO core with trio for async operations. import logging import socket import time +from typing import TYPE_CHECKING from aioquic.quic import ( events, @@ -21,9 +22,7 @@ from libp2p.abc import ( IMuxedStream, IRawConnection, ) -from libp2p.custom_types import ( - StreamHandlerFn, -) +from libp2p.custom_types import TQUICStreamHandlerFn from libp2p.peer.id import ( ID, ) @@ -35,9 +34,11 @@ from .exceptions import ( from .stream import ( QUICStream, ) -from .transport import ( - QUICTransport, -) + +if TYPE_CHECKING: + from .transport import ( + QUICTransport, + ) logger = logging.getLogger(__name__) @@ -49,76 +50,177 @@ class QUICConnection(IRawConnection, IMuxedConn): Uses aioquic's sans-IO core with trio for native async support. QUIC natively provides stream multiplexing, so this connection acts as both a raw connection (for transport layer) and muxed connection (for upper layers). + + Updated to work properly with the QUIC listener for server-side connections. """ def __init__( self, quic_connection: QuicConnection, remote_addr: tuple[str, int], - peer_id: ID, + peer_id: ID | None, local_peer_id: ID, - initiator: bool, + is_initiator: bool, maddr: multiaddr.Multiaddr, - transport: QUICTransport, + transport: "QUICTransport", ): self._quic = quic_connection self._remote_addr = remote_addr self._peer_id = peer_id self._local_peer_id = local_peer_id - self.__is_initiator = initiator + self.__is_initiator = is_initiator self._maddr = maddr self._transport = transport - # Trio networking + # Trio networking - socket may be provided by listener self._socket: trio.socket.SocketType | None = None self._connected_event = trio.Event() self._closed_event = trio.Event() # Stream management self._streams: dict[int, QUICStream] = {} - self._next_stream_id: int = ( - 0 if initiator else 1 - ) # Even for initiator, odd for responder - self._stream_handler: StreamHandlerFn | None = None + self._next_stream_id: int = self._calculate_initial_stream_id() + self._stream_handler: TQUICStreamHandlerFn | None = None + self._stream_id_lock = trio.Lock() # Connection state self._closed = False - self._timer_task = None + self._established = False + self._started = False - logger.debug(f"Created QUIC connection to {peer_id}") + # Background task management + self._background_tasks_started = False + self._nursery: trio.Nursery | None = None + + logger.debug(f"Created QUIC connection to {peer_id} (initiator: {is_initiator})") + + def _calculate_initial_stream_id(self) -> int: + """ + Calculate the initial stream ID based on QUIC specification. + + QUIC stream IDs: + - Client-initiated bidirectional: 0, 4, 8, 12, ... + - Server-initiated bidirectional: 1, 5, 9, 13, ... + - Client-initiated unidirectional: 2, 6, 10, 14, ... + - Server-initiated unidirectional: 3, 7, 11, 15, ... + + For libp2p, we primarily use bidirectional streams. + """ + if self.__is_initiator: + return 0 # Client starts with 0, then 4, 8, 12... + else: + return 1 # Server starts with 1, then 5, 9, 13... @property def is_initiator(self) -> bool: # type: ignore return self.__is_initiator - async def connect(self) -> None: - """Establish the QUIC connection using trio.""" + async def start(self) -> None: + """ + Start the connection and its background tasks. + + This method implements the IMuxedConn.start() interface. + It should be called to begin processing connection events. + """ + if self._started: + logger.warning("Connection already started") + return + + if self._closed: + raise QUICConnectionError("Cannot start a closed connection") + + self._started = True + logger.debug(f"Starting QUIC connection to {self._peer_id}") + + # If this is a client connection, we need to establish the connection + if self.__is_initiator: + await self._initiate_connection() + else: + # For server connections, we're already connected via the listener + self._established = True + self._connected_event.set() + + logger.debug(f"QUIC connection to {self._peer_id} started") + + async def _initiate_connection(self) -> None: + """Initiate client-side connection establishment.""" try: # Create UDP socket using trio self._socket = trio.socket.socket( family=socket.AF_INET, type=socket.SOCK_DGRAM ) + # Connect the socket to the remote address + await self._socket.connect(self._remote_addr) + # Start the connection establishment self._quic.connect(self._remote_addr, now=time.time()) # Send initial packet(s) await self._transmit() - # Start background tasks using trio nursery - async with trio.open_nursery() as nursery: - nursery.start_soon( - self._handle_incoming_data, None, "QUIC INCOMING DATA" - ) - nursery.start_soon(self._handle_timer, None, "QUIC TIMER HANDLER") + # For client connections, we need to manage our own background tasks + # In a real implementation, this would be managed by the transport + # For now, we'll start them here + if not self._background_tasks_started: + # We would need a nursery to start background tasks + # This is a limitation of the current design + logger.warning("Background tasks need nursery - connection may not work properly") - # Wait for connection to be established - await self._connected_event.wait() + except Exception as e: + logger.error(f"Failed to initiate connection: {e}") + raise QUICConnectionError(f"Connection initiation failed: {e}") from e + + async def connect(self, nursery: trio.Nursery) -> None: + """ + Establish the QUIC connection using trio. + + Args: + nursery: Trio nursery for background tasks + + """ + if not self.__is_initiator: + raise QUICConnectionError("connect() should only be called by client connections") + + try: + # Store nursery for background tasks + self._nursery = nursery + + # Create UDP socket using trio + self._socket = trio.socket.socket( + family=socket.AF_INET, type=socket.SOCK_DGRAM + ) + + # Connect the socket to the remote address + await self._socket.connect(self._remote_addr) + + # Start the connection establishment + self._quic.connect(self._remote_addr, now=time.time()) + + # Send initial packet(s) + await self._transmit() + + # Start background tasks + await self._start_background_tasks(nursery) + + # Wait for connection to be established + await self._connected_event.wait() except Exception as e: logger.error(f"Failed to connect: {e}") raise QUICConnectionError(f"Connection failed: {e}") from e + async def _start_background_tasks(self, nursery: trio.Nursery) -> None: + """Start background tasks for connection management.""" + if self._background_tasks_started: + return + + self._background_tasks_started = True + + # Start background tasks + nursery.start_soon(self._handle_incoming_data) + nursery.start_soon(self._handle_timer) + async def _handle_incoming_data(self) -> None: """Handle incoming UDP datagrams in trio.""" while not self._closed: @@ -128,6 +230,10 @@ class QUICConnection(IRawConnection, IMuxedConn): self._quic.receive_datagram(data, addr, now=time.time()) await self._process_events() await self._transmit() + + # Small delay to prevent busy waiting + await trio.sleep(0.001) + except trio.ClosedResourceError: break except Exception as e: @@ -137,18 +243,26 @@ class QUICConnection(IRawConnection, IMuxedConn): async def _handle_timer(self) -> None: """Handle QUIC timer events in trio.""" while not self._closed: - timer_at = self._quic.get_timer() - if timer_at is None: - await trio.sleep(1.0) # No timer set, check again later - continue + try: + timer_at = self._quic.get_timer() + if timer_at is None: + await trio.sleep(0.1) # No timer set, check again later + continue - now = time.time() - if timer_at <= now: - self._quic.handle_timer(now=now) - await self._process_events() - await self._transmit() - else: - await trio.sleep(timer_at - now) + now = time.time() + if timer_at <= now: + self._quic.handle_timer(now=now) + await self._process_events() + await self._transmit() + await trio.sleep(0.001) # Small delay + else: + # Sleep until timer fires, but check periodically + sleep_time = min(timer_at - now, 0.1) + await trio.sleep(sleep_time) + + except Exception as e: + logger.error(f"Error in timer handler: {e}") + await trio.sleep(0.1) async def _process_events(self) -> None: """Process QUIC events from aioquic core.""" @@ -165,6 +279,7 @@ class QUICConnection(IRawConnection, IMuxedConn): elif isinstance(event, events.HandshakeCompleted): logger.debug("QUIC handshake completed") + self._established = True self._connected_event.set() elif isinstance(event, events.StreamDataReceived): @@ -177,25 +292,47 @@ class QUICConnection(IRawConnection, IMuxedConn): """Handle incoming stream data.""" stream_id = event.stream_id + # Get or create stream if stream_id not in self._streams: - # Create new stream for incoming data + # Determine if this is an incoming stream + is_incoming = self._is_incoming_stream(stream_id) + stream = QUICStream( connection=self, stream_id=stream_id, - is_initiator=False, # pyrefly: ignore + is_initiator=not is_incoming, ) self._streams[stream_id] = stream - # Notify stream handler if available - if self._stream_handler: - # Use trio nursery to start stream handler - async with trio.open_nursery() as nursery: - nursery.start_soon(self._stream_handler, stream) + # Notify stream handler for incoming streams + if is_incoming and self._stream_handler: + # Start stream handler in background + # In a real implementation, you might want to use the nursery + # passed to the connection, but for now we'll handle it directly + try: + await self._stream_handler(stream) + except Exception as e: + logger.error(f"Error in stream handler: {e}") # Forward data to stream stream = self._streams[stream_id] await stream.handle_data_received(event.data, event.end_stream) + def _is_incoming_stream(self, stream_id: int) -> bool: + """ + Determine if a stream ID represents an incoming stream. + + For bidirectional streams: + - Even IDs are client-initiated + - Odd IDs are server-initiated + """ + if self.__is_initiator: + # We're the client, so odd stream IDs are incoming + return stream_id % 2 == 1 + else: + # We're the server, so even stream IDs are incoming + return stream_id % 2 == 0 + async def _handle_stream_reset(self, event: events.StreamReset) -> None: """Handle stream reset.""" stream_id = event.stream_id @@ -210,15 +347,15 @@ class QUICConnection(IRawConnection, IMuxedConn): if socket is None: return - for data, addr in self._quic.datagrams_to_send(now=time.time()): - try: + try: + for data, addr in self._quic.datagrams_to_send(now=time.time()): await socket.sendto(data, addr) - except Exception as e: - logger.error(f"Failed to send datagram: {e}") + except Exception as e: + logger.error(f"Failed to send datagram: {e}") # IRawConnection interface - async def write(self, data: bytes): + async def write(self, data: bytes) -> None: """ Write data to the connection. For QUIC, this creates a new stream for each write operation. @@ -230,7 +367,7 @@ class QUICConnection(IRawConnection, IMuxedConn): await stream.write(data) await stream.close() - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int | None = -1) -> bytes: """ Read data from the connection. For QUIC, this reads from the next available stream. @@ -252,14 +389,21 @@ class QUICConnection(IRawConnection, IMuxedConn): self._closed = True logger.debug(f"Closing QUIC connection to {self._peer_id}") - # Close all streams using trio nursery - async with trio.open_nursery() as nursery: - for stream in self._streams.values(): - nursery.start_soon(stream.close) + # Close all streams + stream_close_tasks = [] + for stream in list(self._streams.values()): + stream_close_tasks.append(stream.close()) + + if stream_close_tasks: + # Close streams concurrently + async with trio.open_nursery() as nursery: + for task in stream_close_tasks: + nursery.start_soon(lambda t=task: t) # Close QUIC connection self._quic.close() - await self._transmit() # Send close frames + if self._socket: + await self._transmit() # Send close frames # Close socket if self._socket: @@ -275,6 +419,16 @@ class QUICConnection(IRawConnection, IMuxedConn): """Check if connection is closed.""" return self._closed + @property + def is_established(self) -> bool: + """Check if connection is established (handshake completed).""" + return self._established + + @property + def is_started(self) -> bool: + """Check if connection has been started.""" + return self._started + def multiaddr(self) -> multiaddr.Multiaddr: """Get the multiaddr for this connection.""" return self._maddr @@ -283,6 +437,10 @@ class QUICConnection(IRawConnection, IMuxedConn): """Get the local peer ID.""" return self._local_peer_id + def remote_peer_id(self) -> ID | None: + """Get the remote peer ID.""" + return self._peer_id + # IMuxedConn interface async def open_stream(self) -> IMuxedStream: @@ -296,23 +454,27 @@ class QUICConnection(IRawConnection, IMuxedConn): if self._closed: raise QUICStreamError("Connection is closed") - # Generate next stream ID - stream_id = self._next_stream_id - self._next_stream_id += ( - 2 # Increment by 2 to maintain initiator/responder distinction - ) + if not self._started: + raise QUICStreamError("Connection not started") - # Create stream - stream = QUICStream( - connection=self, stream_id=stream_id, is_initiator=True - ) # pyrefly: ignore + async with self._stream_id_lock: + # Generate next stream ID + stream_id = self._next_stream_id + self._next_stream_id += 4 # Increment by 4 for bidirectional streams - self._streams[stream_id] = stream + # Create stream + stream = QUICStream( + connection=self, + stream_id=stream_id, + is_initiator=True + ) + + self._streams[stream_id] = stream logger.debug(f"Opened QUIC stream {stream_id}") return stream - def set_stream_handler(self, handler_function: StreamHandlerFn) -> None: + def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None: """ Set handler for incoming streams. @@ -341,17 +503,22 @@ class QUICConnection(IRawConnection, IMuxedConn): """ # Extract peer ID from TLS certificate # This should match the expected peer ID - cert_peer_id = self._extract_peer_id_from_cert() + try: + cert_peer_id = self._extract_peer_id_from_cert() - if self._peer_id and cert_peer_id != self._peer_id: - raise QUICConnectionError( - f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}" - ) + if self._peer_id and cert_peer_id != self._peer_id: + raise QUICConnectionError( + f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}" + ) - if not self._peer_id: - self._peer_id = cert_peer_id + if not self._peer_id: + self._peer_id = cert_peer_id - logger.debug(f"Verified peer identity: {self._peer_id}") + logger.debug(f"Verified peer identity: {self._peer_id}") + + except NotImplementedError: + logger.warning("Peer identity verification not implemented - skipping") + # For now, we'll skip verification during development def _extract_peer_id_from_cert(self) -> ID: """Extract peer ID from TLS certificate.""" @@ -363,6 +530,22 @@ class QUICConnection(IRawConnection, IMuxedConn): # The certificate should contain the peer ID in a specific extension raise NotImplementedError("Certificate peer ID extraction not implemented") + def get_stats(self) -> dict: + """Get connection statistics.""" + return { + "peer_id": str(self._peer_id), + "remote_addr": self._remote_addr, + "is_initiator": self.__is_initiator, + "is_established": self._established, + "is_closed": self._closed, + "is_started": self._started, + "active_streams": len(self._streams), + "next_stream_id": self._next_stream_id, + } + + def get_remote_address(self): + return self._remote_addr + def __str__(self) -> str: """String representation of the connection.""" - return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)})" + return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)}, established={self._established}, started={self._started})" diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py new file mode 100644 index 00000000..8757427e --- /dev/null +++ b/libp2p/transport/quic/listener.py @@ -0,0 +1,579 @@ +""" +QUIC Listener implementation for py-libp2p. +Based on go-libp2p and js-libp2p QUIC listener patterns. +Uses aioquic's server-side QUIC implementation with trio. +""" + +import copy +import logging +import socket +import time +from typing import TYPE_CHECKING, Dict + +from aioquic.quic import events +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import QuicConnection +from multiaddr import Multiaddr +import trio + +from libp2p.abc import IListener +from libp2p.custom_types import THandler, TProtocol + +from .config import QUICTransportConfig +from .connection import QUICConnection +from .exceptions import QUICListenError +from .utils import ( + create_quic_multiaddr, + is_quic_multiaddr, + multiaddr_to_quic_version, + quic_multiaddr_to_endpoint, +) + +if TYPE_CHECKING: + from .transport import QUICTransport + +logger = logging.getLogger(__name__) +logger.setLevel("DEBUG") + + +class QUICListener(IListener): + """ + QUIC Listener implementation following libp2p listener interface. + + Handles incoming QUIC connections, manages server-side handshakes, + and integrates with the libp2p connection handler system. + Based on go-libp2p and js-libp2p listener patterns. + """ + + def __init__( + self, + transport: "QUICTransport", + handler_function: THandler, + quic_configs: Dict[TProtocol, QuicConfiguration], + config: QUICTransportConfig, + ): + """ + Initialize QUIC listener. + + Args: + transport: Parent QUIC transport + handler_function: Function to handle new connections + quic_configs: QUIC configurations for different versions + config: QUIC transport configuration + + """ + self._transport = transport + self._handler = handler_function + self._quic_configs = quic_configs + self._config = config + + # Network components + self._socket: trio.socket.SocketType | None = None + self._bound_addresses: list[Multiaddr] = [] + + # Connection management + self._connections: Dict[tuple[str, int], QUICConnection] = {} + self._pending_connections: Dict[tuple[str, int], QuicConnection] = {} + self._connection_lock = trio.Lock() + + # Listener state + self._closed = False + self._listening = False + self._nursery: trio.Nursery | None = None + + # Performance tracking + self._stats = { + "connections_accepted": 0, + "connections_rejected": 0, + "bytes_received": 0, + "packets_processed": 0, + } + + logger.debug("Initialized QUIC listener") + + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + """ + Start listening on the given multiaddr. + + Args: + maddr: Multiaddr to listen on + nursery: Trio nursery for managing background tasks + + Returns: + True if listening started successfully + + Raises: + QUICListenError: If failed to start listening + """ + if not is_quic_multiaddr(maddr): + raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}") + + if self._listening: + raise QUICListenError("Already listening") + + try: + # Extract host and port from multiaddr + host, port = quic_multiaddr_to_endpoint(maddr) + quic_version = multiaddr_to_quic_version(maddr) + + # Validate QUIC version support + if quic_version not in self._quic_configs: + raise QUICListenError(f"Unsupported QUIC version: {quic_version}") + + # Create and bind UDP socket + self._socket = await self._create_and_bind_socket(host, port) + actual_port = self._socket.getsockname()[1] + + # Update multiaddr with actual bound port + actual_maddr = create_quic_multiaddr(host, actual_port, f"/{quic_version}") + self._bound_addresses = [actual_maddr] + + # Store nursery reference and set listening state + self._nursery = nursery + self._listening = True + + # Start background tasks directly in the provided nursery + # This ensures proper cancellation when the nursery exits + nursery.start_soon(self._handle_incoming_packets) + nursery.start_soon(self._manage_connections) + + print(f"QUIC listener started on {actual_maddr}") + return True + + except trio.Cancelled: + print("CLOSING LISTENER") + raise + except Exception as e: + logger.error(f"Failed to start QUIC listener on {maddr}: {e}") + await self._cleanup_socket() + raise QUICListenError(f"Listen failed: {e}") from e + + async def _create_and_bind_socket( + self, host: str, port: int + ) -> trio.socket.SocketType: + """Create and bind UDP socket for QUIC.""" + try: + # Determine address family + try: + import ipaddress + + ip = ipaddress.ip_address(host) + family = socket.AF_INET if ip.version == 4 else socket.AF_INET6 + except ValueError: + # Assume IPv4 for hostnames + family = socket.AF_INET + + # Create UDP socket + sock = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + + # Set socket options for better performance + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, "SO_REUSEPORT"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + # Bind to address + await sock.bind((host, port)) + + logger.debug(f"Created and bound UDP socket to {host}:{port}") + return sock + + except Exception as e: + raise QUICListenError(f"Failed to create socket: {e}") from e + + async def _handle_incoming_packets(self) -> None: + """ + Handle incoming UDP packets and route to appropriate connections. + This is the main packet processing loop. + """ + logger.debug("Started packet handling loop") + + try: + while self._listening and self._socket: + try: + # Receive UDP packet (this blocks until packet arrives or socket closes) + data, addr = await self._socket.recvfrom(65536) + self._stats["bytes_received"] += len(data) + self._stats["packets_processed"] += 1 + + # Process packet asynchronously to avoid blocking + if self._nursery: + self._nursery.start_soon(self._process_packet, data, addr) + + except trio.ClosedResourceError: + # Socket was closed, exit gracefully + logger.debug("Socket closed, exiting packet handler") + break + except Exception as e: + logger.error(f"Error receiving packet: {e}") + # Continue processing other packets + await trio.sleep(0.01) + except trio.Cancelled: + print("PACKET HANDLER CANCELLED - FORCIBLY CLOSING SOCKET") + raise + finally: + print("PACKET HANDLER FINISHED") + logger.debug("Packet handling loop terminated") + + async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: + """ + Process a single incoming packet. + Routes to existing connection or creates new connection. + + Args: + data: Raw UDP packet data + addr: Source address (host, port) + + """ + try: + async with self._connection_lock: + # Check if we have an existing connection for this address + if addr in self._connections: + connection = self._connections[addr] + await self._route_to_connection(connection, data, addr) + elif addr in self._pending_connections: + # Handle packet for pending connection + quic_conn = self._pending_connections[addr] + await self._handle_pending_connection(quic_conn, data, addr) + else: + # New connection + await self._handle_new_connection(data, addr) + + except Exception as e: + logger.error(f"Error processing packet from {addr}: {e}") + + async def _route_to_connection( + self, connection: QUICConnection, data: bytes, addr: tuple[str, int] + ) -> None: + """Route packet to existing connection.""" + try: + # Feed data to the connection's QUIC instance + connection._quic.receive_datagram(data, addr, now=time.time()) + + # Process events and handle responses + await connection._process_events() + await connection._transmit() + + except Exception as e: + logger.error(f"Error routing packet to connection {addr}: {e}") + # Remove problematic connection + await self._remove_connection(addr) + + async def _handle_pending_connection( + self, quic_conn: QuicConnection, data: bytes, addr: tuple[str, int] + ) -> None: + """Handle packet for a pending (handshaking) connection.""" + try: + # Feed data to QUIC connection + quic_conn.receive_datagram(data, addr, now=time.time()) + + # Process events + await self._process_quic_events(quic_conn, addr) + + # Send any outgoing packets + await self._transmit_for_connection(quic_conn) + + except Exception as e: + logger.error(f"Error handling pending connection {addr}: {e}") + # Remove from pending connections + self._pending_connections.pop(addr, None) + + async def _handle_new_connection(self, data: bytes, addr: tuple[str, int]) -> None: + """ + Handle a new incoming connection. + Creates a new QUIC connection and starts handshake. + + Args: + data: Initial packet data + addr: Source address + + """ + try: + # Determine QUIC version from packet + # For now, use the first available configuration + # TODO: Implement proper version negotiation + quic_version = next(iter(self._quic_configs.keys())) + config = self._quic_configs[quic_version] + + # Create server-side QUIC configuration + server_config = copy.deepcopy(config) + server_config.is_client = False + + # Create QUIC connection + quic_conn = QuicConnection(configuration=server_config) + + # Store as pending connection + self._pending_connections[addr] = quic_conn + + # Process initial packet + quic_conn.receive_datagram(data, addr, now=time.time()) + await self._process_quic_events(quic_conn, addr) + await self._transmit_for_connection(quic_conn) + + logger.debug(f"Started handshake for new connection from {addr}") + + except Exception as e: + logger.error(f"Error handling new connection from {addr}: {e}") + self._stats["connections_rejected"] += 1 + + async def _process_quic_events( + self, quic_conn: QuicConnection, addr: tuple[str, int] + ) -> None: + """Process QUIC events for a connection.""" + while True: + event = quic_conn.next_event() + if event is None: + break + + if isinstance(event, events.ConnectionTerminated): + logger.debug( + f"Connection from {addr} terminated: {event.reason_phrase}" + ) + await self._remove_connection(addr) + break + + elif isinstance(event, events.HandshakeCompleted): + logger.debug(f"Handshake completed for {addr}") + await self._promote_pending_connection(quic_conn, addr) + + elif isinstance(event, events.StreamDataReceived): + # Forward to established connection if available + if addr in self._connections: + connection = self._connections[addr] + await connection._handle_stream_data(event) + + elif isinstance(event, events.StreamReset): + # Forward to established connection if available + if addr in self._connections: + connection = self._connections[addr] + await connection._handle_stream_reset(event) + + async def _promote_pending_connection( + self, quic_conn: QuicConnection, addr: tuple[str, int] + ) -> None: + """ + Promote a pending connection to an established connection. + Called after successful handshake completion. + + Args: + quic_conn: Established QUIC connection + addr: Remote address + + """ + try: + # Remove from pending connections + self._pending_connections.pop(addr, None) + + # Create multiaddr for this connection + host, port = addr + # Use the first supported QUIC version for now + quic_version = next(iter(self._quic_configs.keys())) + remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}") + + # Create libp2p connection wrapper + connection = QUICConnection( + quic_connection=quic_conn, + remote_addr=addr, + peer_id=None, # Will be determined during identity verification + local_peer_id=self._transport._peer_id, + is_initiator=False, # We're the server + maddr=remote_maddr, + transport=self._transport, + ) + + # Store the connection + self._connections[addr] = connection + + # Start connection management tasks + if self._nursery: + self._nursery.start_soon(connection._handle_incoming_data) + self._nursery.start_soon(connection._handle_timer) + + # TODO: Verify peer identity + # await connection.verify_peer_identity() + + # Call the connection handler + if self._nursery: + self._nursery.start_soon( + self._handle_new_established_connection, connection + ) + + self._stats["connections_accepted"] += 1 + logger.info(f"Accepted new QUIC connection from {addr}") + + except Exception as e: + logger.error(f"Error promoting connection from {addr}: {e}") + # Clean up + await self._remove_connection(addr) + self._stats["connections_rejected"] += 1 + + async def _handle_new_established_connection( + self, connection: QUICConnection + ) -> None: + """ + Handle a newly established connection by calling the user handler. + + Args: + connection: Established QUIC connection + + """ + try: + # Call the connection handler provided by the transport + await self._handler(connection) + except Exception as e: + logger.error(f"Error in connection handler: {e}") + # Close the problematic connection + await connection.close() + + async def _transmit_for_connection(self, quic_conn: QuicConnection) -> None: + """Send pending datagrams for a QUIC connection.""" + sock = self._socket + if not sock: + return + + for data, addr in quic_conn.datagrams_to_send(now=time.time()): + try: + await sock.sendto(data, addr) + except Exception as e: + logger.error(f"Failed to send datagram to {addr}: {e}") + + async def _manage_connections(self) -> None: + """ + Background task to manage connection lifecycle. + Handles cleanup of closed/idle connections. + """ + try: + while not self._closed: + try: + # Sleep for a short interval + await trio.sleep(1.0) + + # Clean up closed connections + await self._cleanup_closed_connections() + + # Handle connection timeouts + await self._handle_connection_timeouts() + + except Exception as e: + logger.error(f"Error in connection management: {e}") + except trio.Cancelled: + print("CONNECTION MANAGER CANCELLED") + raise + finally: + print("CONNECTION MANAGER FINISHED") + + async def _cleanup_closed_connections(self) -> None: + """Remove closed connections from tracking.""" + async with self._connection_lock: + closed_addrs = [] + + for addr, connection in self._connections.items(): + if connection.is_closed: + closed_addrs.append(addr) + + for addr in closed_addrs: + self._connections.pop(addr, None) + logger.debug(f"Cleaned up closed connection from {addr}") + + async def _handle_connection_timeouts(self) -> None: + """Handle connection timeouts and cleanup.""" + # TODO: Implement connection timeout handling + # Check for idle connections and close them + pass + + async def _remove_connection(self, addr: tuple[str, int]) -> None: + """Remove a connection from tracking.""" + async with self._connection_lock: + # Remove from active connections + connection = self._connections.pop(addr, None) + if connection: + await connection.close() + + # Remove from pending connections + quic_conn = self._pending_connections.pop(addr, None) + if quic_conn: + quic_conn.close() + + async def close(self) -> None: + """Close the listener and cleanup resources.""" + if self._closed: + return + + self._closed = True + self._listening = False + print("Closing QUIC listener") + + # CRITICAL: Close socket FIRST to unblock recvfrom() + await self._cleanup_socket() + + print("SOCKET CLEANUP COMPLETE") + + # Close all connections WITHOUT using the lock during shutdown + # (avoid deadlock if background tasks are cancelled while holding lock) + connections_to_close = list(self._connections.values()) + pending_to_close = list(self._pending_connections.values()) + + print( + f"CLOSING {len(connections_to_close)} connections and {len(pending_to_close)} pending" + ) + + # Close active connections + for connection in connections_to_close: + try: + await connection.close() + except Exception as e: + print(f"Error closing connection: {e}") + + # Close pending connections + for quic_conn in pending_to_close: + try: + quic_conn.close() + except Exception as e: + print(f"Error closing pending connection: {e}") + + # Clear the dictionaries without lock (we're shutting down) + self._connections.clear() + self._pending_connections.clear() + if self._nursery: + print("TASKS", len(self._nursery.child_tasks)) + + print("QUIC listener closed") + + async def _cleanup_socket(self) -> None: + """Clean up the UDP socket.""" + if self._socket: + try: + self._socket.close() + except Exception as e: + logger.error(f"Error closing socket: {e}") + finally: + self._socket = None + + def get_addrs(self) -> tuple[Multiaddr, ...]: + """ + Get the addresses this listener is bound to. + + Returns: + Tuple of bound multiaddrs + + """ + return tuple(self._bound_addresses) + + def is_listening(self) -> bool: + """Check if the listener is actively listening.""" + return self._listening and not self._closed + + def get_stats(self) -> dict: + """Get listener statistics.""" + stats = self._stats.copy() + stats.update( + { + "active_connections": len(self._connections), + "pending_connections": len(self._pending_connections), + "is_listening": self.is_listening(), + } + ) + return stats + + def __str__(self) -> str: + """String representation of the listener.""" + return f"QUICListener(addrs={self._bound_addresses}, connections={len(self._connections)})" diff --git a/libp2p/transport/quic/security.py b/libp2p/transport/quic/security.py new file mode 100644 index 00000000..1a49cf37 --- /dev/null +++ b/libp2p/transport/quic/security.py @@ -0,0 +1,123 @@ +""" +Basic QUIC Security implementation for Module 1. +This provides minimal TLS configuration for QUIC transport. +Full implementation will be in Module 5. +""" + +from dataclasses import dataclass +import os +import tempfile +from typing import Optional + +from libp2p.crypto.keys import PrivateKey +from libp2p.peer.id import ID + +from .exceptions import QUICSecurityError + + +@dataclass +class TLSConfig: + """TLS configuration for QUIC transport.""" + + cert_file: str + key_file: str + ca_file: Optional[str] = None + + +def generate_libp2p_tls_config(private_key: PrivateKey, peer_id: ID) -> TLSConfig: + """ + Generate TLS configuration with libp2p peer identity. + + This is a basic implementation for Module 1. + Full implementation with proper libp2p TLS spec compliance + will be provided in Module 5. + + Args: + private_key: libp2p private key + peer_id: libp2p peer ID + + Returns: + TLS configuration + + Raises: + QUICSecurityError: If TLS configuration generation fails + + """ + try: + # TODO: Implement proper libp2p TLS certificate generation + # This should follow the libp2p TLS specification: + # https://github.com/libp2p/specs/blob/master/tls/tls.md + + # For now, create a basic self-signed certificate + # This is a placeholder implementation + + # Create temporary files for cert and key + with tempfile.NamedTemporaryFile( + mode="w", suffix=".pem", delete=False + ) as cert_file: + cert_path = cert_file.name + # Write placeholder certificate + cert_file.write(_generate_placeholder_cert(peer_id)) + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".key", delete=False + ) as key_file: + key_path = key_file.name + # Write placeholder private key + key_file.write(_generate_placeholder_key(private_key)) + + return TLSConfig(cert_file=cert_path, key_file=key_path) + + except Exception as e: + raise QUICSecurityError(f"Failed to generate TLS config: {e}") from e + + +def _generate_placeholder_cert(peer_id: ID) -> str: + """ + Generate a placeholder certificate. + + This is a temporary implementation for Module 1. + Real implementation will embed the peer ID in the certificate + following the libp2p TLS specification. + """ + # This is a placeholder - real implementation needed + return f"""-----BEGIN CERTIFICATE----- +# Placeholder certificate for peer {peer_id} +# TODO: Implement proper libp2p TLS certificate generation +# This should embed the peer ID in a certificate extension +# according to the libp2p TLS specification +-----END CERTIFICATE-----""" + + +def _generate_placeholder_key(private_key: PrivateKey) -> str: + """ + Generate a placeholder private key. + + This is a temporary implementation for Module 1. + Real implementation will use the actual libp2p private key. + """ + # This is a placeholder - real implementation needed + return """-----BEGIN PRIVATE KEY----- +# Placeholder private key +# TODO: Convert libp2p private key to TLS-compatible format +-----END PRIVATE KEY-----""" + + +def cleanup_tls_config(config: TLSConfig) -> None: + """ + Clean up temporary TLS files. + + Args: + config: TLS configuration to clean up + + """ + try: + if os.path.exists(config.cert_file): + os.unlink(config.cert_file) + if os.path.exists(config.key_file): + os.unlink(config.key_file) + if config.ca_file and os.path.exists(config.ca_file): + os.unlink(config.ca_file) + except Exception: + # Ignore cleanup errors + pass diff --git a/libp2p/transport/quic/stream.py b/libp2p/transport/quic/stream.py index 781cca30..3bff6b4f 100644 --- a/libp2p/transport/quic/stream.py +++ b/libp2p/transport/quic/stream.py @@ -5,16 +5,17 @@ QUIC Stream implementation from types import ( TracebackType, ) +from typing import TYPE_CHECKING, cast import trio -from libp2p.abc import ( - IMuxedStream, -) +if TYPE_CHECKING: + from libp2p.abc import IMuxedStream + + from .connection import QUICConnection +else: + IMuxedStream = cast(type, object) -from .connection import ( - QUICConnection, -) from .exceptions import ( QUICStreamError, ) @@ -41,7 +42,7 @@ class QUICStream(IMuxedStream): self._receive_event = trio.Event() self._close_event = trio.Event() - async def read(self, n: int = -1) -> bytes: + async def read(self, n: int | None = -1) -> bytes: """Read data from the stream.""" if self._closed: raise QUICStreamError("Stream is closed") diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 286c73da..3f8c4004 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -14,9 +14,6 @@ from aioquic.quic.connection import ( QuicConnection, ) import multiaddr -from multiaddr import ( - Multiaddr, -) import trio from libp2p.abc import ( @@ -27,9 +24,15 @@ from libp2p.abc import ( from libp2p.crypto.keys import ( PrivateKey, ) +from libp2p.custom_types import THandler, TProtocol from libp2p.peer.id import ( ID, ) +from libp2p.transport.quic.utils import ( + is_quic_multiaddr, + multiaddr_to_quic_version, + quic_multiaddr_to_endpoint, +) from .config import ( QUICTransportConfig, @@ -41,21 +44,16 @@ from .exceptions import ( QUICDialError, QUICListenError, ) +from .listener import ( + QUICListener, +) + +QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 +QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 logger = logging.getLogger(__name__) -class QUICListener(IListener): - async def close(self): - pass - - async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: - return False - - def get_addrs(self) -> tuple[Multiaddr, ...]: - return () - - class QUICTransport(ITransport): """ QUIC Transport implementation following libp2p transport interface. @@ -65,10 +63,6 @@ class QUICTransport(ITransport): go-libp2p and js-libp2p implementations. """ - # Protocol identifiers matching go-libp2p - PROTOCOL_QUIC_V1 = "/quic-v1" # RFC 9000 - PROTOCOL_QUIC_DRAFT29 = "/quic" # draft-29 - def __init__( self, private_key: PrivateKey, config: QUICTransportConfig | None = None ): @@ -89,7 +83,7 @@ class QUICTransport(ITransport): self._listeners: list[QUICListener] = [] # QUIC configurations for different versions - self._quic_configs: dict[str, QuicConfiguration] = {} + self._quic_configs: dict[TProtocol, QuicConfiguration] = {} self._setup_quic_configurations() # Resource management @@ -110,35 +104,36 @@ class QUICTransport(ITransport): ) # Add TLS certificate generated from libp2p private key - self._setup_tls_configuration(base_config) + # self._setup_tls_configuration(base_config) # QUIC v1 (RFC 9000) configuration quic_v1_config = copy.deepcopy(base_config) quic_v1_config.supported_versions = [0x00000001] # QUIC v1 - self._quic_configs[self.PROTOCOL_QUIC_V1] = quic_v1_config + self._quic_configs[QUIC_V1_PROTOCOL] = quic_v1_config # QUIC draft-29 configuration for compatibility if self._config.enable_draft29: draft29_config = copy.deepcopy(base_config) draft29_config.supported_versions = [0xFF00001D] # draft-29 - self._quic_configs[self.PROTOCOL_QUIC_DRAFT29] = draft29_config + self._quic_configs[QUIC_DRAFT29_PROTOCOL] = draft29_config - def _setup_tls_configuration(self, config: QuicConfiguration) -> None: - """ - Setup TLS configuration with libp2p identity integration. - Similar to go-libp2p's certificate generation approach. - """ - from .security import ( - generate_libp2p_tls_config, - ) + # TODO: SETUP TLS LISTENER + # def _setup_tls_configuration(self, config: QuicConfiguration) -> None: + # """ + # Setup TLS configuration with libp2p identity integration. + # Similar to go-libp2p's certificate generation approach. + # """ + # from .security import ( + # generate_libp2p_tls_config, + # ) - # Generate TLS certificate with embedded libp2p peer ID - # This follows the libp2p TLS spec for peer identity verification - tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id) + # # Generate TLS certificate with embedded libp2p peer ID + # # This follows the libp2p TLS spec for peer identity verification + # tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id) - config.load_cert_chain(tls_config.cert_file, tls_config.key_file) - if tls_config.ca_file: - config.load_verify_locations(tls_config.ca_file) + # config.load_cert_chain(certfile=tls_config.cert_file, keyfile=tls_config.key_file) + # if tls_config.ca_file: + # config.load_verify_locations(tls_config.ca_file) async def dial( self, maddr: multiaddr.Multiaddr, peer_id: ID | None = None @@ -196,14 +191,17 @@ class QUICTransport(ITransport): ) # Establish connection using trio - await connection.connect() + # We need a nursery for this - in real usage, this would be provided + # by the caller or we'd use a transport-level nursery + async with trio.open_nursery() as nursery: + await connection.connect(nursery) # Store connection for management conn_id = f"{host}:{port}:{peer_id}" self._connections[conn_id] = connection # Perform libp2p handshake verification - await connection.verify_peer_identity() + # await connection.verify_peer_identity() logger.info(f"Successfully dialed QUIC connection to {peer_id}") return connection @@ -212,9 +210,7 @@ class QUICTransport(ITransport): logger.error(f"Failed to dial QUIC connection to {maddr}: {e}") raise QUICDialError(f"Dial failed: {e}") from e - def create_listener( - self, handler_function: Callable[[ReadWriteCloser], None] - ) -> IListener: + def create_listener(self, handler_function: THandler) -> IListener: """ Create a QUIC listener. @@ -224,20 +220,22 @@ class QUICTransport(ITransport): Returns: QUIC listener instance + Raises: + QUICListenError: If transport is closed + """ if self._closed: raise QUICListenError("Transport is closed") - # TODO: Create QUIC Listener - # listener = QUICListener( - # transport=self, - # handler_function=handler_function, - # quic_configs=self._quic_configs, - # config=self._config, - # ) - listener = QUICListener() + listener = QUICListener( + transport=self, + handler_function=handler_function, + quic_configs=self._quic_configs, + config=self._config, + ) self._listeners.append(listener) + logger.debug("Created QUIC listener") return listener def can_dial(self, maddr: multiaddr.Multiaddr) -> bool: @@ -253,7 +251,7 @@ class QUICTransport(ITransport): """ return is_quic_multiaddr(maddr) - def protocols(self) -> list[str]: + def protocols(self) -> list[TProtocol]: """ Get supported protocol identifiers. @@ -261,9 +259,9 @@ class QUICTransport(ITransport): List of supported protocol strings """ - protocols = [self.PROTOCOL_QUIC_V1] + protocols = [QUIC_V1_PROTOCOL] if self._config.enable_draft29: - protocols.append(self.PROTOCOL_QUIC_DRAFT29) + protocols.append(QUIC_DRAFT29_PROTOCOL) return protocols def listen_order(self) -> int: @@ -300,6 +298,26 @@ class QUICTransport(ITransport): logger.info("QUIC transport closed") + def get_stats(self) -> dict: + """Get transport statistics.""" + stats = { + "active_connections": len(self._connections), + "active_listeners": len(self._listeners), + "supported_protocols": self.protocols(), + } + + # Aggregate listener stats + listener_stats = {} + for i, listener in enumerate(self._listeners): + listener_stats[f"listener_{i}"] = listener.get_stats() + + if listener_stats: + # TODO: Fix type of listener_stats + # type: ignore + stats["listeners"] = listener_stats + + return stats + def __str__(self) -> str: """String representation of the transport.""" return f"QUICTransport(peer_id={self._peer_id}, protocols={self.protocols()})" diff --git a/libp2p/transport/quic/utils.py b/libp2p/transport/quic/utils.py new file mode 100644 index 00000000..97ad8fa8 --- /dev/null +++ b/libp2p/transport/quic/utils.py @@ -0,0 +1,223 @@ +""" +Multiaddr utilities for QUIC transport. +Handles QUIC-specific multiaddr parsing and validation. +""" + +from typing import Tuple + +import multiaddr + +from libp2p.custom_types import TProtocol + +from .config import QUICTransportConfig + +QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 +QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 +UDP_PROTOCOL = "udp" +IP4_PROTOCOL = "ip4" +IP6_PROTOCOL = "ip6" + + +def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool: + """ + Check if a multiaddr represents a QUIC address. + + Valid QUIC multiaddrs: + - /ip4/127.0.0.1/udp/4001/quic-v1 + - /ip4/127.0.0.1/udp/4001/quic + - /ip6/::1/udp/4001/quic-v1 + - /ip6/::1/udp/4001/quic + + Args: + maddr: Multiaddr to check + + Returns: + True if the multiaddr represents a QUIC address + + """ + try: + # Get protocol names from the multiaddr string + addr_str = str(maddr) + + # Check for required components + has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str + has_udp = f"/{UDP_PROTOCOL}/" in addr_str + has_quic = ( + addr_str.endswith(f"/{QUIC_V1_PROTOCOL}") + or addr_str.endswith(f"/{QUIC_DRAFT29_PROTOCOL}") + or addr_str.endswith("/quic") + ) + + return has_ip and has_udp and has_quic + + except Exception: + return False + + +def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> Tuple[str, int]: + """ + Extract host and port from a QUIC multiaddr. + + Args: + maddr: QUIC multiaddr + + Returns: + Tuple of (host, port) + + Raises: + ValueError: If multiaddr is not a valid QUIC address + + """ + if not is_quic_multiaddr(maddr): + raise ValueError(f"Not a valid QUIC multiaddr: {maddr}") + + try: + # Use multiaddr's value_for_protocol method to extract values + host = None + port = None + + # Try to get IPv4 address + try: + host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore + except ValueError: + pass + + # Try to get IPv6 address if IPv4 not found + if host is None: + try: + host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore + except ValueError: + pass + + # Get UDP port + try: + port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) + port = int(port_str) + except ValueError: + pass + + if host is None or port is None: + raise ValueError(f"Could not extract host/port from {maddr}") + + return host, port + + except Exception as e: + raise ValueError(f"Failed to parse QUIC multiaddr {maddr}: {e}") from e + + +def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol: + """ + Determine QUIC version from multiaddr. + + Args: + maddr: QUIC multiaddr + + Returns: + QUIC version identifier ("/quic-v1" or "/quic") + + Raises: + ValueError: If multiaddr doesn't contain QUIC protocol + + """ + try: + addr_str = str(maddr) + + if f"/{QUIC_V1_PROTOCOL}" in addr_str: + return QUIC_V1_PROTOCOL # RFC 9000 + elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str: + return QUIC_DRAFT29_PROTOCOL # draft-29 + else: + raise ValueError(f"No QUIC protocol found in {maddr}") + + except Exception as e: + raise ValueError(f"Failed to determine QUIC version from {maddr}: {e}") from e + + +def create_quic_multiaddr( + host: str, port: int, version: str = "/quic-v1" +) -> multiaddr.Multiaddr: + """ + Create a QUIC multiaddr from host, port, and version. + + Args: + host: IP address (IPv4 or IPv6) + port: UDP port number + version: QUIC version ("/quic-v1" or "/quic") + + Returns: + QUIC multiaddr + + Raises: + ValueError: If invalid parameters provided + + """ + try: + import ipaddress + + # Determine IP version + try: + ip = ipaddress.ip_address(host) + if isinstance(ip, ipaddress.IPv4Address): + ip_proto = IP4_PROTOCOL + else: + ip_proto = IP6_PROTOCOL + except ValueError: + raise ValueError(f"Invalid IP address: {host}") + + # Validate port + if not (0 <= port <= 65535): + raise ValueError(f"Invalid port: {port}") + + # Validate QUIC version + if version not in ["/quic-v1", "/quic"]: + raise ValueError(f"Invalid QUIC version: {version}") + + # Construct multiaddr + quic_proto = ( + QUIC_V1_PROTOCOL if version == "/quic-v1" else QUIC_DRAFT29_PROTOCOL + ) + addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}" + + return multiaddr.Multiaddr(addr_str) + + except Exception as e: + raise ValueError(f"Failed to create QUIC multiaddr: {e}") from e + + +def is_quic_v1_multiaddr(maddr: multiaddr.Multiaddr) -> bool: + """Check if multiaddr uses QUIC v1 (RFC 9000).""" + try: + return multiaddr_to_quic_version(maddr) == "/quic-v1" + except ValueError: + return False + + +def is_quic_draft29_multiaddr(maddr: multiaddr.Multiaddr) -> bool: + """Check if multiaddr uses QUIC draft-29.""" + try: + return multiaddr_to_quic_version(maddr) == "/quic" + except ValueError: + return False + + +def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr: + """ + Normalize a QUIC multiaddr to canonical form. + + Args: + maddr: Input QUIC multiaddr + + Returns: + Normalized multiaddr + + Raises: + ValueError: If not a valid QUIC multiaddr + + """ + if not is_quic_multiaddr(maddr): + raise ValueError(f"Not a QUIC multiaddr: {maddr}") + + host, port = quic_multiaddr_to_endpoint(maddr) + version = multiaddr_to_quic_version(maddr) + + return create_quic_multiaddr(host, port, version) diff --git a/pyproject.toml b/pyproject.toml index 7f08697e..75191548 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ maintainers = [ { name = "Dave Grantham", email = "dwg@linuxprogrammer.org" }, ] dependencies = [ + "aioquic>=1.2.0", "base58>=1.0.3", "coincurve>=10.0.0", "exceptiongroup>=1.2.0; python_version < '3.11'", diff --git a/tests/core/transport/quic/test_connection.py b/tests/core/transport/quic/test_connection.py new file mode 100644 index 00000000..c368aacb --- /dev/null +++ b/tests/core/transport/quic/test_connection.py @@ -0,0 +1,119 @@ +from unittest.mock import ( + Mock, +) + +import pytest +from multiaddr.multiaddr import Multiaddr + +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.peer.id import ID +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.exceptions import QUICStreamError + + +class TestQUICConnection: + """Test suite for QUIC connection functionality.""" + + @pytest.fixture + def mock_quic_connection(self): + """Create mock aioquic QuicConnection.""" + mock = Mock() + mock.next_event.return_value = None + mock.datagrams_to_send.return_value = [] + mock.get_timer.return_value = None + return mock + + @pytest.fixture + def quic_connection(self, mock_quic_connection): + """Create test QUIC connection.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + return QUICConnection( + quic_connection=mock_quic_connection, + remote_addr=("127.0.0.1", 4001), + peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + + def test_connection_initialization(self, quic_connection): + """Test connection initialization.""" + assert quic_connection._remote_addr == ("127.0.0.1", 4001) + assert quic_connection.is_initiator is True + assert not quic_connection.is_closed + assert not quic_connection.is_established + assert len(quic_connection._streams) == 0 + + def test_stream_id_calculation(self): + """Test stream ID calculation for client/server.""" + # Client connection (initiator) + client_conn = QUICConnection( + quic_connection=Mock(), + remote_addr=("127.0.0.1", 4001), + peer_id=None, + local_peer_id=Mock(), + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + assert client_conn._next_stream_id == 0 # Client starts with 0 + + # Server connection (not initiator) + server_conn = QUICConnection( + quic_connection=Mock(), + remote_addr=("127.0.0.1", 4001), + peer_id=None, + local_peer_id=Mock(), + is_initiator=False, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + assert server_conn._next_stream_id == 1 # Server starts with 1 + + def test_incoming_stream_detection(self, quic_connection): + """Test incoming stream detection logic.""" + # For client (initiator), odd stream IDs are incoming + assert quic_connection._is_incoming_stream(1) is True # Server-initiated + assert quic_connection._is_incoming_stream(0) is False # Client-initiated + assert quic_connection._is_incoming_stream(5) is True # Server-initiated + assert quic_connection._is_incoming_stream(4) is False # Client-initiated + + @pytest.mark.trio + async def test_connection_stats(self, quic_connection): + """Test connection statistics.""" + stats = quic_connection.get_stats() + + expected_keys = [ + "peer_id", + "remote_addr", + "is_initiator", + "is_established", + "is_closed", + "active_streams", + "next_stream_id", + ] + + for key in expected_keys: + assert key in stats + + @pytest.mark.trio + async def test_connection_close(self, quic_connection): + """Test connection close functionality.""" + assert not quic_connection.is_closed + + await quic_connection.close() + + assert quic_connection.is_closed + + @pytest.mark.trio + async def test_stream_operations_on_closed_connection(self, quic_connection): + """Test stream operations on closed connection.""" + await quic_connection.close() + + with pytest.raises(QUICStreamError, match="Connection is closed"): + await quic_connection.open_stream() diff --git a/tests/core/transport/quic/test_listener.py b/tests/core/transport/quic/test_listener.py new file mode 100644 index 00000000..c0874ec4 --- /dev/null +++ b/tests/core/transport/quic/test_listener.py @@ -0,0 +1,171 @@ +from unittest.mock import AsyncMock + +import pytest +from multiaddr.multiaddr import Multiaddr +import trio + +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.transport.quic.exceptions import ( + QUICListenError, +) +from libp2p.transport.quic.listener import QUICListener +from libp2p.transport.quic.transport import ( + QUICTransport, + QUICTransportConfig, +) +from libp2p.transport.quic.utils import ( + create_quic_multiaddr, + quic_multiaddr_to_endpoint, +) + + +class TestQUICListener: + """Test suite for QUIC listener functionality.""" + + @pytest.fixture + def private_key(self): + """Generate test private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def transport_config(self): + """Generate test transport configuration.""" + return QUICTransportConfig(idle_timeout=10.0) + + @pytest.fixture + def transport(self, private_key, transport_config): + """Create test transport instance.""" + return QUICTransport(private_key, transport_config) + + @pytest.fixture + def connection_handler(self): + """Mock connection handler.""" + return AsyncMock() + + @pytest.fixture + def listener(self, transport, connection_handler): + """Create test listener.""" + return transport.create_listener(connection_handler) + + def test_listener_creation(self, transport, connection_handler): + """Test listener creation.""" + listener = transport.create_listener(connection_handler) + + assert isinstance(listener, QUICListener) + assert listener._transport == transport + assert listener._handler == connection_handler + assert not listener._listening + assert not listener._closed + + @pytest.mark.trio + async def test_listener_invalid_multiaddr(self, listener: QUICListener): + """Test listener with invalid multiaddr.""" + async with trio.open_nursery() as nursery: + invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + + with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"): + await listener.listen(invalid_addr, nursery) + + @pytest.mark.trio + async def test_listener_basic_lifecycle(self, listener: QUICListener): + """Test basic listener lifecycle.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random + + async with trio.open_nursery() as nursery: + # Start listening + success = await listener.listen(listen_addr, nursery) + assert success + assert listener.is_listening() + + # Check bound addresses + addrs = listener.get_addrs() + assert len(addrs) == 1 + + # Check stats + stats = listener.get_stats() + assert stats["is_listening"] is True + assert stats["active_connections"] == 0 + assert stats["pending_connections"] == 0 + + # Close listener + await listener.close() + assert not listener.is_listening() + + @pytest.mark.trio + async def test_listener_double_listen(self, listener: QUICListener): + """Test that double listen raises error.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic") + + # The nursery is the outer context + async with trio.open_nursery() as nursery: + # The try/finally is now INSIDE the nursery scope + try: + # The listen method creates the socket and starts background tasks + success = await listener.listen(listen_addr, nursery) + assert success + await trio.sleep(0.01) + + addrs = listener.get_addrs() + assert len(addrs) > 0 + print("ADDRS 1: ", len(addrs)) + print("TEST LOGIC FINISHED") + + async with trio.open_nursery() as nursery2: + with pytest.raises(QUICListenError, match="Already listening"): + await listener.listen(listen_addr, nursery2) + finally: + # This block runs BEFORE the 'async with nursery' exits. + print("INNER FINALLY: Closing listener to release socket...") + + # This closes the socket and sets self._listening = False, + # which helps the background tasks terminate cleanly. + await listener.close() + print("INNER FINALLY: Listener closed.") + + # By the time we get here, the listener and its tasks have been fully + # shut down, allowing the nursery to exit without hanging. + print("TEST COMPLETED SUCCESSFULLY.") + + @pytest.mark.trio + async def test_listener_port_binding(self, listener: QUICListener): + """Test listener port binding and cleanup.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + # The nursery is the outer context + async with trio.open_nursery() as nursery: + # The try/finally is now INSIDE the nursery scope + try: + # The listen method creates the socket and starts background tasks + success = await listener.listen(listen_addr, nursery) + assert success + await trio.sleep(0.5) + + addrs = listener.get_addrs() + assert len(addrs) > 0 + print("TEST LOGIC FINISHED") + + finally: + # This block runs BEFORE the 'async with nursery' exits. + print("INNER FINALLY: Closing listener to release socket...") + + # This closes the socket and sets self._listening = False, + # which helps the background tasks terminate cleanly. + await listener.close() + print("INNER FINALLY: Listener closed.") + + # By the time we get here, the listener and its tasks have been fully + # shut down, allowing the nursery to exit without hanging. + print("TEST COMPLETED SUCCESSFULLY.") + + @pytest.mark.trio + async def test_listener_stats_tracking(self, listener): + """Test listener statistics tracking.""" + initial_stats = listener.get_stats() + + # All counters should start at 0 + assert initial_stats["connections_accepted"] == 0 + assert initial_stats["connections_rejected"] == 0 + assert initial_stats["bytes_received"] == 0 + assert initial_stats["packets_processed"] == 0 diff --git a/tests/core/transport/quic/test_transport.py b/tests/core/transport/quic/test_transport.py index fd5e8e88..59623e90 100644 --- a/tests/core/transport/quic/test_transport.py +++ b/tests/core/transport/quic/test_transport.py @@ -7,6 +7,7 @@ import pytest from libp2p.crypto.ed25519 import ( create_new_key_pair, ) +from libp2p.crypto.keys import PrivateKey from libp2p.transport.quic.exceptions import ( QUICDialError, QUICListenError, @@ -23,7 +24,7 @@ class TestQUICTransport: @pytest.fixture def private_key(self): """Generate test private key.""" - return create_new_key_pair() + return create_new_key_pair().private_key @pytest.fixture def transport_config(self): @@ -33,7 +34,7 @@ class TestQUICTransport: ) @pytest.fixture - def transport(self, private_key, transport_config): + def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig): """Create test transport instance.""" return QUICTransport(private_key, transport_config) @@ -47,18 +48,35 @@ class TestQUICTransport: def test_supported_protocols(self, transport): """Test supported protocol identifiers.""" protocols = transport.protocols() - assert "/quic-v1" in protocols - assert "/quic" in protocols # draft-29 + # TODO: Update when quic-v1 compatible + # assert "quic-v1" in protocols + assert "quic" in protocols # draft-29 - def test_can_dial_quic_addresses(self, transport): + def test_can_dial_quic_addresses(self, transport: QUICTransport): """Test multiaddr compatibility checking.""" import multiaddr # Valid QUIC addresses valid_addrs = [ - multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1"), - multiaddr.Multiaddr("/ip4/192.168.1.1/udp/8080/quic"), - multiaddr.Multiaddr("/ip6/::1/udp/4001/quic-v1"), + # TODO: Update Multiaddr package to accept quic-v1 + multiaddr.Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + multiaddr.Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + multiaddr.Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), ] for addr in valid_addrs: @@ -93,7 +111,7 @@ class TestQUICTransport: await transport.close() with pytest.raises(QUICDialError, match="Transport is closed"): - await transport.dial(multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1")) + await transport.dial(multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic")) def test_create_listener_closed_transport(self, transport): """Test creating listener with closed transport raises error.""" diff --git a/tests/core/transport/quic/test_utils.py b/tests/core/transport/quic/test_utils.py new file mode 100644 index 00000000..d67317c7 --- /dev/null +++ b/tests/core/transport/quic/test_utils.py @@ -0,0 +1,94 @@ +import pytest +from multiaddr.multiaddr import Multiaddr + +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.utils import ( + create_quic_multiaddr, + is_quic_multiaddr, + multiaddr_to_quic_version, + quic_multiaddr_to_endpoint, +) + + +class TestQUICUtils: + """Test suite for QUIC utility functions.""" + + def test_is_quic_multiaddr(self): + """Test QUIC multiaddr validation.""" + # Valid QUIC multiaddrs + valid = [ + # TODO: Update Multiaddr package to accept quic-v1 + Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + ] + + for addr in valid: + assert is_quic_multiaddr(addr) + + # Invalid multiaddrs + invalid = [ + Multiaddr("/ip4/127.0.0.1/tcp/4001"), + Multiaddr("/ip4/127.0.0.1/udp/4001"), + Multiaddr("/ip4/127.0.0.1/udp/4001/ws"), + ] + + for addr in invalid: + assert not is_quic_multiaddr(addr) + + def test_quic_multiaddr_to_endpoint(self): + """Test multiaddr to endpoint conversion.""" + addr = Multiaddr("/ip4/192.168.1.100/udp/4001/quic") + host, port = quic_multiaddr_to_endpoint(addr) + + assert host == "192.168.1.100" + assert port == 4001 + + # Test IPv6 + # TODO: Update Multiaddr project to handle ip6 + # addr6 = Multiaddr("/ip6/::1/udp/8080/quic") + # host6, port6 = quic_multiaddr_to_endpoint(addr6) + + # assert host6 == "::1" + # assert port6 == 8080 + + def test_create_quic_multiaddr(self): + """Test QUIC multiaddr creation.""" + # IPv4 + addr = create_quic_multiaddr("127.0.0.1", 4001, "/quic") + assert str(addr) == "/ip4/127.0.0.1/udp/4001/quic" + + # IPv6 + addr6 = create_quic_multiaddr("::1", 8080, "/quic") + assert str(addr6) == "/ip6/::1/udp/8080/quic" + + def test_multiaddr_to_quic_version(self): + """Test QUIC version extraction.""" + addr = Multiaddr("/ip4/127.0.0.1/udp/4001/quic") + version = multiaddr_to_quic_version(addr) + assert version in ["quic", "quic-v1"] # Depending on implementation + + def test_invalid_multiaddr_operations(self): + """Test error handling for invalid multiaddrs.""" + invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + + with pytest.raises(ValueError): + quic_multiaddr_to_endpoint(invalid_addr) + + with pytest.raises(ValueError): + multiaddr_to_quic_version(invalid_addr)