From 123c86c0915790b4e9e36a640a2d4ebf8122184f Mon Sep 17 00:00:00 2001 From: Akash Mondal Date: Tue, 17 Jun 2025 13:54:32 +0000 Subject: [PATCH] fix: duplication connection creation for same sessions --- examples/echo/test_quic.py | 289 ++++++++++++++++++ libp2p/transport/quic/listener.py | 476 ++++++++++++++++++++++------- libp2p/transport/quic/security.py | 322 +++++++++++++++++-- libp2p/transport/quic/transport.py | 78 +++-- 4 files changed, 982 insertions(+), 183 deletions(-) create mode 100644 examples/echo/test_quic.py diff --git a/examples/echo/test_quic.py b/examples/echo/test_quic.py new file mode 100644 index 00000000..446b8e57 --- /dev/null +++ b/examples/echo/test_quic.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" +Fixed QUIC handshake test to debug connection issues. +""" + +import logging +from pathlib import Path +import secrets +import sys + +import trio + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.transport.quic.transport import QUICTransport, QUICTransportConfig +from libp2p.transport.quic.utils import create_quic_multiaddr + +# Adjust this path to your project structure +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) +# Setup logging +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) + + +async def test_certificate_generation(): + """Test certificate generation in isolation.""" + print("\n=== TESTING CERTIFICATE GENERATION ===") + + try: + from libp2p.peer.id import ID + from libp2p.transport.quic.security import create_quic_security_transport + + # Create key pair + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + print(f"Generated peer ID: {peer_id}") + + # Create security manager + security_manager = create_quic_security_transport(private_key, peer_id) + print("✅ Security manager created") + + # Test server config + server_config = security_manager.create_server_config() + print("✅ Server config created") + + # Validate certificate + cert = server_config.certificate + private_key_obj = server_config.private_key + + print(f"Certificate type: {type(cert)}") + print(f"Private key type: {type(private_key_obj)}") + print(f"Certificate subject: {cert.subject}") + print(f"Certificate issuer: {cert.issuer}") + + # Check for libp2p extension + has_libp2p_ext = False + for ext in cert.extensions: + if str(ext.oid) == "1.3.6.1.4.1.53594.1.1": + has_libp2p_ext = True + print(f"✅ Found libp2p extension: {ext.oid}") + print(f"Extension critical: {ext.critical}") + print(f"Extension value length: {len(ext.value)} bytes") + break + + if not has_libp2p_ext: + print("❌ No libp2p extension found!") + print("Available extensions:") + for ext in cert.extensions: + print(f" - {ext.oid} (critical: {ext.critical})") + + # Check certificate/key match + from cryptography.hazmat.primitives import serialization + + cert_public_key = cert.public_key() + private_public_key = private_key_obj.public_key() + + cert_pub_bytes = cert_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + private_pub_bytes = private_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + if cert_pub_bytes == private_pub_bytes: + print("✅ Certificate and private key match") + return has_libp2p_ext + else: + print("❌ Certificate and private key DO NOT match") + return False + + except Exception as e: + print(f"❌ Certificate test failed: {e}") + import traceback + + traceback.print_exc() + return False + + +async def test_basic_quic_connection(): + """Test basic QUIC connection with proper server setup.""" + print("\n=== TESTING BASIC QUIC CONNECTION ===") + + try: + from aioquic.quic.configuration import QuicConfiguration + from aioquic.quic.connection import QuicConnection + + from libp2p.peer.id import ID + from libp2p.transport.quic.security import create_quic_security_transport + + # Create certificates + server_key = create_new_key_pair().private_key + server_peer_id = ID.from_pubkey(server_key.get_public_key()) + server_security = create_quic_security_transport(server_key, server_peer_id) + + client_key = create_new_key_pair().private_key + client_peer_id = ID.from_pubkey(client_key.get_public_key()) + client_security = create_quic_security_transport(client_key, client_peer_id) + + # Create server config + server_tls_config = server_security.create_server_config() + server_config = QuicConfiguration( + is_client=False, + certificate=server_tls_config.certificate, + private_key=server_tls_config.private_key, + alpn_protocols=["libp2p"], + ) + + # Create client config + client_tls_config = client_security.create_client_config() + client_config = QuicConfiguration( + is_client=True, + certificate=client_tls_config.certificate, + private_key=client_tls_config.private_key, + alpn_protocols=["libp2p"], + ) + + print("✅ QUIC configurations created") + + # Test creating connections with proper parameters + # For server, we need to provide original_destination_connection_id + original_dcid = secrets.token_bytes(8) + + server_conn = QuicConnection( + configuration=server_config, + original_destination_connection_id=original_dcid, + ) + + # For client, no original_destination_connection_id needed + client_conn = QuicConnection(configuration=client_config) + + print("✅ QUIC connections created") + print(f"Server state: {server_conn._state}") + print(f"Client state: {client_conn._state}") + + # Test that certificates are valid + print(f"Server has certificate: {server_config.certificate is not None}") + print(f"Server has private key: {server_config.private_key is not None}") + print(f"Client has certificate: {client_config.certificate is not None}") + print(f"Client has private key: {client_config.private_key is not None}") + + return True + + except Exception as e: + print(f"❌ Basic QUIC test failed: {e}") + import traceback + + traceback.print_exc() + return False + + +async def test_server_startup(): + """Test server startup with timeout.""" + print("\n=== TESTING SERVER STARTUP ===") + + try: + # Create transport + private_key = create_new_key_pair().private_key + config = QUICTransportConfig( + idle_timeout=10.0, # Reduced timeout for testing + connection_timeout=10.0, + enable_draft29=False, + ) + + transport = QUICTransport(private_key, config) + print("✅ Transport created successfully") + + # Test configuration + print(f"Available configs: {list(transport._quic_configs.keys())}") + + config_valid = True + for config_key, quic_config in transport._quic_configs.items(): + print(f"\n--- Testing config: {config_key} ---") + print(f"is_client: {quic_config.is_client}") + print(f"has_certificate: {quic_config.certificate is not None}") + print(f"has_private_key: {quic_config.private_key is not None}") + print(f"alpn_protocols: {quic_config.alpn_protocols}") + print(f"verify_mode: {quic_config.verify_mode}") + + if quic_config.certificate: + cert = quic_config.certificate + print(f"Certificate subject: {cert.subject}") + + # Check for libp2p extension + has_libp2p_ext = False + for ext in cert.extensions: + if str(ext.oid) == "1.3.6.1.4.1.53594.1.1": + has_libp2p_ext = True + break + print(f"Has libp2p extension: {has_libp2p_ext}") + + if not has_libp2p_ext: + config_valid = False + + if not config_valid: + print("❌ Transport configuration invalid - missing libp2p extensions") + return False + + # Create listener + async def dummy_handler(connection): + print(f"New connection: {connection}") + + listener = transport.create_listener(dummy_handler) + print("✅ Listener created successfully") + + # Try to bind with timeout + maddr = create_quic_multiaddr("127.0.0.1", 0, "quic-v1") + + async with trio.open_nursery() as nursery: + result = await listener.listen(maddr, nursery) + if result: + print("✅ Server bound successfully") + addresses = listener.get_addresses() + print(f"Listening on: {addresses}") + + # Keep running for a short time + with trio.move_on_after(3.0): # 3 second timeout + await trio.sleep(5.0) + + print("✅ Server test completed (timed out normally)") + return True + else: + print("❌ Failed to bind server") + return False + + except Exception as e: + print(f"❌ Server test failed: {e}") + import traceback + + traceback.print_exc() + return False + + +async def main(): + """Run all tests with better error handling.""" + print("Starting QUIC diagnostic tests...") + + # Test 1: Certificate generation + cert_ok = await test_certificate_generation() + if not cert_ok: + print("\n❌ CRITICAL: Certificate generation failed!") + print("Apply the certificate generation fix and try again.") + return + + # Test 2: Basic QUIC connection + quic_ok = await test_basic_quic_connection() + if not quic_ok: + print("\n❌ CRITICAL: Basic QUIC connection test failed!") + return + + # Test 3: Server startup + server_ok = await test_server_startup() + if not server_ok: + print("\n❌ Server startup test failed!") + return + + print("\n✅ ALL TESTS PASSED!") + print("=== DIAGNOSTIC COMPLETE ===") + print("Your QUIC implementation should now work correctly.") + print("Try running your echo example again.") + + +if __name__ == "__main__": + trio.run(main) diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py index 76fc18c5..b14efd5e 100644 --- a/libp2p/transport/quic/listener.py +++ b/libp2p/transport/quic/listener.py @@ -249,23 +249,35 @@ class QUICListener(IListener): async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: """ - Enhanced packet processing with connection ID routing and version negotiation. - FIXED: Added address-based connection reuse to prevent multiple connections. + Enhanced packet processing with better connection ID routing and debugging. """ try: self._stats["packets_processed"] += 1 self._stats["bytes_received"] += len(data) + print(f"🔧 PACKET: Processing {len(data)} bytes from {addr}") + # Parse packet to extract connection information packet_info = self.parse_quic_packet(data) - print(f"🔧 DEBUG: Address mappings: {self._addr_to_cid}") print( - f"🔧 DEBUG: Pending connections: {list(self._pending_connections.keys())}" + f"🔧 DEBUG: Address mappings: {dict((k, v.hex()) for k, v in self._addr_to_cid.items())}" + ) + print( + f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}" + ) + print( + f"🔧 DEBUG: Established connections: {[cid.hex() for cid in self._connections.keys()]}" ) async with self._connection_lock: if packet_info: + print( + f"🔧 PACKET: Parsed packet - version: 0x{packet_info.version:08x}, " + f"dest_cid: {packet_info.destination_cid.hex()}, " + f"src_cid: {packet_info.source_cid.hex()}" + ) + # Check for version negotiation if packet_info.version == 0: logger.warning( @@ -275,6 +287,9 @@ class QUICListener(IListener): # Check if version is supported if packet_info.version not in self._supported_versions: + print( + f"❌ PACKET: Unsupported version 0x{packet_info.version:08x}" + ) await self._send_version_negotiation( addr, packet_info.source_cid ) @@ -283,87 +298,66 @@ class QUICListener(IListener): # Route based on destination connection ID dest_cid = packet_info.destination_cid + # First, try exact connection ID match if dest_cid in self._connections: - # Existing established connection - print(f"🔧 ROUTING: To established connection {dest_cid.hex()}") + print( + f"✅ PACKET: Routing to established connection {dest_cid.hex()}" + ) connection = self._connections[dest_cid] await self._route_to_connection(connection, data, addr) + return elif dest_cid in self._pending_connections: - # Existing pending connection - print(f"🔧 ROUTING: To pending connection {dest_cid.hex()}") + print( + f"✅ PACKET: Routing to pending connection {dest_cid.hex()}" + ) quic_conn = self._pending_connections[dest_cid] await self._handle_pending_connection( quic_conn, data, addr, dest_cid ) + return - else: - # CRITICAL FIX: Check for existing connection by address BEFORE creating new - existing_cid = self._addr_to_cid.get(addr) + # If no exact match, try address-based routing (connection ID might not match) + mapped_cid = self._addr_to_cid.get(addr) + if mapped_cid: + print( + f"🔧 PACKET: Found address mapping {addr} -> {mapped_cid.hex()}" + ) + print( + f"🔧 PACKET: Client dest_cid {dest_cid.hex()} != our cid {mapped_cid.hex()}" + ) - if existing_cid is not None: + if mapped_cid in self._connections: print( - f"✅ FOUND: Existing connection {existing_cid.hex()} for address {addr}" + "✅ PACKET: Using established connection via address mapping" ) + connection = self._connections[mapped_cid] + await self._route_to_connection(connection, data, addr) + return + elif mapped_cid in self._pending_connections: print( - f"🔧 NOTE: Client dest_cid {dest_cid.hex()} != our cid {existing_cid.hex()}" + "✅ PACKET: Using pending connection via address mapping" ) + quic_conn = self._pending_connections[mapped_cid] + await self._handle_pending_connection( + quic_conn, data, addr, mapped_cid + ) + return - # Route to existing connection by address - if existing_cid in self._pending_connections: - print( - "🔧 ROUTING: Using existing pending connection by address" - ) - quic_conn = self._pending_connections[existing_cid] - await self._handle_pending_connection( - quic_conn, data, addr, existing_cid - ) - elif existing_cid in self._connections: - print( - "🔧 ROUTING: Using existing established connection by address" - ) - connection = self._connections[existing_cid] - await self._route_to_connection(connection, data, addr) - else: - print( - f"❌ ERROR: Address mapping exists but connection {existing_cid.hex()} not found!" - ) - # Clean up broken mapping and create new - self._addr_to_cid.pop(addr, None) - if packet_info.packet_type == 0: # Initial packet - print( - "🔧 NEW: Creating new connection after cleanup" - ) - await self._handle_new_connection( - data, addr, packet_info - ) + # No existing connection found, create new one + print(f"🔧 PACKET: Creating new connection for {addr}") + await self._handle_new_connection(data, addr, packet_info) - else: - # Truly new connection - only handle Initial packets - if packet_info.packet_type == 0: # Initial packet - print(f"🔧 NEW: Creating first connection for {addr}") - await self._handle_new_connection( - data, addr, packet_info - ) - - # Debug the newly created connection - new_cid = self._addr_to_cid.get(addr) - if new_cid and new_cid in self._pending_connections: - quic_conn = self._pending_connections[new_cid] - await self._debug_quic_connection_state( - quic_conn, new_cid - ) - else: - logger.debug( - f"Ignoring non-Initial packet for unknown connection ID from {addr}" - ) else: - # Fallback to address-based routing for short header packets + # Failed to parse packet + print(f"❌ PACKET: Failed to parse packet from {addr}") await self._handle_short_header_packet(data, addr) except Exception as e: logger.error(f"Error processing packet from {addr}: {e}") - self._stats["invalid_packets"] += 1 + import traceback + + traceback.print_exc() async def _send_version_negotiation( self, addr: tuple[str, int], source_cid: bytes @@ -404,29 +398,31 @@ class QUICListener(IListener): logger.error(f"Failed to send version negotiation to {addr}: {e}") async def _handle_new_connection( - self, - data: bytes, - addr: tuple[str, int], - packet_info: QUICPacketInfo, + self, data: bytes, addr: tuple[str, int], packet_info: QUICPacketInfo ) -> None: - """ - Handle new connection with proper version negotiation. - """ + """Handle new connection with proper connection ID handling.""" try: + print(f"🔧 NEW_CONN: Starting handshake for {addr}") + + # Find appropriate QUIC configuration quic_config = None + config_key = None + for protocol, config in self._quic_configs.items(): wire_versions = custom_quic_version_to_wire_format(protocol) if wire_versions == packet_info.version: quic_config = config + config_key = protocol break if not quic_config: - logger.warning( - f"No configuration found for version {packet_info.version:08x}" - ) + print(f"❌ NEW_CONN: No configuration found for version 0x{packet_info.version:08x}") + print(f"🔧 NEW_CONN: Available configs: {list(self._quic_configs.keys())}") await self._send_version_negotiation(addr, packet_info.source_cid) return + print(f"✅ NEW_CONN: Using config {config_key} for version 0x{packet_info.version:08x}") + # Create server-side QUIC configuration server_config = create_server_config_from_base( base_config=quic_config, @@ -434,39 +430,158 @@ class QUICListener(IListener): transport_config=self._config, ) - # Generate a new destination connection ID for this connection - # In a real implementation, this should be cryptographically secure - import secrets + # Debug the server configuration + print(f"🔧 NEW_CONN: Server config - is_client: {server_config.is_client}") + print(f"🔧 NEW_CONN: Server config - has_certificate: {server_config.certificate is not None}") + print(f"🔧 NEW_CONN: Server config - has_private_key: {server_config.private_key is not None}") + print(f"🔧 NEW_CONN: Server config - ALPN: {server_config.alpn_protocols}") + print(f"🔧 NEW_CONN: Server config - verify_mode: {server_config.verify_mode}") + # Validate certificate has libp2p extension + if server_config.certificate: + cert = server_config.certificate + has_libp2p_ext = False + for ext in cert.extensions: + if str(ext.oid) == "1.3.6.1.4.1.53594.1.1": + has_libp2p_ext = True + break + print(f"🔧 NEW_CONN: Certificate has libp2p extension: {has_libp2p_ext}") + + if not has_libp2p_ext: + print("❌ NEW_CONN: Certificate missing libp2p extension!") + + # Generate a new destination connection ID for this connection + import secrets destination_cid = secrets.token_bytes(8) - # Create QUIC connection with specific version + print(f"🔧 NEW_CONN: Generated new CID: {destination_cid.hex()}") + print(f"🔧 NEW_CONN: Original destination CID: {packet_info.destination_cid.hex()}") + + # Create QUIC connection with proper parameters for server + # CRITICAL FIX: Pass the original destination connection ID from the initial packet quic_conn = QuicConnection( configuration=server_config, - original_destination_connection_id=packet_info.destination_cid, + original_destination_connection_id=packet_info.destination_cid, # Use the original DCID from packet ) - # Store connection mapping + print("✅ NEW_CONN: QUIC connection created successfully") + + # Store connection mapping using our generated CID self._pending_connections[destination_cid] = quic_conn self._addr_to_cid[addr] = destination_cid self._cid_to_addr[destination_cid] = addr + print(f"🔧 NEW_CONN: Stored mappings for {addr} <-> {destination_cid.hex()}") print("Receiving Datagram") # Process initial packet quic_conn.receive_datagram(data, addr, now=time.time()) + + # Debug connection state after receiving packet + await self._debug_quic_connection_state_detailed(quic_conn, destination_cid) + + # Process events and send response await self._process_quic_events(quic_conn, addr, destination_cid) await self._transmit_for_connection(quic_conn, addr) logger.debug( f"Started handshake for new connection from {addr} " - f"(version: {packet_info.version:08x}, cid: {destination_cid.hex()})" + f"(version: 0x{packet_info.version:08x}, cid: {destination_cid.hex()})" ) except Exception as e: logger.error(f"Error handling new connection from {addr}: {e}") + import traceback + traceback.print_exc() self._stats["connections_rejected"] += 1 + async def _debug_quic_connection_state_detailed( + self, quic_conn: QuicConnection, connection_id: bytes + ): + """Enhanced connection state debugging.""" + try: + print(f"🔧 QUIC_STATE: Debugging connection {connection_id.hex()}") + + if not quic_conn: + print("❌ QUIC_STATE: QUIC CONNECTION NOT FOUND") + return + + # Check TLS state + if hasattr(quic_conn, "tls") and quic_conn.tls: + print("✅ QUIC_STATE: TLS context exists") + if hasattr(quic_conn.tls, "state"): + print(f"🔧 QUIC_STATE: TLS state: {quic_conn.tls.state}") + + # Check if we have peer certificate + if ( + hasattr(quic_conn.tls, "_peer_certificate") + and quic_conn.tls._peer_certificate + ): + print("✅ QUIC_STATE: Peer certificate available") + else: + print("🔧 QUIC_STATE: No peer certificate yet") + + # Check TLS handshake completion + if hasattr(quic_conn.tls, "handshake_complete"): + handshake_status = quic_conn._handshake_complete + print( + f"🔧 QUIC_STATE: TLS handshake complete: {handshake_status}" + ) + else: + print("❌ QUIC_STATE: No TLS context!") + + # Check connection state + if hasattr(quic_conn, "_state"): + print(f"🔧 QUIC_STATE: Connection state: {quic_conn._state}") + + # Check if handshake is complete + if hasattr(quic_conn, "_handshake_complete"): + print( + f"🔧 QUIC_STATE: Handshake complete: {quic_conn._handshake_complete}" + ) + + # Check configuration + if hasattr(quic_conn, "configuration"): + config = quic_conn.configuration + print( + f"🔧 QUIC_STATE: Config certificate: {config.certificate is not None}" + ) + print( + f"🔧 QUIC_STATE: Config private_key: {config.private_key is not None}" + ) + print(f"🔧 QUIC_STATE: Config is_client: {config.is_client}") + print(f"🔧 QUIC_STATE: Config verify_mode: {config.verify_mode}") + print(f"🔧 QUIC_STATE: Config ALPN: {config.alpn_protocols}") + + if config.certificate: + cert = config.certificate + print(f"🔧 QUIC_STATE: Certificate subject: {cert.subject}") + print( + f"🔧 QUIC_STATE: Certificate valid from: {cert.not_valid_before}" + ) + print( + f"🔧 QUIC_STATE: Certificate valid until: {cert.not_valid_after}" + ) + + # Check for connection errors + if hasattr(quic_conn, "_close_event") and quic_conn._close_event: + print( + f"❌ QUIC_STATE: Connection has close event: {quic_conn._close_event}" + ) + + # Check for TLS errors + if ( + hasattr(quic_conn, "_handshake_complete") + and not quic_conn._handshake_complete + ): + print("⚠️ QUIC_STATE: Handshake not yet complete") + + except Exception as e: + print(f"❌ QUIC_STATE: Error checking state: {e}") + import traceback + + traceback.print_exc() + async def _handle_short_header_packet( self, data: bytes, addr: tuple[str, int] ) -> None: @@ -515,54 +630,141 @@ class QUICListener(IListener): addr: tuple[str, int], dest_cid: bytes, ) -> None: - """Handle packet for a pending (handshaking) connection.""" + """Handle packet for a pending (handshaking) connection with enhanced debugging.""" try: + print( + f"🔧 PENDING: Handling packet for pending connection {dest_cid.hex()}" + ) + print(f"🔧 PENDING: Packet size: {len(data)} bytes from {addr}") + + # Check connection state before processing + if hasattr(quic_conn, "_state"): + print(f"🔧 PENDING: Connection state before: {quic_conn._state}") + + if ( + hasattr(quic_conn, "tls") + and quic_conn.tls + and hasattr(quic_conn.tls, "state") + ): + print(f"🔧 PENDING: TLS state before: {quic_conn.tls.state}") + # Feed data to QUIC connection quic_conn.receive_datagram(data, addr, now=time.time()) + print("✅ PENDING: Datagram received by QUIC connection") - # Process events + # Check state after receiving packet + if hasattr(quic_conn, "_state"): + print(f"🔧 PENDING: Connection state after: {quic_conn._state}") + + if ( + hasattr(quic_conn, "tls") + and quic_conn.tls + and hasattr(quic_conn.tls, "state") + ): + print(f"🔧 PENDING: TLS state after: {quic_conn.tls.state}") + + # Process events - this is crucial for handshake progression + print("🔧 PENDING: Processing QUIC events...") await self._process_quic_events(quic_conn, addr, dest_cid) - # Send any outgoing packets + # Send any outgoing packets - this is where the response should be sent + print("🔧 PENDING: Transmitting response...") await self._transmit_for_connection(quic_conn, addr) + # Check if handshake completed + if ( + hasattr(quic_conn, "_handshake_complete") + and quic_conn._handshake_complete + ): + print("✅ PENDING: Handshake completed, promoting connection") + await self._promote_pending_connection(quic_conn, addr, dest_cid) + else: + print("🔧 PENDING: Handshake still in progress") + + # Debug why handshake might be stuck + await self._debug_handshake_state(quic_conn, dest_cid) + except Exception as e: logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}") - # Remove from pending connections + import traceback + + traceback.print_exc() + + # Remove problematic pending connection + print(f"❌ PENDING: Removing problematic connection {dest_cid.hex()}") await self._remove_pending_connection(dest_cid) async def _process_quic_events( self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes ) -> None: - """Process QUIC events for a connection with connection ID context.""" - while True: - event = quic_conn.next_event() - if event is None: - break + """Process QUIC events with enhanced debugging.""" + try: + events_processed = 0 + while True: + event = quic_conn.next_event() + if event is None: + break - if isinstance(event, events.ConnectionTerminated): - logger.debug( - f"Connection {dest_cid.hex()} from {addr} " - f"terminated: {event.reason_phrase}" + events_processed += 1 + print( + f"🔧 EVENT: Processing event {events_processed}: {type(event).__name__}" ) - await self._remove_connection(dest_cid) - break - elif isinstance(event, events.HandshakeCompleted): - logger.debug(f"Handshake completed for connection {dest_cid.hex()}") - await self._promote_pending_connection(quic_conn, addr, dest_cid) + if isinstance(event, events.ConnectionTerminated): + print( + f"❌ EVENT: Connection terminated - code: {event.error_code}, reason: {event.reason_phrase}" + ) + logger.debug( + f"Connection {dest_cid.hex()} from {addr} " + f"terminated: {event.reason_phrase}" + ) + await self._remove_connection(dest_cid) + break - elif isinstance(event, events.StreamDataReceived): - # Forward to established connection if available - if dest_cid in self._connections: - connection = self._connections[dest_cid] - await connection._handle_stream_data(event) + elif isinstance(event, events.HandshakeCompleted): + print( + f"✅ EVENT: Handshake completed for connection {dest_cid.hex()}" + ) + logger.debug(f"Handshake completed for connection {dest_cid.hex()}") + await self._promote_pending_connection(quic_conn, addr, dest_cid) - elif isinstance(event, events.StreamReset): - # Forward to established connection if available - if dest_cid in self._connections: - connection = self._connections[dest_cid] - await connection._handle_stream_reset(event) + elif isinstance(event, events.StreamDataReceived): + print(f"🔧 EVENT: Stream data received on stream {event.stream_id}") + # Forward to established connection if available + if dest_cid in self._connections: + connection = self._connections[dest_cid] + await connection._handle_stream_data(event) + + elif isinstance(event, events.StreamReset): + print(f"🔧 EVENT: Stream reset on stream {event.stream_id}") + # Forward to established connection if available + if dest_cid in self._connections: + connection = self._connections[dest_cid] + await connection._handle_stream_reset(event) + + elif isinstance(event, events.ConnectionIdIssued): + print( + f"🔧 EVENT: Connection ID issued: {event.connection_id.hex()}" + ) + + elif isinstance(event, events.ConnectionIdRetired): + print( + f"🔧 EVENT: Connection ID retired: {event.connection_id.hex()}" + ) + + else: + print(f"🔧 EVENT: Unhandled event type: {type(event).__name__}") + + if events_processed == 0: + print("🔧 EVENT: No events to process") + else: + print(f"🔧 EVENT: Processed {events_processed} events total") + + except Exception as e: + print(f"❌ EVENT: Error processing events: {e}") + import traceback + + traceback.print_exc() async def _debug_quic_connection_state( self, quic_conn: QuicConnection, connection_id: bytes @@ -972,3 +1174,61 @@ class QUICListener(IListener): stats["active_connections"] = len(self._connections) stats["pending_connections"] = len(self._pending_connections) return stats + + async def _debug_handshake_state(self, quic_conn: QuicConnection, dest_cid: bytes): + """Debug why handshake might be stuck.""" + try: + print(f"🔧 HANDSHAKE_DEBUG: Analyzing stuck handshake for {dest_cid.hex()}") + + # Check TLS handshake state + if hasattr(quic_conn, "tls") and quic_conn.tls: + tls = quic_conn.tls + print( + f"🔧 HANDSHAKE_DEBUG: TLS state: {getattr(tls, 'state', 'Unknown')}" + ) + + # Check for TLS errors + if hasattr(tls, "_error") and tls._error: + print(f"❌ HANDSHAKE_DEBUG: TLS error: {tls._error}") + + # Check certificate validation + if hasattr(tls, "_peer_certificate"): + if tls._peer_certificate: + print("✅ HANDSHAKE_DEBUG: Peer certificate received") + else: + print("❌ HANDSHAKE_DEBUG: No peer certificate") + + # Check ALPN negotiation + if hasattr(tls, "_alpn_protocols"): + if tls._alpn_protocols: + print( + f"✅ HANDSHAKE_DEBUG: ALPN negotiated: {tls._alpn_protocols}" + ) + else: + print("❌ HANDSHAKE_DEBUG: No ALPN protocol negotiated") + + # Check QUIC connection state + if hasattr(quic_conn, "_state"): + state = quic_conn._state + print(f"🔧 HANDSHAKE_DEBUG: QUIC state: {state}") + + # Check specific states that might indicate problems + if "FIRSTFLIGHT" in str(state): + print("⚠️ HANDSHAKE_DEBUG: Connection stuck in FIRSTFLIGHT state") + elif "CONNECTED" in str(state): + print( + "⚠️ HANDSHAKE_DEBUG: Connection shows CONNECTED but handshake not complete" + ) + + # Check for pending crypto data + if hasattr(quic_conn, "_cryptos") and quic_conn._cryptos: + print(f"🔧 HANDSHAKE_DEBUG: Crypto data present {len(quic_conn._cryptos.keys())}") + + # Check loss detection state + if hasattr(quic_conn, "_loss") and quic_conn._loss: + loss_detection = quic_conn._loss + if hasattr(loss_detection, "_pto_count"): + print(f"🔧 HANDSHAKE_DEBUG: PTO count: {loss_detection._pto_count}") + + except Exception as e: + print(f"❌ HANDSHAKE_DEBUG: Error during debug: {e}") diff --git a/libp2p/transport/quic/security.py b/libp2p/transport/quic/security.py index 1e265241..28abc626 100644 --- a/libp2p/transport/quic/security.py +++ b/libp2p/transport/quic/security.py @@ -4,9 +4,11 @@ Implements libp2p TLS specification for QUIC transport with peer identity integr Based on go-libp2p and js-libp2p security patterns. """ -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta import logging +import ssl +from typing import List, Optional, Union from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization @@ -25,11 +27,6 @@ from .exceptions import ( QUICPeerVerificationError, ) -TSecurityConfig = dict[ - str, - Certificate | EllipticCurvePrivateKey | RSAPrivateKey | bool | list[str], -] - logger = logging.getLogger(__name__) # libp2p TLS Extension OID - Official libp2p specification @@ -312,7 +309,7 @@ class CertificateGenerator: x509.UnrecognizedExtension( oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data ), - critical=True, # This extension is critical for libp2p + critical=False, ) .sign(cert_private_key, hashes.SHA256()) ) @@ -407,6 +404,269 @@ class PeerAuthenticator: ) from e +@dataclass +class QUICTLSSecurityConfig: + """ + Type-safe TLS security configuration for QUIC transport. + """ + + # Core TLS components (required) + certificate: Certificate + private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey] + + # Certificate chain (optional) + certificate_chain: List[Certificate] = field(default_factory=list) + + # ALPN protocols + alpn_protocols: List[str] = field(default_factory=lambda: ["libp2p"]) + + # TLS verification settings + verify_mode: Union[bool, ssl.VerifyMode] = False + check_hostname: bool = False + + # Optional peer ID for validation + peer_id: Optional[ID] = None + + # Configuration metadata + is_client_config: bool = False + config_name: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + self._validate() + + def _validate(self) -> None: + """Validate the TLS configuration.""" + if self.certificate is None: + raise ValueError("Certificate is required") + + if self.private_key is None: + raise ValueError("Private key is required") + + if not isinstance(self.certificate, x509.Certificate): + raise TypeError( + f"Certificate must be x509.Certificate, got {type(self.certificate)}" + ) + + if not isinstance( + self.private_key, (ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey) + ): + raise TypeError( + f"Private key must be EC or RSA key, got {type(self.private_key)}" + ) + + if not self.alpn_protocols: + raise ValueError("At least one ALPN protocol is required") + + def to_dict(self) -> dict: + """ + Convert to dictionary format for compatibility with existing code. + + Returns: + Dictionary compatible with the original TSecurityConfig format + + """ + return { + "certificate": self.certificate, + "private_key": self.private_key, + "certificate_chain": self.certificate_chain.copy(), + "alpn_protocols": self.alpn_protocols.copy(), + "verify_mode": self.verify_mode, + "check_hostname": self.check_hostname, + } + + @classmethod + def from_dict(cls, config_dict: dict, **kwargs) -> "QUICTLSSecurityConfig": + """ + Create instance from dictionary format. + + Args: + config_dict: Dictionary in TSecurityConfig format + **kwargs: Additional parameters for the config + + Returns: + QUICTLSSecurityConfig instance + + """ + return cls( + certificate=config_dict["certificate"], + private_key=config_dict["private_key"], + certificate_chain=config_dict.get("certificate_chain", []), + alpn_protocols=config_dict.get("alpn_protocols", ["libp2p"]), + verify_mode=config_dict.get("verify_mode", False), + check_hostname=config_dict.get("check_hostname", False), + **kwargs, + ) + + def validate_certificate_key_match(self) -> bool: + """ + Validate that the certificate and private key match. + + Returns: + True if certificate and private key match + + """ + try: + from cryptography.hazmat.primitives import serialization + + # Get public keys from both certificate and private key + cert_public_key = self.certificate.public_key() + private_public_key = self.private_key.public_key() + + # Compare their PEM representations + cert_pub_pem = cert_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + private_pub_pem = private_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return cert_pub_pem == private_pub_pem + + except Exception: + return False + + def has_libp2p_extension(self) -> bool: + """ + Check if the certificate has the required libp2p extension. + + Returns: + True if libp2p extension is present + + """ + try: + libp2p_oid = "1.3.6.1.4.1.53594.1.1" + for ext in self.certificate.extensions: + if str(ext.oid) == libp2p_oid: + return True + return False + except Exception: + return False + + def is_certificate_valid(self) -> bool: + """ + Check if the certificate is currently valid (not expired). + + Returns: + True if certificate is valid + + """ + try: + from datetime import datetime + + now = datetime.utcnow() + return ( + self.certificate.not_valid_before + <= now + <= self.certificate.not_valid_after + ) + except Exception: + return False + + def get_certificate_info(self) -> dict: + """ + Get certificate information for debugging. + + Returns: + Dictionary with certificate details + + """ + try: + return { + "subject": str(self.certificate.subject), + "issuer": str(self.certificate.issuer), + "serial_number": self.certificate.serial_number, + "not_valid_before": self.certificate.not_valid_before, + "not_valid_after": self.certificate.not_valid_after, + "has_libp2p_extension": self.has_libp2p_extension(), + "is_valid": self.is_certificate_valid(), + "certificate_key_match": self.validate_certificate_key_match(), + } + except Exception as e: + return {"error": str(e)} + + def debug_print(self) -> None: + """Print debugging information about this configuration.""" + print(f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===") + print(f"Is client config: {self.is_client_config}") + print(f"ALPN protocols: {self.alpn_protocols}") + print(f"Verify mode: {self.verify_mode}") + print(f"Check hostname: {self.check_hostname}") + print(f"Certificate chain length: {len(self.certificate_chain)}") + + cert_info = self.get_certificate_info() + for key, value in cert_info.items(): + print(f"Certificate {key}: {value}") + + print(f"Private key type: {type(self.private_key).__name__}") + if hasattr(self.private_key, "key_size"): + print(f"Private key size: {self.private_key.key_size}") + + +def create_server_tls_config( + certificate: Certificate, + private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey], + peer_id: Optional[ID] = None, + **kwargs, +) -> QUICTLSSecurityConfig: + """ + Create a server TLS configuration. + + Args: + certificate: X.509 certificate + private_key: Private key corresponding to certificate + peer_id: Optional peer ID for validation + **kwargs: Additional configuration parameters + + Returns: + Server TLS configuration + + """ + return QUICTLSSecurityConfig( + certificate=certificate, + private_key=private_key, + peer_id=peer_id, + is_client_config=False, + config_name="server", + verify_mode=False, # Server doesn't verify client certs in libp2p + check_hostname=False, + **kwargs, + ) + + +def create_client_tls_config( + certificate: Certificate, + private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey], + peer_id: Optional[ID] = None, + **kwargs, +) -> QUICTLSSecurityConfig: + """ + Create a client TLS configuration. + + Args: + certificate: X.509 certificate + private_key: Private key corresponding to certificate + peer_id: Optional peer ID for validation + **kwargs: Additional configuration parameters + + Returns: + Client TLS configuration + + """ + return QUICTLSSecurityConfig( + certificate=certificate, + private_key=private_key, + peer_id=peer_id, + is_client_config=True, + config_name="client", + verify_mode=False, # Client doesn't verify server certs in libp2p + check_hostname=False, + **kwargs, + ) + + class QUICTLSConfigManager: """ Manages TLS configuration for QUIC transport with libp2p security. @@ -424,44 +684,40 @@ class QUICTLSConfigManager: libp2p_private_key, peer_id ) - def create_server_config( - self, - ) -> TSecurityConfig: + def create_server_config(self) -> QUICTLSSecurityConfig: """ - Create aioquic server configuration with libp2p TLS settings. - Returns cryptography objects instead of DER bytes. + Create server configuration using the new class-based approach. Returns: - Configuration dictionary for aioquic QuicConfiguration + QUICTLSSecurityConfig instance for server """ - config: TSecurityConfig = { - "certificate": self.tls_config.certificate, - "private_key": self.tls_config.private_key, - "certificate_chain": [], - "alpn_protocols": ["libp2p"], - "verify_mode": False, - "check_hostname": False, - } + config = create_server_tls_config( + certificate=self.tls_config.certificate, + private_key=self.tls_config.private_key, + peer_id=self.peer_id, + ) + + print("🔧 SECURITY: Created server config") + config.debug_print() return config - def create_client_config(self) -> TSecurityConfig: + def create_client_config(self) -> QUICTLSSecurityConfig: """ - Create aioquic client configuration with libp2p TLS settings. - Returns cryptography objects instead of DER bytes. + Create client configuration using the new class-based approach. Returns: - Configuration dictionary for aioquic QuicConfiguration + QUICTLSSecurityConfig instance for client """ - config: TSecurityConfig = { - "certificate": self.tls_config.certificate, - "private_key": self.tls_config.private_key, - "certificate_chain": [], - "alpn_protocols": ["libp2p"], - "verify_mode": False, - "check_hostname": False, - } + config = create_client_tls_config( + certificate=self.tls_config.certificate, + private_key=self.tls_config.private_key, + peer_id=self.peer_id, + ) + + print("🔧 SECURITY: Created client config") + config.debug_print() return config def verify_peer_identity( diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py index 30218a12..8aed36f0 100644 --- a/libp2p/transport/quic/transport.py +++ b/libp2p/transport/quic/transport.py @@ -5,7 +5,6 @@ Based on aioquic library with interface consistency to go-libp2p and js-libp2p. Updated to include Module 5 security integration. """ -from collections.abc import Iterable import copy import logging import sys @@ -31,7 +30,7 @@ from libp2p.custom_types import THandler, TProtocol from libp2p.peer.id import ( ID, ) -from libp2p.transport.quic.security import TSecurityConfig +from libp2p.transport.quic.security import QUICTLSSecurityConfig from libp2p.transport.quic.utils import ( get_alpn_protocols, is_quic_multiaddr, @@ -192,7 +191,7 @@ class QUICTransport(ITransport): ) from e def _apply_tls_configuration( - self, config: QuicConfiguration, tls_config: TSecurityConfig + self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig ) -> None: """ Apply TLS configuration to a QUIC configuration using aioquic's actual API. @@ -203,52 +202,47 @@ class QUICTransport(ITransport): """ try: - # Set certificate and private key directly on the configuration - # aioquic expects cryptography objects, not DER bytes - if "certificate" in tls_config and "private_key" in tls_config: - # The security manager should return cryptography objects - # not DER bytes, but if it returns DER bytes, we need to handle that - certificate = tls_config["certificate"] - private_key = tls_config["private_key"] - # Check if we received DER bytes and need - # to convert to cryptography objects - if isinstance(certificate, bytes): + # The security manager should return cryptography objects + # not DER bytes, but if it returns DER bytes, we need to handle that + certificate = tls_config.certificate + private_key = tls_config.private_key + + # Check if we received DER bytes and need + # to convert to cryptography objects + if isinstance(certificate, bytes): + from cryptography import x509 + + certificate = x509.load_der_x509_certificate(certificate) + + if isinstance(private_key, bytes): + from cryptography.hazmat.primitives import serialization + + private_key = serialization.load_der_private_key( # type: ignore + private_key, password=None + ) + + # Set directly on the configuration object + config.certificate = certificate + config.private_key = private_key + + # Handle certificate chain if provided + certificate_chain = tls_config.certificate_chain + # Convert DER bytes to cryptography objects if needed + chain_objects = [] + for cert in certificate_chain: + if isinstance(cert, bytes): from cryptography import x509 - certificate = x509.load_der_x509_certificate(certificate) - - if isinstance(private_key, bytes): - from cryptography.hazmat.primitives import serialization - - private_key = serialization.load_der_private_key( # type: ignore - private_key, password=None - ) - - # Set directly on the configuration object - config.certificate = certificate - config.private_key = private_key - - # Handle certificate chain if provided - certificate_chain = tls_config.get("certificate_chain", []) - if certificate_chain and isinstance(certificate_chain, Iterable): - # Convert DER bytes to cryptography objects if needed - chain_objects = [] - for cert in certificate_chain: - if isinstance(cert, bytes): - from cryptography import x509 - - cert = x509.load_der_x509_certificate(cert) - chain_objects.append(cert) - config.certificate_chain = chain_objects + cert = x509.load_der_x509_certificate(cert) + chain_objects.append(cert) + config.certificate_chain = chain_objects # Set ALPN protocols - if "alpn_protocols" in tls_config: - config.alpn_protocols = tls_config["alpn_protocols"] # type: ignore + config.alpn_protocols = tls_config.alpn_protocols # Set certificate verification mode - if "verify_mode" in tls_config: - config.verify_mode = tls_config["verify_mode"] # type: ignore + config.verify_mode = tls_config.verify_mode logger.debug("Successfully applied TLS configuration to QUIC config")