fix: duplication connection creation for same sessions

This commit is contained in:
Akash Mondal
2025-06-17 13:54:32 +00:00
committed by lla-dane
parent 369f79306f
commit 123c86c091
4 changed files with 982 additions and 183 deletions

289
examples/echo/test_quic.py Normal file
View File

@ -0,0 +1,289 @@
#!/usr/bin/env python3
"""
Fixed QUIC handshake test to debug connection issues.
"""
import logging
from pathlib import Path
import secrets
import sys
import trio
from libp2p.crypto.ed25519 import create_new_key_pair
from libp2p.transport.quic.transport import QUICTransport, QUICTransportConfig
from libp2p.transport.quic.utils import create_quic_multiaddr
# Adjust this path to your project structure
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
# Setup logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
async def test_certificate_generation():
"""Test certificate generation in isolation."""
print("\n=== TESTING CERTIFICATE GENERATION ===")
try:
from libp2p.peer.id import ID
from libp2p.transport.quic.security import create_quic_security_transport
# Create key pair
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
print(f"Generated peer ID: {peer_id}")
# Create security manager
security_manager = create_quic_security_transport(private_key, peer_id)
print("✅ Security manager created")
# Test server config
server_config = security_manager.create_server_config()
print("✅ Server config created")
# Validate certificate
cert = server_config.certificate
private_key_obj = server_config.private_key
print(f"Certificate type: {type(cert)}")
print(f"Private key type: {type(private_key_obj)}")
print(f"Certificate subject: {cert.subject}")
print(f"Certificate issuer: {cert.issuer}")
# Check for libp2p extension
has_libp2p_ext = False
for ext in cert.extensions:
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
has_libp2p_ext = True
print(f"✅ Found libp2p extension: {ext.oid}")
print(f"Extension critical: {ext.critical}")
print(f"Extension value length: {len(ext.value)} bytes")
break
if not has_libp2p_ext:
print("❌ No libp2p extension found!")
print("Available extensions:")
for ext in cert.extensions:
print(f" - {ext.oid} (critical: {ext.critical})")
# Check certificate/key match
from cryptography.hazmat.primitives import serialization
cert_public_key = cert.public_key()
private_public_key = private_key_obj.public_key()
cert_pub_bytes = cert_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
private_pub_bytes = private_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
if cert_pub_bytes == private_pub_bytes:
print("✅ Certificate and private key match")
return has_libp2p_ext
else:
print("❌ Certificate and private key DO NOT match")
return False
except Exception as e:
print(f"❌ Certificate test failed: {e}")
import traceback
traceback.print_exc()
return False
async def test_basic_quic_connection():
"""Test basic QUIC connection with proper server setup."""
print("\n=== TESTING BASIC QUIC CONNECTION ===")
try:
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from libp2p.peer.id import ID
from libp2p.transport.quic.security import create_quic_security_transport
# Create certificates
server_key = create_new_key_pair().private_key
server_peer_id = ID.from_pubkey(server_key.get_public_key())
server_security = create_quic_security_transport(server_key, server_peer_id)
client_key = create_new_key_pair().private_key
client_peer_id = ID.from_pubkey(client_key.get_public_key())
client_security = create_quic_security_transport(client_key, client_peer_id)
# Create server config
server_tls_config = server_security.create_server_config()
server_config = QuicConfiguration(
is_client=False,
certificate=server_tls_config.certificate,
private_key=server_tls_config.private_key,
alpn_protocols=["libp2p"],
)
# Create client config
client_tls_config = client_security.create_client_config()
client_config = QuicConfiguration(
is_client=True,
certificate=client_tls_config.certificate,
private_key=client_tls_config.private_key,
alpn_protocols=["libp2p"],
)
print("✅ QUIC configurations created")
# Test creating connections with proper parameters
# For server, we need to provide original_destination_connection_id
original_dcid = secrets.token_bytes(8)
server_conn = QuicConnection(
configuration=server_config,
original_destination_connection_id=original_dcid,
)
# For client, no original_destination_connection_id needed
client_conn = QuicConnection(configuration=client_config)
print("✅ QUIC connections created")
print(f"Server state: {server_conn._state}")
print(f"Client state: {client_conn._state}")
# Test that certificates are valid
print(f"Server has certificate: {server_config.certificate is not None}")
print(f"Server has private key: {server_config.private_key is not None}")
print(f"Client has certificate: {client_config.certificate is not None}")
print(f"Client has private key: {client_config.private_key is not None}")
return True
except Exception as e:
print(f"❌ Basic QUIC test failed: {e}")
import traceback
traceback.print_exc()
return False
async def test_server_startup():
"""Test server startup with timeout."""
print("\n=== TESTING SERVER STARTUP ===")
try:
# Create transport
private_key = create_new_key_pair().private_key
config = QUICTransportConfig(
idle_timeout=10.0, # Reduced timeout for testing
connection_timeout=10.0,
enable_draft29=False,
)
transport = QUICTransport(private_key, config)
print("✅ Transport created successfully")
# Test configuration
print(f"Available configs: {list(transport._quic_configs.keys())}")
config_valid = True
for config_key, quic_config in transport._quic_configs.items():
print(f"\n--- Testing config: {config_key} ---")
print(f"is_client: {quic_config.is_client}")
print(f"has_certificate: {quic_config.certificate is not None}")
print(f"has_private_key: {quic_config.private_key is not None}")
print(f"alpn_protocols: {quic_config.alpn_protocols}")
print(f"verify_mode: {quic_config.verify_mode}")
if quic_config.certificate:
cert = quic_config.certificate
print(f"Certificate subject: {cert.subject}")
# Check for libp2p extension
has_libp2p_ext = False
for ext in cert.extensions:
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
has_libp2p_ext = True
break
print(f"Has libp2p extension: {has_libp2p_ext}")
if not has_libp2p_ext:
config_valid = False
if not config_valid:
print("❌ Transport configuration invalid - missing libp2p extensions")
return False
# Create listener
async def dummy_handler(connection):
print(f"New connection: {connection}")
listener = transport.create_listener(dummy_handler)
print("✅ Listener created successfully")
# Try to bind with timeout
maddr = create_quic_multiaddr("127.0.0.1", 0, "quic-v1")
async with trio.open_nursery() as nursery:
result = await listener.listen(maddr, nursery)
if result:
print("✅ Server bound successfully")
addresses = listener.get_addresses()
print(f"Listening on: {addresses}")
# Keep running for a short time
with trio.move_on_after(3.0): # 3 second timeout
await trio.sleep(5.0)
print("✅ Server test completed (timed out normally)")
return True
else:
print("❌ Failed to bind server")
return False
except Exception as e:
print(f"❌ Server test failed: {e}")
import traceback
traceback.print_exc()
return False
async def main():
"""Run all tests with better error handling."""
print("Starting QUIC diagnostic tests...")
# Test 1: Certificate generation
cert_ok = await test_certificate_generation()
if not cert_ok:
print("\n❌ CRITICAL: Certificate generation failed!")
print("Apply the certificate generation fix and try again.")
return
# Test 2: Basic QUIC connection
quic_ok = await test_basic_quic_connection()
if not quic_ok:
print("\n❌ CRITICAL: Basic QUIC connection test failed!")
return
# Test 3: Server startup
server_ok = await test_server_startup()
if not server_ok:
print("\n❌ Server startup test failed!")
return
print("\n✅ ALL TESTS PASSED!")
print("=== DIAGNOSTIC COMPLETE ===")
print("Your QUIC implementation should now work correctly.")
print("Try running your echo example again.")
if __name__ == "__main__":
trio.run(main)

