diff --git a/examples/echo/debug_handshake.py b/examples/echo/debug_handshake.py new file mode 100644 index 00000000..fb823d0b --- /dev/null +++ b/examples/echo/debug_handshake.py @@ -0,0 +1,371 @@ +def debug_quic_connection_state(conn, name="Connection"): + """Enhanced debugging function for QUIC connection state.""" + print(f"\nšŸ” === {name} Debug Info ===") + + # Basic connection state + print(f"State: {getattr(conn, '_state', 'unknown')}") + print(f"Handshake complete: {getattr(conn, '_handshake_complete', False)}") + + # Connection IDs + if hasattr(conn, "_host_connection_id"): + print( + f"Host CID: {conn._host_connection_id.hex() if conn._host_connection_id else 'None'}" + ) + if hasattr(conn, "_peer_connection_id"): + print( + f"Peer CID: {conn._peer_connection_id.hex() if conn._peer_connection_id else 'None'}" + ) + + # Check for connection ID sequences + if hasattr(conn, "_local_connection_ids"): + print( + f"Local CID sequence: {[cid.cid.hex() for cid in conn._local_connection_ids]}" + ) + if hasattr(conn, "_remote_connection_ids"): + print( + f"Remote CID sequence: {[cid.cid.hex() for cid in conn._remote_connection_ids]}" + ) + + # TLS state + if hasattr(conn, "tls") and conn.tls: + tls_state = getattr(conn.tls, "state", "unknown") + print(f"TLS state: {tls_state}") + + # Check for certificates + peer_cert = getattr(conn.tls, "_peer_certificate", None) + print(f"Has peer certificate: {peer_cert is not None}") + + # Transport parameters + if hasattr(conn, "_remote_transport_parameters"): + params = conn._remote_transport_parameters + if params: + print(f"Remote transport parameters received: {len(params)} params") + + print(f"=== End {name} Debug ===\n") + + +def debug_firstflight_event(server_conn, name="Server"): + """Debug connection ID changes specifically around FIRSTFLIGHT event.""" + print(f"\nšŸŽÆ === {name} FIRSTFLIGHT Event Debug ===") + + # Connection state + state = getattr(server_conn, "_state", "unknown") + print(f"Connection State: {state}") + + # Connection IDs + peer_cid = getattr(server_conn, "_peer_connection_id", None) + host_cid = getattr(server_conn, "_host_connection_id", None) + original_dcid = getattr(server_conn, "original_destination_connection_id", None) + + print(f"Peer CID: {peer_cid.hex() if peer_cid else 'None'}") + print(f"Host CID: {host_cid.hex() if host_cid else 'None'}") + print(f"Original DCID: {original_dcid.hex() if original_dcid else 'None'}") + + print(f"=== End {name} FIRSTFLIGHT Debug ===\n") + + +def create_minimal_quic_test(): + """Simplified test to isolate FIRSTFLIGHT connection ID issues.""" + print("\n=== MINIMAL QUIC FIRSTFLIGHT CONNECTION ID TEST ===") + + from time import time + from aioquic.quic.configuration import QuicConfiguration + from aioquic.quic.connection import QuicConnection + from aioquic.buffer import Buffer + from aioquic.quic.packet import pull_quic_header + + # Minimal configs without certificates first + client_config = QuicConfiguration( + is_client=True, alpn_protocols=["libp2p"], connection_id_length=8 + ) + + server_config = QuicConfiguration( + is_client=False, alpn_protocols=["libp2p"], connection_id_length=8 + ) + + # Create client and connect + client_conn = QuicConnection(configuration=client_config) + server_addr = ("127.0.0.1", 4321) + + print("šŸ”— Client calling connect()...") + client_conn.connect(server_addr, now=time()) + + # Debug client state after connect + debug_quic_connection_state(client_conn, "Client After Connect") + + # Get initial client packet + initial_packets = client_conn.datagrams_to_send(now=time()) + if not initial_packets: + print("āŒ No initial packets from client") + return False + + initial_packet = initial_packets[0][0] + + # Parse header to get client's source CID (what server should use as peer CID) + header = pull_quic_header(Buffer(data=initial_packet), host_cid_length=8) + client_source_cid = header.source_cid + client_dest_cid = header.destination_cid + + print(f"šŸ“¦ Initial packet analysis:") + print( + f" Client Source CID: {client_source_cid.hex()} (server should use as peer CID)" + ) + print(f" Client Dest CID: {client_dest_cid.hex()}") + + # Create server with proper ODCID + print( + f"\nšŸ—ļø Creating server with original_destination_connection_id={client_dest_cid.hex()}..." + ) + server_conn = QuicConnection( + configuration=server_config, + original_destination_connection_id=client_dest_cid, + ) + + # Debug server state after creation (before FIRSTFLIGHT) + debug_firstflight_event(server_conn, "Server After Creation (Pre-FIRSTFLIGHT)") + + # šŸŽÆ CRITICAL: Process initial packet (this triggers FIRSTFLIGHT event) + print(f"šŸš€ Processing initial packet (triggering FIRSTFLIGHT)...") + client_addr = ("127.0.0.1", 1234) + + # Before receive_datagram + print(f"šŸ“Š BEFORE receive_datagram (FIRSTFLIGHT):") + print(f" Server state: {getattr(server_conn, '_state', 'unknown')}") + print( + f" Server peer CID: {server_conn._peer_cid.cid.hex()}" + ) + print(f" Expected peer CID after FIRSTFLIGHT: {client_source_cid.hex()}") + + # This call triggers FIRSTFLIGHT: FIRSTFLIGHT -> CONNECTED + server_conn.receive_datagram(initial_packet, client_addr, now=time()) + + # After receive_datagram (FIRSTFLIGHT should have happened) + print(f"šŸ“Š AFTER receive_datagram (Post-FIRSTFLIGHT):") + print(f" Server state: {getattr(server_conn, '_state', 'unknown')}") + print( + f" Server peer CID: {server_conn._peer_cid.cid.hex()}" + ) + + # Check if FIRSTFLIGHT set peer CID correctly + actual_peer_cid = server_conn._peer_cid.cid + if actual_peer_cid == client_source_cid: + print("āœ… FIRSTFLIGHT correctly set peer CID from client source CID") + firstflight_success = True + else: + print("āŒ FIRSTFLIGHT BUG: peer CID not set correctly!") + print(f" Expected: {client_source_cid.hex()}") + print(f" Actual: {actual_peer_cid.hex() if actual_peer_cid else 'None'}") + firstflight_success = False + + # Debug both connections after FIRSTFLIGHT + debug_firstflight_event(server_conn, "Server After FIRSTFLIGHT") + debug_quic_connection_state(client_conn, "Client After Server Processing") + + # Check server response packets + print(f"\nšŸ“¤ Checking server response packets...") + server_packets = server_conn.datagrams_to_send(now=time()) + if server_packets: + response_packet = server_packets[0][0] + response_header = pull_quic_header( + Buffer(data=response_packet), host_cid_length=8 + ) + + print(f"šŸ“Š Server response packet:") + print(f" Source CID: {response_header.source_cid.hex()}") + print(f" Dest CID: {response_header.destination_cid.hex()}") + print(f" Expected dest CID: {client_source_cid.hex()}") + + # Final verification + if response_header.destination_cid == client_source_cid: + print("āœ… Server response uses correct destination CID!") + return True + else: + print(f"āŒ Server response uses WRONG destination CID!") + print(f" This proves the FIRSTFLIGHT bug - peer CID not set correctly") + print(f" Expected: {client_source_cid.hex()}") + print(f" Actual: {response_header.destination_cid.hex()}") + return False + else: + print("āŒ Server did not generate response packet") + return False + + +def create_minimal_quic_test_with_config(client_config, server_config): + """Run FIRSTFLIGHT test with provided configurations.""" + from time import time + from aioquic.buffer import Buffer + from aioquic.quic.connection import QuicConnection + from aioquic.quic.packet import pull_quic_header + + print("\n=== FIRSTFLIGHT TEST WITH CERTIFICATES ===") + + # Create client and connect + client_conn = QuicConnection(configuration=client_config) + server_addr = ("127.0.0.1", 4321) + + print("šŸ”— Client calling connect() with certificates...") + client_conn.connect(server_addr, now=time()) + + # Get initial packets and extract client source CID + initial_packets = client_conn.datagrams_to_send(now=time()) + if not initial_packets: + print("āŒ No initial packets from client") + return False + + # Extract client source CID from initial packet + initial_packet = initial_packets[0][0] + header = pull_quic_header(Buffer(data=initial_packet), host_cid_length=8) + client_source_cid = header.source_cid + + print(f"šŸ“¦ Client source CID (expected server peer CID): {client_source_cid.hex()}") + + # Create server with client's source CID as original destination + server_conn = QuicConnection( + configuration=server_config, + original_destination_connection_id=client_source_cid, + ) + + # Debug server before FIRSTFLIGHT + print(f"\nšŸ“Š BEFORE FIRSTFLIGHT (server creation):") + print(f" Server state: {getattr(server_conn, '_state', 'unknown')}") + print( + f" Server peer CID: {server_conn._peer_cid.cid.hex()}" + ) + print( + f" Server original DCID: {server_conn.original_destination_connection_id.hex()}" + ) + + # Process initial packet (triggers FIRSTFLIGHT) + client_addr = ("127.0.0.1", 1234) + + print(f"\nšŸš€ Triggering FIRSTFLIGHT by processing initial packet...") + for datagram, _ in initial_packets: + header = pull_quic_header(Buffer(data=datagram)) + print( + f" Processing packet: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}" + ) + + # This triggers FIRSTFLIGHT + server_conn.receive_datagram(datagram, client_addr, now=time()) + + # Debug immediately after FIRSTFLIGHT + print(f"\nšŸ“Š AFTER FIRSTFLIGHT:") + print(f" Server state: {getattr(server_conn, '_state', 'unknown')}") + print( + f" Server peer CID: {server_conn._peer_cid.cid.hex()}" + ) + print(f" Expected peer CID: {header.source_cid.hex()}") + + # Check if FIRSTFLIGHT worked correctly + actual_peer_cid = getattr(server_conn, "_peer_connection_id", None) + if actual_peer_cid == header.source_cid: + print("āœ… FIRSTFLIGHT correctly set peer CID") + else: + print("āŒ FIRSTFLIGHT failed to set peer CID correctly") + print(f" This is the root cause of the handshake failure!") + + # Check server response + server_packets = server_conn.datagrams_to_send(now=time()) + if server_packets: + response_packet = server_packets[0][0] + response_header = pull_quic_header( + Buffer(data=response_packet), host_cid_length=8 + ) + + print(f"\nšŸ“¤ Server response analysis:") + print(f" Response dest CID: {response_header.destination_cid.hex()}") + print(f" Expected dest CID: {client_source_cid.hex()}") + + if response_header.destination_cid == client_source_cid: + print("āœ… Server response uses correct destination CID!") + return True + else: + print("āŒ FIRSTFLIGHT bug confirmed - wrong destination CID in response!") + print( + " This proves aioquic doesn't set peer CID correctly during FIRSTFLIGHT" + ) + return False + + print("āŒ No server response packets") + return False + + +async def test_with_certificates(): + """Test with proper certificate setup and FIRSTFLIGHT debugging.""" + print("\n=== CERTIFICATE-BASED FIRSTFLIGHT TEST ===") + + # Import your existing certificate creation functions + from libp2p.crypto.ed25519 import create_new_key_pair + from libp2p.peer.id import ID + from libp2p.transport.quic.security import create_quic_security_transport + + # Create security configs + client_key_pair = create_new_key_pair() + server_key_pair = create_new_key_pair() + + client_security_config = create_quic_security_transport( + client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key) + ) + server_security_config = create_quic_security_transport( + server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key) + ) + + # Apply the minimal test logic with certificates + from aioquic.quic.configuration import QuicConfiguration + + client_config = QuicConfiguration( + is_client=True, alpn_protocols=["libp2p"], connection_id_length=8 + ) + client_config.certificate = client_security_config.tls_config.certificate + client_config.private_key = client_security_config.tls_config.private_key + client_config.verify_mode = ( + client_security_config.create_client_config().verify_mode + ) + + server_config = QuicConfiguration( + is_client=False, alpn_protocols=["libp2p"], connection_id_length=8 + ) + server_config.certificate = server_security_config.tls_config.certificate + server_config.private_key = server_security_config.tls_config.private_key + server_config.verify_mode = ( + server_security_config.create_server_config().verify_mode + ) + + # Run the FIRSTFLIGHT test with certificates + return create_minimal_quic_test_with_config(client_config, server_config) + + +async def main(): + print("šŸŽÆ Testing FIRSTFLIGHT connection ID behavior...") + + # # First test without certificates + # print("\n" + "=" * 60) + # print("PHASE 1: Testing FIRSTFLIGHT without certificates") + # print("=" * 60) + # minimal_success = create_minimal_quic_test() + + # Then test with certificates + print("\n" + "=" * 60) + print("PHASE 2: Testing FIRSTFLIGHT with certificates") + print("=" * 60) + cert_success = await test_with_certificates() + + # Summary + print("\n" + "=" * 60) + print("FIRSTFLIGHT TEST SUMMARY") + print("=" * 60) + # print(f"Minimal test (no certs): {'āœ… PASS' if minimal_success else 'āŒ FAIL'}") + print(f"Certificate test: {'āœ… PASS' if cert_success else 'āŒ FAIL'}") + + if not cert_success: + print("\nšŸ”„ FIRSTFLIGHT BUG CONFIRMED:") + print(" - aioquic fails to set peer CID correctly during FIRSTFLIGHT event") + print(" - Server uses wrong destination CID in response packets") + print(" - Client drops responses → handshake fails") + print(" - Fix: Override _peer_connection_id after receive_datagram()") + + +if __name__ == "__main__": + import trio + + trio.run(main) diff --git a/examples/echo/test_handshake.py b/examples/echo/test_handshake.py new file mode 100644 index 00000000..e04b083f --- /dev/null +++ b/examples/echo/test_handshake.py @@ -0,0 +1,205 @@ +from aioquic._buffer import Buffer +from aioquic.quic.packet import pull_quic_header +from aioquic.quic.connection import QuicConnection +from aioquic.quic.configuration import QuicConfiguration +from tempfile import NamedTemporaryFile +from libp2p.peer.id import ID +from libp2p.transport.quic.security import create_quic_security_transport +from libp2p.crypto.ed25519 import create_new_key_pair +from time import time +import os +import trio + + +async def test_full_handshake_and_certificate_exchange(): + """ + Test a full handshake to ensure it completes and peer certificates are exchanged. + FIXED VERSION: Corrects connection ID management and address handling. + """ + print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (FIXED) ===") + + # 1. Generate KeyPairs and create libp2p security configs for client and server. + client_key_pair = create_new_key_pair() + server_key_pair = create_new_key_pair() + + client_security_config = create_quic_security_transport( + client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key) + ) + server_security_config = create_quic_security_transport( + server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key) + ) + print("āœ… libp2p security configs created.") + + # 2. Create aioquic configurations with consistent settings + client_secrets_log_file = NamedTemporaryFile( + mode="w", delete=False, suffix="-client.log" + ) + client_aioquic_config = QuicConfiguration( + is_client=True, + alpn_protocols=["libp2p"], + secrets_log_file=client_secrets_log_file, + connection_id_length=8, # Set consistent CID length + ) + client_aioquic_config.certificate = client_security_config.tls_config.certificate + client_aioquic_config.private_key = client_security_config.tls_config.private_key + client_aioquic_config.verify_mode = ( + client_security_config.create_client_config().verify_mode + ) + + server_secrets_log_file = NamedTemporaryFile( + mode="w", delete=False, suffix="-server.log" + ) + server_aioquic_config = QuicConfiguration( + is_client=False, + alpn_protocols=["libp2p"], + secrets_log_file=server_secrets_log_file, + connection_id_length=8, # Set consistent CID length + ) + server_aioquic_config.certificate = server_security_config.tls_config.certificate + server_aioquic_config.private_key = server_security_config.tls_config.private_key + server_aioquic_config.verify_mode = ( + server_security_config.create_server_config().verify_mode + ) + print("āœ… aioquic configurations created and configured.") + print(f"šŸ”‘ Client secrets will be logged to: {client_secrets_log_file.name}") + print(f"šŸ”‘ Server secrets will be logged to: {server_secrets_log_file.name}") + + # 3. Use consistent addresses - this is crucial! + # The client will connect TO the server address, but packets will come FROM client address + client_address = ("127.0.0.1", 1234) # Client binds to this + server_address = ("127.0.0.1", 4321) # Server binds to this + + # 4. Create client connection and initiate connection + client_conn = QuicConnection(configuration=client_aioquic_config) + # Client connects to server address - this sets up the initial packet with proper CIDs + client_conn.connect(server_address, now=time()) + print("āœ… Client connection initiated.") + + # 5. Get the initial client packet and extract ODCID properly + client_datagrams = client_conn.datagrams_to_send(now=time()) + if not client_datagrams: + raise AssertionError("āŒ Client did not generate initial packet") + + client_initial_packet = client_datagrams[0][0] + header = pull_quic_header(Buffer(data=client_initial_packet), host_cid_length=8) + original_dcid = header.destination_cid + client_source_cid = header.source_cid + + print(f"šŸ“Š Client ODCID: {original_dcid.hex()}") + print(f"šŸ“Š Client source CID: {client_source_cid.hex()}") + + # 6. Create server connection with the correct ODCID + server_conn = QuicConnection( + configuration=server_aioquic_config, + original_destination_connection_id=original_dcid, + ) + print("āœ… Server connection created with correct ODCID.") + + # 7. Feed the initial client packet to server + # IMPORTANT: Use client_address as the source for the packet + for datagram, _ in client_datagrams: + header = pull_quic_header(Buffer(data=datagram)) + print( + f"šŸ“¤ Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}" + ) + server_conn.receive_datagram(datagram, client_address, now=time()) + + # 8. Manual handshake loop with proper packet tracking + max_duration_s = 3 # Increased timeout + start_time = time() + packet_count = 0 + + while time() - start_time < max_duration_s: + # Process client -> server packets + client_packets = list(client_conn.datagrams_to_send(now=time())) + for datagram, _ in client_packets: + header = pull_quic_header(Buffer(data=datagram)) + print( + f"šŸ“¤ Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}" + ) + server_conn.receive_datagram(datagram, client_address, now=time()) + packet_count += 1 + + # Process server -> client packets + server_packets = list(server_conn.datagrams_to_send(now=time())) + for datagram, _ in server_packets: + header = pull_quic_header(Buffer(data=datagram)) + print( + f"šŸ“¤ Server -> Client: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}" + ) + # CRITICAL: Server sends back to client_address, not server_address + client_conn.receive_datagram(datagram, server_address, now=time()) + packet_count += 1 + + # Check for completion + client_complete = getattr(client_conn, "_handshake_complete", False) + server_complete = getattr(server_conn, "_handshake_complete", False) + + print( + f"šŸ”„ Handshake status: Client={client_complete}, Server={server_complete}, Packets={packet_count}" + ) + + if client_complete and server_complete: + print("šŸŽ‰ Handshake completed for both peers!") + break + + # If no packets were exchanged in this iteration, wait a bit + if not client_packets and not server_packets: + await trio.sleep(0.01) + + # Safety check - if too many packets, something is wrong + if packet_count > 50: + print("āš ļø Too many packets exchanged, possible handshake loop") + break + + # 9. Enhanced handshake completion checks + client_handshake_complete = getattr(client_conn, "_handshake_complete", False) + server_handshake_complete = getattr(server_conn, "_handshake_complete", False) + + # Debug additional state information + print(f"šŸ” Final client state: {getattr(client_conn, '_state', 'unknown')}") + print(f"šŸ” Final server state: {getattr(server_conn, '_state', 'unknown')}") + + if hasattr(client_conn, "tls") and client_conn.tls: + print(f"šŸ” Client TLS state: {getattr(client_conn.tls, 'state', 'unknown')}") + if hasattr(server_conn, "tls") and server_conn.tls: + print(f"šŸ” Server TLS state: {getattr(server_conn.tls, 'state', 'unknown')}") + + # 10. Cleanup and assertions + client_secrets_log_file.close() + server_secrets_log_file.close() + os.unlink(client_secrets_log_file.name) + os.unlink(server_secrets_log_file.name) + + # Final assertions + assert client_handshake_complete, ( + f"āŒ Client handshake did not complete. " + f"State: {getattr(client_conn, '_state', 'unknown')}, " + f"Packets: {packet_count}" + ) + assert server_handshake_complete, ( + f"āŒ Server handshake did not complete. " + f"State: {getattr(server_conn, '_state', 'unknown')}, " + f"Packets: {packet_count}" + ) + print("āœ… Handshake completed for both peers.") + + # Certificate exchange verification + client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None) + server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None) + + assert client_peer_cert is not None, ( + "āŒ Client FAILED to receive server certificate." + ) + print("āœ… Client successfully received server certificate.") + + assert server_peer_cert is not None, ( + "āŒ Server FAILED to receive client certificate." + ) + print("āœ… Server successfully received client certificate.") + + print("šŸŽ‰ Test Passed: Full handshake and certificate exchange successful.") + return True + +if __name__ == "__main__": + trio.run(test_full_handshake_and_certificate_exchange) \ No newline at end of file diff --git a/examples/echo/test_quic.py b/examples/echo/test_quic.py index 29d62cab..ea97bd20 100644 --- a/examples/echo/test_quic.py +++ b/examples/echo/test_quic.py @@ -1,20 +1,39 @@ #!/usr/bin/env python3 + + """ Fixed QUIC handshake test to debug connection issues. """ import logging +import os from pathlib import Path import secrets import sys +from tempfile import NamedTemporaryFile +from time import time +from aioquic._buffer import Buffer +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import QuicConnection +from aioquic.quic.logger import QuicFileLogger +from aioquic.quic.packet import pull_quic_header import trio from libp2p.crypto.ed25519 import create_new_key_pair -from libp2p.transport.quic.security import LIBP2P_TLS_EXTENSION_OID +from libp2p.peer.id import ID +from libp2p.transport.quic.security import ( + LIBP2P_TLS_EXTENSION_OID, + create_quic_security_transport, +) from libp2p.transport.quic.transport import QUICTransport, QUICTransportConfig from libp2p.transport.quic.utils import create_quic_multiaddr +logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG +) + + # Adjust this path to your project structure project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) @@ -256,10 +275,162 @@ async def test_server_startup(): return False +async def test_full_handshake_and_certificate_exchange(): + """ + Test a full handshake to ensure it completes and peer certificates are exchanged. + This version is corrected to use the actual APIs available in the codebase. + """ + print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (CORRECTED) ===") + + # 1. Generate KeyPairs and create libp2p security configs for client and server. + # The `create_quic_security_transport` function from `test_quic.py` is the + # correct helper to use, and it requires a `KeyPair` argument. + client_key_pair = create_new_key_pair() + server_key_pair = create_new_key_pair() + + # This is the correct way to get the security configuration objects. + client_security_config = create_quic_security_transport( + client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key) + ) + server_security_config = create_quic_security_transport( + server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key) + ) + print("āœ… libp2p security configs created.") + + # 2. Create aioquic configurations and manually apply security settings, + # mimicking what the `QUICTransport` class does internally. + client_secrets_log_file = NamedTemporaryFile( + mode="w", delete=False, suffix="-client.log" + ) + client_aioquic_config = QuicConfiguration( + is_client=True, + alpn_protocols=["libp2p"], + secrets_log_file=client_secrets_log_file, + ) + client_aioquic_config.certificate = client_security_config.tls_config.certificate + client_aioquic_config.private_key = client_security_config.tls_config.private_key + client_aioquic_config.verify_mode = ( + client_security_config.create_client_config().verify_mode + ) + client_aioquic_config.quic_logger = QuicFileLogger( + "/home/akmo/GitHub/py-libp2p/examples/echo/logs" + ) + + server_secrets_log_file = NamedTemporaryFile( + mode="w", delete=False, suffix="-server.log" + ) + + server_aioquic_config = QuicConfiguration( + is_client=False, + alpn_protocols=["libp2p"], + secrets_log_file=server_secrets_log_file, + ) + server_aioquic_config.certificate = server_security_config.tls_config.certificate + server_aioquic_config.private_key = server_security_config.tls_config.private_key + server_aioquic_config.verify_mode = ( + server_security_config.create_server_config().verify_mode + ) + server_aioquic_config.quic_logger = QuicFileLogger( + "/home/akmo/GitHub/py-libp2p/examples/echo/logs" + ) + print("āœ… aioquic configurations created and configured.") + print(f"šŸ”‘ Client secrets will be logged to: {client_secrets_log_file.name}") + print(f"šŸ”‘ Server secrets will be logged to: {server_secrets_log_file.name}") + + # 3. Instantiate client, initiate its `connect` call, and get the ODCID for the server. + client_address = ("127.0.0.1", 1234) + server_address = ("127.0.0.1", 4321) + + client_aioquic_config.connection_id_length = 8 + client_conn = QuicConnection(configuration=client_aioquic_config) + client_conn.connect(server_address, now=time()) + print("āœ… aioquic connections instantiated correctly.") + + print("šŸ”§ Client CIDs") + print(f"Local Init CID: ", client_conn._local_initial_source_connection_id.hex()) + print( + f"Remote Init CID: ", + (client_conn._remote_initial_source_connection_id or b"").hex(), + ) + print( + f"Original Destination CID: ", + client_conn.original_destination_connection_id.hex(), + ) + print(f"Host CID: {client_conn._host_cids[0].cid.hex()}") + + # 4. Instantiate the server with the ODCID from the client. + server_aioquic_config.connection_id_length = 8 + server_conn = QuicConnection( + configuration=server_aioquic_config, + original_destination_connection_id=client_conn.original_destination_connection_id, + ) + print("āœ… aioquic connections instantiated correctly.") + + # 5. Manually drive the handshake process by exchanging datagrams. + max_duration_s = 5 + start_time = time() + + while time() - start_time < max_duration_s: + for datagram, _ in client_conn.datagrams_to_send(now=time()): + header = pull_quic_header(Buffer(data=datagram)) + print("Client packet source connection id", header.source_cid.hex()) + print("Client packet destination connection id", header.destination_cid.hex()) + print("--SERVER INJESTING CLIENT PACKET---") + server_conn.receive_datagram(datagram, client_address, now=time()) + + print( + f"Server remote initial source id: {(server_conn._remote_initial_source_connection_id or b'').hex()}" + ) + for datagram, _ in server_conn.datagrams_to_send(now=time()): + header = pull_quic_header(Buffer(data=datagram)) + print("Server packet source connection id", header.source_cid.hex()) + print("Server packet destination connection id", header.destination_cid.hex()) + print("--CLIENT INJESTING SERVER PACKET---") + client_conn.receive_datagram(datagram, server_address, now=time()) + + # Check for completion + if client_conn._handshake_complete and server_conn._handshake_complete: + break + + await trio.sleep(0.01) + + # 6. Assertions to verify the outcome. + assert client_conn._handshake_complete, "āŒ Client handshake did not complete." + assert server_conn._handshake_complete, "āŒ Server handshake did not complete." + print("āœ… Handshake completed for both peers.") + + # The key assertion: check if the peer certificate was received. + client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None) + server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None) + + client_secrets_log_file.close() + server_secrets_log_file.close() + os.unlink(client_secrets_log_file.name) + os.unlink(server_secrets_log_file.name) + + assert client_peer_cert is not None, ( + "āŒ Client FAILED to receive server certificate." + ) + print("āœ… Client successfully received server certificate.") + + assert server_peer_cert is not None, ( + "āŒ Server FAILED to receive client certificate." + ) + print("āœ… Server successfully received client certificate.") + + print("šŸŽ‰ Test Passed: Full handshake and certificate exchange successful.") + + async def main(): """Run all tests with better error handling.""" print("Starting QUIC diagnostic tests...") + handshake_ok = await test_full_handshake_and_certificate_exchange() + if not handshake_ok: + print("\nāŒ CRITICAL: Handshake failed!") + print("Apply the handshake fix and try again.") + return + # Test 1: Certificate generation cert_ok = await test_certificate_generation() if not cert_ok: diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py index 7a85e309..0f499817 100644 --- a/libp2p/transport/quic/listener.py +++ b/libp2p/transport/quic/listener.py @@ -276,9 +276,6 @@ class QUICListener(IListener): # Parse packet to extract connection information packet_info = self.parse_quic_packet(data) - print( - f"šŸ”§ DEBUG: Address mappings: {dict((k, v.hex()) for k, v in self._addr_to_cid.items())}" - ) print( f"šŸ”§ DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}" ) @@ -333,33 +330,6 @@ class QUICListener(IListener): ) return - # If no exact match, try address-based routing (connection ID might not match) - mapped_cid = self._addr_to_cid.get(addr) - if mapped_cid: - print( - f"šŸ”§ PACKET: Found address mapping {addr} -> {mapped_cid.hex()}" - ) - print( - f"šŸ”§ PACKET: Client dest_cid {dest_cid.hex()} != our cid {mapped_cid.hex()}" - ) - - if mapped_cid in self._connections: - print( - "āœ… PACKET: Using established connection via address mapping" - ) - connection = self._connections[mapped_cid] - await self._route_to_connection(connection, data, addr) - return - elif mapped_cid in self._pending_connections: - print( - "āœ… PACKET: Using pending connection via address mapping" - ) - quic_conn = self._pending_connections[mapped_cid] - await self._handle_pending_connection( - quic_conn, data, addr, mapped_cid - ) - return - # No existing connection found, create new one print(f"šŸ”§ PACKET: Creating new connection for {addr}") await self._handle_new_connection(data, addr, packet_info) @@ -491,10 +461,9 @@ class QUICListener(IListener): ) # Create QUIC connection with proper parameters for server - # CRITICAL FIX: Pass the original destination connection ID from the initial packet quic_conn = QuicConnection( configuration=server_config, - original_destination_connection_id=packet_info.destination_cid, # Use the original DCID from packet + original_destination_connection_id=packet_info.destination_cid, ) quic_conn._replenish_connection_ids() diff --git a/libp2p/transport/quic/security.py b/libp2p/transport/quic/security.py index 50683dab..b6fd1050 100644 --- a/libp2p/transport/quic/security.py +++ b/libp2p/transport/quic/security.py @@ -1,3 +1,4 @@ + """ QUIC Security implementation for py-libp2p Module 5. Implements libp2p TLS specification for QUIC transport with peer identity integration. @@ -15,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import ec, rsa from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from cryptography.x509.base import Certificate +from cryptography.x509.extensions import Extension, UnrecognizedExtension from cryptography.x509.oid import NameOID from libp2p.crypto.keys import PrivateKey, PublicKey @@ -128,57 +130,106 @@ class LibP2PExtensionHandler: ) from e @staticmethod - def parse_signed_key_extension(extension_data: bytes) -> tuple[PublicKey, bytes]: + def parse_signed_key_extension(extension: Extension) -> tuple[PublicKey, bytes]: """ - Parse the libp2p Public Key Extension to extract public key and signature. - - Args: - extension_data: The extension data bytes - - Returns: - Tuple of (libp2p_public_key, signature) - - Raises: - QUICCertificateError: If extension parsing fails - + Parse the libp2p Public Key Extension with enhanced debugging. """ try: + print(f"šŸ” Extension type: {type(extension)}") + print(f"šŸ” Extension.value type: {type(extension.value)}") + + # Extract the raw bytes from the extension + if isinstance(extension.value, UnrecognizedExtension): + # Use the .value property to get the bytes + raw_bytes = extension.value.value + print("šŸ” Extension is UnrecognizedExtension, using .value property") + else: + # Fallback if it's already bytes somehow + raw_bytes = extension.value + print("šŸ” Extension.value is already bytes") + + print(f"šŸ” Total extension length: {len(raw_bytes)} bytes") + print(f"šŸ” Extension hex (first 50 bytes): {raw_bytes[:50].hex()}") + + if not isinstance(raw_bytes, bytes): + raise QUICCertificateError(f"Expected bytes, got {type(raw_bytes)}") + offset = 0 # Parse public key length and data - if len(extension_data) < 4: + if len(raw_bytes) < 4: raise QUICCertificateError("Extension too short for public key length") public_key_length = int.from_bytes( - extension_data[offset : offset + 4], byteorder="big" + raw_bytes[offset : offset + 4], byteorder="big" ) + print(f"šŸ” Public key length: {public_key_length} bytes") offset += 4 - if len(extension_data) < offset + public_key_length: + if len(raw_bytes) < offset + public_key_length: raise QUICCertificateError("Extension too short for public key data") - public_key_bytes = extension_data[offset : offset + public_key_length] + public_key_bytes = raw_bytes[offset : offset + public_key_length] + print(f"šŸ” Public key data: {public_key_bytes.hex()}") offset += public_key_length + print(f"šŸ” Offset after public key: {offset}") # Parse signature length and data - if len(extension_data) < offset + 4: + if len(raw_bytes) < offset + 4: raise QUICCertificateError("Extension too short for signature length") signature_length = int.from_bytes( - extension_data[offset : offset + 4], byteorder="big" + raw_bytes[offset : offset + 4], byteorder="big" ) + print(f"šŸ” Signature length: {signature_length} bytes") offset += 4 + print(f"šŸ” Offset after signature length: {offset}") - if len(extension_data) < offset + signature_length: + if len(raw_bytes) < offset + signature_length: raise QUICCertificateError("Extension too short for signature data") - signature = extension_data[offset : offset + signature_length] + signature = raw_bytes[offset : offset + signature_length] + print(f"šŸ” Extracted signature length: {len(signature)} bytes") + print(f"šŸ” Signature hex (first 20 bytes): {signature[:20].hex()}") + print(f"šŸ” Signature starts with DER header: {signature[:2].hex() == '3045'}") + + # Detailed signature analysis + if len(signature) >= 2: + if signature[0] == 0x30: + der_length = signature[1] + print(f"šŸ” DER sequence length field: {der_length}") + print(f"šŸ” Expected DER total: {der_length + 2}") + print(f"šŸ” Actual signature length: {len(signature)}") + + if len(signature) != der_length + 2: + print(f"āš ļø DER length mismatch! Expected {der_length + 2}, got {len(signature)}") + # Try truncating to correct DER length + if der_length + 2 < len(signature): + print(f"šŸ”§ Truncating signature to correct DER length: {der_length + 2}") + signature = signature[:der_length + 2] + + # Check if we have extra data + expected_total = 4 + public_key_length + 4 + signature_length + print(f"šŸ” Expected total length: {expected_total}") + print(f"šŸ” Actual total length: {len(raw_bytes)}") + + if len(raw_bytes) > expected_total: + extra_bytes = len(raw_bytes) - expected_total + print(f"āš ļø Extra {extra_bytes} bytes detected!") + print(f"šŸ” Extra data: {raw_bytes[expected_total:].hex()}") + # Deserialize the public key public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes) + print(f"šŸ” Successfully deserialized public key: {type(public_key)}") + + print(f"šŸ” Final signature to return: {len(signature)} bytes") return public_key, signature except Exception as e: + print(f"āŒ Extension parsing failed: {e}") + import traceback + print(f"āŒ Traceback: {traceback.format_exc()}") raise QUICCertificateError( f"Failed to parse signed key extension: {e}" ) from e @@ -361,9 +412,15 @@ class PeerAuthenticator: if not libp2p_extension: raise QUICPeerVerificationError("Certificate missing libp2p extension") + assert libp2p_extension.value is not None + print(f"Extension type: {type(libp2p_extension)}") + print(f"Extension value type: {type(libp2p_extension.value)}") + if hasattr(libp2p_extension.value, "__len__"): + print(f"Extension value length: {len(libp2p_extension.value)}") + print(f"Extension value: {libp2p_extension.value}") # Parse the extension to get public key and signature public_key, signature = self.extension_handler.parse_signed_key_extension( - libp2p_extension.value + libp2p_extension ) # Get certificate public key for signature verification @@ -376,7 +433,7 @@ class PeerAuthenticator: signature_payload = b"libp2p-tls-handshake:" + cert_public_key_bytes try: - public_key.verify(signature, signature_payload) + public_key.verify(signature_payload, signature) except Exception as e: raise QUICPeerVerificationError( f"Invalid signature in libp2p extension: {e}" @@ -387,6 +444,8 @@ class PeerAuthenticator: # Verify against expected peer ID if provided if expected_peer_id and derived_peer_id != expected_peer_id: + print(f"Expected Peer id: {expected_peer_id}") + print(f"Derived Peer ID: {derived_peer_id}") raise QUICPeerVerificationError( f"Peer ID mismatch: expected {expected_peer_id}, " f"got {derived_peer_id}" diff --git a/pyproject.toml b/pyproject.toml index ac9689d0..e3a38295 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ maintainers = [ dependencies = [ "aioquic>=1.2.0", "base58>=1.0.3", - "coincurve>=10.0.0", + "coincurve==21.0.0", "exceptiongroup>=1.2.0; python_version < '3.11'", "grpcio>=1.41.0", "lru-dict>=1.1.6",