diff --git a/examples/echo/echo_quic.py b/examples/echo/echo_quic.py new file mode 100644 index 00000000..a2f8ffd0 --- /dev/null +++ b/examples/echo/echo_quic.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +QUIC Echo Example - Direct replacement for examples/echo/echo.py + +This program demonstrates a simple echo protocol using QUIC transport where a peer +listens for connections and copies back any input received on a stream. + +Modified from the original TCP version to use QUIC transport, providing: +- Built-in TLS security +- Native stream multiplexing +- Better performance over UDP +- Modern QUIC protocol features +""" + +import argparse + +import multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.network.stream.net_stream import INetStream +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.transport.quic.config import QUICTransportConfig + +PROTOCOL_ID = TProtocol("/echo/1.0.0") + + +async def _echo_stream_handler(stream: INetStream) -> None: + """ + Echo stream handler - unchanged from TCP version. + + Demonstrates transport abstraction: same handler works for both TCP and QUIC. + """ + # Wait until EOF + msg = await stream.read() + await stream.write(msg) + await stream.close() + + +async def run(port: int, destination: str, seed: int | None = None) -> None: + """ + Run echo server or client with QUIC transport. + + Key changes from TCP version: + 1. UDP multiaddr instead of TCP + 2. QUIC transport configuration + 3. Everything else remains the same! + """ + # CHANGED: UDP + QUIC instead of TCP + listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic") + + if seed: + import random + + random.seed(seed) + secret_number = random.getrandbits(32 * 8) + secret = secret_number.to_bytes(length=32, byteorder="big") + else: + import secrets + + secret = secrets.token_bytes(32) + + # NEW: QUIC transport configuration + quic_config = QUICTransportConfig( + idle_timeout=30.0, + max_concurrent_streams=1000, + connection_timeout=10.0, + ) + + # CHANGED: Add QUIC transport options + host = new_host( + key_pair=create_new_key_pair(secret), + transport_opt={"quic_config": quic_config}, + ) + + async with host.run(listen_addrs=[listen_addr]): + print(f"I am {host.get_id().to_string()}") + + if not destination: # Server mode + host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + + print( + "Run this from the same folder in another console:\n\n" + f"python3 ./examples/echo/echo_quic.py " + f"-d {host.get_addrs()[0]}\n" + ) + print("Waiting for incoming QUIC connections...") + await trio.sleep_forever() + + else: # Client mode + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + # Associate the peer with local ip address + await host.connect(info) + + # Start a stream with the destination. + # Multiaddress of the destination peer is fetched from the peerstore + # using 'peerId'. + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + + msg = b"hi, there!\n" + + await stream.write(msg) + # Notify the other side about EOF + await stream.close() + response = await stream.read() + + print(f"Sent: {msg.decode('utf-8')}") + print(f"Got: {response.decode('utf-8')}") + + +def main() -> None: + """Main function - help text updated for QUIC.""" + description = """ + This program demonstrates a simple echo protocol using QUIC + transport where a peer listens for connections and copies back + any input received on a stream. + + QUIC provides built-in TLS security and stream multiplexing over UDP. + + To use it, first run 'python ./echo.py -p ', where is + the UDP port number.Then, run another host with , + 'python ./echo.py -p -d ' + where is the QUIC multiaddress of the previous listener host. + """ + + example_maddr = "/ip4/127.0.0.1/udp/8000/quic/p2p/QmQn4SwGkDZKkUEpBRBv" + + parser = argparse.ArgumentParser(description=description) + parser.add_argument("-p", "--port", default=8000, type=int, help="UDP port number") + parser.add_argument( + "-d", + "--destination", + type=str, + help=f"destination multiaddr string, e.g. {example_maddr}", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + help="provide a seed to the random number generator", + ) + args = parser.parse_args() + try: + trio.run(run, args.port, args.destination, args.seed) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 350ae46b..59a42ff6 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,3 +1,7 @@ +from libp2p.transport.quic.utils import is_quic_multiaddr +from typing import Any +from libp2p.transport.quic.transport import QUICTransport +from libp2p.transport.quic.config import QUICTransportConfig from collections.abc import ( Mapping, Sequence, @@ -5,16 +9,12 @@ from collections.abc import ( from importlib.metadata import version as __version from typing import ( Literal, - Optional, - Type, - cast, ) import multiaddr from libp2p.abc import ( IHost, - IMuxedConn, INetworkService, IPeerRouting, IPeerStore, @@ -163,6 +163,7 @@ def new_swarm( peerstore_opt: IPeerStore | None = None, muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, + transport_opt: dict[Any, Any] | None = None, ) -> INetworkService: """ Create a swarm instance based on the parameters. @@ -173,6 +174,7 @@ def new_swarm( :param peerstore_opt: optional peerstore :param muxer_preference: optional explicit muxer preference :param listen_addrs: optional list of multiaddrs to listen on + :param transport_opt: options for transport :return: return a default swarm instance Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer @@ -185,14 +187,24 @@ def new_swarm( id_opt = generate_peer_id_from(key_pair) + transport: TCP | QUICTransport + if listen_addrs is None: - transport = TCP() + transport_opt = transport_opt or {} + quic_config: QUICTransportConfig | None = transport_opt.get('quic_config') + + if quic_config: + transport = QUICTransport(key_pair.private_key, quic_config) + else: + transport = TCP() else: addr = listen_addrs[0] if addr.__contains__("tcp"): transport = TCP() elif addr.__contains__("quic"): - raise ValueError("QUIC not yet supported") + transport_opt = transport_opt or {} + quic_config = transport_opt.get('quic_config', QUICTransportConfig()) + transport = QUICTransport(key_pair.private_key, quic_config) else: raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}") @@ -253,6 +265,7 @@ def new_host( enable_mDNS: bool = False, bootstrap: list[str] | None = None, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, + transport_opt: dict[Any, Any] | None = None, ) -> IHost: """ Create a new libp2p host based on the given parameters. @@ -266,8 +279,10 @@ def new_host( :param listen_addrs: optional list of multiaddrs to listen on :param enable_mDNS: whether to enable mDNS discovery :param bootstrap: optional list of bootstrap peer addresses as strings + :param transport_opt: optional dictionary of properties of transport :return: return a host instance """ + print("INIT") swarm = new_swarm( key_pair=key_pair, muxer_opt=muxer_opt, @@ -275,6 +290,7 @@ def new_host( peerstore_opt=peerstore_opt, muxer_preference=muxer_preference, listen_addrs=listen_addrs, + transport_opt=transport_opt ) if disc_opt is not None: diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 67d46279..331a0ce4 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -170,14 +170,7 @@ class Swarm(Service, INetworkService): async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn: """ Try to create a connection to peer_id with addr. - - :param addr: the address we want to connect with - :param peer_id: the peer we want to connect to - :raises SwarmException: raised when an error occurs - :return: network connection """ - # Dial peer (connection to peer does not yet exist) - # Transport dials peer (gets back a raw conn) try: raw_conn = await self.transport.dial(addr) except OpenConnectionError as error: @@ -188,8 +181,15 @@ class Swarm(Service, INetworkService): logger.debug("dialed peer %s over base transport", peer_id) - # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure - # the conn and then mux the conn + # NEW: Check if this is a QUIC connection (already secure and muxed) + if isinstance(raw_conn, IMuxedConn): + # QUIC connections are already secure and muxed, skip upgrade steps + logger.debug("detected QUIC connection, skipping upgrade steps") + swarm_conn = await self.add_conn(raw_conn) + logger.debug("successfully dialed peer %s via QUIC", peer_id) + return swarm_conn + + # Standard TCP flow - security then mux upgrade try: secured_conn = await self.upgrader.upgrade_security(raw_conn, True, peer_id) except SecurityUpgradeFailure as error: @@ -211,9 +211,7 @@ class Swarm(Service, INetworkService): logger.debug("upgraded mux for peer %s", peer_id) swarm_conn = await self.add_conn(muxed_conn) - logger.debug("successfully dialed peer %s", peer_id) - return swarm_conn async def new_stream(self, peer_id: ID) -> INetStream: diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py index d6b53519..abdb3d8f 100644 --- a/libp2p/transport/quic/connection.py +++ b/libp2p/transport/quic/connection.py @@ -34,6 +34,11 @@ if TYPE_CHECKING: from .security import QUICTLSConfigManager from .transport import QUICTransport +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) logger = logging.getLogger(__name__) @@ -286,11 +291,13 @@ class QUICConnection(IRawConnection, IMuxedConn): try: with QUICErrorContext("connection_establishment", "connection"): # Start the connection if not already started + print("STARTING TO CONNECT") if not self._started: await self.start() # Start background event processing if not self._background_tasks_started: + print("STARTING BACKGROUND TASK") await self._start_background_tasks() # Wait for handshake completion with timeout @@ -324,16 +331,17 @@ class QUICConnection(IRawConnection, IMuxedConn): self._background_tasks_started = True # Start event processing task - self._nursery.start_soon(self._event_processing_loop) + self._nursery.start_soon(async_fn=self._event_processing_loop) # Start periodic tasks - self._nursery.start_soon(self._periodic_maintenance) + # self._nursery.start_soon(async_fn=self._periodic_maintenance) logger.debug("Started background tasks for QUIC connection") async def _event_processing_loop(self) -> None: """Main event processing loop for the connection.""" logger.debug("Started QUIC event processing loop") + print("Started QUIC event processing loop") try: while not self._closed: @@ -347,7 +355,7 @@ class QUICConnection(IRawConnection, IMuxedConn): await self._transmit() # Short sleep to prevent busy waiting - await trio.sleep(0.001) # 1ms + await trio.sleep(0.01) except Exception as e: logger.error(f"Error in event processing loop: {e}") @@ -381,6 +389,7 @@ class QUICConnection(IRawConnection, IMuxedConn): QUICPeerVerificationError: If peer verification fails """ + print("VERIFYING PEER IDENTITY") if not self._security_manager: logger.warning("No security manager available for peer verification") return @@ -719,6 +728,7 @@ class QUICConnection(IRawConnection, IMuxedConn): async def _handle_quic_event(self, event: events.QuicEvent) -> None: """Handle a single QUIC event.""" + print(f"QUIC event: {type(event).__name__}") if isinstance(event, events.ConnectionTerminated): await self._handle_connection_terminated(event) elif isinstance(event, events.HandshakeCompleted): @@ -731,6 +741,7 @@ class QUICConnection(IRawConnection, IMuxedConn): await self._handle_datagram_received(event) else: logger.debug(f"Unhandled QUIC event: {type(event).__name__}") + print(f"Unhandled QUIC event: {type(event).__name__}") async def _handle_handshake_completed( self, event: events.HandshakeCompleted @@ -897,6 +908,7 @@ class QUICConnection(IRawConnection, IMuxedConn): """Send pending datagrams using trio.""" sock = self._socket if not sock: + print("No socket to transmit") return try: diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py index 91a9c007..4cbc8e74 100644 --- a/libp2p/transport/quic/listener.py +++ b/libp2p/transport/quic/listener.py @@ -1,14 +1,12 @@ """ -QUIC Listener implementation for py-libp2p. -Based on go-libp2p and js-libp2p QUIC listener patterns. -Uses aioquic's server-side QUIC implementation with trio. +QUIC Listener """ -import copy import logging import socket +import struct import time -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from aioquic.quic import events from aioquic.quic.configuration import QuicConfiguration @@ -19,12 +17,14 @@ import trio from libp2p.abc import IListener from libp2p.custom_types import THandler, TProtocol from libp2p.transport.quic.security import QUICTLSConfigManager +from libp2p.transport.quic.utils import custom_quic_version_to_wire_format from .config import QUICTransportConfig from .connection import QUICConnection from .exceptions import QUICListenError from .utils import ( create_quic_multiaddr, + create_server_config_from_base, is_quic_multiaddr, multiaddr_to_quic_version, quic_multiaddr_to_endpoint, @@ -33,17 +33,41 @@ from .utils import ( if TYPE_CHECKING: from .transport import QUICTransport +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) logger = logging.getLogger(__name__) -logger.setLevel("DEBUG") + + +class QUICPacketInfo: + """Information extracted from a QUIC packet header.""" + + def __init__( + self, + version: int, + destination_cid: bytes, + source_cid: bytes, + packet_type: int, + token: bytes | None = None, + ): + self.version = version + self.destination_cid = destination_cid + self.source_cid = source_cid + self.packet_type = packet_type + self.token = token class QUICListener(IListener): """ - QUIC Listener implementation following libp2p listener interface. + Enhanced QUIC Listener with proper connection ID handling and protocol negotiation. - Handles incoming QUIC connections, manages server-side handshakes, - and integrates with the libp2p connection handler system. - Based on go-libp2p and js-libp2p listener patterns. + Key improvements: + - Proper QUIC packet parsing to extract connection IDs + - Version negotiation following RFC 9000 + - Connection routing based on destination connection ID + - Support for connection migration """ def __init__( @@ -54,17 +78,7 @@ class QUICListener(IListener): config: QUICTransportConfig, security_manager: QUICTLSConfigManager | None = None, ): - """ - Initialize QUIC listener. - - Args: - transport: Parent QUIC transport - handler_function: Function to handle new connections - quic_configs: QUIC configurations for different versions - config: QUIC transport configuration - security_manager: Security manager for TLS/certificate handling - - """ + """Initialize enhanced QUIC listener.""" self._transport = transport self._handler = handler_function self._quic_configs = quic_configs @@ -75,11 +89,24 @@ class QUICListener(IListener): self._socket: trio.socket.SocketType | None = None self._bound_addresses: list[Multiaddr] = [] - # Connection management - self._connections: dict[tuple[str, int], QUICConnection] = {} - self._pending_connections: dict[tuple[str, int], QuicConnection] = {} + # Enhanced connection management with connection ID routing + self._connections: dict[ + bytes, QUICConnection + ] = {} # destination_cid -> connection + self._pending_connections: dict[ + bytes, QuicConnection + ] = {} # destination_cid -> quic_conn + self._addr_to_cid: dict[ + tuple[str, int], bytes + ] = {} # (host, port) -> destination_cid + self._cid_to_addr: dict[ + bytes, tuple[str, int] + ] = {} # destination_cid -> (host, port) self._connection_lock = trio.Lock() + # Version negotiation support + self._supported_versions = self._get_supported_versions() + # Listener state self._closed = False self._listening = False @@ -89,164 +116,321 @@ class QUICListener(IListener): self._stats = { "connections_accepted": 0, "connections_rejected": 0, + "version_negotiations": 0, "bytes_received": 0, "packets_processed": 0, + "invalid_packets": 0, } - logger.debug("Initialized QUIC listener") + logger.debug("Initialized enhanced QUIC listener with connection ID support") - async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: - """ - Start listening on the given multiaddr. - - Args: - maddr: Multiaddr to listen on - nursery: Trio nursery for managing background tasks - - Returns: - True if listening started successfully - - Raises: - QUICListenError: If failed to start listening - - """ - if not is_quic_multiaddr(maddr): - raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}") - - if self._listening: - raise QUICListenError("Already listening") - - try: - # Extract host and port from multiaddr - host, port = quic_multiaddr_to_endpoint(maddr) - quic_version = multiaddr_to_quic_version(maddr) - - protocol = f"{quic_version}_server" - - # Validate QUIC version support - if protocol not in self._quic_configs: - raise QUICListenError(f"Unsupported QUIC version: {quic_version}") - - # Create and bind UDP socket - self._socket = await self._create_and_bind_socket(host, port) - actual_port = self._socket.getsockname()[1] - - # Update multiaddr with actual bound port - actual_maddr = create_quic_multiaddr(host, actual_port, f"/{quic_version}") - self._bound_addresses = [actual_maddr] - - # Store nursery reference and set listening state - self._nursery = nursery - self._listening = True - - # Start background tasks directly in the provided nursery - # This e per cancellation when the nursery exits - nursery.start_soon(self._handle_incoming_packets) - nursery.start_soon(self._manage_connections) - - logger.info(f"QUIC listener started on {actual_maddr}") - return True - - except trio.Cancelled: - print("CLOSING LISTENER") - raise - except Exception as e: - logger.error(f"Failed to start QUIC listener on {maddr}: {e}") - await self._cleanup_socket() - raise QUICListenError(f"Listen failed: {e}") from e - - async def _create_and_bind_socket( - self, host: str, port: int - ) -> trio.socket.SocketType: - """Create and bind UDP socket for QUIC.""" - try: - # Determine address family + def _get_supported_versions(self) -> set[int]: + """Get wire format versions for all supported QUIC configurations.""" + versions: set[int] = set() + for protocol in self._quic_configs: try: - import ipaddress + config = self._quic_configs[protocol] + wire_versions = config.supported_versions + for version in wire_versions: + versions.add(version) + except Exception as e: + logger.warning(f"Failed to get wire version for {protocol}: {e}") + return versions - ip = ipaddress.ip_address(host) - family = socket.AF_INET if ip.version == 4 else socket.AF_INET6 - except ValueError: - # Assume IPv4 for hostnames - family = socket.AF_INET + def parse_quic_packet(self, data: bytes) -> QUICPacketInfo | None: + """ + Parse QUIC packet header to extract connection IDs and version. + Based on RFC 9000 packet format. + """ + try: + if len(data) < 1: + return None - # Create UDP socket - sock = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + # Read first byte to get packet type and flags + first_byte = data[0] - # Set socket options for better performance - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if hasattr(socket, "SO_REUSEPORT"): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + # Check if this is a long header packet (version negotiation, initial, etc.) + is_long_header = (first_byte & 0x80) != 0 - # Bind to address - await sock.bind((host, port)) + if not is_long_header: + # Short header packet - extract destination connection ID + # For short headers, we need to know the connection ID length + # This is typically managed by the connection state + # For now, we'll handle this in the connection routing logic + return None - logger.debug(f"Created and bound UDP socket to {host}:{port}") - return sock + # Long header packet parsing + offset = 1 + + # Extract version (4 bytes) + if len(data) < offset + 4: + return None + version = struct.unpack("!I", data[offset : offset + 4])[0] + offset += 4 + + # Extract destination connection ID length and value + if len(data) < offset + 1: + return None + dest_cid_len = data[offset] + offset += 1 + + if len(data) < offset + dest_cid_len: + return None + dest_cid = data[offset : offset + dest_cid_len] + offset += dest_cid_len + + # Extract source connection ID length and value + if len(data) < offset + 1: + return None + src_cid_len = data[offset] + offset += 1 + + if len(data) < offset + src_cid_len: + return None + src_cid = data[offset : offset + src_cid_len] + offset += src_cid_len + + # Determine packet type from first byte + packet_type = (first_byte & 0x30) >> 4 + + # For Initial packets, extract token + token = b"" + if packet_type == 0: # Initial packet + if len(data) < offset + 1: + return None + # Token length is variable-length integer + token_len, token_len_bytes = self._decode_varint(data[offset:]) + offset += token_len_bytes + + if len(data) < offset + token_len: + return None + token = data[offset : offset + token_len] + + return QUICPacketInfo( + version=version, + destination_cid=dest_cid, + source_cid=src_cid, + packet_type=packet_type, + token=token, + ) except Exception as e: - raise QUICListenError(f"Failed to create socket: {e}") from e + logger.debug(f"Failed to parse QUIC packet: {e}") + return None - async def _handle_incoming_packets(self) -> None: - """ - Handle incoming UDP packets and route to appropriate connections. - This is the main packet processing loop. - """ - logger.debug("Started packet handling loop") + def _decode_varint(self, data: bytes) -> tuple[int, int]: + """Decode QUIC variable-length integer.""" + if len(data) < 1: + return 0, 0 - try: - while self._listening and self._socket: - try: - # Receive UDP packet - # (this blocks until packet arrives or socket closes) - data, addr = await self._socket.recvfrom(65536) - self._stats["bytes_received"] += len(data) - self._stats["packets_processed"] += 1 + first_byte = data[0] + length_bits = (first_byte & 0xC0) >> 6 - # Process packet asynchronously to avoid blocking - if self._nursery: - self._nursery.start_soon(self._process_packet, data, addr) - - except trio.ClosedResourceError: - # Socket was closed, exit gracefully - logger.debug("Socket closed, exiting packet handler") - break - except Exception as e: - logger.error(f"Error receiving packet: {e}") - # Continue processing other packets - await trio.sleep(0.01) - except trio.Cancelled: - logger.info("Received Cancel, stopping handling incoming packets") - raise - finally: - logger.debug("Packet handling loop terminated") + if length_bits == 0: + return first_byte & 0x3F, 1 + elif length_bits == 1: + if len(data) < 2: + return 0, 0 + return ((first_byte & 0x3F) << 8) | data[1], 2 + elif length_bits == 2: + if len(data) < 4: + return 0, 0 + return ((first_byte & 0x3F) << 24) | (data[1] << 16) | ( + data[2] << 8 + ) | data[3], 4 + else: # length_bits == 3 + if len(data) < 8: + return 0, 0 + value = (first_byte & 0x3F) << 56 + for i in range(1, 8): + value |= data[i] << (8 * (7 - i)) + return value, 8 async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: """ - Process a single incoming packet. - Routes to existing connection or creates new connection. - - Args: - data: Raw UDP packet data - addr: Source address (host, port) - + Enhanced packet processing with connection ID routing and version negotiation. """ try: + self._stats["packets_processed"] += 1 + self._stats["bytes_received"] += len(data) + + # Parse packet to extract connection information + packet_info = self.parse_quic_packet(data) + async with self._connection_lock: - # Check if we have an existing connection for this address - if addr in self._connections: - connection = self._connections[addr] - await self._route_to_connection(connection, data, addr) - elif addr in self._pending_connections: - # Handle packet for pending connection - quic_conn = self._pending_connections[addr] - await self._handle_pending_connection(quic_conn, data, addr) + if packet_info: + # Check for version negotiation + if packet_info.version == 0: + # Version negotiation packet - this shouldn't happen on server + logger.warning( + f"Received version negotiation packet from {addr}" + ) + return + + # Check if version is supported + if packet_info.version not in self._supported_versions: + await self._send_version_negotiation( + addr, packet_info.source_cid + ) + return + + # Route based on destination connection ID + dest_cid = packet_info.destination_cid + + if dest_cid in self._connections: + # Existing connection + connection = self._connections[dest_cid] + await self._route_to_connection(connection, data, addr) + elif dest_cid in self._pending_connections: + # Pending connection + quic_conn = self._pending_connections[dest_cid] + await self._handle_pending_connection( + quic_conn, data, addr, dest_cid + ) + else: + # New connection - only handle Initial packets for new conn + if packet_info.packet_type == 0: # Initial packet + await self._handle_new_connection(data, addr, packet_info) + else: + logger.debug( + "Ignoring non-Initial packet for unknown " + f"connection ID from {addr}" + ) else: - # New connection - await self._handle_new_connection(data, addr) + # Fallback to address-based routing for short header packets + await self._handle_short_header_packet(data, addr) except Exception as e: logger.error(f"Error processing packet from {addr}: {e}") + self._stats["invalid_packets"] += 1 + + async def _send_version_negotiation( + self, addr: tuple[str, int], source_cid: bytes + ) -> None: + """Send version negotiation packet to client.""" + try: + self._stats["version_negotiations"] += 1 + + # Construct version negotiation packet + packet = bytearray() + + # First byte: long header (1) + unused bits (0111) + packet.append(0x80 | 0x70) + + # Version: 0 for version negotiation + packet.extend(struct.pack("!I", 0)) + + # Destination connection ID (echo source CID from client) + packet.append(len(source_cid)) + packet.extend(source_cid) + + # Source connection ID (empty for version negotiation) + packet.append(0) + + # Supported versions + for version in sorted(self._supported_versions): + packet.extend(struct.pack("!I", version)) + + # Send the packet + if self._socket: + await self._socket.sendto(bytes(packet), addr) + logger.debug( + f"Sent version negotiation to {addr} " + f"with versions {sorted(self._supported_versions)}" + ) + + except Exception as e: + logger.error(f"Failed to send version negotiation to {addr}: {e}") + + async def _handle_new_connection( + self, + data: bytes, + addr: tuple[str, int], + packet_info: QUICPacketInfo, + ) -> None: + """ + Handle new connection with proper version negotiation. + """ + try: + quic_config = None + for protocol, config in self._quic_configs.items(): + wire_versions = custom_quic_version_to_wire_format(protocol) + if wire_versions == packet_info.version: + print("PROTOCOL:", protocol) + quic_config = config + break + + if not quic_config: + logger.warning( + f"No configuration found for version {packet_info.version:08x}" + ) + await self._send_version_negotiation(addr, packet_info.source_cid) + return + + # Create server-side QUIC configuration + server_config = create_server_config_from_base( + base_config=quic_config, + security_manager=self._security_manager, + transport_config=self._config, + ) + + # Generate a new destination connection ID for this connection + # In a real implementation, this should be cryptographically secure + import secrets + + destination_cid = secrets.token_bytes(8) + + # Create QUIC connection with specific version + quic_conn = QuicConnection( + configuration=server_config, + original_destination_connection_id=packet_info.destination_cid, + ) + + # Store connection mapping + self._pending_connections[destination_cid] = quic_conn + self._addr_to_cid[addr] = destination_cid + self._cid_to_addr[destination_cid] = addr + + print("Receiving Datagram") + + # Process initial packet + quic_conn.receive_datagram(data, addr, now=time.time()) + print("Processing quic events") + await self._process_quic_events(quic_conn, addr, destination_cid) + await self._transmit_for_connection(quic_conn, addr) + + logger.debug( + f"Started handshake for new connection from {addr} " + f"(version: {packet_info.version:08x}, cid: {destination_cid.hex()})" + ) + + except Exception as e: + logger.error(f"Error handling new connection from {addr}: {e}") + self._stats["connections_rejected"] += 1 + + async def _handle_short_header_packet( + self, data: bytes, addr: tuple[str, int] + ) -> None: + """Handle short header packets using address-based fallback routing.""" + try: + # Check if we have a connection for this address + dest_cid = self._addr_to_cid.get(addr) + if dest_cid: + if dest_cid in self._connections: + connection = self._connections[dest_cid] + await self._route_to_connection(connection, data, addr) + elif dest_cid in self._pending_connections: + quic_conn = self._pending_connections[dest_cid] + await self._handle_pending_connection( + quic_conn, data, addr, dest_cid + ) + else: + logger.debug( + f"Received short header packet from unknown address {addr}" + ) + + except Exception as e: + logger.error(f"Error handling short header packet from {addr}: {e}") async def _route_to_connection( self, connection: QUICConnection, data: bytes, addr: tuple[str, int] @@ -263,10 +447,14 @@ class QUICListener(IListener): except Exception as e: logger.error(f"Error routing packet to connection {addr}: {e}") # Remove problematic connection - await self._remove_connection(addr) + await self._remove_connection_by_addr(addr) async def _handle_pending_connection( - self, quic_conn: QuicConnection, data: bytes, addr: tuple[str, int] + self, + quic_conn: QuicConnection, + data: bytes, + addr: tuple[str, int], + dest_cid: bytes, ) -> None: """Handle packet for a pending (handshaking) connection.""" try: @@ -274,58 +462,20 @@ class QUICListener(IListener): quic_conn.receive_datagram(data, addr, now=time.time()) # Process events - await self._process_quic_events(quic_conn, addr) + await self._process_quic_events(quic_conn, addr, dest_cid) # Send any outgoing packets - await self._transmit_for_connection(quic_conn) + await self._transmit_for_connection(quic_conn, addr) except Exception as e: - logger.error(f"Error handling pending connection {addr}: {e}") + logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}") # Remove from pending connections - self._pending_connections.pop(addr, None) - - async def _handle_new_connection(self, data: bytes, addr: tuple[str, int]) -> None: - """ - Handle a new incoming connection. - Creates a new QUIC connection and starts handshake. - - Args: - data: Initial packet data - addr: Source address - - """ - try: - # Determine QUIC version from packet - # For now, use the first available configuration - # TODO: Implement proper version negotiation - quic_version = next(iter(self._quic_configs.keys())) - config = self._quic_configs[quic_version] - - # Create server-side QUIC configuration - server_config = copy.deepcopy(config) - server_config.is_client = False - - # Create QUIC connection - quic_conn = QuicConnection(configuration=server_config) - - # Store as pending connection - self._pending_connections[addr] = quic_conn - - # Process initial packet - quic_conn.receive_datagram(data, addr, now=time.time()) - await self._process_quic_events(quic_conn, addr) - await self._transmit_for_connection(quic_conn) - - logger.debug(f"Started handshake for new connection from {addr}") - - except Exception as e: - logger.error(f"Error handling new connection from {addr}: {e}") - self._stats["connections_rejected"] += 1 + await self._remove_pending_connection(dest_cid) async def _process_quic_events( - self, quic_conn: QuicConnection, addr: tuple[str, int] + self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes ) -> None: - """Process QUIC events for a connection.""" + """Process QUIC events for a connection with connection ID context.""" while True: event = quic_conn.next_event() if event is None: @@ -333,46 +483,39 @@ class QUICListener(IListener): if isinstance(event, events.ConnectionTerminated): logger.debug( - f"Connection from {addr} terminated: {event.reason_phrase}" + f"Connection {dest_cid.hex()} from {addr} " + f"terminated: {event.reason_phrase}" ) - await self._remove_connection(addr) + await self._remove_connection(dest_cid) break elif isinstance(event, events.HandshakeCompleted): - logger.debug(f"Handshake completed for {addr}") - await self._promote_pending_connection(quic_conn, addr) + logger.debug(f"Handshake completed for connection {dest_cid.hex()}") + await self._promote_pending_connection(quic_conn, addr, dest_cid) elif isinstance(event, events.StreamDataReceived): # Forward to established connection if available - if addr in self._connections: - connection = self._connections[addr] + if dest_cid in self._connections: + connection = self._connections[dest_cid] await connection._handle_stream_data(event) elif isinstance(event, events.StreamReset): # Forward to established connection if available - if addr in self._connections: - connection = self._connections[addr] + if dest_cid in self._connections: + connection = self._connections[dest_cid] await connection._handle_stream_reset(event) async def _promote_pending_connection( - self, quic_conn: QuicConnection, addr: tuple[str, int] + self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes ) -> None: - """ - Promote a pending connection to an established connection. - Called after successful handshake completion. - - Args: - quic_conn: Established QUIC connection - addr: Remote address - - """ + """Promote a pending connection to an established connection.""" try: # Remove from pending connections - self._pending_connections.pop(addr, None) + self._pending_connections.pop(dest_cid, None) # Create multiaddr for this connection host, port = addr - # Use the first supported QUIC version for now + # Use the appropriate QUIC version quic_version = next(iter(self._quic_configs.keys())) remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}") @@ -388,22 +531,25 @@ class QUICListener(IListener): security_manager=self._security_manager, ) - # Store the connection - self._connections[addr] = connection + # Store the connection with connection ID + self._connections[dest_cid] = connection # Start connection management tasks if self._nursery: self._nursery.start_soon(connection._handle_datagram_received) self._nursery.start_soon(connection._handle_timer_events) + # Handle security verification if self._security_manager: try: await connection._verify_peer_identity_with_security() - logger.info(f"Security verification successful for {addr}") + logger.info( + f"Security verification successful for {dest_cid.hex()}" + ) except Exception as e: - logger.error(f"Security verification failed for {addr}: {e}") - self._stats["security_failures"] += 1 - # Close the connection due to security failure + logger.error( + f"Security verification failed for {dest_cid.hex()}: {e}" + ) await connection.close() return @@ -414,188 +560,203 @@ class QUICListener(IListener): ) self._stats["connections_accepted"] += 1 - logger.info(f"Accepted new QUIC connection from {addr}") + logger.info(f"Accepted new QUIC connection {dest_cid.hex()} from {addr}") except Exception as e: - logger.error(f"Error promoting connection from {addr}: {e}") - # Clean up - await self._remove_connection(addr) + logger.error(f"Error promoting connection {dest_cid.hex()}: {e}") + await self._remove_connection(dest_cid) self._stats["connections_rejected"] += 1 - async def _handle_new_established_connection( - self, connection: QUICConnection - ) -> None: - """ - Handle a newly established connection by calling the user handler. - - Args: - connection: Established QUIC connection - - """ + async def _remove_connection(self, dest_cid: bytes) -> None: + """Remove connection by connection ID.""" try: - # Call the connection handler provided by the transport - await self._handler(connection) - except Exception as e: - logger.error(f"Error in connection handler: {e}") - # Close the problematic connection - await connection.close() - - async def _transmit_for_connection(self, quic_conn: QuicConnection) -> None: - """Send pending datagrams for a QUIC connection.""" - sock = self._socket - if not sock: - return - - for data, addr in quic_conn.datagrams_to_send(now=time.time()): - try: - await sock.sendto(data, addr) - except Exception as e: - logger.error(f"Failed to send datagram to {addr}: {e}") - - async def _manage_connections(self) -> None: - """ - Background task to manage connection lifecycle. - Handles cleanup of closed/idle connections. - """ - try: - while not self._closed: - try: - # Sleep for a short interval - await trio.sleep(1.0) - - # Clean up closed connections - await self._cleanup_closed_connections() - - # Handle connection timeouts - await self._handle_connection_timeouts() - - except Exception as e: - logger.error(f"Error in connection management: {e}") - except trio.Cancelled: - raise - - async def _cleanup_closed_connections(self) -> None: - """Remove closed connections from tracking.""" - async with self._connection_lock: - closed_addrs = [] - - for addr, connection in self._connections.items(): - if connection.is_closed: - closed_addrs.append(addr) - - for addr in closed_addrs: - self._connections.pop(addr, None) - logger.debug(f"Cleaned up closed connection from {addr}") - - async def _handle_connection_timeouts(self) -> None: - """Handle connection timeouts and cleanup.""" - # TODO: Implement connection timeout handling - # Check for idle connections and close them - pass - - async def _remove_connection(self, addr: tuple[str, int]) -> None: - """Remove a connection from tracking.""" - async with self._connection_lock: - # Remove from active connections - connection = self._connections.pop(addr, None) + # Remove connection + connection = self._connections.pop(dest_cid, None) if connection: await connection.close() - # Remove from pending connections - quic_conn = self._pending_connections.pop(addr, None) - if quic_conn: - quic_conn.close() + # Clean up mappings + addr = self._cid_to_addr.pop(dest_cid, None) + if addr: + self._addr_to_cid.pop(addr, None) + + logger.debug(f"Removed connection {dest_cid.hex()}") + + except Exception as e: + logger.error(f"Error removing connection {dest_cid.hex()}: {e}") + + async def _remove_pending_connection(self, dest_cid: bytes) -> None: + """Remove pending connection by connection ID.""" + try: + self._pending_connections.pop(dest_cid, None) + addr = self._cid_to_addr.pop(dest_cid, None) + if addr: + self._addr_to_cid.pop(addr, None) + logger.debug(f"Removed pending connection {dest_cid.hex()}") + except Exception as e: + logger.error(f"Error removing pending connection {dest_cid.hex()}: {e}") + + async def _remove_connection_by_addr(self, addr: tuple[str, int]) -> None: + """Remove connection by address (fallback method).""" + dest_cid = self._addr_to_cid.get(addr) + if dest_cid: + await self._remove_connection(dest_cid) + + async def _transmit_for_connection( + self, quic_conn: QuicConnection, addr: tuple[str, int] + ) -> None: + """Send outgoing packets for a QUIC connection.""" + try: + while True: + datagrams = quic_conn.datagrams_to_send(now=time.time()) + if not datagrams: + break + + for datagram, _ in datagrams: + if self._socket: + await self._socket.sendto(datagram, addr) + + except Exception as e: + logger.error(f"Error transmitting packets to {addr}: {e}") + + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + """Start listening on the given multiaddr with enhanced connection handling.""" + if self._listening: + raise QUICListenError("Already listening") + + if not is_quic_multiaddr(maddr): + raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}") + + try: + host, port = quic_multiaddr_to_endpoint(maddr) + + # Create and configure socket + self._socket = await self._create_socket(host, port) + self._nursery = nursery + + # Get the actual bound address + bound_host, bound_port = self._socket.getsockname() + quic_version = multiaddr_to_quic_version(maddr) + bound_maddr = create_quic_multiaddr(bound_host, bound_port, quic_version) + self._bound_addresses = [bound_maddr] + + self._listening = True + + # Start packet handling loop + nursery.start_soon(self._handle_incoming_packets) + + logger.info( + f"QUIC listener started on {bound_maddr} with connection ID support" + ) + return True + + except Exception as e: + await self.close() + raise QUICListenError(f"Failed to start listening: {e}") from e + + async def _create_socket(self, host: str, port: int) -> trio.socket.SocketType: + """Create and configure UDP socket.""" + try: + # Determine address family + try: + import ipaddress + + ip = ipaddress.ip_address(host) + family = socket.AF_INET if ip.version == 4 else socket.AF_INET6 + except ValueError: + family = socket.AF_INET + + # Create UDP socket + sock = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + + # Set socket options + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, "SO_REUSEPORT"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + # Bind to address + await sock.bind((host, port)) + + logger.debug(f"Created and bound UDP socket to {host}:{port}") + return sock + + except Exception as e: + raise QUICListenError(f"Failed to create socket: {e}") from e + + async def _handle_incoming_packets(self) -> None: + """Handle incoming UDP packets with enhanced routing.""" + logger.debug("Started enhanced packet handling loop") + + try: + while self._listening and self._socket: + try: + # Receive UDP packet + data, addr = await self._socket.recvfrom(65536) + + # Process packet asynchronously + if self._nursery: + self._nursery.start_soon(self._process_packet, data, addr) + + except trio.ClosedResourceError: + logger.debug("Socket closed, exiting packet handler") + break + except Exception as e: + logger.error(f"Error receiving packet: {e}") + await trio.sleep(0.01) + except trio.Cancelled: + logger.info("Packet handling cancelled") + raise + finally: + logger.debug("Enhanced packet handling loop terminated") async def close(self) -> None: - """Close the listener and cleanup resources.""" + """Close the listener and clean up resources.""" if self._closed: return self._closed = True self._listening = False - logger.debug("Closing QUIC listener") - # CRITICAL: Close socket FIRST to unblock recvfrom() - await self._cleanup_socket() + try: + # Close all connections + async with self._connection_lock: + for dest_cid in list(self._connections.keys()): + await self._remove_connection(dest_cid) - logger.debug("SOCKET CLEANUP COMPLETE") + for dest_cid in list(self._pending_connections.keys()): + await self._remove_pending_connection(dest_cid) - # Close all connections WITHOUT using the lock during shutdown - # (avoid deadlock if background tasks are cancelled while holding lock) - connections_to_close = list(self._connections.values()) - pending_to_close = list(self._pending_connections.values()) - - logger.debug( - f"CLOSING {connections_to_close} connections and {pending_to_close} pending" - ) - - # Close active connections - for connection in connections_to_close: - try: - await connection.close() - except Exception as e: - print(f"Error closing connection: {e}") - - # Close pending connections - for quic_conn in pending_to_close: - try: - quic_conn.close() - except Exception as e: - print(f"Error closing pending connection: {e}") - - # Clear the dictionaries without lock (we're shutting down) - self._connections.clear() - self._pending_connections.clear() - logger.debug("QUIC listener closed") - - async def _cleanup_socket(self) -> None: - """Clean up the UDP socket.""" - if self._socket: - try: + # Close socket + if self._socket: self._socket.close() - except Exception as e: - logger.error(f"Error closing socket: {e}") - finally: self._socket = None - def get_addrs(self) -> tuple[Multiaddr, ...]: - """ - Get the addresses this listener is bound to. + self._bound_addresses.clear() - Returns: - Tuple of bound multiaddrs + logger.info("QUIC listener closed") - """ - return tuple(self._bound_addresses) + except Exception as e: + logger.error(f"Error closing listener: {e}") - def is_listening(self) -> bool: - """Check if the listener is actively listening.""" - return self._listening and not self._closed + def get_addresses(self) -> list[Multiaddr]: + """Get the bound addresses.""" + return self._bound_addresses.copy() + + async def _handle_new_established_connection( + self, connection: QUICConnection + ) -> None: + """Handle a newly established connection.""" + try: + await self._handler(connection) + except Exception as e: + logger.error(f"Error in connection handler: {e}") + await connection.close() + + def get_addrs(self) -> tuple[Multiaddr]: + return tuple(self.get_addresses()) def get_stats(self) -> dict[str, int]: - """Get listener statistics.""" - stats = self._stats.copy() - stats.update( - { - "active_connections": len(self._connections), - "pending_connections": len(self._pending_connections), - "is_listening": self.is_listening(), - } - ) - return stats + return self._stats - def get_security_manager(self) -> Optional["QUICTLSConfigManager"]: - """ - Get the security manager for this listener. - - Returns: - The QUIC TLS configuration manager, or None if not configured - - """ - return self._security_manager - - def __str__(self) -> str: - """String representation of the listener.""" - addr = self._bound_addresses - conn_count = len(self._connections) - return f"QUICListener(addrs={addr}, connections={conn_count})" + def is_listening(self) -> bool: + raise NotImplementedError() diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 59d62715..71d4891e 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -13,7 +13,7 @@ from aioquic.quic.configuration import ( QuicConfiguration, ) from aioquic.quic.connection import ( - QuicConnection, + QuicConnection as NativeQUICConnection, ) import multiaddr import trio @@ -60,6 +60,11 @@ from .security import ( QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) logger = logging.getLogger(__name__) @@ -279,20 +284,24 @@ class QUICTransport(ITransport): # Get appropriate QUIC client configuration config_key = TProtocol(f"{quic_version}_client") + print("config_key", config_key, self._quic_configs.keys()) config = self._quic_configs.get(config_key) if not config: raise QUICDialError(f"Unsupported QUIC version: {quic_version}") + config.is_client = True logger.debug( f"Dialing QUIC connection to {host}:{port} (version: {quic_version})" ) + print("Start QUIC Connection") # Create QUIC connection using aioquic's sans-IO core - quic_connection = QuicConnection(configuration=config) + native_quic_connection = NativeQUICConnection(configuration=config) + print("QUIC Connection Created") # Create trio-based QUIC connection wrapper with security connection = QUICConnection( - quic_connection=quic_connection, + quic_connection=native_quic_connection, remote_addr=(host, port), peer_id=peer_id, local_peer_id=self._peer_id, @@ -354,6 +363,7 @@ class QUICTransport(ITransport): ) logger.info(f"Peer identity verified: {verified_peer_id}") + print(f"Peer identity verified: {verified_peer_id}") except Exception as e: raise QUICSecurityError(f"Peer identity verification failed: {e}") from e diff --git a/libp2p/transport/quic/utils.py b/libp2p/transport/quic/utils.py index c9db6fa9..97634a91 100644 --- a/libp2p/transport/quic/utils.py +++ b/libp2p/transport/quic/utils.py @@ -5,14 +5,19 @@ Based on go-libp2p and js-libp2p QUIC implementations. """ import ipaddress +import logging +from aioquic.quic.configuration import QuicConfiguration import multiaddr from libp2p.custom_types import TProtocol +from libp2p.transport.quic.security import QUICTLSConfigManager from .config import QUICTransportConfig from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError +logger = logging.getLogger(__name__) + # Protocol constants QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 @@ -20,6 +25,18 @@ UDP_PROTOCOL = "udp" IP4_PROTOCOL = "ip4" IP6_PROTOCOL = "ip6" +SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_SERVER" +SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_V1_PROTOCOL}_SERVER" +CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_DRAFT29_PROTOCOL}_SERVER" +CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_SERVER" + +CUSTOM_QUIC_VERSION_MAPPING = { + SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000 + CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000 + SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29 + CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29 +} + # QUIC version to wire format mappings (required for aioquic) QUIC_VERSION_MAPPINGS = { QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000 @@ -218,6 +235,27 @@ def quic_version_to_wire_format(version: TProtocol) -> int: return wire_version +def custom_quic_version_to_wire_format(version: TProtocol) -> int: + """ + Convert QUIC version string to wire format integer for aioquic. + + Args: + version: QUIC version string ("quic-v1" or "quic") + + Returns: + Wire format version number + + Raises: + QUICUnsupportedVersionError: If version is not supported + + """ + wire_version = QUIC_VERSION_MAPPINGS.get(version) + if wire_version is None: + raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}") + + return wire_version + + def get_alpn_protocols() -> list[str]: """ Get ALPN protocols for libp2p over QUIC. @@ -250,3 +288,94 @@ def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr: version = multiaddr_to_quic_version(maddr) return create_quic_multiaddr(host, port, version) + + +def create_server_config_from_base( + base_config: QuicConfiguration, + security_manager: QUICTLSConfigManager | None = None, + transport_config: QUICTransportConfig | None = None, +) -> QuicConfiguration: + """ + Create a server configuration without using deepcopy. + Manually copies attributes while handling cryptography objects properly. + """ + try: + # Create new server configuration from scratch + server_config = QuicConfiguration(is_client=False) + + # Copy basic configuration attributes (these are safe to copy) + copyable_attrs = [ + "alpn_protocols", + "verify_mode", + "max_datagram_frame_size", + "idle_timeout", + "max_concurrent_streams", + "supported_versions", + "max_data", + "max_stream_data", + "stateless_retry", + "quantum_readiness_test", + ] + + for attr in copyable_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(server_config, attr, value) + + # Handle cryptography objects - these need direct reference, not copying + crypto_attrs = [ + "certificate", + "private_key", + "certificate_chain", + "ca_certs", + ] + + for attr in crypto_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(server_config, attr, value) + + # Apply security manager configuration if available + if security_manager: + try: + server_tls_config = security_manager.create_server_config() + + # Override with security manager's TLS configuration + if "certificate" in server_tls_config: + server_config.certificate = server_tls_config["certificate"] + if "private_key" in server_tls_config: + server_config.private_key = server_tls_config["private_key"] + if "certificate_chain" in server_tls_config: + # type: ignore + server_config.certificate_chain = server_tls_config[ # type: ignore + "certificate_chain" # type: ignore + ] + if "alpn_protocols" in server_tls_config: + # type: ignore + server_config.alpn_protocols = server_tls_config["alpn_protocols"] # type: ignore + + except Exception as e: + logger.warning(f"Failed to apply security manager config: {e}") + + # Set transport-specific defaults if provided + if transport_config: + if server_config.idle_timeout == 0: + server_config.idle_timeout = getattr( + transport_config, "idle_timeout", 30.0 + ) + if server_config.max_datagram_frame_size is None: + server_config.max_datagram_frame_size = getattr( + transport_config, "max_datagram_size", 1200 + ) + # Ensure we have ALPN protocols + if server_config.alpn_protocols: + server_config.alpn_protocols = ["libp2p"] + + logger.debug("Successfully created server config without deepcopy") + return server_config + + except Exception as e: + logger.error(f"Failed to create server config: {e}") + raise diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index 605913ec..e8e59c8d 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -183,10 +183,13 @@ def test_new_swarm_tcp_multiaddr_supported(): assert isinstance(swarm.transport, TCP) -def test_new_swarm_quic_multiaddr_raises(): +def test_new_swarm_quic_multiaddr_supported(): + from libp2p.transport.quic.transport import QUICTransport + addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic") - with pytest.raises(ValueError, match="QUIC not yet supported"): - new_swarm(listen_addrs=[addr]) + swarm = new_swarm(listen_addrs=[addr]) + assert isinstance(swarm, Swarm) + assert isinstance(swarm.transport, QUICTransport) @pytest.mark.trio