View File

@ -249,23 +249,35 @@ class QUICListener(IListener):
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
"""
Enhanced packet processing with connection ID routing and version negotiation.
FIXED: Added address-based connection reuse to prevent multiple connections.
Enhanced packet processing with better connection ID routing and debugging.
"""
try:
self._stats["packets_processed"] += 1
self._stats["bytes_received"] += len(data)
print(f"🔧 PACKET: Processing {len(data)} bytes from {addr}")
# Parse packet to extract connection information
packet_info = self.parse_quic_packet(data)
print(f"🔧 DEBUG: Address mappings: {self._addr_to_cid}")
print(
f"🔧 DEBUG: Pending connections: {list(self._pending_connections.keys())}"
f"🔧 DEBUG: Address mappings: {dict((k, v.hex()) for k, v in self._addr_to_cid.items())}"
)
print(
f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}"
)
print(
f"🔧 DEBUG: Established connections: {[cid.hex() for cid in self._connections.keys()]}"
)
async with self._connection_lock:
if packet_info:
print(
f"🔧 PACKET: Parsed packet - version: 0x{packet_info.version:08x}, "
f"dest_cid: {packet_info.destination_cid.hex()}, "
f"src_cid: {packet_info.source_cid.hex()}"
)
# Check for version negotiation
if packet_info.version == 0:
logger.warning(
@ -275,6 +287,9 @@ class QUICListener(IListener):
# Check if version is supported
if packet_info.version not in self._supported_versions:
print(
f"❌ PACKET: Unsupported version 0x{packet_info.version:08x}"
)
await self._send_version_negotiation(
addr, packet_info.source_cid
)
@ -283,87 +298,66 @@ class QUICListener(IListener):
# Route based on destination connection ID
dest_cid = packet_info.destination_cid
# First, try exact connection ID match
if dest_cid in self._connections:
# Existing established connection
print(f"🔧 ROUTING: To established connection {dest_cid.hex()}")
print(
f"✅ PACKET: Routing to established connection {dest_cid.hex()}"
)
connection = self._connections[dest_cid]
await self._route_to_connection(connection, data, addr)
return
elif dest_cid in self._pending_connections:
# Existing pending connection
print(f"🔧 ROUTING: To pending connection {dest_cid.hex()}")
print(
f"✅ PACKET: Routing to pending connection {dest_cid.hex()}"
)
quic_conn = self._pending_connections[dest_cid]
await self._handle_pending_connection(
quic_conn, data, addr, dest_cid
)
return
else:
# CRITICAL FIX: Check for existing connection by address BEFORE creating new
existing_cid = self._addr_to_cid.get(addr)
# If no exact match, try address-based routing (connection ID might not match)
mapped_cid = self._addr_to_cid.get(addr)
if mapped_cid:
print(
f"🔧 PACKET: Found address mapping {addr} -> {mapped_cid.hex()}"
)
print(
f"🔧 PACKET: Client dest_cid {dest_cid.hex()} != our cid {mapped_cid.hex()}"
)
if existing_cid is not None:
if mapped_cid in self._connections:
print(
f"FOUND: Existing connection {existing_cid.hex()} for address {addr}"
"PACKET: Using established connection via address mapping"
)
connection = self._connections[mapped_cid]
await self._route_to_connection(connection, data, addr)
return
elif mapped_cid in self._pending_connections:
print(
f"🔧 NOTE: Client dest_cid {dest_cid.hex()} != our cid {existing_cid.hex()}"
"✅ PACKET: Using pending connection via address mapping"
)
quic_conn = self._pending_connections[mapped_cid]
await self._handle_pending_connection(
quic_conn, data, addr, mapped_cid
)
return
# Route to existing connection by address
if existing_cid in self._pending_connections:
print(
"🔧 ROUTING: Using existing pending connection by address"
)
quic_conn = self._pending_connections[existing_cid]
await self._handle_pending_connection(
quic_conn, data, addr, existing_cid
)
elif existing_cid in self._connections:
print(
"🔧 ROUTING: Using existing established connection by address"
)
connection = self._connections[existing_cid]
await self._route_to_connection(connection, data, addr)
else:
print(
f"❌ ERROR: Address mapping exists but connection {existing_cid.hex()} not found!"
)
# Clean up broken mapping and create new
self._addr_to_cid.pop(addr, None)
if packet_info.packet_type == 0: # Initial packet
print(
"🔧 NEW: Creating new connection after cleanup"
)
await self._handle_new_connection(
data, addr, packet_info
)
# No existing connection found, create new one
print(f"🔧 PACKET: Creating new connection for {addr}")
await self._handle_new_connection(data, addr, packet_info)
else:
# Truly new connection - only handle Initial packets
if packet_info.packet_type == 0: # Initial packet
print(f"🔧 NEW: Creating first connection for {addr}")
await self._handle_new_connection(
data, addr, packet_info
)
# Debug the newly created connection
new_cid = self._addr_to_cid.get(addr)
if new_cid and new_cid in self._pending_connections:
quic_conn = self._pending_connections[new_cid]
await self._debug_quic_connection_state(
quic_conn, new_cid
)
else:
logger.debug(
f"Ignoring non-Initial packet for unknown connection ID from {addr}"
)
else:
# Fallback to address-based routing for short header packets
# Failed to parse packet
print(f"❌ PACKET: Failed to parse packet from {addr}")
await self._handle_short_header_packet(data, addr)
except Exception as e:
logger.error(f"Error processing packet from {addr}: {e}")
self._stats["invalid_packets"] += 1
import traceback
traceback.print_exc()
async def _send_version_negotiation(
self, addr: tuple[str, int], source_cid: bytes
@ -404,29 +398,31 @@ class QUICListener(IListener):
logger.error(f"Failed to send version negotiation to {addr}: {e}")
async def _handle_new_connection(
self,
data: bytes,
addr: tuple[str, int],
packet_info: QUICPacketInfo,
self, data: bytes, addr: tuple[str, int], packet_info: QUICPacketInfo
) -> None:
"""
Handle new connection with proper version negotiation.
"""
"""Handle new connection with proper connection ID handling."""
try:
print(f"🔧 NEW_CONN: Starting handshake for {addr}")
# Find appropriate QUIC configuration
quic_config = None
config_key = None
for protocol, config in self._quic_configs.items():
wire_versions = custom_quic_version_to_wire_format(protocol)
if wire_versions == packet_info.version:
quic_config = config
config_key = protocol
break
if not quic_config:
logger.warning(
f"No configuration found for version {packet_info.version:08x}"
)
print(f"❌ NEW_CONN: No configuration found for version 0x{packet_info.version:08x}")
print(f"🔧 NEW_CONN: Available configs: {list(self._quic_configs.keys())}")
await self._send_version_negotiation(addr, packet_info.source_cid)
return
print(f"✅ NEW_CONN: Using config {config_key} for version 0x{packet_info.version:08x}")
# Create server-side QUIC configuration
server_config = create_server_config_from_base(
base_config=quic_config,
@ -434,39 +430,158 @@ class QUICListener(IListener):
transport_config=self._config,
)
# Generate a new destination connection ID for this connection
# In a real implementation, this should be cryptographically secure
import secrets
# Debug the server configuration
print(f"🔧 NEW_CONN: Server config - is_client: {server_config.is_client}")
print(f"🔧 NEW_CONN: Server config - has_certificate: {server_config.certificate is not None}")
print(f"🔧 NEW_CONN: Server config - has_private_key: {server_config.private_key is not None}")
print(f"🔧 NEW_CONN: Server config - ALPN: {server_config.alpn_protocols}")
print(f"🔧 NEW_CONN: Server config - verify_mode: {server_config.verify_mode}")
# Validate certificate has libp2p extension
if server_config.certificate:
cert = server_config.certificate
has_libp2p_ext = False
for ext in cert.extensions:
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
has_libp2p_ext = True
break
print(f"🔧 NEW_CONN: Certificate has libp2p extension: {has_libp2p_ext}")
if not has_libp2p_ext:
print("❌ NEW_CONN: Certificate missing libp2p extension!")
# Generate a new destination connection ID for this connection
import secrets
destination_cid = secrets.token_bytes(8)
# Create QUIC connection with specific version
print(f"🔧 NEW_CONN: Generated new CID: {destination_cid.hex()}")
print(f"🔧 NEW_CONN: Original destination CID: {packet_info.destination_cid.hex()}")
# Create QUIC connection with proper parameters for server
# CRITICAL FIX: Pass the original destination connection ID from the initial packet
quic_conn = QuicConnection(
configuration=server_config,
original_destination_connection_id=packet_info.destination_cid,
original_destination_connection_id=packet_info.destination_cid, # Use the original DCID from packet
)
# Store connection mapping
print("✅ NEW_CONN: QUIC connection created successfully")
# Store connection mapping using our generated CID
self._pending_connections[destination_cid] = quic_conn
self._addr_to_cid[addr] = destination_cid
self._cid_to_addr[destination_cid] = addr
print(f"🔧 NEW_CONN: Stored mappings for {addr} <-> {destination_cid.hex()}")
print("Receiving Datagram")
# Process initial packet
quic_conn.receive_datagram(data, addr, now=time.time())
# Debug connection state after receiving packet
await self._debug_quic_connection_state_detailed(quic_conn, destination_cid)
# Process events and send response
await self._process_quic_events(quic_conn, addr, destination_cid)
await self._transmit_for_connection(quic_conn, addr)
logger.debug(
f"Started handshake for new connection from {addr} "
f"(version: {packet_info.version:08x}, cid: {destination_cid.hex()})"
f"(version: 0x{packet_info.version:08x}, cid: {destination_cid.hex()})"
)
except Exception as e:
logger.error(f"Error handling new connection from {addr}: {e}")
import traceback
traceback.print_exc()
self._stats["connections_rejected"] += 1
async def _debug_quic_connection_state_detailed(
self, quic_conn: QuicConnection, connection_id: bytes
):
"""Enhanced connection state debugging."""
try:
print(f"🔧 QUIC_STATE: Debugging connection {connection_id.hex()}")
if not quic_conn:
print("❌ QUIC_STATE: QUIC CONNECTION NOT FOUND")
return
# Check TLS state
if hasattr(quic_conn, "tls") and quic_conn.tls:
print("✅ QUIC_STATE: TLS context exists")
if hasattr(quic_conn.tls, "state"):
print(f"🔧 QUIC_STATE: TLS state: {quic_conn.tls.state}")
# Check if we have peer certificate
if (
hasattr(quic_conn.tls, "_peer_certificate")
and quic_conn.tls._peer_certificate
):
print("✅ QUIC_STATE: Peer certificate available")
else:
print("🔧 QUIC_STATE: No peer certificate yet")
# Check TLS handshake completion
if hasattr(quic_conn.tls, "handshake_complete"):
handshake_status = quic_conn._handshake_complete
print(
f"🔧 QUIC_STATE: TLS handshake complete: {handshake_status}"
)
else:
print("❌ QUIC_STATE: No TLS context!")
# Check connection state
if hasattr(quic_conn, "_state"):
print(f"🔧 QUIC_STATE: Connection state: {quic_conn._state}")
# Check if handshake is complete
if hasattr(quic_conn, "_handshake_complete"):
print(
f"🔧 QUIC_STATE: Handshake complete: {quic_conn._handshake_complete}"
)
# Check configuration
if hasattr(quic_conn, "configuration"):
config = quic_conn.configuration
print(
f"🔧 QUIC_STATE: Config certificate: {config.certificate is not None}"
)
print(
f"🔧 QUIC_STATE: Config private_key: {config.private_key is not None}"
)
print(f"🔧 QUIC_STATE: Config is_client: {config.is_client}")
print(f"🔧 QUIC_STATE: Config verify_mode: {config.verify_mode}")
print(f"🔧 QUIC_STATE: Config ALPN: {config.alpn_protocols}")
if config.certificate:
cert = config.certificate
print(f"🔧 QUIC_STATE: Certificate subject: {cert.subject}")
print(
f"🔧 QUIC_STATE: Certificate valid from: {cert.not_valid_before}"
)
print(
f"🔧 QUIC_STATE: Certificate valid until: {cert.not_valid_after}"
)
# Check for connection errors
if hasattr(quic_conn, "_close_event") and quic_conn._close_event:
print(
f"❌ QUIC_STATE: Connection has close event: {quic_conn._close_event}"
)
# Check for TLS errors
if (
hasattr(quic_conn, "_handshake_complete")
and not quic_conn._handshake_complete
):
print("⚠️ QUIC_STATE: Handshake not yet complete")
except Exception as e:
print(f"❌ QUIC_STATE: Error checking state: {e}")
import traceback
traceback.print_exc()
async def _handle_short_header_packet(
self, data: bytes, addr: tuple[str, int]
) -> None:
@ -515,54 +630,141 @@ class QUICListener(IListener):
addr: tuple[str, int],
dest_cid: bytes,
) -> None:
"""Handle packet for a pending (handshaking) connection."""
"""Handle packet for a pending (handshaking) connection with enhanced debugging."""
try:
print(
f"🔧 PENDING: Handling packet for pending connection {dest_cid.hex()}"
)
print(f"🔧 PENDING: Packet size: {len(data)} bytes from {addr}")
# Check connection state before processing
if hasattr(quic_conn, "_state"):
print(f"🔧 PENDING: Connection state before: {quic_conn._state}")
if (
hasattr(quic_conn, "tls")
and quic_conn.tls
and hasattr(quic_conn.tls, "state")
):
print(f"🔧 PENDING: TLS state before: {quic_conn.tls.state}")
# Feed data to QUIC connection
quic_conn.receive_datagram(data, addr, now=time.time())
print("✅ PENDING: Datagram received by QUIC connection")
# Process events
# Check state after receiving packet
if hasattr(quic_conn, "_state"):
print(f"🔧 PENDING: Connection state after: {quic_conn._state}")
if (
hasattr(quic_conn, "tls")
and quic_conn.tls
and hasattr(quic_conn.tls, "state")
):
print(f"🔧 PENDING: TLS state after: {quic_conn.tls.state}")
# Process events - this is crucial for handshake progression
print("🔧 PENDING: Processing QUIC events...")
await self._process_quic_events(quic_conn, addr, dest_cid)
# Send any outgoing packets
# Send any outgoing packets - this is where the response should be sent
print("🔧 PENDING: Transmitting response...")
await self._transmit_for_connection(quic_conn, addr)
# Check if handshake completed
if (
hasattr(quic_conn, "_handshake_complete")
and quic_conn._handshake_complete
):
print("✅ PENDING: Handshake completed, promoting connection")
await self._promote_pending_connection(quic_conn, addr, dest_cid)
else:
print("🔧 PENDING: Handshake still in progress")
# Debug why handshake might be stuck
await self._debug_handshake_state(quic_conn, dest_cid)
except Exception as e:
logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}")
# Remove from pending connections
import traceback
traceback.print_exc()
# Remove problematic pending connection
print(f"❌ PENDING: Removing problematic connection {dest_cid.hex()}")
await self._remove_pending_connection(dest_cid)
async def _process_quic_events(
self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes
) -> None:
"""Process QUIC events for a connection with connection ID context."""
while True:
event = quic_conn.next_event()
if event is None:
break
"""Process QUIC events with enhanced debugging."""
try:
events_processed = 0
while True:
event = quic_conn.next_event()
if event is None:
break
if isinstance(event, events.ConnectionTerminated):
logger.debug(
f"Connection {dest_cid.hex()} from {addr} "
f"terminated: {event.reason_phrase}"
events_processed += 1
print(
f"🔧 EVENT: Processing event {events_processed}: {type(event).__name__}"
)
await self._remove_connection(dest_cid)
break
elif isinstance(event, events.HandshakeCompleted):
logger.debug(f"Handshake completed for connection {dest_cid.hex()}")
await self._promote_pending_connection(quic_conn, addr, dest_cid)
if isinstance(event, events.ConnectionTerminated):
print(
f"❌ EVENT: Connection terminated - code: {event.error_code}, reason: {event.reason_phrase}"
)
logger.debug(
f"Connection {dest_cid.hex()} from {addr} "
f"terminated: {event.reason_phrase}"
)
await self._remove_connection(dest_cid)
break
elif isinstance(event, events.StreamDataReceived):
# Forward to established connection if available
if dest_cid in self._connections:
connection = self._connections[dest_cid]
await connection._handle_stream_data(event)
elif isinstance(event, events.HandshakeCompleted):
print(
f"✅ EVENT: Handshake completed for connection {dest_cid.hex()}"
)
logger.debug(f"Handshake completed for connection {dest_cid.hex()}")
await self._promote_pending_connection(quic_conn, addr, dest_cid)
elif isinstance(event, events.StreamReset):
# Forward to established connection if available
if dest_cid in self._connections:
connection = self._connections[dest_cid]
await connection._handle_stream_reset(event)
elif isinstance(event, events.StreamDataReceived):
print(f"🔧 EVENT: Stream data received on stream {event.stream_id}")
# Forward to established connection if available
if dest_cid in self._connections:
connection = self._connections[dest_cid]
await connection._handle_stream_data(event)
elif isinstance(event, events.StreamReset):
print(f"🔧 EVENT: Stream reset on stream {event.stream_id}")
# Forward to established connection if available
if dest_cid in self._connections:
connection = self._connections[dest_cid]
await connection._handle_stream_reset(event)
elif isinstance(event, events.ConnectionIdIssued):
print(
f"🔧 EVENT: Connection ID issued: {event.connection_id.hex()}"
)
elif isinstance(event, events.ConnectionIdRetired):
print(
f"🔧 EVENT: Connection ID retired: {event.connection_id.hex()}"
)
else:
print(f"🔧 EVENT: Unhandled event type: {type(event).__name__}")
if events_processed == 0:
print("🔧 EVENT: No events to process")
else:
print(f"🔧 EVENT: Processed {events_processed} events total")
except Exception as e:
print(f"❌ EVENT: Error processing events: {e}")
import traceback
traceback.print_exc()
async def _debug_quic_connection_state(
self, quic_conn: QuicConnection, connection_id: bytes
@ -972,3 +1174,61 @@ class QUICListener(IListener):
stats["active_connections"] = len(self._connections)
stats["pending_connections"] = len(self._pending_connections)
return stats
async def _debug_handshake_state(self, quic_conn: QuicConnection, dest_cid: bytes):
"""Debug why handshake might be stuck."""
try:
print(f"🔧 HANDSHAKE_DEBUG: Analyzing stuck handshake for {dest_cid.hex()}")
# Check TLS handshake state
if hasattr(quic_conn, "tls") and quic_conn.tls:
tls = quic_conn.tls
print(
f"🔧 HANDSHAKE_DEBUG: TLS state: {getattr(tls, 'state', 'Unknown')}"
)
# Check for TLS errors
if hasattr(tls, "_error") and tls._error:
print(f"❌ HANDSHAKE_DEBUG: TLS error: {tls._error}")
# Check certificate validation
if hasattr(tls, "_peer_certificate"):
if tls._peer_certificate:
print("✅ HANDSHAKE_DEBUG: Peer certificate received")
else:
print("❌ HANDSHAKE_DEBUG: No peer certificate")
# Check ALPN negotiation
if hasattr(tls, "_alpn_protocols"):
if tls._alpn_protocols:
print(
f"✅ HANDSHAKE_DEBUG: ALPN negotiated: {tls._alpn_protocols}"
)
else:
print("❌ HANDSHAKE_DEBUG: No ALPN protocol negotiated")
# Check QUIC connection state
if hasattr(quic_conn, "_state"):
state = quic_conn._state
print(f"🔧 HANDSHAKE_DEBUG: QUIC state: {state}")
# Check specific states that might indicate problems
if "FIRSTFLIGHT" in str(state):
print("⚠️ HANDSHAKE_DEBUG: Connection stuck in FIRSTFLIGHT state")
elif "CONNECTED" in str(state):
print(
"⚠️ HANDSHAKE_DEBUG: Connection shows CONNECTED but handshake not complete"
)
# Check for pending crypto data
if hasattr(quic_conn, "_cryptos") and quic_conn._cryptos:
print(f"🔧 HANDSHAKE_DEBUG: Crypto data present {len(quic_conn._cryptos.keys())}")
# Check loss detection state
if hasattr(quic_conn, "_loss") and quic_conn._loss:
loss_detection = quic_conn._loss
if hasattr(loss_detection, "_pto_count"):
print(f"🔧 HANDSHAKE_DEBUG: PTO count: {loss_detection._pto_count}")
except Exception as e:
print(f"❌ HANDSHAKE_DEBUG: Error during debug: {e}")

View File

@ -4,9 +4,11 @@ Implements libp2p TLS specification for QUIC transport with peer identity integr
Based on go-libp2p and js-libp2p security patterns.
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import logging
import ssl
from typing import List, Optional, Union
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
@ -25,11 +27,6 @@ from .exceptions import (
QUICPeerVerificationError,
)
TSecurityConfig = dict[
str,
Certificate | EllipticCurvePrivateKey | RSAPrivateKey | bool | list[str],
]
logger = logging.getLogger(__name__)
# libp2p TLS Extension OID - Official libp2p specification
@ -312,7 +309,7 @@ class CertificateGenerator:
x509.UnrecognizedExtension(
oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data
),
critical=True, # This extension is critical for libp2p
critical=False,
)
.sign(cert_private_key, hashes.SHA256())
)
@ -407,6 +404,269 @@ class PeerAuthenticator:
) from e
@dataclass
class QUICTLSSecurityConfig:
"""
Type-safe TLS security configuration for QUIC transport.
"""
# Core TLS components (required)
certificate: Certificate
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey]
# Certificate chain (optional)
certificate_chain: List[Certificate] = field(default_factory=list)
# ALPN protocols
alpn_protocols: List[str] = field(default_factory=lambda: ["libp2p"])
# TLS verification settings
verify_mode: Union[bool, ssl.VerifyMode] = False
check_hostname: bool = False
# Optional peer ID for validation
peer_id: Optional[ID] = None
# Configuration metadata
is_client_config: bool = False
config_name: Optional[str] = None
def __post_init__(self):
"""Validate configuration after initialization."""
self._validate()
def _validate(self) -> None:
"""Validate the TLS configuration."""
if self.certificate is None:
raise ValueError("Certificate is required")
if self.private_key is None:
raise ValueError("Private key is required")
if not isinstance(self.certificate, x509.Certificate):
raise TypeError(
f"Certificate must be x509.Certificate, got {type(self.certificate)}"
)
if not isinstance(
self.private_key, (ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey)
):
raise TypeError(
f"Private key must be EC or RSA key, got {type(self.private_key)}"
)
if not self.alpn_protocols:
raise ValueError("At least one ALPN protocol is required")
def to_dict(self) -> dict:
"""
Convert to dictionary format for compatibility with existing code.
Returns:
Dictionary compatible with the original TSecurityConfig format
"""
return {
"certificate": self.certificate,
"private_key": self.private_key,
"certificate_chain": self.certificate_chain.copy(),
"alpn_protocols": self.alpn_protocols.copy(),
"verify_mode": self.verify_mode,
"check_hostname": self.check_hostname,
}
@classmethod
def from_dict(cls, config_dict: dict, **kwargs) -> "QUICTLSSecurityConfig":
"""
Create instance from dictionary format.
Args:
config_dict: Dictionary in TSecurityConfig format
**kwargs: Additional parameters for the config
Returns:
QUICTLSSecurityConfig instance
"""
return cls(
certificate=config_dict["certificate"],
private_key=config_dict["private_key"],
certificate_chain=config_dict.get("certificate_chain", []),
alpn_protocols=config_dict.get("alpn_protocols", ["libp2p"]),
verify_mode=config_dict.get("verify_mode", False),
check_hostname=config_dict.get("check_hostname", False),
**kwargs,
)
def validate_certificate_key_match(self) -> bool:
"""
Validate that the certificate and private key match.
Returns:
True if certificate and private key match
"""
try:
from cryptography.hazmat.primitives import serialization
# Get public keys from both certificate and private key
cert_public_key = self.certificate.public_key()
private_public_key = self.private_key.public_key()
# Compare their PEM representations
cert_pub_pem = cert_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
private_pub_pem = private_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return cert_pub_pem == private_pub_pem
except Exception:
return False
def has_libp2p_extension(self) -> bool:
"""
Check if the certificate has the required libp2p extension.
Returns:
True if libp2p extension is present
"""
try:
libp2p_oid = "1.3.6.1.4.1.53594.1.1"
for ext in self.certificate.extensions:
if str(ext.oid) == libp2p_oid:
return True
return False
except Exception:
return False
def is_certificate_valid(self) -> bool:
"""
Check if the certificate is currently valid (not expired).
Returns:
True if certificate is valid
"""
try:
from datetime import datetime
now = datetime.utcnow()
return (
self.certificate.not_valid_before
<= now
<= self.certificate.not_valid_after
)
except Exception:
return False
def get_certificate_info(self) -> dict:
"""
Get certificate information for debugging.
Returns:
Dictionary with certificate details
"""
try:
return {
"subject": str(self.certificate.subject),
"issuer": str(self.certificate.issuer),
"serial_number": self.certificate.serial_number,
"not_valid_before": self.certificate.not_valid_before,
"not_valid_after": self.certificate.not_valid_after,
"has_libp2p_extension": self.has_libp2p_extension(),
"is_valid": self.is_certificate_valid(),
"certificate_key_match": self.validate_certificate_key_match(),
}
except Exception as e:
return {"error": str(e)}
def debug_print(self) -> None:
"""Print debugging information about this configuration."""
print(f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===")
print(f"Is client config: {self.is_client_config}")
print(f"ALPN protocols: {self.alpn_protocols}")
print(f"Verify mode: {self.verify_mode}")
print(f"Check hostname: {self.check_hostname}")
print(f"Certificate chain length: {len(self.certificate_chain)}")
cert_info = self.get_certificate_info()
for key, value in cert_info.items():
print(f"Certificate {key}: {value}")
print(f"Private key type: {type(self.private_key).__name__}")
if hasattr(self.private_key, "key_size"):
print(f"Private key size: {self.private_key.key_size}")
def create_server_tls_config(
certificate: Certificate,
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
peer_id: Optional[ID] = None,
**kwargs,
) -> QUICTLSSecurityConfig:
"""
Create a server TLS configuration.
Args:
certificate: X.509 certificate
private_key: Private key corresponding to certificate
peer_id: Optional peer ID for validation
**kwargs: Additional configuration parameters
Returns:
Server TLS configuration
"""
return QUICTLSSecurityConfig(
certificate=certificate,
private_key=private_key,
peer_id=peer_id,
is_client_config=False,
config_name="server",
verify_mode=False, # Server doesn't verify client certs in libp2p
check_hostname=False,
**kwargs,
)
def create_client_tls_config(
certificate: Certificate,
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
peer_id: Optional[ID] = None,
**kwargs,
) -> QUICTLSSecurityConfig:
"""
Create a client TLS configuration.
Args:
certificate: X.509 certificate
private_key: Private key corresponding to certificate
peer_id: Optional peer ID for validation
**kwargs: Additional configuration parameters
Returns:
Client TLS configuration
"""
return QUICTLSSecurityConfig(
certificate=certificate,
private_key=private_key,
peer_id=peer_id,
is_client_config=True,
config_name="client",
verify_mode=False, # Client doesn't verify server certs in libp2p
check_hostname=False,
**kwargs,
)
class QUICTLSConfigManager:
"""
Manages TLS configuration for QUIC transport with libp2p security.
@ -424,44 +684,40 @@ class QUICTLSConfigManager:
libp2p_private_key, peer_id
)
def create_server_config(
self,
) -> TSecurityConfig:
def create_server_config(self) -> QUICTLSSecurityConfig:
"""
Create aioquic server configuration with libp2p TLS settings.
Returns cryptography objects instead of DER bytes.
Create server configuration using the new class-based approach.
Returns:
Configuration dictionary for aioquic QuicConfiguration
QUICTLSSecurityConfig instance for server
"""
config: TSecurityConfig = {
"certificate": self.tls_config.certificate,
"private_key": self.tls_config.private_key,
"certificate_chain": [],
"alpn_protocols": ["libp2p"],
"verify_mode": False,
"check_hostname": False,
}
config = create_server_tls_config(
certificate=self.tls_config.certificate,
private_key=self.tls_config.private_key,
peer_id=self.peer_id,
)
print("🔧 SECURITY: Created server config")
config.debug_print()
return config
def create_client_config(self) -> TSecurityConfig:
def create_client_config(self) -> QUICTLSSecurityConfig:
"""
Create aioquic client configuration with libp2p TLS settings.
Returns cryptography objects instead of DER bytes.
Create client configuration using the new class-based approach.
Returns:
Configuration dictionary for aioquic QuicConfiguration
QUICTLSSecurityConfig instance for client
"""
config: TSecurityConfig = {
"certificate": self.tls_config.certificate,
"private_key": self.tls_config.private_key,
"certificate_chain": [],
"alpn_protocols": ["libp2p"],
"verify_mode": False,
"check_hostname": False,
}
config = create_client_tls_config(
certificate=self.tls_config.certificate,
private_key=self.tls_config.private_key,
peer_id=self.peer_id,
)
print("🔧 SECURITY: Created client config")
config.debug_print()
return config
def verify_peer_identity(

View File

@ -5,7 +5,6 @@ Based on aioquic library with interface consistency to go-libp2p and js-libp2p.
Updated to include Module 5 security integration.
"""
from collections.abc import Iterable
import copy
import logging
import sys
@ -31,7 +30,7 @@ from libp2p.custom_types import THandler, TProtocol
from libp2p.peer.id import (
ID,
)
from libp2p.transport.quic.security import TSecurityConfig
from libp2p.transport.quic.security import QUICTLSSecurityConfig
from libp2p.transport.quic.utils import (
get_alpn_protocols,
is_quic_multiaddr,
@ -192,7 +191,7 @@ class QUICTransport(ITransport):
) from e
def _apply_tls_configuration(
self, config: QuicConfiguration, tls_config: TSecurityConfig
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
) -> None:
"""
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
@ -203,52 +202,47 @@ class QUICTransport(ITransport):
"""
try:
# Set certificate and private key directly on the configuration
# aioquic expects cryptography objects, not DER bytes
if "certificate" in tls_config and "private_key" in tls_config:
# The security manager should return cryptography objects
# not DER bytes, but if it returns DER bytes, we need to handle that
certificate = tls_config["certificate"]
private_key = tls_config["private_key"]
# Check if we received DER bytes and need
# to convert to cryptography objects
if isinstance(certificate, bytes):
# The security manager should return cryptography objects
# not DER bytes, but if it returns DER bytes, we need to handle that
certificate = tls_config.certificate
private_key = tls_config.private_key
# Check if we received DER bytes and need
# to convert to cryptography objects
if isinstance(certificate, bytes):
from cryptography import x509
certificate = x509.load_der_x509_certificate(certificate)
if isinstance(private_key, bytes):
from cryptography.hazmat.primitives import serialization
private_key = serialization.load_der_private_key( # type: ignore
private_key, password=None
)
# Set directly on the configuration object
config.certificate = certificate
config.private_key = private_key
# Handle certificate chain if provided
certificate_chain = tls_config.certificate_chain
# Convert DER bytes to cryptography objects if needed
chain_objects = []
for cert in certificate_chain:
if isinstance(cert, bytes):
from cryptography import x509
certificate = x509.load_der_x509_certificate(certificate)
if isinstance(private_key, bytes):
from cryptography.hazmat.primitives import serialization
private_key = serialization.load_der_private_key( # type: ignore
private_key, password=None
)
# Set directly on the configuration object
config.certificate = certificate
config.private_key = private_key
# Handle certificate chain if provided
certificate_chain = tls_config.get("certificate_chain", [])
if certificate_chain and isinstance(certificate_chain, Iterable):
# Convert DER bytes to cryptography objects if needed
chain_objects = []
for cert in certificate_chain:
if isinstance(cert, bytes):
from cryptography import x509
cert = x509.load_der_x509_certificate(cert)
chain_objects.append(cert)
config.certificate_chain = chain_objects
cert = x509.load_der_x509_certificate(cert)
chain_objects.append(cert)
config.certificate_chain = chain_objects
# Set ALPN protocols
if "alpn_protocols" in tls_config:
config.alpn_protocols = tls_config["alpn_protocols"] # type: ignore
config.alpn_protocols = tls_config.alpn_protocols
# Set certificate verification mode
if "verify_mode" in tls_config:
config.verify_mode = tls_config["verify_mode"] # type: ignore
config.verify_mode = tls_config.verify_mode
logger.debug("Successfully applied TLS configuration to QUIC config")