mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
fix: duplication connection creation for same sessions
This commit is contained in:
289
examples/echo/test_quic.py
Normal file
289
examples/echo/test_quic.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Fixed QUIC handshake test to debug connection issues.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import secrets
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||||
|
from libp2p.transport.quic.transport import QUICTransport, QUICTransportConfig
|
||||||
|
from libp2p.transport.quic.utils import create_quic_multiaddr
|
||||||
|
|
||||||
|
# Adjust this path to your project structure
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_certificate_generation():
|
||||||
|
"""Test certificate generation in isolation."""
|
||||||
|
print("\n=== TESTING CERTIFICATE GENERATION ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.transport.quic.security import create_quic_security_transport
|
||||||
|
|
||||||
|
# Create key pair
|
||||||
|
private_key = create_new_key_pair().private_key
|
||||||
|
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||||
|
|
||||||
|
print(f"Generated peer ID: {peer_id}")
|
||||||
|
|
||||||
|
# Create security manager
|
||||||
|
security_manager = create_quic_security_transport(private_key, peer_id)
|
||||||
|
print("✅ Security manager created")
|
||||||
|
|
||||||
|
# Test server config
|
||||||
|
server_config = security_manager.create_server_config()
|
||||||
|
print("✅ Server config created")
|
||||||
|
|
||||||
|
# Validate certificate
|
||||||
|
cert = server_config.certificate
|
||||||
|
private_key_obj = server_config.private_key
|
||||||
|
|
||||||
|
print(f"Certificate type: {type(cert)}")
|
||||||
|
print(f"Private key type: {type(private_key_obj)}")
|
||||||
|
print(f"Certificate subject: {cert.subject}")
|
||||||
|
print(f"Certificate issuer: {cert.issuer}")
|
||||||
|
|
||||||
|
# Check for libp2p extension
|
||||||
|
has_libp2p_ext = False
|
||||||
|
for ext in cert.extensions:
|
||||||
|
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
|
||||||
|
has_libp2p_ext = True
|
||||||
|
print(f"✅ Found libp2p extension: {ext.oid}")
|
||||||
|
print(f"Extension critical: {ext.critical}")
|
||||||
|
print(f"Extension value length: {len(ext.value)} bytes")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_libp2p_ext:
|
||||||
|
print("❌ No libp2p extension found!")
|
||||||
|
print("Available extensions:")
|
||||||
|
for ext in cert.extensions:
|
||||||
|
print(f" - {ext.oid} (critical: {ext.critical})")
|
||||||
|
|
||||||
|
# Check certificate/key match
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
cert_public_key = cert.public_key()
|
||||||
|
private_public_key = private_key_obj.public_key()
|
||||||
|
|
||||||
|
cert_pub_bytes = cert_public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
private_pub_bytes = private_public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cert_pub_bytes == private_pub_bytes:
|
||||||
|
print("✅ Certificate and private key match")
|
||||||
|
return has_libp2p_ext
|
||||||
|
else:
|
||||||
|
print("❌ Certificate and private key DO NOT match")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Certificate test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_basic_quic_connection():
|
||||||
|
"""Test basic QUIC connection with proper server setup."""
|
||||||
|
print("\n=== TESTING BASIC QUIC CONNECTION ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aioquic.quic.configuration import QuicConfiguration
|
||||||
|
from aioquic.quic.connection import QuicConnection
|
||||||
|
|
||||||
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.transport.quic.security import create_quic_security_transport
|
||||||
|
|
||||||
|
# Create certificates
|
||||||
|
server_key = create_new_key_pair().private_key
|
||||||
|
server_peer_id = ID.from_pubkey(server_key.get_public_key())
|
||||||
|
server_security = create_quic_security_transport(server_key, server_peer_id)
|
||||||
|
|
||||||
|
client_key = create_new_key_pair().private_key
|
||||||
|
client_peer_id = ID.from_pubkey(client_key.get_public_key())
|
||||||
|
client_security = create_quic_security_transport(client_key, client_peer_id)
|
||||||
|
|
||||||
|
# Create server config
|
||||||
|
server_tls_config = server_security.create_server_config()
|
||||||
|
server_config = QuicConfiguration(
|
||||||
|
is_client=False,
|
||||||
|
certificate=server_tls_config.certificate,
|
||||||
|
private_key=server_tls_config.private_key,
|
||||||
|
alpn_protocols=["libp2p"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create client config
|
||||||
|
client_tls_config = client_security.create_client_config()
|
||||||
|
client_config = QuicConfiguration(
|
||||||
|
is_client=True,
|
||||||
|
certificate=client_tls_config.certificate,
|
||||||
|
private_key=client_tls_config.private_key,
|
||||||
|
alpn_protocols=["libp2p"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("✅ QUIC configurations created")
|
||||||
|
|
||||||
|
# Test creating connections with proper parameters
|
||||||
|
# For server, we need to provide original_destination_connection_id
|
||||||
|
original_dcid = secrets.token_bytes(8)
|
||||||
|
|
||||||
|
server_conn = QuicConnection(
|
||||||
|
configuration=server_config,
|
||||||
|
original_destination_connection_id=original_dcid,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For client, no original_destination_connection_id needed
|
||||||
|
client_conn = QuicConnection(configuration=client_config)
|
||||||
|
|
||||||
|
print("✅ QUIC connections created")
|
||||||
|
print(f"Server state: {server_conn._state}")
|
||||||
|
print(f"Client state: {client_conn._state}")
|
||||||
|
|
||||||
|
# Test that certificates are valid
|
||||||
|
print(f"Server has certificate: {server_config.certificate is not None}")
|
||||||
|
print(f"Server has private key: {server_config.private_key is not None}")
|
||||||
|
print(f"Client has certificate: {client_config.certificate is not None}")
|
||||||
|
print(f"Client has private key: {client_config.private_key is not None}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Basic QUIC test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_server_startup():
|
||||||
|
"""Test server startup with timeout."""
|
||||||
|
print("\n=== TESTING SERVER STARTUP ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create transport
|
||||||
|
private_key = create_new_key_pair().private_key
|
||||||
|
config = QUICTransportConfig(
|
||||||
|
idle_timeout=10.0, # Reduced timeout for testing
|
||||||
|
connection_timeout=10.0,
|
||||||
|
enable_draft29=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport = QUICTransport(private_key, config)
|
||||||
|
print("✅ Transport created successfully")
|
||||||
|
|
||||||
|
# Test configuration
|
||||||
|
print(f"Available configs: {list(transport._quic_configs.keys())}")
|
||||||
|
|
||||||
|
config_valid = True
|
||||||
|
for config_key, quic_config in transport._quic_configs.items():
|
||||||
|
print(f"\n--- Testing config: {config_key} ---")
|
||||||
|
print(f"is_client: {quic_config.is_client}")
|
||||||
|
print(f"has_certificate: {quic_config.certificate is not None}")
|
||||||
|
print(f"has_private_key: {quic_config.private_key is not None}")
|
||||||
|
print(f"alpn_protocols: {quic_config.alpn_protocols}")
|
||||||
|
print(f"verify_mode: {quic_config.verify_mode}")
|
||||||
|
|
||||||
|
if quic_config.certificate:
|
||||||
|
cert = quic_config.certificate
|
||||||
|
print(f"Certificate subject: {cert.subject}")
|
||||||
|
|
||||||
|
# Check for libp2p extension
|
||||||
|
has_libp2p_ext = False
|
||||||
|
for ext in cert.extensions:
|
||||||
|
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
|
||||||
|
has_libp2p_ext = True
|
||||||
|
break
|
||||||
|
print(f"Has libp2p extension: {has_libp2p_ext}")
|
||||||
|
|
||||||
|
if not has_libp2p_ext:
|
||||||
|
config_valid = False
|
||||||
|
|
||||||
|
if not config_valid:
|
||||||
|
print("❌ Transport configuration invalid - missing libp2p extensions")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create listener
|
||||||
|
async def dummy_handler(connection):
|
||||||
|
print(f"New connection: {connection}")
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
|
print("✅ Listener created successfully")
|
||||||
|
|
||||||
|
# Try to bind with timeout
|
||||||
|
maddr = create_quic_multiaddr("127.0.0.1", 0, "quic-v1")
|
||||||
|
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
result = await listener.listen(maddr, nursery)
|
||||||
|
if result:
|
||||||
|
print("✅ Server bound successfully")
|
||||||
|
addresses = listener.get_addresses()
|
||||||
|
print(f"Listening on: {addresses}")
|
||||||
|
|
||||||
|
# Keep running for a short time
|
||||||
|
with trio.move_on_after(3.0): # 3 second timeout
|
||||||
|
await trio.sleep(5.0)
|
||||||
|
|
||||||
|
print("✅ Server test completed (timed out normally)")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("❌ Failed to bind server")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Server test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run all tests with better error handling."""
|
||||||
|
print("Starting QUIC diagnostic tests...")
|
||||||
|
|
||||||
|
# Test 1: Certificate generation
|
||||||
|
cert_ok = await test_certificate_generation()
|
||||||
|
if not cert_ok:
|
||||||
|
print("\n❌ CRITICAL: Certificate generation failed!")
|
||||||
|
print("Apply the certificate generation fix and try again.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Test 2: Basic QUIC connection
|
||||||
|
quic_ok = await test_basic_quic_connection()
|
||||||
|
if not quic_ok:
|
||||||
|
print("\n❌ CRITICAL: Basic QUIC connection test failed!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Test 3: Server startup
|
||||||
|
server_ok = await test_server_startup()
|
||||||
|
if not server_ok:
|
||||||
|
print("\n❌ Server startup test failed!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n✅ ALL TESTS PASSED!")
|
||||||
|
print("=== DIAGNOSTIC COMPLETE ===")
|
||||||
|
print("Your QUIC implementation should now work correctly.")
|
||||||
|
print("Try running your echo example again.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
trio.run(main)
|
||||||
@ -249,23 +249,35 @@ class QUICListener(IListener):
|
|||||||
|
|
||||||
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
|
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||||
"""
|
"""
|
||||||
Enhanced packet processing with connection ID routing and version negotiation.
|
Enhanced packet processing with better connection ID routing and debugging.
|
||||||
FIXED: Added address-based connection reuse to prevent multiple connections.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._stats["packets_processed"] += 1
|
self._stats["packets_processed"] += 1
|
||||||
self._stats["bytes_received"] += len(data)
|
self._stats["bytes_received"] += len(data)
|
||||||
|
|
||||||
|
print(f"🔧 PACKET: Processing {len(data)} bytes from {addr}")
|
||||||
|
|
||||||
# Parse packet to extract connection information
|
# Parse packet to extract connection information
|
||||||
packet_info = self.parse_quic_packet(data)
|
packet_info = self.parse_quic_packet(data)
|
||||||
|
|
||||||
print(f"🔧 DEBUG: Address mappings: {self._addr_to_cid}")
|
|
||||||
print(
|
print(
|
||||||
f"🔧 DEBUG: Pending connections: {list(self._pending_connections.keys())}"
|
f"🔧 DEBUG: Address mappings: {dict((k, v.hex()) for k, v in self._addr_to_cid.items())}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"🔧 DEBUG: Established connections: {[cid.hex() for cid in self._connections.keys()]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._connection_lock:
|
async with self._connection_lock:
|
||||||
if packet_info:
|
if packet_info:
|
||||||
|
print(
|
||||||
|
f"🔧 PACKET: Parsed packet - version: 0x{packet_info.version:08x}, "
|
||||||
|
f"dest_cid: {packet_info.destination_cid.hex()}, "
|
||||||
|
f"src_cid: {packet_info.source_cid.hex()}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check for version negotiation
|
# Check for version negotiation
|
||||||
if packet_info.version == 0:
|
if packet_info.version == 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -275,6 +287,9 @@ class QUICListener(IListener):
|
|||||||
|
|
||||||
# Check if version is supported
|
# Check if version is supported
|
||||||
if packet_info.version not in self._supported_versions:
|
if packet_info.version not in self._supported_versions:
|
||||||
|
print(
|
||||||
|
f"❌ PACKET: Unsupported version 0x{packet_info.version:08x}"
|
||||||
|
)
|
||||||
await self._send_version_negotiation(
|
await self._send_version_negotiation(
|
||||||
addr, packet_info.source_cid
|
addr, packet_info.source_cid
|
||||||
)
|
)
|
||||||
@ -283,87 +298,66 @@ class QUICListener(IListener):
|
|||||||
# Route based on destination connection ID
|
# Route based on destination connection ID
|
||||||
dest_cid = packet_info.destination_cid
|
dest_cid = packet_info.destination_cid
|
||||||
|
|
||||||
|
# First, try exact connection ID match
|
||||||
if dest_cid in self._connections:
|
if dest_cid in self._connections:
|
||||||
# Existing established connection
|
print(
|
||||||
print(f"🔧 ROUTING: To established connection {dest_cid.hex()}")
|
f"✅ PACKET: Routing to established connection {dest_cid.hex()}"
|
||||||
|
)
|
||||||
connection = self._connections[dest_cid]
|
connection = self._connections[dest_cid]
|
||||||
await self._route_to_connection(connection, data, addr)
|
await self._route_to_connection(connection, data, addr)
|
||||||
|
return
|
||||||
|
|
||||||
elif dest_cid in self._pending_connections:
|
elif dest_cid in self._pending_connections:
|
||||||
# Existing pending connection
|
print(
|
||||||
print(f"🔧 ROUTING: To pending connection {dest_cid.hex()}")
|
f"✅ PACKET: Routing to pending connection {dest_cid.hex()}"
|
||||||
|
)
|
||||||
quic_conn = self._pending_connections[dest_cid]
|
quic_conn = self._pending_connections[dest_cid]
|
||||||
await self._handle_pending_connection(
|
await self._handle_pending_connection(
|
||||||
quic_conn, data, addr, dest_cid
|
quic_conn, data, addr, dest_cid
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
else:
|
# If no exact match, try address-based routing (connection ID might not match)
|
||||||
# CRITICAL FIX: Check for existing connection by address BEFORE creating new
|
mapped_cid = self._addr_to_cid.get(addr)
|
||||||
existing_cid = self._addr_to_cid.get(addr)
|
if mapped_cid:
|
||||||
|
print(
|
||||||
|
f"🔧 PACKET: Found address mapping {addr} -> {mapped_cid.hex()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"🔧 PACKET: Client dest_cid {dest_cid.hex()} != our cid {mapped_cid.hex()}"
|
||||||
|
)
|
||||||
|
|
||||||
if existing_cid is not None:
|
if mapped_cid in self._connections:
|
||||||
print(
|
print(
|
||||||
f"✅ FOUND: Existing connection {existing_cid.hex()} for address {addr}"
|
"✅ PACKET: Using established connection via address mapping"
|
||||||
)
|
)
|
||||||
|
connection = self._connections[mapped_cid]
|
||||||
|
await self._route_to_connection(connection, data, addr)
|
||||||
|
return
|
||||||
|
elif mapped_cid in self._pending_connections:
|
||||||
print(
|
print(
|
||||||
f"🔧 NOTE: Client dest_cid {dest_cid.hex()} != our cid {existing_cid.hex()}"
|
"✅ PACKET: Using pending connection via address mapping"
|
||||||
)
|
)
|
||||||
|
quic_conn = self._pending_connections[mapped_cid]
|
||||||
|
await self._handle_pending_connection(
|
||||||
|
quic_conn, data, addr, mapped_cid
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Route to existing connection by address
|
# No existing connection found, create new one
|
||||||
if existing_cid in self._pending_connections:
|
print(f"🔧 PACKET: Creating new connection for {addr}")
|
||||||
print(
|
await self._handle_new_connection(data, addr, packet_info)
|
||||||
"🔧 ROUTING: Using existing pending connection by address"
|
|
||||||
)
|
|
||||||
quic_conn = self._pending_connections[existing_cid]
|
|
||||||
await self._handle_pending_connection(
|
|
||||||
quic_conn, data, addr, existing_cid
|
|
||||||
)
|
|
||||||
elif existing_cid in self._connections:
|
|
||||||
print(
|
|
||||||
"🔧 ROUTING: Using existing established connection by address"
|
|
||||||
)
|
|
||||||
connection = self._connections[existing_cid]
|
|
||||||
await self._route_to_connection(connection, data, addr)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"❌ ERROR: Address mapping exists but connection {existing_cid.hex()} not found!"
|
|
||||||
)
|
|
||||||
# Clean up broken mapping and create new
|
|
||||||
self._addr_to_cid.pop(addr, None)
|
|
||||||
if packet_info.packet_type == 0: # Initial packet
|
|
||||||
print(
|
|
||||||
"🔧 NEW: Creating new connection after cleanup"
|
|
||||||
)
|
|
||||||
await self._handle_new_connection(
|
|
||||||
data, addr, packet_info
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Truly new connection - only handle Initial packets
|
|
||||||
if packet_info.packet_type == 0: # Initial packet
|
|
||||||
print(f"🔧 NEW: Creating first connection for {addr}")
|
|
||||||
await self._handle_new_connection(
|
|
||||||
data, addr, packet_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# Debug the newly created connection
|
|
||||||
new_cid = self._addr_to_cid.get(addr)
|
|
||||||
if new_cid and new_cid in self._pending_connections:
|
|
||||||
quic_conn = self._pending_connections[new_cid]
|
|
||||||
await self._debug_quic_connection_state(
|
|
||||||
quic_conn, new_cid
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"Ignoring non-Initial packet for unknown connection ID from {addr}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Fallback to address-based routing for short header packets
|
# Failed to parse packet
|
||||||
|
print(f"❌ PACKET: Failed to parse packet from {addr}")
|
||||||
await self._handle_short_header_packet(data, addr)
|
await self._handle_short_header_packet(data, addr)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing packet from {addr}: {e}")
|
logger.error(f"Error processing packet from {addr}: {e}")
|
||||||
self._stats["invalid_packets"] += 1
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _send_version_negotiation(
|
async def _send_version_negotiation(
|
||||||
self, addr: tuple[str, int], source_cid: bytes
|
self, addr: tuple[str, int], source_cid: bytes
|
||||||
@ -404,29 +398,31 @@ class QUICListener(IListener):
|
|||||||
logger.error(f"Failed to send version negotiation to {addr}: {e}")
|
logger.error(f"Failed to send version negotiation to {addr}: {e}")
|
||||||
|
|
||||||
async def _handle_new_connection(
|
async def _handle_new_connection(
|
||||||
self,
|
self, data: bytes, addr: tuple[str, int], packet_info: QUICPacketInfo
|
||||||
data: bytes,
|
|
||||||
addr: tuple[str, int],
|
|
||||||
packet_info: QUICPacketInfo,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Handle new connection with proper connection ID handling."""
|
||||||
Handle new connection with proper version negotiation.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
|
print(f"🔧 NEW_CONN: Starting handshake for {addr}")
|
||||||
|
|
||||||
|
# Find appropriate QUIC configuration
|
||||||
quic_config = None
|
quic_config = None
|
||||||
|
config_key = None
|
||||||
|
|
||||||
for protocol, config in self._quic_configs.items():
|
for protocol, config in self._quic_configs.items():
|
||||||
wire_versions = custom_quic_version_to_wire_format(protocol)
|
wire_versions = custom_quic_version_to_wire_format(protocol)
|
||||||
if wire_versions == packet_info.version:
|
if wire_versions == packet_info.version:
|
||||||
quic_config = config
|
quic_config = config
|
||||||
|
config_key = protocol
|
||||||
break
|
break
|
||||||
|
|
||||||
if not quic_config:
|
if not quic_config:
|
||||||
logger.warning(
|
print(f"❌ NEW_CONN: No configuration found for version 0x{packet_info.version:08x}")
|
||||||
f"No configuration found for version {packet_info.version:08x}"
|
print(f"🔧 NEW_CONN: Available configs: {list(self._quic_configs.keys())}")
|
||||||
)
|
|
||||||
await self._send_version_negotiation(addr, packet_info.source_cid)
|
await self._send_version_negotiation(addr, packet_info.source_cid)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
print(f"✅ NEW_CONN: Using config {config_key} for version 0x{packet_info.version:08x}")
|
||||||
|
|
||||||
# Create server-side QUIC configuration
|
# Create server-side QUIC configuration
|
||||||
server_config = create_server_config_from_base(
|
server_config = create_server_config_from_base(
|
||||||
base_config=quic_config,
|
base_config=quic_config,
|
||||||
@ -434,39 +430,158 @@ class QUICListener(IListener):
|
|||||||
transport_config=self._config,
|
transport_config=self._config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate a new destination connection ID for this connection
|
# Debug the server configuration
|
||||||
# In a real implementation, this should be cryptographically secure
|
print(f"🔧 NEW_CONN: Server config - is_client: {server_config.is_client}")
|
||||||
import secrets
|
print(f"🔧 NEW_CONN: Server config - has_certificate: {server_config.certificate is not None}")
|
||||||
|
print(f"🔧 NEW_CONN: Server config - has_private_key: {server_config.private_key is not None}")
|
||||||
|
print(f"🔧 NEW_CONN: Server config - ALPN: {server_config.alpn_protocols}")
|
||||||
|
print(f"🔧 NEW_CONN: Server config - verify_mode: {server_config.verify_mode}")
|
||||||
|
|
||||||
|
# Validate certificate has libp2p extension
|
||||||
|
if server_config.certificate:
|
||||||
|
cert = server_config.certificate
|
||||||
|
has_libp2p_ext = False
|
||||||
|
for ext in cert.extensions:
|
||||||
|
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
|
||||||
|
has_libp2p_ext = True
|
||||||
|
break
|
||||||
|
print(f"🔧 NEW_CONN: Certificate has libp2p extension: {has_libp2p_ext}")
|
||||||
|
|
||||||
|
if not has_libp2p_ext:
|
||||||
|
print("❌ NEW_CONN: Certificate missing libp2p extension!")
|
||||||
|
|
||||||
|
# Generate a new destination connection ID for this connection
|
||||||
|
import secrets
|
||||||
destination_cid = secrets.token_bytes(8)
|
destination_cid = secrets.token_bytes(8)
|
||||||
|
|
||||||
# Create QUIC connection with specific version
|
print(f"🔧 NEW_CONN: Generated new CID: {destination_cid.hex()}")
|
||||||
|
print(f"🔧 NEW_CONN: Original destination CID: {packet_info.destination_cid.hex()}")
|
||||||
|
|
||||||
|
# Create QUIC connection with proper parameters for server
|
||||||
|
# CRITICAL FIX: Pass the original destination connection ID from the initial packet
|
||||||
quic_conn = QuicConnection(
|
quic_conn = QuicConnection(
|
||||||
configuration=server_config,
|
configuration=server_config,
|
||||||
original_destination_connection_id=packet_info.destination_cid,
|
original_destination_connection_id=packet_info.destination_cid, # Use the original DCID from packet
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store connection mapping
|
print("✅ NEW_CONN: QUIC connection created successfully")
|
||||||
|
|
||||||
|
# Store connection mapping using our generated CID
|
||||||
self._pending_connections[destination_cid] = quic_conn
|
self._pending_connections[destination_cid] = quic_conn
|
||||||
self._addr_to_cid[addr] = destination_cid
|
self._addr_to_cid[addr] = destination_cid
|
||||||
self._cid_to_addr[destination_cid] = addr
|
self._cid_to_addr[destination_cid] = addr
|
||||||
|
|
||||||
|
print(f"🔧 NEW_CONN: Stored mappings for {addr} <-> {destination_cid.hex()}")
|
||||||
print("Receiving Datagram")
|
print("Receiving Datagram")
|
||||||
|
|
||||||
# Process initial packet
|
# Process initial packet
|
||||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||||
|
|
||||||
|
# Debug connection state after receiving packet
|
||||||
|
await self._debug_quic_connection_state_detailed(quic_conn, destination_cid)
|
||||||
|
|
||||||
|
# Process events and send response
|
||||||
await self._process_quic_events(quic_conn, addr, destination_cid)
|
await self._process_quic_events(quic_conn, addr, destination_cid)
|
||||||
await self._transmit_for_connection(quic_conn, addr)
|
await self._transmit_for_connection(quic_conn, addr)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Started handshake for new connection from {addr} "
|
f"Started handshake for new connection from {addr} "
|
||||||
f"(version: {packet_info.version:08x}, cid: {destination_cid.hex()})"
|
f"(version: 0x{packet_info.version:08x}, cid: {destination_cid.hex()})"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling new connection from {addr}: {e}")
|
logger.error(f"Error handling new connection from {addr}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
self._stats["connections_rejected"] += 1
|
self._stats["connections_rejected"] += 1
|
||||||
|
|
||||||
|
async def _debug_quic_connection_state_detailed(
|
||||||
|
self, quic_conn: QuicConnection, connection_id: bytes
|
||||||
|
):
|
||||||
|
"""Enhanced connection state debugging."""
|
||||||
|
try:
|
||||||
|
print(f"🔧 QUIC_STATE: Debugging connection {connection_id.hex()}")
|
||||||
|
|
||||||
|
if not quic_conn:
|
||||||
|
print("❌ QUIC_STATE: QUIC CONNECTION NOT FOUND")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check TLS state
|
||||||
|
if hasattr(quic_conn, "tls") and quic_conn.tls:
|
||||||
|
print("✅ QUIC_STATE: TLS context exists")
|
||||||
|
if hasattr(quic_conn.tls, "state"):
|
||||||
|
print(f"🔧 QUIC_STATE: TLS state: {quic_conn.tls.state}")
|
||||||
|
|
||||||
|
# Check if we have peer certificate
|
||||||
|
if (
|
||||||
|
hasattr(quic_conn.tls, "_peer_certificate")
|
||||||
|
and quic_conn.tls._peer_certificate
|
||||||
|
):
|
||||||
|
print("✅ QUIC_STATE: Peer certificate available")
|
||||||
|
else:
|
||||||
|
print("🔧 QUIC_STATE: No peer certificate yet")
|
||||||
|
|
||||||
|
# Check TLS handshake completion
|
||||||
|
if hasattr(quic_conn.tls, "handshake_complete"):
|
||||||
|
handshake_status = quic_conn._handshake_complete
|
||||||
|
print(
|
||||||
|
f"🔧 QUIC_STATE: TLS handshake complete: {handshake_status}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("❌ QUIC_STATE: No TLS context!")
|
||||||
|
|
||||||
|
# Check connection state
|
||||||
|
if hasattr(quic_conn, "_state"):
|
||||||
|
print(f"🔧 QUIC_STATE: Connection state: {quic_conn._state}")
|
||||||
|
|
||||||
|
# Check if handshake is complete
|
||||||
|
if hasattr(quic_conn, "_handshake_complete"):
|
||||||
|
print(
|
||||||
|
f"🔧 QUIC_STATE: Handshake complete: {quic_conn._handshake_complete}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check configuration
|
||||||
|
if hasattr(quic_conn, "configuration"):
|
||||||
|
config = quic_conn.configuration
|
||||||
|
print(
|
||||||
|
f"🔧 QUIC_STATE: Config certificate: {config.certificate is not None}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"🔧 QUIC_STATE: Config private_key: {config.private_key is not None}"
|
||||||
|
)
|
||||||
|
print(f"🔧 QUIC_STATE: Config is_client: {config.is_client}")
|
||||||
|
print(f"🔧 QUIC_STATE: Config verify_mode: {config.verify_mode}")
|
||||||
|
print(f"🔧 QUIC_STATE: Config ALPN: {config.alpn_protocols}")
|
||||||
|
|
||||||
|
if config.certificate:
|
||||||
|
cert = config.certificate
|
||||||
|
print(f"🔧 QUIC_STATE: Certificate subject: {cert.subject}")
|
||||||
|
print(
|
||||||
|
f"🔧 QUIC_STATE: Certificate valid from: {cert.not_valid_before}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"🔧 QUIC_STATE: Certificate valid until: {cert.not_valid_after}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for connection errors
|
||||||
|
if hasattr(quic_conn, "_close_event") and quic_conn._close_event:
|
||||||
|
print(
|
||||||
|
f"❌ QUIC_STATE: Connection has close event: {quic_conn._close_event}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for TLS errors
|
||||||
|
if (
|
||||||
|
hasattr(quic_conn, "_handshake_complete")
|
||||||
|
and not quic_conn._handshake_complete
|
||||||
|
):
|
||||||
|
print("⚠️ QUIC_STATE: Handshake not yet complete")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ QUIC_STATE: Error checking state: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _handle_short_header_packet(
|
async def _handle_short_header_packet(
|
||||||
self, data: bytes, addr: tuple[str, int]
|
self, data: bytes, addr: tuple[str, int]
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -515,54 +630,141 @@ class QUICListener(IListener):
|
|||||||
addr: tuple[str, int],
|
addr: tuple[str, int],
|
||||||
dest_cid: bytes,
|
dest_cid: bytes,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle packet for a pending (handshaking) connection."""
|
"""Handle packet for a pending (handshaking) connection with enhanced debugging."""
|
||||||
try:
|
try:
|
||||||
|
print(
|
||||||
|
f"🔧 PENDING: Handling packet for pending connection {dest_cid.hex()}"
|
||||||
|
)
|
||||||
|
print(f"🔧 PENDING: Packet size: {len(data)} bytes from {addr}")
|
||||||
|
|
||||||
|
# Check connection state before processing
|
||||||
|
if hasattr(quic_conn, "_state"):
|
||||||
|
print(f"🔧 PENDING: Connection state before: {quic_conn._state}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(quic_conn, "tls")
|
||||||
|
and quic_conn.tls
|
||||||
|
and hasattr(quic_conn.tls, "state")
|
||||||
|
):
|
||||||
|
print(f"🔧 PENDING: TLS state before: {quic_conn.tls.state}")
|
||||||
|
|
||||||
# Feed data to QUIC connection
|
# Feed data to QUIC connection
|
||||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||||
|
print("✅ PENDING: Datagram received by QUIC connection")
|
||||||
|
|
||||||
# Process events
|
# Check state after receiving packet
|
||||||
|
if hasattr(quic_conn, "_state"):
|
||||||
|
print(f"🔧 PENDING: Connection state after: {quic_conn._state}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(quic_conn, "tls")
|
||||||
|
and quic_conn.tls
|
||||||
|
and hasattr(quic_conn.tls, "state")
|
||||||
|
):
|
||||||
|
print(f"🔧 PENDING: TLS state after: {quic_conn.tls.state}")
|
||||||
|
|
||||||
|
# Process events - this is crucial for handshake progression
|
||||||
|
print("🔧 PENDING: Processing QUIC events...")
|
||||||
await self._process_quic_events(quic_conn, addr, dest_cid)
|
await self._process_quic_events(quic_conn, addr, dest_cid)
|
||||||
|
|
||||||
# Send any outgoing packets
|
# Send any outgoing packets - this is where the response should be sent
|
||||||
|
print("🔧 PENDING: Transmitting response...")
|
||||||
await self._transmit_for_connection(quic_conn, addr)
|
await self._transmit_for_connection(quic_conn, addr)
|
||||||
|
|
||||||
|
# Check if handshake completed
|
||||||
|
if (
|
||||||
|
hasattr(quic_conn, "_handshake_complete")
|
||||||
|
and quic_conn._handshake_complete
|
||||||
|
):
|
||||||
|
print("✅ PENDING: Handshake completed, promoting connection")
|
||||||
|
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||||
|
else:
|
||||||
|
print("🔧 PENDING: Handshake still in progress")
|
||||||
|
|
||||||
|
# Debug why handshake might be stuck
|
||||||
|
await self._debug_handshake_state(quic_conn, dest_cid)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}")
|
logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}")
|
||||||
# Remove from pending connections
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Remove problematic pending connection
|
||||||
|
print(f"❌ PENDING: Removing problematic connection {dest_cid.hex()}")
|
||||||
await self._remove_pending_connection(dest_cid)
|
await self._remove_pending_connection(dest_cid)
|
||||||
|
|
||||||
async def _process_quic_events(
|
async def _process_quic_events(
|
||||||
self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes
|
self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process QUIC events for a connection with connection ID context."""
|
"""Process QUIC events with enhanced debugging."""
|
||||||
while True:
|
try:
|
||||||
event = quic_conn.next_event()
|
events_processed = 0
|
||||||
if event is None:
|
while True:
|
||||||
break
|
event = quic_conn.next_event()
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
|
||||||
if isinstance(event, events.ConnectionTerminated):
|
events_processed += 1
|
||||||
logger.debug(
|
print(
|
||||||
f"Connection {dest_cid.hex()} from {addr} "
|
f"🔧 EVENT: Processing event {events_processed}: {type(event).__name__}"
|
||||||
f"terminated: {event.reason_phrase}"
|
|
||||||
)
|
)
|
||||||
await self._remove_connection(dest_cid)
|
|
||||||
break
|
|
||||||
|
|
||||||
elif isinstance(event, events.HandshakeCompleted):
|
if isinstance(event, events.ConnectionTerminated):
|
||||||
logger.debug(f"Handshake completed for connection {dest_cid.hex()}")
|
print(
|
||||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
f"❌ EVENT: Connection terminated - code: {event.error_code}, reason: {event.reason_phrase}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Connection {dest_cid.hex()} from {addr} "
|
||||||
|
f"terminated: {event.reason_phrase}"
|
||||||
|
)
|
||||||
|
await self._remove_connection(dest_cid)
|
||||||
|
break
|
||||||
|
|
||||||
elif isinstance(event, events.StreamDataReceived):
|
elif isinstance(event, events.HandshakeCompleted):
|
||||||
# Forward to established connection if available
|
print(
|
||||||
if dest_cid in self._connections:
|
f"✅ EVENT: Handshake completed for connection {dest_cid.hex()}"
|
||||||
connection = self._connections[dest_cid]
|
)
|
||||||
await connection._handle_stream_data(event)
|
logger.debug(f"Handshake completed for connection {dest_cid.hex()}")
|
||||||
|
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||||
|
|
||||||
elif isinstance(event, events.StreamReset):
|
elif isinstance(event, events.StreamDataReceived):
|
||||||
# Forward to established connection if available
|
print(f"🔧 EVENT: Stream data received on stream {event.stream_id}")
|
||||||
if dest_cid in self._connections:
|
# Forward to established connection if available
|
||||||
connection = self._connections[dest_cid]
|
if dest_cid in self._connections:
|
||||||
await connection._handle_stream_reset(event)
|
connection = self._connections[dest_cid]
|
||||||
|
await connection._handle_stream_data(event)
|
||||||
|
|
||||||
|
elif isinstance(event, events.StreamReset):
|
||||||
|
print(f"🔧 EVENT: Stream reset on stream {event.stream_id}")
|
||||||
|
# Forward to established connection if available
|
||||||
|
if dest_cid in self._connections:
|
||||||
|
connection = self._connections[dest_cid]
|
||||||
|
await connection._handle_stream_reset(event)
|
||||||
|
|
||||||
|
elif isinstance(event, events.ConnectionIdIssued):
|
||||||
|
print(
|
||||||
|
f"🔧 EVENT: Connection ID issued: {event.connection_id.hex()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(event, events.ConnectionIdRetired):
|
||||||
|
print(
|
||||||
|
f"🔧 EVENT: Connection ID retired: {event.connection_id.hex()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"🔧 EVENT: Unhandled event type: {type(event).__name__}")
|
||||||
|
|
||||||
|
if events_processed == 0:
|
||||||
|
print("🔧 EVENT: No events to process")
|
||||||
|
else:
|
||||||
|
print(f"🔧 EVENT: Processed {events_processed} events total")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ EVENT: Error processing events: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _debug_quic_connection_state(
|
async def _debug_quic_connection_state(
|
||||||
self, quic_conn: QuicConnection, connection_id: bytes
|
self, quic_conn: QuicConnection, connection_id: bytes
|
||||||
@ -972,3 +1174,61 @@ class QUICListener(IListener):
|
|||||||
stats["active_connections"] = len(self._connections)
|
stats["active_connections"] = len(self._connections)
|
||||||
stats["pending_connections"] = len(self._pending_connections)
|
stats["pending_connections"] = len(self._pending_connections)
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
async def _debug_handshake_state(self, quic_conn: QuicConnection, dest_cid: bytes):
|
||||||
|
"""Debug why handshake might be stuck."""
|
||||||
|
try:
|
||||||
|
print(f"🔧 HANDSHAKE_DEBUG: Analyzing stuck handshake for {dest_cid.hex()}")
|
||||||
|
|
||||||
|
# Check TLS handshake state
|
||||||
|
if hasattr(quic_conn, "tls") and quic_conn.tls:
|
||||||
|
tls = quic_conn.tls
|
||||||
|
print(
|
||||||
|
f"🔧 HANDSHAKE_DEBUG: TLS state: {getattr(tls, 'state', 'Unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for TLS errors
|
||||||
|
if hasattr(tls, "_error") and tls._error:
|
||||||
|
print(f"❌ HANDSHAKE_DEBUG: TLS error: {tls._error}")
|
||||||
|
|
||||||
|
# Check certificate validation
|
||||||
|
if hasattr(tls, "_peer_certificate"):
|
||||||
|
if tls._peer_certificate:
|
||||||
|
print("✅ HANDSHAKE_DEBUG: Peer certificate received")
|
||||||
|
else:
|
||||||
|
print("❌ HANDSHAKE_DEBUG: No peer certificate")
|
||||||
|
|
||||||
|
# Check ALPN negotiation
|
||||||
|
if hasattr(tls, "_alpn_protocols"):
|
||||||
|
if tls._alpn_protocols:
|
||||||
|
print(
|
||||||
|
f"✅ HANDSHAKE_DEBUG: ALPN negotiated: {tls._alpn_protocols}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("❌ HANDSHAKE_DEBUG: No ALPN protocol negotiated")
|
||||||
|
|
||||||
|
# Check QUIC connection state
|
||||||
|
if hasattr(quic_conn, "_state"):
|
||||||
|
state = quic_conn._state
|
||||||
|
print(f"🔧 HANDSHAKE_DEBUG: QUIC state: {state}")
|
||||||
|
|
||||||
|
# Check specific states that might indicate problems
|
||||||
|
if "FIRSTFLIGHT" in str(state):
|
||||||
|
print("⚠️ HANDSHAKE_DEBUG: Connection stuck in FIRSTFLIGHT state")
|
||||||
|
elif "CONNECTED" in str(state):
|
||||||
|
print(
|
||||||
|
"⚠️ HANDSHAKE_DEBUG: Connection shows CONNECTED but handshake not complete"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for pending crypto data
|
||||||
|
if hasattr(quic_conn, "_cryptos") and quic_conn._cryptos:
|
||||||
|
print(f"🔧 HANDSHAKE_DEBUG: Crypto data present {len(quic_conn._cryptos.keys())}")
|
||||||
|
|
||||||
|
# Check loss detection state
|
||||||
|
if hasattr(quic_conn, "_loss") and quic_conn._loss:
|
||||||
|
loss_detection = quic_conn._loss
|
||||||
|
if hasattr(loss_detection, "_pto_count"):
|
||||||
|
print(f"🔧 HANDSHAKE_DEBUG: PTO count: {loss_detection._pto_count}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ HANDSHAKE_DEBUG: Error during debug: {e}")
|
||||||
|
|||||||
@ -4,9 +4,11 @@ Implements libp2p TLS specification for QUIC transport with peer identity integr
|
|||||||
Based on go-libp2p and js-libp2p security patterns.
|
Based on go-libp2p and js-libp2p security patterns.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
|
import ssl
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
@ -25,11 +27,6 @@ from .exceptions import (
|
|||||||
QUICPeerVerificationError,
|
QUICPeerVerificationError,
|
||||||
)
|
)
|
||||||
|
|
||||||
TSecurityConfig = dict[
|
|
||||||
str,
|
|
||||||
Certificate | EllipticCurvePrivateKey | RSAPrivateKey | bool | list[str],
|
|
||||||
]
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# libp2p TLS Extension OID - Official libp2p specification
|
# libp2p TLS Extension OID - Official libp2p specification
|
||||||
@ -312,7 +309,7 @@ class CertificateGenerator:
|
|||||||
x509.UnrecognizedExtension(
|
x509.UnrecognizedExtension(
|
||||||
oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data
|
oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data
|
||||||
),
|
),
|
||||||
critical=True, # This extension is critical for libp2p
|
critical=False,
|
||||||
)
|
)
|
||||||
.sign(cert_private_key, hashes.SHA256())
|
.sign(cert_private_key, hashes.SHA256())
|
||||||
)
|
)
|
||||||
@ -407,6 +404,269 @@ class PeerAuthenticator:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QUICTLSSecurityConfig:
|
||||||
|
"""
|
||||||
|
Type-safe TLS security configuration for QUIC transport.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core TLS components (required)
|
||||||
|
certificate: Certificate
|
||||||
|
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey]
|
||||||
|
|
||||||
|
# Certificate chain (optional)
|
||||||
|
certificate_chain: List[Certificate] = field(default_factory=list)
|
||||||
|
|
||||||
|
# ALPN protocols
|
||||||
|
alpn_protocols: List[str] = field(default_factory=lambda: ["libp2p"])
|
||||||
|
|
||||||
|
# TLS verification settings
|
||||||
|
verify_mode: Union[bool, ssl.VerifyMode] = False
|
||||||
|
check_hostname: bool = False
|
||||||
|
|
||||||
|
# Optional peer ID for validation
|
||||||
|
peer_id: Optional[ID] = None
|
||||||
|
|
||||||
|
# Configuration metadata
|
||||||
|
is_client_config: bool = False
|
||||||
|
config_name: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate configuration after initialization."""
|
||||||
|
self._validate()
|
||||||
|
|
||||||
|
def _validate(self) -> None:
|
||||||
|
"""Validate the TLS configuration."""
|
||||||
|
if self.certificate is None:
|
||||||
|
raise ValueError("Certificate is required")
|
||||||
|
|
||||||
|
if self.private_key is None:
|
||||||
|
raise ValueError("Private key is required")
|
||||||
|
|
||||||
|
if not isinstance(self.certificate, x509.Certificate):
|
||||||
|
raise TypeError(
|
||||||
|
f"Certificate must be x509.Certificate, got {type(self.certificate)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(
|
||||||
|
self.private_key, (ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey)
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
f"Private key must be EC or RSA key, got {type(self.private_key)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.alpn_protocols:
|
||||||
|
raise ValueError("At least one ALPN protocol is required")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""
|
||||||
|
Convert to dictionary format for compatibility with existing code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary compatible with the original TSecurityConfig format
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"certificate": self.certificate,
|
||||||
|
"private_key": self.private_key,
|
||||||
|
"certificate_chain": self.certificate_chain.copy(),
|
||||||
|
"alpn_protocols": self.alpn_protocols.copy(),
|
||||||
|
"verify_mode": self.verify_mode,
|
||||||
|
"check_hostname": self.check_hostname,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict: dict, **kwargs) -> "QUICTLSSecurityConfig":
|
||||||
|
"""
|
||||||
|
Create instance from dictionary format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: Dictionary in TSecurityConfig format
|
||||||
|
**kwargs: Additional parameters for the config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QUICTLSSecurityConfig instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
certificate=config_dict["certificate"],
|
||||||
|
private_key=config_dict["private_key"],
|
||||||
|
certificate_chain=config_dict.get("certificate_chain", []),
|
||||||
|
alpn_protocols=config_dict.get("alpn_protocols", ["libp2p"]),
|
||||||
|
verify_mode=config_dict.get("verify_mode", False),
|
||||||
|
check_hostname=config_dict.get("check_hostname", False),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_certificate_key_match(self) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that the certificate and private key match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if certificate and private key match
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
# Get public keys from both certificate and private key
|
||||||
|
cert_public_key = self.certificate.public_key()
|
||||||
|
private_public_key = self.private_key.public_key()
|
||||||
|
|
||||||
|
# Compare their PEM representations
|
||||||
|
cert_pub_pem = cert_public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
private_pub_pem = private_public_key.public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cert_pub_pem == private_pub_pem
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def has_libp2p_extension(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the certificate has the required libp2p extension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if libp2p extension is present
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
libp2p_oid = "1.3.6.1.4.1.53594.1.1"
|
||||||
|
for ext in self.certificate.extensions:
|
||||||
|
if str(ext.oid) == libp2p_oid:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_certificate_valid(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the certificate is currently valid (not expired).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if certificate is valid
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
return (
|
||||||
|
self.certificate.not_valid_before
|
||||||
|
<= now
|
||||||
|
<= self.certificate.not_valid_after
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_certificate_info(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get certificate information for debugging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with certificate details
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"subject": str(self.certificate.subject),
|
||||||
|
"issuer": str(self.certificate.issuer),
|
||||||
|
"serial_number": self.certificate.serial_number,
|
||||||
|
"not_valid_before": self.certificate.not_valid_before,
|
||||||
|
"not_valid_after": self.certificate.not_valid_after,
|
||||||
|
"has_libp2p_extension": self.has_libp2p_extension(),
|
||||||
|
"is_valid": self.is_certificate_valid(),
|
||||||
|
"certificate_key_match": self.validate_certificate_key_match(),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
def debug_print(self) -> None:
|
||||||
|
"""Print debugging information about this configuration."""
|
||||||
|
print(f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===")
|
||||||
|
print(f"Is client config: {self.is_client_config}")
|
||||||
|
print(f"ALPN protocols: {self.alpn_protocols}")
|
||||||
|
print(f"Verify mode: {self.verify_mode}")
|
||||||
|
print(f"Check hostname: {self.check_hostname}")
|
||||||
|
print(f"Certificate chain length: {len(self.certificate_chain)}")
|
||||||
|
|
||||||
|
cert_info = self.get_certificate_info()
|
||||||
|
for key, value in cert_info.items():
|
||||||
|
print(f"Certificate {key}: {value}")
|
||||||
|
|
||||||
|
print(f"Private key type: {type(self.private_key).__name__}")
|
||||||
|
if hasattr(self.private_key, "key_size"):
|
||||||
|
print(f"Private key size: {self.private_key.key_size}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_server_tls_config(
|
||||||
|
certificate: Certificate,
|
||||||
|
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||||
|
peer_id: Optional[ID] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> QUICTLSSecurityConfig:
|
||||||
|
"""
|
||||||
|
Create a server TLS configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
certificate: X.509 certificate
|
||||||
|
private_key: Private key corresponding to certificate
|
||||||
|
peer_id: Optional peer ID for validation
|
||||||
|
**kwargs: Additional configuration parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Server TLS configuration
|
||||||
|
|
||||||
|
"""
|
||||||
|
return QUICTLSSecurityConfig(
|
||||||
|
certificate=certificate,
|
||||||
|
private_key=private_key,
|
||||||
|
peer_id=peer_id,
|
||||||
|
is_client_config=False,
|
||||||
|
config_name="server",
|
||||||
|
verify_mode=False, # Server doesn't verify client certs in libp2p
|
||||||
|
check_hostname=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_client_tls_config(
|
||||||
|
certificate: Certificate,
|
||||||
|
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||||
|
peer_id: Optional[ID] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> QUICTLSSecurityConfig:
|
||||||
|
"""
|
||||||
|
Create a client TLS configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
certificate: X.509 certificate
|
||||||
|
private_key: Private key corresponding to certificate
|
||||||
|
peer_id: Optional peer ID for validation
|
||||||
|
**kwargs: Additional configuration parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Client TLS configuration
|
||||||
|
|
||||||
|
"""
|
||||||
|
return QUICTLSSecurityConfig(
|
||||||
|
certificate=certificate,
|
||||||
|
private_key=private_key,
|
||||||
|
peer_id=peer_id,
|
||||||
|
is_client_config=True,
|
||||||
|
config_name="client",
|
||||||
|
verify_mode=False, # Client doesn't verify server certs in libp2p
|
||||||
|
check_hostname=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QUICTLSConfigManager:
|
class QUICTLSConfigManager:
|
||||||
"""
|
"""
|
||||||
Manages TLS configuration for QUIC transport with libp2p security.
|
Manages TLS configuration for QUIC transport with libp2p security.
|
||||||
@ -424,44 +684,40 @@ class QUICTLSConfigManager:
|
|||||||
libp2p_private_key, peer_id
|
libp2p_private_key, peer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_server_config(
|
def create_server_config(self) -> QUICTLSSecurityConfig:
|
||||||
self,
|
|
||||||
) -> TSecurityConfig:
|
|
||||||
"""
|
"""
|
||||||
Create aioquic server configuration with libp2p TLS settings.
|
Create server configuration using the new class-based approach.
|
||||||
Returns cryptography objects instead of DER bytes.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configuration dictionary for aioquic QuicConfiguration
|
QUICTLSSecurityConfig instance for server
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config: TSecurityConfig = {
|
config = create_server_tls_config(
|
||||||
"certificate": self.tls_config.certificate,
|
certificate=self.tls_config.certificate,
|
||||||
"private_key": self.tls_config.private_key,
|
private_key=self.tls_config.private_key,
|
||||||
"certificate_chain": [],
|
peer_id=self.peer_id,
|
||||||
"alpn_protocols": ["libp2p"],
|
)
|
||||||
"verify_mode": False,
|
|
||||||
"check_hostname": False,
|
print("🔧 SECURITY: Created server config")
|
||||||
}
|
config.debug_print()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def create_client_config(self) -> TSecurityConfig:
|
def create_client_config(self) -> QUICTLSSecurityConfig:
|
||||||
"""
|
"""
|
||||||
Create aioquic client configuration with libp2p TLS settings.
|
Create client configuration using the new class-based approach.
|
||||||
Returns cryptography objects instead of DER bytes.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configuration dictionary for aioquic QuicConfiguration
|
QUICTLSSecurityConfig instance for client
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config: TSecurityConfig = {
|
config = create_client_tls_config(
|
||||||
"certificate": self.tls_config.certificate,
|
certificate=self.tls_config.certificate,
|
||||||
"private_key": self.tls_config.private_key,
|
private_key=self.tls_config.private_key,
|
||||||
"certificate_chain": [],
|
peer_id=self.peer_id,
|
||||||
"alpn_protocols": ["libp2p"],
|
)
|
||||||
"verify_mode": False,
|
|
||||||
"check_hostname": False,
|
print("🔧 SECURITY: Created client config")
|
||||||
}
|
config.debug_print()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def verify_peer_identity(
|
def verify_peer_identity(
|
||||||
|
|||||||
@ -5,7 +5,6 @@ Based on aioquic library with interface consistency to go-libp2p and js-libp2p.
|
|||||||
Updated to include Module 5 security integration.
|
Updated to include Module 5 security integration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Iterable
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
@ -31,7 +30,7 @@ from libp2p.custom_types import THandler, TProtocol
|
|||||||
from libp2p.peer.id import (
|
from libp2p.peer.id import (
|
||||||
ID,
|
ID,
|
||||||
)
|
)
|
||||||
from libp2p.transport.quic.security import TSecurityConfig
|
from libp2p.transport.quic.security import QUICTLSSecurityConfig
|
||||||
from libp2p.transport.quic.utils import (
|
from libp2p.transport.quic.utils import (
|
||||||
get_alpn_protocols,
|
get_alpn_protocols,
|
||||||
is_quic_multiaddr,
|
is_quic_multiaddr,
|
||||||
@ -192,7 +191,7 @@ class QUICTransport(ITransport):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
def _apply_tls_configuration(
|
def _apply_tls_configuration(
|
||||||
self, config: QuicConfiguration, tls_config: TSecurityConfig
|
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
||||||
@ -203,52 +202,47 @@ class QUICTransport(ITransport):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Set certificate and private key directly on the configuration
|
|
||||||
# aioquic expects cryptography objects, not DER bytes
|
|
||||||
if "certificate" in tls_config and "private_key" in tls_config:
|
|
||||||
# The security manager should return cryptography objects
|
|
||||||
# not DER bytes, but if it returns DER bytes, we need to handle that
|
|
||||||
certificate = tls_config["certificate"]
|
|
||||||
private_key = tls_config["private_key"]
|
|
||||||
|
|
||||||
# Check if we received DER bytes and need
|
# The security manager should return cryptography objects
|
||||||
# to convert to cryptography objects
|
# not DER bytes, but if it returns DER bytes, we need to handle that
|
||||||
if isinstance(certificate, bytes):
|
certificate = tls_config.certificate
|
||||||
|
private_key = tls_config.private_key
|
||||||
|
|
||||||
|
# Check if we received DER bytes and need
|
||||||
|
# to convert to cryptography objects
|
||||||
|
if isinstance(certificate, bytes):
|
||||||
|
from cryptography import x509
|
||||||
|
|
||||||
|
certificate = x509.load_der_x509_certificate(certificate)
|
||||||
|
|
||||||
|
if isinstance(private_key, bytes):
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
|
||||||
|
private_key = serialization.load_der_private_key( # type: ignore
|
||||||
|
private_key, password=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set directly on the configuration object
|
||||||
|
config.certificate = certificate
|
||||||
|
config.private_key = private_key
|
||||||
|
|
||||||
|
# Handle certificate chain if provided
|
||||||
|
certificate_chain = tls_config.certificate_chain
|
||||||
|
# Convert DER bytes to cryptography objects if needed
|
||||||
|
chain_objects = []
|
||||||
|
for cert in certificate_chain:
|
||||||
|
if isinstance(cert, bytes):
|
||||||
from cryptography import x509
|
from cryptography import x509
|
||||||
|
|
||||||
certificate = x509.load_der_x509_certificate(certificate)
|
cert = x509.load_der_x509_certificate(cert)
|
||||||
|
chain_objects.append(cert)
|
||||||
if isinstance(private_key, bytes):
|
config.certificate_chain = chain_objects
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
|
|
||||||
private_key = serialization.load_der_private_key( # type: ignore
|
|
||||||
private_key, password=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set directly on the configuration object
|
|
||||||
config.certificate = certificate
|
|
||||||
config.private_key = private_key
|
|
||||||
|
|
||||||
# Handle certificate chain if provided
|
|
||||||
certificate_chain = tls_config.get("certificate_chain", [])
|
|
||||||
if certificate_chain and isinstance(certificate_chain, Iterable):
|
|
||||||
# Convert DER bytes to cryptography objects if needed
|
|
||||||
chain_objects = []
|
|
||||||
for cert in certificate_chain:
|
|
||||||
if isinstance(cert, bytes):
|
|
||||||
from cryptography import x509
|
|
||||||
|
|
||||||
cert = x509.load_der_x509_certificate(cert)
|
|
||||||
chain_objects.append(cert)
|
|
||||||
config.certificate_chain = chain_objects
|
|
||||||
|
|
||||||
# Set ALPN protocols
|
# Set ALPN protocols
|
||||||
if "alpn_protocols" in tls_config:
|
config.alpn_protocols = tls_config.alpn_protocols
|
||||||
config.alpn_protocols = tls_config["alpn_protocols"] # type: ignore
|
|
||||||
|
|
||||||
# Set certificate verification mode
|
# Set certificate verification mode
|
||||||
if "verify_mode" in tls_config:
|
config.verify_mode = tls_config.verify_mode
|
||||||
config.verify_mode = tls_config["verify_mode"] # type: ignore
|
|
||||||
|
|
||||||
logger.debug("Successfully applied TLS configuration to QUIC config")
|
logger.debug("Successfully applied TLS configuration to QUIC config")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user