mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: peer verification successful
This commit is contained in:
371
examples/echo/debug_handshake.py
Normal file
371
examples/echo/debug_handshake.py
Normal file
@ -0,0 +1,371 @@
|
||||
def debug_quic_connection_state(conn, name="Connection"):
|
||||
"""Enhanced debugging function for QUIC connection state."""
|
||||
print(f"\n🔍 === {name} Debug Info ===")
|
||||
|
||||
# Basic connection state
|
||||
print(f"State: {getattr(conn, '_state', 'unknown')}")
|
||||
print(f"Handshake complete: {getattr(conn, '_handshake_complete', False)}")
|
||||
|
||||
# Connection IDs
|
||||
if hasattr(conn, "_host_connection_id"):
|
||||
print(
|
||||
f"Host CID: {conn._host_connection_id.hex() if conn._host_connection_id else 'None'}"
|
||||
)
|
||||
if hasattr(conn, "_peer_connection_id"):
|
||||
print(
|
||||
f"Peer CID: {conn._peer_connection_id.hex() if conn._peer_connection_id else 'None'}"
|
||||
)
|
||||
|
||||
# Check for connection ID sequences
|
||||
if hasattr(conn, "_local_connection_ids"):
|
||||
print(
|
||||
f"Local CID sequence: {[cid.cid.hex() for cid in conn._local_connection_ids]}"
|
||||
)
|
||||
if hasattr(conn, "_remote_connection_ids"):
|
||||
print(
|
||||
f"Remote CID sequence: {[cid.cid.hex() for cid in conn._remote_connection_ids]}"
|
||||
)
|
||||
|
||||
# TLS state
|
||||
if hasattr(conn, "tls") and conn.tls:
|
||||
tls_state = getattr(conn.tls, "state", "unknown")
|
||||
print(f"TLS state: {tls_state}")
|
||||
|
||||
# Check for certificates
|
||||
peer_cert = getattr(conn.tls, "_peer_certificate", None)
|
||||
print(f"Has peer certificate: {peer_cert is not None}")
|
||||
|
||||
# Transport parameters
|
||||
if hasattr(conn, "_remote_transport_parameters"):
|
||||
params = conn._remote_transport_parameters
|
||||
if params:
|
||||
print(f"Remote transport parameters received: {len(params)} params")
|
||||
|
||||
print(f"=== End {name} Debug ===\n")
|
||||
|
||||
|
||||
def debug_firstflight_event(server_conn, name="Server"):
|
||||
"""Debug connection ID changes specifically around FIRSTFLIGHT event."""
|
||||
print(f"\n🎯 === {name} FIRSTFLIGHT Event Debug ===")
|
||||
|
||||
# Connection state
|
||||
state = getattr(server_conn, "_state", "unknown")
|
||||
print(f"Connection State: {state}")
|
||||
|
||||
# Connection IDs
|
||||
peer_cid = getattr(server_conn, "_peer_connection_id", None)
|
||||
host_cid = getattr(server_conn, "_host_connection_id", None)
|
||||
original_dcid = getattr(server_conn, "original_destination_connection_id", None)
|
||||
|
||||
print(f"Peer CID: {peer_cid.hex() if peer_cid else 'None'}")
|
||||
print(f"Host CID: {host_cid.hex() if host_cid else 'None'}")
|
||||
print(f"Original DCID: {original_dcid.hex() if original_dcid else 'None'}")
|
||||
|
||||
print(f"=== End {name} FIRSTFLIGHT Debug ===\n")
|
||||
|
||||
|
||||
def create_minimal_quic_test():
|
||||
"""Simplified test to isolate FIRSTFLIGHT connection ID issues."""
|
||||
print("\n=== MINIMAL QUIC FIRSTFLIGHT CONNECTION ID TEST ===")
|
||||
|
||||
from time import time
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.buffer import Buffer
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
|
||||
# Minimal configs without certificates first
|
||||
client_config = QuicConfiguration(
|
||||
is_client=True, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
|
||||
server_config = QuicConfiguration(
|
||||
is_client=False, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
|
||||
# Create client and connect
|
||||
client_conn = QuicConnection(configuration=client_config)
|
||||
server_addr = ("127.0.0.1", 4321)
|
||||
|
||||
print("🔗 Client calling connect()...")
|
||||
client_conn.connect(server_addr, now=time())
|
||||
|
||||
# Debug client state after connect
|
||||
debug_quic_connection_state(client_conn, "Client After Connect")
|
||||
|
||||
# Get initial client packet
|
||||
initial_packets = client_conn.datagrams_to_send(now=time())
|
||||
if not initial_packets:
|
||||
print("❌ No initial packets from client")
|
||||
return False
|
||||
|
||||
initial_packet = initial_packets[0][0]
|
||||
|
||||
# Parse header to get client's source CID (what server should use as peer CID)
|
||||
header = pull_quic_header(Buffer(data=initial_packet), host_cid_length=8)
|
||||
client_source_cid = header.source_cid
|
||||
client_dest_cid = header.destination_cid
|
||||
|
||||
print(f"📦 Initial packet analysis:")
|
||||
print(
|
||||
f" Client Source CID: {client_source_cid.hex()} (server should use as peer CID)"
|
||||
)
|
||||
print(f" Client Dest CID: {client_dest_cid.hex()}")
|
||||
|
||||
# Create server with proper ODCID
|
||||
print(
|
||||
f"\n🏗️ Creating server with original_destination_connection_id={client_dest_cid.hex()}..."
|
||||
)
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=client_dest_cid,
|
||||
)
|
||||
|
||||
# Debug server state after creation (before FIRSTFLIGHT)
|
||||
debug_firstflight_event(server_conn, "Server After Creation (Pre-FIRSTFLIGHT)")
|
||||
|
||||
# 🎯 CRITICAL: Process initial packet (this triggers FIRSTFLIGHT event)
|
||||
print(f"🚀 Processing initial packet (triggering FIRSTFLIGHT)...")
|
||||
client_addr = ("127.0.0.1", 1234)
|
||||
|
||||
# Before receive_datagram
|
||||
print(f"📊 BEFORE receive_datagram (FIRSTFLIGHT):")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
print(f" Expected peer CID after FIRSTFLIGHT: {client_source_cid.hex()}")
|
||||
|
||||
# This call triggers FIRSTFLIGHT: FIRSTFLIGHT -> CONNECTED
|
||||
server_conn.receive_datagram(initial_packet, client_addr, now=time())
|
||||
|
||||
# After receive_datagram (FIRSTFLIGHT should have happened)
|
||||
print(f"📊 AFTER receive_datagram (Post-FIRSTFLIGHT):")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
|
||||
# Check if FIRSTFLIGHT set peer CID correctly
|
||||
actual_peer_cid = server_conn._peer_cid.cid
|
||||
if actual_peer_cid == client_source_cid:
|
||||
print("✅ FIRSTFLIGHT correctly set peer CID from client source CID")
|
||||
firstflight_success = True
|
||||
else:
|
||||
print("❌ FIRSTFLIGHT BUG: peer CID not set correctly!")
|
||||
print(f" Expected: {client_source_cid.hex()}")
|
||||
print(f" Actual: {actual_peer_cid.hex() if actual_peer_cid else 'None'}")
|
||||
firstflight_success = False
|
||||
|
||||
# Debug both connections after FIRSTFLIGHT
|
||||
debug_firstflight_event(server_conn, "Server After FIRSTFLIGHT")
|
||||
debug_quic_connection_state(client_conn, "Client After Server Processing")
|
||||
|
||||
# Check server response packets
|
||||
print(f"\n📤 Checking server response packets...")
|
||||
server_packets = server_conn.datagrams_to_send(now=time())
|
||||
if server_packets:
|
||||
response_packet = server_packets[0][0]
|
||||
response_header = pull_quic_header(
|
||||
Buffer(data=response_packet), host_cid_length=8
|
||||
)
|
||||
|
||||
print(f"📊 Server response packet:")
|
||||
print(f" Source CID: {response_header.source_cid.hex()}")
|
||||
print(f" Dest CID: {response_header.destination_cid.hex()}")
|
||||
print(f" Expected dest CID: {client_source_cid.hex()}")
|
||||
|
||||
# Final verification
|
||||
if response_header.destination_cid == client_source_cid:
|
||||
print("✅ Server response uses correct destination CID!")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Server response uses WRONG destination CID!")
|
||||
print(f" This proves the FIRSTFLIGHT bug - peer CID not set correctly")
|
||||
print(f" Expected: {client_source_cid.hex()}")
|
||||
print(f" Actual: {response_header.destination_cid.hex()}")
|
||||
return False
|
||||
else:
|
||||
print("❌ Server did not generate response packet")
|
||||
return False
|
||||
|
||||
|
||||
def create_minimal_quic_test_with_config(client_config, server_config):
|
||||
"""Run FIRSTFLIGHT test with provided configurations."""
|
||||
from time import time
|
||||
from aioquic.buffer import Buffer
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
|
||||
print("\n=== FIRSTFLIGHT TEST WITH CERTIFICATES ===")
|
||||
|
||||
# Create client and connect
|
||||
client_conn = QuicConnection(configuration=client_config)
|
||||
server_addr = ("127.0.0.1", 4321)
|
||||
|
||||
print("🔗 Client calling connect() with certificates...")
|
||||
client_conn.connect(server_addr, now=time())
|
||||
|
||||
# Get initial packets and extract client source CID
|
||||
initial_packets = client_conn.datagrams_to_send(now=time())
|
||||
if not initial_packets:
|
||||
print("❌ No initial packets from client")
|
||||
return False
|
||||
|
||||
# Extract client source CID from initial packet
|
||||
initial_packet = initial_packets[0][0]
|
||||
header = pull_quic_header(Buffer(data=initial_packet), host_cid_length=8)
|
||||
client_source_cid = header.source_cid
|
||||
|
||||
print(f"📦 Client source CID (expected server peer CID): {client_source_cid.hex()}")
|
||||
|
||||
# Create server with client's source CID as original destination
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=client_source_cid,
|
||||
)
|
||||
|
||||
# Debug server before FIRSTFLIGHT
|
||||
print(f"\n📊 BEFORE FIRSTFLIGHT (server creation):")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
print(
|
||||
f" Server original DCID: {server_conn.original_destination_connection_id.hex()}"
|
||||
)
|
||||
|
||||
# Process initial packet (triggers FIRSTFLIGHT)
|
||||
client_addr = ("127.0.0.1", 1234)
|
||||
|
||||
print(f"\n🚀 Triggering FIRSTFLIGHT by processing initial packet...")
|
||||
for datagram, _ in initial_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f" Processing packet: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
|
||||
# This triggers FIRSTFLIGHT
|
||||
server_conn.receive_datagram(datagram, client_addr, now=time())
|
||||
|
||||
# Debug immediately after FIRSTFLIGHT
|
||||
print(f"\n📊 AFTER FIRSTFLIGHT:")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
print(f" Expected peer CID: {header.source_cid.hex()}")
|
||||
|
||||
# Check if FIRSTFLIGHT worked correctly
|
||||
actual_peer_cid = getattr(server_conn, "_peer_connection_id", None)
|
||||
if actual_peer_cid == header.source_cid:
|
||||
print("✅ FIRSTFLIGHT correctly set peer CID")
|
||||
else:
|
||||
print("❌ FIRSTFLIGHT failed to set peer CID correctly")
|
||||
print(f" This is the root cause of the handshake failure!")
|
||||
|
||||
# Check server response
|
||||
server_packets = server_conn.datagrams_to_send(now=time())
|
||||
if server_packets:
|
||||
response_packet = server_packets[0][0]
|
||||
response_header = pull_quic_header(
|
||||
Buffer(data=response_packet), host_cid_length=8
|
||||
)
|
||||
|
||||
print(f"\n📤 Server response analysis:")
|
||||
print(f" Response dest CID: {response_header.destination_cid.hex()}")
|
||||
print(f" Expected dest CID: {client_source_cid.hex()}")
|
||||
|
||||
if response_header.destination_cid == client_source_cid:
|
||||
print("✅ Server response uses correct destination CID!")
|
||||
return True
|
||||
else:
|
||||
print("❌ FIRSTFLIGHT bug confirmed - wrong destination CID in response!")
|
||||
print(
|
||||
" This proves aioquic doesn't set peer CID correctly during FIRSTFLIGHT"
|
||||
)
|
||||
return False
|
||||
|
||||
print("❌ No server response packets")
|
||||
return False
|
||||
|
||||
|
||||
async def test_with_certificates():
|
||||
"""Test with proper certificate setup and FIRSTFLIGHT debugging."""
|
||||
print("\n=== CERTIFICATE-BASED FIRSTFLIGHT TEST ===")
|
||||
|
||||
# Import your existing certificate creation functions
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
|
||||
# Create security configs
|
||||
client_key_pair = create_new_key_pair()
|
||||
server_key_pair = create_new_key_pair()
|
||||
|
||||
client_security_config = create_quic_security_transport(
|
||||
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
||||
)
|
||||
server_security_config = create_quic_security_transport(
|
||||
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
||||
)
|
||||
|
||||
# Apply the minimal test logic with certificates
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
|
||||
client_config = QuicConfiguration(
|
||||
is_client=True, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
client_config.certificate = client_security_config.tls_config.certificate
|
||||
client_config.private_key = client_security_config.tls_config.private_key
|
||||
client_config.verify_mode = (
|
||||
client_security_config.create_client_config().verify_mode
|
||||
)
|
||||
|
||||
server_config = QuicConfiguration(
|
||||
is_client=False, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
server_config.certificate = server_security_config.tls_config.certificate
|
||||
server_config.private_key = server_security_config.tls_config.private_key
|
||||
server_config.verify_mode = (
|
||||
server_security_config.create_server_config().verify_mode
|
||||
)
|
||||
|
||||
# Run the FIRSTFLIGHT test with certificates
|
||||
return create_minimal_quic_test_with_config(client_config, server_config)
|
||||
|
||||
|
||||
async def main():
|
||||
print("🎯 Testing FIRSTFLIGHT connection ID behavior...")
|
||||
|
||||
# # First test without certificates
|
||||
# print("\n" + "=" * 60)
|
||||
# print("PHASE 1: Testing FIRSTFLIGHT without certificates")
|
||||
# print("=" * 60)
|
||||
# minimal_success = create_minimal_quic_test()
|
||||
|
||||
# Then test with certificates
|
||||
print("\n" + "=" * 60)
|
||||
print("PHASE 2: Testing FIRSTFLIGHT with certificates")
|
||||
print("=" * 60)
|
||||
cert_success = await test_with_certificates()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("FIRSTFLIGHT TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
# print(f"Minimal test (no certs): {'✅ PASS' if minimal_success else '❌ FAIL'}")
|
||||
print(f"Certificate test: {'✅ PASS' if cert_success else '❌ FAIL'}")
|
||||
|
||||
if not cert_success:
|
||||
print("\n🔥 FIRSTFLIGHT BUG CONFIRMED:")
|
||||
print(" - aioquic fails to set peer CID correctly during FIRSTFLIGHT event")
|
||||
print(" - Server uses wrong destination CID in response packets")
|
||||
print(" - Client drops responses → handshake fails")
|
||||
print(" - Fix: Override _peer_connection_id after receive_datagram()")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import trio
|
||||
|
||||
trio.run(main)
|
||||
205
examples/echo/test_handshake.py
Normal file
205
examples/echo/test_handshake.py
Normal file
@ -0,0 +1,205 @@
|
||||
from aioquic._buffer import Buffer
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from tempfile import NamedTemporaryFile
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from time import time
|
||||
import os
|
||||
import trio
|
||||
|
||||
|
||||
async def test_full_handshake_and_certificate_exchange():
|
||||
"""
|
||||
Test a full handshake to ensure it completes and peer certificates are exchanged.
|
||||
FIXED VERSION: Corrects connection ID management and address handling.
|
||||
"""
|
||||
print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (FIXED) ===")
|
||||
|
||||
# 1. Generate KeyPairs and create libp2p security configs for client and server.
|
||||
client_key_pair = create_new_key_pair()
|
||||
server_key_pair = create_new_key_pair()
|
||||
|
||||
client_security_config = create_quic_security_transport(
|
||||
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
||||
)
|
||||
server_security_config = create_quic_security_transport(
|
||||
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
||||
)
|
||||
print("✅ libp2p security configs created.")
|
||||
|
||||
# 2. Create aioquic configurations with consistent settings
|
||||
client_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-client.log"
|
||||
)
|
||||
client_aioquic_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=client_secrets_log_file,
|
||||
connection_id_length=8, # Set consistent CID length
|
||||
)
|
||||
client_aioquic_config.certificate = client_security_config.tls_config.certificate
|
||||
client_aioquic_config.private_key = client_security_config.tls_config.private_key
|
||||
client_aioquic_config.verify_mode = (
|
||||
client_security_config.create_client_config().verify_mode
|
||||
)
|
||||
|
||||
server_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-server.log"
|
||||
)
|
||||
server_aioquic_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=server_secrets_log_file,
|
||||
connection_id_length=8, # Set consistent CID length
|
||||
)
|
||||
server_aioquic_config.certificate = server_security_config.tls_config.certificate
|
||||
server_aioquic_config.private_key = server_security_config.tls_config.private_key
|
||||
server_aioquic_config.verify_mode = (
|
||||
server_security_config.create_server_config().verify_mode
|
||||
)
|
||||
print("✅ aioquic configurations created and configured.")
|
||||
print(f"🔑 Client secrets will be logged to: {client_secrets_log_file.name}")
|
||||
print(f"🔑 Server secrets will be logged to: {server_secrets_log_file.name}")
|
||||
|
||||
# 3. Use consistent addresses - this is crucial!
|
||||
# The client will connect TO the server address, but packets will come FROM client address
|
||||
client_address = ("127.0.0.1", 1234) # Client binds to this
|
||||
server_address = ("127.0.0.1", 4321) # Server binds to this
|
||||
|
||||
# 4. Create client connection and initiate connection
|
||||
client_conn = QuicConnection(configuration=client_aioquic_config)
|
||||
# Client connects to server address - this sets up the initial packet with proper CIDs
|
||||
client_conn.connect(server_address, now=time())
|
||||
print("✅ Client connection initiated.")
|
||||
|
||||
# 5. Get the initial client packet and extract ODCID properly
|
||||
client_datagrams = client_conn.datagrams_to_send(now=time())
|
||||
if not client_datagrams:
|
||||
raise AssertionError("❌ Client did not generate initial packet")
|
||||
|
||||
client_initial_packet = client_datagrams[0][0]
|
||||
header = pull_quic_header(Buffer(data=client_initial_packet), host_cid_length=8)
|
||||
original_dcid = header.destination_cid
|
||||
client_source_cid = header.source_cid
|
||||
|
||||
print(f"📊 Client ODCID: {original_dcid.hex()}")
|
||||
print(f"📊 Client source CID: {client_source_cid.hex()}")
|
||||
|
||||
# 6. Create server connection with the correct ODCID
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_aioquic_config,
|
||||
original_destination_connection_id=original_dcid,
|
||||
)
|
||||
print("✅ Server connection created with correct ODCID.")
|
||||
|
||||
# 7. Feed the initial client packet to server
|
||||
# IMPORTANT: Use client_address as the source for the packet
|
||||
for datagram, _ in client_datagrams:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
|
||||
# 8. Manual handshake loop with proper packet tracking
|
||||
max_duration_s = 3 # Increased timeout
|
||||
start_time = time()
|
||||
packet_count = 0
|
||||
|
||||
while time() - start_time < max_duration_s:
|
||||
# Process client -> server packets
|
||||
client_packets = list(client_conn.datagrams_to_send(now=time()))
|
||||
for datagram, _ in client_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
packet_count += 1
|
||||
|
||||
# Process server -> client packets
|
||||
server_packets = list(server_conn.datagrams_to_send(now=time()))
|
||||
for datagram, _ in server_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Server -> Client: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
# CRITICAL: Server sends back to client_address, not server_address
|
||||
client_conn.receive_datagram(datagram, server_address, now=time())
|
||||
packet_count += 1
|
||||
|
||||
# Check for completion
|
||||
client_complete = getattr(client_conn, "_handshake_complete", False)
|
||||
server_complete = getattr(server_conn, "_handshake_complete", False)
|
||||
|
||||
print(
|
||||
f"🔄 Handshake status: Client={client_complete}, Server={server_complete}, Packets={packet_count}"
|
||||
)
|
||||
|
||||
if client_complete and server_complete:
|
||||
print("🎉 Handshake completed for both peers!")
|
||||
break
|
||||
|
||||
# If no packets were exchanged in this iteration, wait a bit
|
||||
if not client_packets and not server_packets:
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Safety check - if too many packets, something is wrong
|
||||
if packet_count > 50:
|
||||
print("⚠️ Too many packets exchanged, possible handshake loop")
|
||||
break
|
||||
|
||||
# 9. Enhanced handshake completion checks
|
||||
client_handshake_complete = getattr(client_conn, "_handshake_complete", False)
|
||||
server_handshake_complete = getattr(server_conn, "_handshake_complete", False)
|
||||
|
||||
# Debug additional state information
|
||||
print(f"🔍 Final client state: {getattr(client_conn, '_state', 'unknown')}")
|
||||
print(f"🔍 Final server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
|
||||
if hasattr(client_conn, "tls") and client_conn.tls:
|
||||
print(f"🔍 Client TLS state: {getattr(client_conn.tls, 'state', 'unknown')}")
|
||||
if hasattr(server_conn, "tls") and server_conn.tls:
|
||||
print(f"🔍 Server TLS state: {getattr(server_conn.tls, 'state', 'unknown')}")
|
||||
|
||||
# 10. Cleanup and assertions
|
||||
client_secrets_log_file.close()
|
||||
server_secrets_log_file.close()
|
||||
os.unlink(client_secrets_log_file.name)
|
||||
os.unlink(server_secrets_log_file.name)
|
||||
|
||||
# Final assertions
|
||||
assert client_handshake_complete, (
|
||||
f"❌ Client handshake did not complete. "
|
||||
f"State: {getattr(client_conn, '_state', 'unknown')}, "
|
||||
f"Packets: {packet_count}"
|
||||
)
|
||||
assert server_handshake_complete, (
|
||||
f"❌ Server handshake did not complete. "
|
||||
f"State: {getattr(server_conn, '_state', 'unknown')}, "
|
||||
f"Packets: {packet_count}"
|
||||
)
|
||||
print("✅ Handshake completed for both peers.")
|
||||
|
||||
# Certificate exchange verification
|
||||
client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None)
|
||||
server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None)
|
||||
|
||||
assert client_peer_cert is not None, (
|
||||
"❌ Client FAILED to receive server certificate."
|
||||
)
|
||||
print("✅ Client successfully received server certificate.")
|
||||
|
||||
assert server_peer_cert is not None, (
|
||||
"❌ Server FAILED to receive client certificate."
|
||||
)
|
||||
print("✅ Server successfully received client certificate.")
|
||||
|
||||
print("🎉 Test Passed: Full handshake and certificate exchange successful.")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(test_full_handshake_and_certificate_exchange)
|
||||
@ -1,20 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
"""
|
||||
Fixed QUIC handshake test to debug connection issues.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
from time import time
|
||||
|
||||
from aioquic._buffer import Buffer
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.logger import QuicFileLogger
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.transport.quic.security import LIBP2P_TLS_EXTENSION_OID
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import (
|
||||
LIBP2P_TLS_EXTENSION_OID,
|
||||
create_quic_security_transport,
|
||||
)
|
||||
from libp2p.transport.quic.transport import QUICTransport, QUICTransportConfig
|
||||
from libp2p.transport.quic.utils import create_quic_multiaddr
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG
|
||||
)
|
||||
|
||||
|
||||
# Adjust this path to your project structure
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
@ -256,10 +275,162 @@ async def test_server_startup():
|
||||
return False
|
||||
|
||||
|
||||
async def test_full_handshake_and_certificate_exchange():
|
||||
"""
|
||||
Test a full handshake to ensure it completes and peer certificates are exchanged.
|
||||
This version is corrected to use the actual APIs available in the codebase.
|
||||
"""
|
||||
print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (CORRECTED) ===")
|
||||
|
||||
# 1. Generate KeyPairs and create libp2p security configs for client and server.
|
||||
# The `create_quic_security_transport` function from `test_quic.py` is the
|
||||
# correct helper to use, and it requires a `KeyPair` argument.
|
||||
client_key_pair = create_new_key_pair()
|
||||
server_key_pair = create_new_key_pair()
|
||||
|
||||
# This is the correct way to get the security configuration objects.
|
||||
client_security_config = create_quic_security_transport(
|
||||
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
||||
)
|
||||
server_security_config = create_quic_security_transport(
|
||||
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
||||
)
|
||||
print("✅ libp2p security configs created.")
|
||||
|
||||
# 2. Create aioquic configurations and manually apply security settings,
|
||||
# mimicking what the `QUICTransport` class does internally.
|
||||
client_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-client.log"
|
||||
)
|
||||
client_aioquic_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=client_secrets_log_file,
|
||||
)
|
||||
client_aioquic_config.certificate = client_security_config.tls_config.certificate
|
||||
client_aioquic_config.private_key = client_security_config.tls_config.private_key
|
||||
client_aioquic_config.verify_mode = (
|
||||
client_security_config.create_client_config().verify_mode
|
||||
)
|
||||
client_aioquic_config.quic_logger = QuicFileLogger(
|
||||
"/home/akmo/GitHub/py-libp2p/examples/echo/logs"
|
||||
)
|
||||
|
||||
server_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-server.log"
|
||||
)
|
||||
|
||||
server_aioquic_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=server_secrets_log_file,
|
||||
)
|
||||
server_aioquic_config.certificate = server_security_config.tls_config.certificate
|
||||
server_aioquic_config.private_key = server_security_config.tls_config.private_key
|
||||
server_aioquic_config.verify_mode = (
|
||||
server_security_config.create_server_config().verify_mode
|
||||
)
|
||||
server_aioquic_config.quic_logger = QuicFileLogger(
|
||||
"/home/akmo/GitHub/py-libp2p/examples/echo/logs"
|
||||
)
|
||||
print("✅ aioquic configurations created and configured.")
|
||||
print(f"🔑 Client secrets will be logged to: {client_secrets_log_file.name}")
|
||||
print(f"🔑 Server secrets will be logged to: {server_secrets_log_file.name}")
|
||||
|
||||
# 3. Instantiate client, initiate its `connect` call, and get the ODCID for the server.
|
||||
client_address = ("127.0.0.1", 1234)
|
||||
server_address = ("127.0.0.1", 4321)
|
||||
|
||||
client_aioquic_config.connection_id_length = 8
|
||||
client_conn = QuicConnection(configuration=client_aioquic_config)
|
||||
client_conn.connect(server_address, now=time())
|
||||
print("✅ aioquic connections instantiated correctly.")
|
||||
|
||||
print("🔧 Client CIDs")
|
||||
print(f"Local Init CID: ", client_conn._local_initial_source_connection_id.hex())
|
||||
print(
|
||||
f"Remote Init CID: ",
|
||||
(client_conn._remote_initial_source_connection_id or b"").hex(),
|
||||
)
|
||||
print(
|
||||
f"Original Destination CID: ",
|
||||
client_conn.original_destination_connection_id.hex(),
|
||||
)
|
||||
print(f"Host CID: {client_conn._host_cids[0].cid.hex()}")
|
||||
|
||||
# 4. Instantiate the server with the ODCID from the client.
|
||||
server_aioquic_config.connection_id_length = 8
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_aioquic_config,
|
||||
original_destination_connection_id=client_conn.original_destination_connection_id,
|
||||
)
|
||||
print("✅ aioquic connections instantiated correctly.")
|
||||
|
||||
# 5. Manually drive the handshake process by exchanging datagrams.
|
||||
max_duration_s = 5
|
||||
start_time = time()
|
||||
|
||||
while time() - start_time < max_duration_s:
|
||||
for datagram, _ in client_conn.datagrams_to_send(now=time()):
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print("Client packet source connection id", header.source_cid.hex())
|
||||
print("Client packet destination connection id", header.destination_cid.hex())
|
||||
print("--SERVER INJESTING CLIENT PACKET---")
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
|
||||
print(
|
||||
f"Server remote initial source id: {(server_conn._remote_initial_source_connection_id or b'').hex()}"
|
||||
)
|
||||
for datagram, _ in server_conn.datagrams_to_send(now=time()):
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print("Server packet source connection id", header.source_cid.hex())
|
||||
print("Server packet destination connection id", header.destination_cid.hex())
|
||||
print("--CLIENT INJESTING SERVER PACKET---")
|
||||
client_conn.receive_datagram(datagram, server_address, now=time())
|
||||
|
||||
# Check for completion
|
||||
if client_conn._handshake_complete and server_conn._handshake_complete:
|
||||
break
|
||||
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# 6. Assertions to verify the outcome.
|
||||
assert client_conn._handshake_complete, "❌ Client handshake did not complete."
|
||||
assert server_conn._handshake_complete, "❌ Server handshake did not complete."
|
||||
print("✅ Handshake completed for both peers.")
|
||||
|
||||
# The key assertion: check if the peer certificate was received.
|
||||
client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None)
|
||||
server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None)
|
||||
|
||||
client_secrets_log_file.close()
|
||||
server_secrets_log_file.close()
|
||||
os.unlink(client_secrets_log_file.name)
|
||||
os.unlink(server_secrets_log_file.name)
|
||||
|
||||
assert client_peer_cert is not None, (
|
||||
"❌ Client FAILED to receive server certificate."
|
||||
)
|
||||
print("✅ Client successfully received server certificate.")
|
||||
|
||||
assert server_peer_cert is not None, (
|
||||
"❌ Server FAILED to receive client certificate."
|
||||
)
|
||||
print("✅ Server successfully received client certificate.")
|
||||
|
||||
print("🎉 Test Passed: Full handshake and certificate exchange successful.")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests with better error handling."""
|
||||
print("Starting QUIC diagnostic tests...")
|
||||
|
||||
handshake_ok = await test_full_handshake_and_certificate_exchange()
|
||||
if not handshake_ok:
|
||||
print("\n❌ CRITICAL: Handshake failed!")
|
||||
print("Apply the handshake fix and try again.")
|
||||
return
|
||||
|
||||
# Test 1: Certificate generation
|
||||
cert_ok = await test_certificate_generation()
|
||||
if not cert_ok:
|
||||
|
||||
@ -276,9 +276,6 @@ class QUICListener(IListener):
|
||||
# Parse packet to extract connection information
|
||||
packet_info = self.parse_quic_packet(data)
|
||||
|
||||
print(
|
||||
f"🔧 DEBUG: Address mappings: {dict((k, v.hex()) for k, v in self._addr_to_cid.items())}"
|
||||
)
|
||||
print(
|
||||
f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}"
|
||||
)
|
||||
@ -333,33 +330,6 @@ class QUICListener(IListener):
|
||||
)
|
||||
return
|
||||
|
||||
# If no exact match, try address-based routing (connection ID might not match)
|
||||
mapped_cid = self._addr_to_cid.get(addr)
|
||||
if mapped_cid:
|
||||
print(
|
||||
f"🔧 PACKET: Found address mapping {addr} -> {mapped_cid.hex()}"
|
||||
)
|
||||
print(
|
||||
f"🔧 PACKET: Client dest_cid {dest_cid.hex()} != our cid {mapped_cid.hex()}"
|
||||
)
|
||||
|
||||
if mapped_cid in self._connections:
|
||||
print(
|
||||
"✅ PACKET: Using established connection via address mapping"
|
||||
)
|
||||
connection = self._connections[mapped_cid]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
return
|
||||
elif mapped_cid in self._pending_connections:
|
||||
print(
|
||||
"✅ PACKET: Using pending connection via address mapping"
|
||||
)
|
||||
quic_conn = self._pending_connections[mapped_cid]
|
||||
await self._handle_pending_connection(
|
||||
quic_conn, data, addr, mapped_cid
|
||||
)
|
||||
return
|
||||
|
||||
# No existing connection found, create new one
|
||||
print(f"🔧 PACKET: Creating new connection for {addr}")
|
||||
await self._handle_new_connection(data, addr, packet_info)
|
||||
@ -491,10 +461,9 @@ class QUICListener(IListener):
|
||||
)
|
||||
|
||||
# Create QUIC connection with proper parameters for server
|
||||
# CRITICAL FIX: Pass the original destination connection ID from the initial packet
|
||||
quic_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=packet_info.destination_cid, # Use the original DCID from packet
|
||||
original_destination_connection_id=packet_info.destination_cid,
|
||||
)
|
||||
|
||||
quic_conn._replenish_connection_ids()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
|
||||
"""
|
||||
QUIC Security implementation for py-libp2p Module 5.
|
||||
Implements libp2p TLS specification for QUIC transport with peer identity integration.
|
||||
@ -15,6 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import ec, rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
from cryptography.x509.base import Certificate
|
||||
from cryptography.x509.extensions import Extension, UnrecognizedExtension
|
||||
from cryptography.x509.oid import NameOID
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
@ -128,57 +130,106 @@ class LibP2PExtensionHandler:
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def parse_signed_key_extension(extension_data: bytes) -> tuple[PublicKey, bytes]:
|
||||
def parse_signed_key_extension(extension: Extension) -> tuple[PublicKey, bytes]:
|
||||
"""
|
||||
Parse the libp2p Public Key Extension to extract public key and signature.
|
||||
|
||||
Args:
|
||||
extension_data: The extension data bytes
|
||||
|
||||
Returns:
|
||||
Tuple of (libp2p_public_key, signature)
|
||||
|
||||
Raises:
|
||||
QUICCertificateError: If extension parsing fails
|
||||
|
||||
Parse the libp2p Public Key Extension with enhanced debugging.
|
||||
"""
|
||||
try:
|
||||
print(f"🔍 Extension type: {type(extension)}")
|
||||
print(f"🔍 Extension.value type: {type(extension.value)}")
|
||||
|
||||
# Extract the raw bytes from the extension
|
||||
if isinstance(extension.value, UnrecognizedExtension):
|
||||
# Use the .value property to get the bytes
|
||||
raw_bytes = extension.value.value
|
||||
print("🔍 Extension is UnrecognizedExtension, using .value property")
|
||||
else:
|
||||
# Fallback if it's already bytes somehow
|
||||
raw_bytes = extension.value
|
||||
print("🔍 Extension.value is already bytes")
|
||||
|
||||
print(f"🔍 Total extension length: {len(raw_bytes)} bytes")
|
||||
print(f"🔍 Extension hex (first 50 bytes): {raw_bytes[:50].hex()}")
|
||||
|
||||
if not isinstance(raw_bytes, bytes):
|
||||
raise QUICCertificateError(f"Expected bytes, got {type(raw_bytes)}")
|
||||
|
||||
offset = 0
|
||||
|
||||
# Parse public key length and data
|
||||
if len(extension_data) < 4:
|
||||
if len(raw_bytes) < 4:
|
||||
raise QUICCertificateError("Extension too short for public key length")
|
||||
|
||||
public_key_length = int.from_bytes(
|
||||
extension_data[offset : offset + 4], byteorder="big"
|
||||
raw_bytes[offset : offset + 4], byteorder="big"
|
||||
)
|
||||
print(f"🔍 Public key length: {public_key_length} bytes")
|
||||
offset += 4
|
||||
|
||||
if len(extension_data) < offset + public_key_length:
|
||||
if len(raw_bytes) < offset + public_key_length:
|
||||
raise QUICCertificateError("Extension too short for public key data")
|
||||
|
||||
public_key_bytes = extension_data[offset : offset + public_key_length]
|
||||
public_key_bytes = raw_bytes[offset : offset + public_key_length]
|
||||
print(f"🔍 Public key data: {public_key_bytes.hex()}")
|
||||
offset += public_key_length
|
||||
print(f"🔍 Offset after public key: {offset}")
|
||||
|
||||
# Parse signature length and data
|
||||
if len(extension_data) < offset + 4:
|
||||
if len(raw_bytes) < offset + 4:
|
||||
raise QUICCertificateError("Extension too short for signature length")
|
||||
|
||||
signature_length = int.from_bytes(
|
||||
extension_data[offset : offset + 4], byteorder="big"
|
||||
raw_bytes[offset : offset + 4], byteorder="big"
|
||||
)
|
||||
print(f"🔍 Signature length: {signature_length} bytes")
|
||||
offset += 4
|
||||
print(f"🔍 Offset after signature length: {offset}")
|
||||
|
||||
if len(extension_data) < offset + signature_length:
|
||||
if len(raw_bytes) < offset + signature_length:
|
||||
raise QUICCertificateError("Extension too short for signature data")
|
||||
|
||||
signature = extension_data[offset : offset + signature_length]
|
||||
signature = raw_bytes[offset : offset + signature_length]
|
||||
print(f"🔍 Extracted signature length: {len(signature)} bytes")
|
||||
print(f"🔍 Signature hex (first 20 bytes): {signature[:20].hex()}")
|
||||
print(f"🔍 Signature starts with DER header: {signature[:2].hex() == '3045'}")
|
||||
|
||||
# Detailed signature analysis
|
||||
if len(signature) >= 2:
|
||||
if signature[0] == 0x30:
|
||||
der_length = signature[1]
|
||||
print(f"🔍 DER sequence length field: {der_length}")
|
||||
print(f"🔍 Expected DER total: {der_length + 2}")
|
||||
print(f"🔍 Actual signature length: {len(signature)}")
|
||||
|
||||
if len(signature) != der_length + 2:
|
||||
print(f"⚠️ DER length mismatch! Expected {der_length + 2}, got {len(signature)}")
|
||||
# Try truncating to correct DER length
|
||||
if der_length + 2 < len(signature):
|
||||
print(f"🔧 Truncating signature to correct DER length: {der_length + 2}")
|
||||
signature = signature[:der_length + 2]
|
||||
|
||||
# Check if we have extra data
|
||||
expected_total = 4 + public_key_length + 4 + signature_length
|
||||
print(f"🔍 Expected total length: {expected_total}")
|
||||
print(f"🔍 Actual total length: {len(raw_bytes)}")
|
||||
|
||||
if len(raw_bytes) > expected_total:
|
||||
extra_bytes = len(raw_bytes) - expected_total
|
||||
print(f"⚠️ Extra {extra_bytes} bytes detected!")
|
||||
print(f"🔍 Extra data: {raw_bytes[expected_total:].hex()}")
|
||||
|
||||
# Deserialize the public key
|
||||
public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes)
|
||||
print(f"🔍 Successfully deserialized public key: {type(public_key)}")
|
||||
|
||||
print(f"🔍 Final signature to return: {len(signature)} bytes")
|
||||
|
||||
return public_key, signature
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Extension parsing failed: {e}")
|
||||
import traceback
|
||||
print(f"❌ Traceback: {traceback.format_exc()}")
|
||||
raise QUICCertificateError(
|
||||
f"Failed to parse signed key extension: {e}"
|
||||
) from e
|
||||
@ -361,9 +412,15 @@ class PeerAuthenticator:
|
||||
if not libp2p_extension:
|
||||
raise QUICPeerVerificationError("Certificate missing libp2p extension")
|
||||
|
||||
assert libp2p_extension.value is not None
|
||||
print(f"Extension type: {type(libp2p_extension)}")
|
||||
print(f"Extension value type: {type(libp2p_extension.value)}")
|
||||
if hasattr(libp2p_extension.value, "__len__"):
|
||||
print(f"Extension value length: {len(libp2p_extension.value)}")
|
||||
print(f"Extension value: {libp2p_extension.value}")
|
||||
# Parse the extension to get public key and signature
|
||||
public_key, signature = self.extension_handler.parse_signed_key_extension(
|
||||
libp2p_extension.value
|
||||
libp2p_extension
|
||||
)
|
||||
|
||||
# Get certificate public key for signature verification
|
||||
@ -376,7 +433,7 @@ class PeerAuthenticator:
|
||||
signature_payload = b"libp2p-tls-handshake:" + cert_public_key_bytes
|
||||
|
||||
try:
|
||||
public_key.verify(signature, signature_payload)
|
||||
public_key.verify(signature_payload, signature)
|
||||
except Exception as e:
|
||||
raise QUICPeerVerificationError(
|
||||
f"Invalid signature in libp2p extension: {e}"
|
||||
@ -387,6 +444,8 @@ class PeerAuthenticator:
|
||||
|
||||
# Verify against expected peer ID if provided
|
||||
if expected_peer_id and derived_peer_id != expected_peer_id:
|
||||
print(f"Expected Peer id: {expected_peer_id}")
|
||||
print(f"Derived Peer ID: {derived_peer_id}")
|
||||
raise QUICPeerVerificationError(
|
||||
f"Peer ID mismatch: expected {expected_peer_id}, "
|
||||
f"got {derived_peer_id}"
|
||||
|
||||
@ -18,7 +18,7 @@ maintainers = [
|
||||
dependencies = [
|
||||
"aioquic>=1.2.0",
|
||||
"base58>=1.0.3",
|
||||
"coincurve>=10.0.0",
|
||||
"coincurve==21.0.0",
|
||||
"exceptiongroup>=1.2.0; python_version < '3.11'",
|
||||
"grpcio>=1.41.0",
|
||||
"lru-dict>=1.1.6",
|
||||
|
||||
Reference in New Issue
Block a user