mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: duplication connection creation for same sessions
This commit is contained in:
289
examples/echo/test_quic.py
Normal file
289
examples/echo/test_quic.py
Normal file
@ -0,0 +1,289 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fixed QUIC handshake test to debug connection issues.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.transport.quic.transport import QUICTransport, QUICTransportConfig
|
||||
from libp2p.transport.quic.utils import create_quic_multiaddr
|
||||
|
||||
# Adjust this path to your project structure
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
|
||||
async def test_certificate_generation():
|
||||
"""Test certificate generation in isolation."""
|
||||
print("\n=== TESTING CERTIFICATE GENERATION ===")
|
||||
|
||||
try:
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
|
||||
# Create key pair
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
print(f"Generated peer ID: {peer_id}")
|
||||
|
||||
# Create security manager
|
||||
security_manager = create_quic_security_transport(private_key, peer_id)
|
||||
print("✅ Security manager created")
|
||||
|
||||
# Test server config
|
||||
server_config = security_manager.create_server_config()
|
||||
print("✅ Server config created")
|
||||
|
||||
# Validate certificate
|
||||
cert = server_config.certificate
|
||||
private_key_obj = server_config.private_key
|
||||
|
||||
print(f"Certificate type: {type(cert)}")
|
||||
print(f"Private key type: {type(private_key_obj)}")
|
||||
print(f"Certificate subject: {cert.subject}")
|
||||
print(f"Certificate issuer: {cert.issuer}")
|
||||
|
||||
# Check for libp2p extension
|
||||
has_libp2p_ext = False
|
||||
for ext in cert.extensions:
|
||||
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
|
||||
has_libp2p_ext = True
|
||||
print(f"✅ Found libp2p extension: {ext.oid}")
|
||||
print(f"Extension critical: {ext.critical}")
|
||||
print(f"Extension value length: {len(ext.value)} bytes")
|
||||
break
|
||||
|
||||
if not has_libp2p_ext:
|
||||
print("❌ No libp2p extension found!")
|
||||
print("Available extensions:")
|
||||
for ext in cert.extensions:
|
||||
print(f" - {ext.oid} (critical: {ext.critical})")
|
||||
|
||||
# Check certificate/key match
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
cert_public_key = cert.public_key()
|
||||
private_public_key = private_key_obj.public_key()
|
||||
|
||||
cert_pub_bytes = cert_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
private_pub_bytes = private_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
if cert_pub_bytes == private_pub_bytes:
|
||||
print("✅ Certificate and private key match")
|
||||
return has_libp2p_ext
|
||||
else:
|
||||
print("❌ Certificate and private key DO NOT match")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Certificate test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_basic_quic_connection():
|
||||
"""Test basic QUIC connection with proper server setup."""
|
||||
print("\n=== TESTING BASIC QUIC CONNECTION ===")
|
||||
|
||||
try:
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
|
||||
# Create certificates
|
||||
server_key = create_new_key_pair().private_key
|
||||
server_peer_id = ID.from_pubkey(server_key.get_public_key())
|
||||
server_security = create_quic_security_transport(server_key, server_peer_id)
|
||||
|
||||
client_key = create_new_key_pair().private_key
|
||||
client_peer_id = ID.from_pubkey(client_key.get_public_key())
|
||||
client_security = create_quic_security_transport(client_key, client_peer_id)
|
||||
|
||||
# Create server config
|
||||
server_tls_config = server_security.create_server_config()
|
||||
server_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
certificate=server_tls_config.certificate,
|
||||
private_key=server_tls_config.private_key,
|
||||
alpn_protocols=["libp2p"],
|
||||
)
|
||||
|
||||
# Create client config
|
||||
client_tls_config = client_security.create_client_config()
|
||||
client_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
certificate=client_tls_config.certificate,
|
||||
private_key=client_tls_config.private_key,
|
||||
alpn_protocols=["libp2p"],
|
||||
)
|
||||
|
||||
print("✅ QUIC configurations created")
|
||||
|
||||
# Test creating connections with proper parameters
|
||||
# For server, we need to provide original_destination_connection_id
|
||||
original_dcid = secrets.token_bytes(8)
|
||||
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=original_dcid,
|
||||
)
|
||||
|
||||
# For client, no original_destination_connection_id needed
|
||||
client_conn = QuicConnection(configuration=client_config)
|
||||
|
||||
print("✅ QUIC connections created")
|
||||
print(f"Server state: {server_conn._state}")
|
||||
print(f"Client state: {client_conn._state}")
|
||||
|
||||
# Test that certificates are valid
|
||||
print(f"Server has certificate: {server_config.certificate is not None}")
|
||||
print(f"Server has private key: {server_config.private_key is not None}")
|
||||
print(f"Client has certificate: {client_config.certificate is not None}")
|
||||
print(f"Client has private key: {client_config.private_key is not None}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Basic QUIC test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_server_startup():
|
||||
"""Test server startup with timeout."""
|
||||
print("\n=== TESTING SERVER STARTUP ===")
|
||||
|
||||
try:
|
||||
# Create transport
|
||||
private_key = create_new_key_pair().private_key
|
||||
config = QUICTransportConfig(
|
||||
idle_timeout=10.0, # Reduced timeout for testing
|
||||
connection_timeout=10.0,
|
||||
enable_draft29=False,
|
||||
)
|
||||
|
||||
transport = QUICTransport(private_key, config)
|
||||
print("✅ Transport created successfully")
|
||||
|
||||
# Test configuration
|
||||
print(f"Available configs: {list(transport._quic_configs.keys())}")
|
||||
|
||||
config_valid = True
|
||||
for config_key, quic_config in transport._quic_configs.items():
|
||||
print(f"\n--- Testing config: {config_key} ---")
|
||||
print(f"is_client: {quic_config.is_client}")
|
||||
print(f"has_certificate: {quic_config.certificate is not None}")
|
||||
print(f"has_private_key: {quic_config.private_key is not None}")
|
||||
print(f"alpn_protocols: {quic_config.alpn_protocols}")
|
||||
print(f"verify_mode: {quic_config.verify_mode}")
|
||||
|
||||
if quic_config.certificate:
|
||||
cert = quic_config.certificate
|
||||
print(f"Certificate subject: {cert.subject}")
|
||||
|
||||
# Check for libp2p extension
|
||||
has_libp2p_ext = False
|
||||
for ext in cert.extensions:
|
||||
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
|
||||
has_libp2p_ext = True
|
||||
break
|
||||
print(f"Has libp2p extension: {has_libp2p_ext}")
|
||||
|
||||
if not has_libp2p_ext:
|
||||
config_valid = False
|
||||
|
||||
if not config_valid:
|
||||
print("❌ Transport configuration invalid - missing libp2p extensions")
|
||||
return False
|
||||
|
||||
# Create listener
|
||||
async def dummy_handler(connection):
|
||||
print(f"New connection: {connection}")
|
||||
|
||||
listener = transport.create_listener(dummy_handler)
|
||||
print("✅ Listener created successfully")
|
||||
|
||||
# Try to bind with timeout
|
||||
maddr = create_quic_multiaddr("127.0.0.1", 0, "quic-v1")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
result = await listener.listen(maddr, nursery)
|
||||
if result:
|
||||
print("✅ Server bound successfully")
|
||||
addresses = listener.get_addresses()
|
||||
print(f"Listening on: {addresses}")
|
||||
|
||||
# Keep running for a short time
|
||||
with trio.move_on_after(3.0): # 3 second timeout
|
||||
await trio.sleep(5.0)
|
||||
|
||||
print("✅ Server test completed (timed out normally)")
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to bind server")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Server test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests with better error handling."""
|
||||
print("Starting QUIC diagnostic tests...")
|
||||
|
||||
# Test 1: Certificate generation
|
||||
cert_ok = await test_certificate_generation()
|
||||
if not cert_ok:
|
||||
print("\n❌ CRITICAL: Certificate generation failed!")
|
||||
print("Apply the certificate generation fix and try again.")
|
||||
return
|
||||
|
||||
# Test 2: Basic QUIC connection
|
||||
quic_ok = await test_basic_quic_connection()
|
||||
if not quic_ok:
|
||||
print("\n❌ CRITICAL: Basic QUIC connection test failed!")
|
||||
return
|
||||
|
||||
# Test 3: Server startup
|
||||
server_ok = await test_server_startup()
|
||||
if not server_ok:
|
||||
print("\n❌ Server startup test failed!")
|
||||
return
|
||||
|
||||
print("\n✅ ALL TESTS PASSED!")
|
||||
print("=== DIAGNOSTIC COMPLETE ===")
|
||||
print("Your QUIC implementation should now work correctly.")
|
||||
print("Try running your echo example again.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
@ -249,23 +249,35 @@ class QUICListener(IListener):
|
||||
|
||||
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""
|
||||
Enhanced packet processing with connection ID routing and version negotiation.
|
||||
FIXED: Added address-based connection reuse to prevent multiple connections.
|
||||
Enhanced packet processing with better connection ID routing and debugging.
|
||||
"""
|
||||
try:
|
||||
self._stats["packets_processed"] += 1
|
||||
self._stats["bytes_received"] += len(data)
|
||||
|
||||
print(f"🔧 PACKET: Processing {len(data)} bytes from {addr}")
|
||||
|
||||
# Parse packet to extract connection information
|
||||
packet_info = self.parse_quic_packet(data)
|
||||
|
||||
print(f"🔧 DEBUG: Address mappings: {self._addr_to_cid}")
|
||||
print(
|
||||
f"🔧 DEBUG: Pending connections: {list(self._pending_connections.keys())}"
|
||||
f"🔧 DEBUG: Address mappings: {dict((k, v.hex()) for k, v in self._addr_to_cid.items())}"
|
||||
)
|
||||
print(
|
||||
f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}"
|
||||
)
|
||||
print(
|
||||
f"🔧 DEBUG: Established connections: {[cid.hex() for cid in self._connections.keys()]}"
|
||||
)
|
||||
|
||||
async with self._connection_lock:
|
||||
if packet_info:
|
||||
print(
|
||||
f"🔧 PACKET: Parsed packet - version: 0x{packet_info.version:08x}, "
|
||||
f"dest_cid: {packet_info.destination_cid.hex()}, "
|
||||
f"src_cid: {packet_info.source_cid.hex()}"
|
||||
)
|
||||
|
||||
# Check for version negotiation
|
||||
if packet_info.version == 0:
|
||||
logger.warning(
|
||||
@ -275,6 +287,9 @@ class QUICListener(IListener):
|
||||
|
||||
# Check if version is supported
|
||||
if packet_info.version not in self._supported_versions:
|
||||
print(
|
||||
f"❌ PACKET: Unsupported version 0x{packet_info.version:08x}"
|
||||
)
|
||||
await self._send_version_negotiation(
|
||||
addr, packet_info.source_cid
|
||||
)
|
||||
@ -283,87 +298,66 @@ class QUICListener(IListener):
|
||||
# Route based on destination connection ID
|
||||
dest_cid = packet_info.destination_cid
|
||||
|
||||
# First, try exact connection ID match
|
||||
if dest_cid in self._connections:
|
||||
# Existing established connection
|
||||
print(f"🔧 ROUTING: To established connection {dest_cid.hex()}")
|
||||
print(
|
||||
f"✅ PACKET: Routing to established connection {dest_cid.hex()}"
|
||||
)
|
||||
connection = self._connections[dest_cid]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
return
|
||||
|
||||
elif dest_cid in self._pending_connections:
|
||||
# Existing pending connection
|
||||
print(f"🔧 ROUTING: To pending connection {dest_cid.hex()}")
|
||||
print(
|
||||
f"✅ PACKET: Routing to pending connection {dest_cid.hex()}"
|
||||
)
|
||||
quic_conn = self._pending_connections[dest_cid]
|
||||
await self._handle_pending_connection(
|
||||
quic_conn, data, addr, dest_cid
|
||||
)
|
||||
return
|
||||
|
||||
else:
|
||||
# CRITICAL FIX: Check for existing connection by address BEFORE creating new
|
||||
existing_cid = self._addr_to_cid.get(addr)
|
||||
# If no exact match, try address-based routing (connection ID might not match)
|
||||
mapped_cid = self._addr_to_cid.get(addr)
|
||||
if mapped_cid:
|
||||
print(
|
||||
f"🔧 PACKET: Found address mapping {addr} -> {mapped_cid.hex()}"
|
||||
)
|
||||
print(
|
||||
f"🔧 PACKET: Client dest_cid {dest_cid.hex()} != our cid {mapped_cid.hex()}"
|
||||
)
|
||||
|
||||
if existing_cid is not None:
|
||||
if mapped_cid in self._connections:
|
||||
print(
|
||||
f"✅ FOUND: Existing connection {existing_cid.hex()} for address {addr}"
|
||||
"✅ PACKET: Using established connection via address mapping"
|
||||
)
|
||||
connection = self._connections[mapped_cid]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
return
|
||||
elif mapped_cid in self._pending_connections:
|
||||
print(
|
||||
f"🔧 NOTE: Client dest_cid {dest_cid.hex()} != our cid {existing_cid.hex()}"
|
||||
"✅ PACKET: Using pending connection via address mapping"
|
||||
)
|
||||
quic_conn = self._pending_connections[mapped_cid]
|
||||
await self._handle_pending_connection(
|
||||
quic_conn, data, addr, mapped_cid
|
||||
)
|
||||
return
|
||||
|
||||
# Route to existing connection by address
|
||||
if existing_cid in self._pending_connections:
|
||||
print(
|
||||
"🔧 ROUTING: Using existing pending connection by address"
|
||||
)
|
||||
quic_conn = self._pending_connections[existing_cid]
|
||||
await self._handle_pending_connection(
|
||||
quic_conn, data, addr, existing_cid
|
||||
)
|
||||
elif existing_cid in self._connections:
|
||||
print(
|
||||
"🔧 ROUTING: Using existing established connection by address"
|
||||
)
|
||||
connection = self._connections[existing_cid]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
else:
|
||||
print(
|
||||
f"❌ ERROR: Address mapping exists but connection {existing_cid.hex()} not found!"
|
||||
)
|
||||
# Clean up broken mapping and create new
|
||||
self._addr_to_cid.pop(addr, None)
|
||||
if packet_info.packet_type == 0: # Initial packet
|
||||
print(
|
||||
"🔧 NEW: Creating new connection after cleanup"
|
||||
)
|
||||
await self._handle_new_connection(
|
||||
data, addr, packet_info
|
||||
)
|
||||
# No existing connection found, create new one
|
||||
print(f"🔧 PACKET: Creating new connection for {addr}")
|
||||
await self._handle_new_connection(data, addr, packet_info)
|
||||
|
||||
else:
|
||||
# Truly new connection - only handle Initial packets
|
||||
if packet_info.packet_type == 0: # Initial packet
|
||||
print(f"🔧 NEW: Creating first connection for {addr}")
|
||||
await self._handle_new_connection(
|
||||
data, addr, packet_info
|
||||
)
|
||||
|
||||
# Debug the newly created connection
|
||||
new_cid = self._addr_to_cid.get(addr)
|
||||
if new_cid and new_cid in self._pending_connections:
|
||||
quic_conn = self._pending_connections[new_cid]
|
||||
await self._debug_quic_connection_state(
|
||||
quic_conn, new_cid
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Ignoring non-Initial packet for unknown connection ID from {addr}"
|
||||
)
|
||||
else:
|
||||
# Fallback to address-based routing for short header packets
|
||||
# Failed to parse packet
|
||||
print(f"❌ PACKET: Failed to parse packet from {addr}")
|
||||
await self._handle_short_header_packet(data, addr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing packet from {addr}: {e}")
|
||||
self._stats["invalid_packets"] += 1
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _send_version_negotiation(
|
||||
self, addr: tuple[str, int], source_cid: bytes
|
||||
@ -404,29 +398,31 @@ class QUICListener(IListener):
|
||||
logger.error(f"Failed to send version negotiation to {addr}: {e}")
|
||||
|
||||
async def _handle_new_connection(
|
||||
self,
|
||||
data: bytes,
|
||||
addr: tuple[str, int],
|
||||
packet_info: QUICPacketInfo,
|
||||
self, data: bytes, addr: tuple[str, int], packet_info: QUICPacketInfo
|
||||
) -> None:
|
||||
"""
|
||||
Handle new connection with proper version negotiation.
|
||||
"""
|
||||
"""Handle new connection with proper connection ID handling."""
|
||||
try:
|
||||
print(f"🔧 NEW_CONN: Starting handshake for {addr}")
|
||||
|
||||
# Find appropriate QUIC configuration
|
||||
quic_config = None
|
||||
config_key = None
|
||||
|
||||
for protocol, config in self._quic_configs.items():
|
||||
wire_versions = custom_quic_version_to_wire_format(protocol)
|
||||
if wire_versions == packet_info.version:
|
||||
quic_config = config
|
||||
config_key = protocol
|
||||
break
|
||||
|
||||
if not quic_config:
|
||||
logger.warning(
|
||||
f"No configuration found for version {packet_info.version:08x}"
|
||||
)
|
||||
print(f"❌ NEW_CONN: No configuration found for version 0x{packet_info.version:08x}")
|
||||
print(f"🔧 NEW_CONN: Available configs: {list(self._quic_configs.keys())}")
|
||||
await self._send_version_negotiation(addr, packet_info.source_cid)
|
||||
return
|
||||
|
||||
print(f"✅ NEW_CONN: Using config {config_key} for version 0x{packet_info.version:08x}")
|
||||
|
||||
# Create server-side QUIC configuration
|
||||
server_config = create_server_config_from_base(
|
||||
base_config=quic_config,
|
||||
@ -434,39 +430,158 @@ class QUICListener(IListener):
|
||||
transport_config=self._config,
|
||||
)
|
||||
|
||||
# Generate a new destination connection ID for this connection
|
||||
# In a real implementation, this should be cryptographically secure
|
||||
import secrets
|
||||
# Debug the server configuration
|
||||
print(f"🔧 NEW_CONN: Server config - is_client: {server_config.is_client}")
|
||||
print(f"🔧 NEW_CONN: Server config - has_certificate: {server_config.certificate is not None}")
|
||||
print(f"🔧 NEW_CONN: Server config - has_private_key: {server_config.private_key is not None}")
|
||||
print(f"🔧 NEW_CONN: Server config - ALPN: {server_config.alpn_protocols}")
|
||||
print(f"🔧 NEW_CONN: Server config - verify_mode: {server_config.verify_mode}")
|
||||
|
||||
# Validate certificate has libp2p extension
|
||||
if server_config.certificate:
|
||||
cert = server_config.certificate
|
||||
has_libp2p_ext = False
|
||||
for ext in cert.extensions:
|
||||
if str(ext.oid) == "1.3.6.1.4.1.53594.1.1":
|
||||
has_libp2p_ext = True
|
||||
break
|
||||
print(f"🔧 NEW_CONN: Certificate has libp2p extension: {has_libp2p_ext}")
|
||||
|
||||
if not has_libp2p_ext:
|
||||
print("❌ NEW_CONN: Certificate missing libp2p extension!")
|
||||
|
||||
# Generate a new destination connection ID for this connection
|
||||
import secrets
|
||||
destination_cid = secrets.token_bytes(8)
|
||||
|
||||
# Create QUIC connection with specific version
|
||||
print(f"🔧 NEW_CONN: Generated new CID: {destination_cid.hex()}")
|
||||
print(f"🔧 NEW_CONN: Original destination CID: {packet_info.destination_cid.hex()}")
|
||||
|
||||
# Create QUIC connection with proper parameters for server
|
||||
# CRITICAL FIX: Pass the original destination connection ID from the initial packet
|
||||
quic_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=packet_info.destination_cid,
|
||||
original_destination_connection_id=packet_info.destination_cid, # Use the original DCID from packet
|
||||
)
|
||||
|
||||
# Store connection mapping
|
||||
print("✅ NEW_CONN: QUIC connection created successfully")
|
||||
|
||||
# Store connection mapping using our generated CID
|
||||
self._pending_connections[destination_cid] = quic_conn
|
||||
self._addr_to_cid[addr] = destination_cid
|
||||
self._cid_to_addr[destination_cid] = addr
|
||||
|
||||
print(f"🔧 NEW_CONN: Stored mappings for {addr} <-> {destination_cid.hex()}")
|
||||
print("Receiving Datagram")
|
||||
|
||||
# Process initial packet
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
|
||||
# Debug connection state after receiving packet
|
||||
await self._debug_quic_connection_state_detailed(quic_conn, destination_cid)
|
||||
|
||||
# Process events and send response
|
||||
await self._process_quic_events(quic_conn, addr, destination_cid)
|
||||
await self._transmit_for_connection(quic_conn, addr)
|
||||
|
||||
logger.debug(
|
||||
f"Started handshake for new connection from {addr} "
|
||||
f"(version: {packet_info.version:08x}, cid: {destination_cid.hex()})"
|
||||
f"(version: 0x{packet_info.version:08x}, cid: {destination_cid.hex()})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling new connection from {addr}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
self._stats["connections_rejected"] += 1
|
||||
|
||||
async def _debug_quic_connection_state_detailed(
|
||||
self, quic_conn: QuicConnection, connection_id: bytes
|
||||
):
|
||||
"""Enhanced connection state debugging."""
|
||||
try:
|
||||
print(f"🔧 QUIC_STATE: Debugging connection {connection_id.hex()}")
|
||||
|
||||
if not quic_conn:
|
||||
print("❌ QUIC_STATE: QUIC CONNECTION NOT FOUND")
|
||||
return
|
||||
|
||||
# Check TLS state
|
||||
if hasattr(quic_conn, "tls") and quic_conn.tls:
|
||||
print("✅ QUIC_STATE: TLS context exists")
|
||||
if hasattr(quic_conn.tls, "state"):
|
||||
print(f"🔧 QUIC_STATE: TLS state: {quic_conn.tls.state}")
|
||||
|
||||
# Check if we have peer certificate
|
||||
if (
|
||||
hasattr(quic_conn.tls, "_peer_certificate")
|
||||
and quic_conn.tls._peer_certificate
|
||||
):
|
||||
print("✅ QUIC_STATE: Peer certificate available")
|
||||
else:
|
||||
print("🔧 QUIC_STATE: No peer certificate yet")
|
||||
|
||||
# Check TLS handshake completion
|
||||
if hasattr(quic_conn.tls, "handshake_complete"):
|
||||
handshake_status = quic_conn._handshake_complete
|
||||
print(
|
||||
f"🔧 QUIC_STATE: TLS handshake complete: {handshake_status}"
|
||||
)
|
||||
else:
|
||||
print("❌ QUIC_STATE: No TLS context!")
|
||||
|
||||
# Check connection state
|
||||
if hasattr(quic_conn, "_state"):
|
||||
print(f"🔧 QUIC_STATE: Connection state: {quic_conn._state}")
|
||||
|
||||
# Check if handshake is complete
|
||||
if hasattr(quic_conn, "_handshake_complete"):
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Handshake complete: {quic_conn._handshake_complete}"
|
||||
)
|
||||
|
||||
# Check configuration
|
||||
if hasattr(quic_conn, "configuration"):
|
||||
config = quic_conn.configuration
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Config certificate: {config.certificate is not None}"
|
||||
)
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Config private_key: {config.private_key is not None}"
|
||||
)
|
||||
print(f"🔧 QUIC_STATE: Config is_client: {config.is_client}")
|
||||
print(f"🔧 QUIC_STATE: Config verify_mode: {config.verify_mode}")
|
||||
print(f"🔧 QUIC_STATE: Config ALPN: {config.alpn_protocols}")
|
||||
|
||||
if config.certificate:
|
||||
cert = config.certificate
|
||||
print(f"🔧 QUIC_STATE: Certificate subject: {cert.subject}")
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Certificate valid from: {cert.not_valid_before}"
|
||||
)
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Certificate valid until: {cert.not_valid_after}"
|
||||
)
|
||||
|
||||
# Check for connection errors
|
||||
if hasattr(quic_conn, "_close_event") and quic_conn._close_event:
|
||||
print(
|
||||
f"❌ QUIC_STATE: Connection has close event: {quic_conn._close_event}"
|
||||
)
|
||||
|
||||
# Check for TLS errors
|
||||
if (
|
||||
hasattr(quic_conn, "_handshake_complete")
|
||||
and not quic_conn._handshake_complete
|
||||
):
|
||||
print("⚠️ QUIC_STATE: Handshake not yet complete")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ QUIC_STATE: Error checking state: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _handle_short_header_packet(
|
||||
self, data: bytes, addr: tuple[str, int]
|
||||
) -> None:
|
||||
@ -515,54 +630,141 @@ class QUICListener(IListener):
|
||||
addr: tuple[str, int],
|
||||
dest_cid: bytes,
|
||||
) -> None:
|
||||
"""Handle packet for a pending (handshaking) connection."""
|
||||
"""Handle packet for a pending (handshaking) connection with enhanced debugging."""
|
||||
try:
|
||||
print(
|
||||
f"🔧 PENDING: Handling packet for pending connection {dest_cid.hex()}"
|
||||
)
|
||||
print(f"🔧 PENDING: Packet size: {len(data)} bytes from {addr}")
|
||||
|
||||
# Check connection state before processing
|
||||
if hasattr(quic_conn, "_state"):
|
||||
print(f"🔧 PENDING: Connection state before: {quic_conn._state}")
|
||||
|
||||
if (
|
||||
hasattr(quic_conn, "tls")
|
||||
and quic_conn.tls
|
||||
and hasattr(quic_conn.tls, "state")
|
||||
):
|
||||
print(f"🔧 PENDING: TLS state before: {quic_conn.tls.state}")
|
||||
|
||||
# Feed data to QUIC connection
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
print("✅ PENDING: Datagram received by QUIC connection")
|
||||
|
||||
# Process events
|
||||
# Check state after receiving packet
|
||||
if hasattr(quic_conn, "_state"):
|
||||
print(f"🔧 PENDING: Connection state after: {quic_conn._state}")
|
||||
|
||||
if (
|
||||
hasattr(quic_conn, "tls")
|
||||
and quic_conn.tls
|
||||
and hasattr(quic_conn.tls, "state")
|
||||
):
|
||||
print(f"🔧 PENDING: TLS state after: {quic_conn.tls.state}")
|
||||
|
||||
# Process events - this is crucial for handshake progression
|
||||
print("🔧 PENDING: Processing QUIC events...")
|
||||
await self._process_quic_events(quic_conn, addr, dest_cid)
|
||||
|
||||
# Send any outgoing packets
|
||||
# Send any outgoing packets - this is where the response should be sent
|
||||
print("🔧 PENDING: Transmitting response...")
|
||||
await self._transmit_for_connection(quic_conn, addr)
|
||||
|
||||
# Check if handshake completed
|
||||
if (
|
||||
hasattr(quic_conn, "_handshake_complete")
|
||||
and quic_conn._handshake_complete
|
||||
):
|
||||
print("✅ PENDING: Handshake completed, promoting connection")
|
||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||
else:
|
||||
print("🔧 PENDING: Handshake still in progress")
|
||||
|
||||
# Debug why handshake might be stuck
|
||||
await self._debug_handshake_state(quic_conn, dest_cid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}")
|
||||
# Remove from pending connections
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Remove problematic pending connection
|
||||
print(f"❌ PENDING: Removing problematic connection {dest_cid.hex()}")
|
||||
await self._remove_pending_connection(dest_cid)
|
||||
|
||||
async def _process_quic_events(
|
||||
self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes
|
||||
) -> None:
|
||||
"""Process QUIC events for a connection with connection ID context."""
|
||||
while True:
|
||||
event = quic_conn.next_event()
|
||||
if event is None:
|
||||
break
|
||||
"""Process QUIC events with enhanced debugging."""
|
||||
try:
|
||||
events_processed = 0
|
||||
while True:
|
||||
event = quic_conn.next_event()
|
||||
if event is None:
|
||||
break
|
||||
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
logger.debug(
|
||||
f"Connection {dest_cid.hex()} from {addr} "
|
||||
f"terminated: {event.reason_phrase}"
|
||||
events_processed += 1
|
||||
print(
|
||||
f"🔧 EVENT: Processing event {events_processed}: {type(event).__name__}"
|
||||
)
|
||||
await self._remove_connection(dest_cid)
|
||||
break
|
||||
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
logger.debug(f"Handshake completed for connection {dest_cid.hex()}")
|
||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
print(
|
||||
f"❌ EVENT: Connection terminated - code: {event.error_code}, reason: {event.reason_phrase}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Connection {dest_cid.hex()} from {addr} "
|
||||
f"terminated: {event.reason_phrase}"
|
||||
)
|
||||
await self._remove_connection(dest_cid)
|
||||
break
|
||||
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
# Forward to established connection if available
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
await connection._handle_stream_data(event)
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
print(
|
||||
f"✅ EVENT: Handshake completed for connection {dest_cid.hex()}"
|
||||
)
|
||||
logger.debug(f"Handshake completed for connection {dest_cid.hex()}")
|
||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||
|
||||
elif isinstance(event, events.StreamReset):
|
||||
# Forward to established connection if available
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
await connection._handle_stream_reset(event)
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
print(f"🔧 EVENT: Stream data received on stream {event.stream_id}")
|
||||
# Forward to established connection if available
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
await connection._handle_stream_data(event)
|
||||
|
||||
elif isinstance(event, events.StreamReset):
|
||||
print(f"🔧 EVENT: Stream reset on stream {event.stream_id}")
|
||||
# Forward to established connection if available
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
await connection._handle_stream_reset(event)
|
||||
|
||||
elif isinstance(event, events.ConnectionIdIssued):
|
||||
print(
|
||||
f"🔧 EVENT: Connection ID issued: {event.connection_id.hex()}"
|
||||
)
|
||||
|
||||
elif isinstance(event, events.ConnectionIdRetired):
|
||||
print(
|
||||
f"🔧 EVENT: Connection ID retired: {event.connection_id.hex()}"
|
||||
)
|
||||
|
||||
else:
|
||||
print(f"🔧 EVENT: Unhandled event type: {type(event).__name__}")
|
||||
|
||||
if events_processed == 0:
|
||||
print("🔧 EVENT: No events to process")
|
||||
else:
|
||||
print(f"🔧 EVENT: Processed {events_processed} events total")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ EVENT: Error processing events: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _debug_quic_connection_state(
|
||||
self, quic_conn: QuicConnection, connection_id: bytes
|
||||
@ -972,3 +1174,61 @@ class QUICListener(IListener):
|
||||
stats["active_connections"] = len(self._connections)
|
||||
stats["pending_connections"] = len(self._pending_connections)
|
||||
return stats
|
||||
|
||||
async def _debug_handshake_state(self, quic_conn: QuicConnection, dest_cid: bytes):
|
||||
"""Debug why handshake might be stuck."""
|
||||
try:
|
||||
print(f"🔧 HANDSHAKE_DEBUG: Analyzing stuck handshake for {dest_cid.hex()}")
|
||||
|
||||
# Check TLS handshake state
|
||||
if hasattr(quic_conn, "tls") and quic_conn.tls:
|
||||
tls = quic_conn.tls
|
||||
print(
|
||||
f"🔧 HANDSHAKE_DEBUG: TLS state: {getattr(tls, 'state', 'Unknown')}"
|
||||
)
|
||||
|
||||
# Check for TLS errors
|
||||
if hasattr(tls, "_error") and tls._error:
|
||||
print(f"❌ HANDSHAKE_DEBUG: TLS error: {tls._error}")
|
||||
|
||||
# Check certificate validation
|
||||
if hasattr(tls, "_peer_certificate"):
|
||||
if tls._peer_certificate:
|
||||
print("✅ HANDSHAKE_DEBUG: Peer certificate received")
|
||||
else:
|
||||
print("❌ HANDSHAKE_DEBUG: No peer certificate")
|
||||
|
||||
# Check ALPN negotiation
|
||||
if hasattr(tls, "_alpn_protocols"):
|
||||
if tls._alpn_protocols:
|
||||
print(
|
||||
f"✅ HANDSHAKE_DEBUG: ALPN negotiated: {tls._alpn_protocols}"
|
||||
)
|
||||
else:
|
||||
print("❌ HANDSHAKE_DEBUG: No ALPN protocol negotiated")
|
||||
|
||||
# Check QUIC connection state
|
||||
if hasattr(quic_conn, "_state"):
|
||||
state = quic_conn._state
|
||||
print(f"🔧 HANDSHAKE_DEBUG: QUIC state: {state}")
|
||||
|
||||
# Check specific states that might indicate problems
|
||||
if "FIRSTFLIGHT" in str(state):
|
||||
print("⚠️ HANDSHAKE_DEBUG: Connection stuck in FIRSTFLIGHT state")
|
||||
elif "CONNECTED" in str(state):
|
||||
print(
|
||||
"⚠️ HANDSHAKE_DEBUG: Connection shows CONNECTED but handshake not complete"
|
||||
)
|
||||
|
||||
# Check for pending crypto data
|
||||
if hasattr(quic_conn, "_cryptos") and quic_conn._cryptos:
|
||||
print(f"🔧 HANDSHAKE_DEBUG: Crypto data present {len(quic_conn._cryptos.keys())}")
|
||||
|
||||
# Check loss detection state
|
||||
if hasattr(quic_conn, "_loss") and quic_conn._loss:
|
||||
loss_detection = quic_conn._loss
|
||||
if hasattr(loss_detection, "_pto_count"):
|
||||
print(f"🔧 HANDSHAKE_DEBUG: PTO count: {loss_detection._pto_count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ HANDSHAKE_DEBUG: Error during debug: {e}")
|
||||
|
||||
@ -4,9 +4,11 @@ Implements libp2p TLS specification for QUIC transport with peer identity integr
|
||||
Based on go-libp2p and js-libp2p security patterns.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import ssl
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
@ -25,11 +27,6 @@ from .exceptions import (
|
||||
QUICPeerVerificationError,
|
||||
)
|
||||
|
||||
TSecurityConfig = dict[
|
||||
str,
|
||||
Certificate | EllipticCurvePrivateKey | RSAPrivateKey | bool | list[str],
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# libp2p TLS Extension OID - Official libp2p specification
|
||||
@ -312,7 +309,7 @@ class CertificateGenerator:
|
||||
x509.UnrecognizedExtension(
|
||||
oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data
|
||||
),
|
||||
critical=True, # This extension is critical for libp2p
|
||||
critical=False,
|
||||
)
|
||||
.sign(cert_private_key, hashes.SHA256())
|
||||
)
|
||||
@ -407,6 +404,269 @@ class PeerAuthenticator:
|
||||
) from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class QUICTLSSecurityConfig:
|
||||
"""
|
||||
Type-safe TLS security configuration for QUIC transport.
|
||||
"""
|
||||
|
||||
# Core TLS components (required)
|
||||
certificate: Certificate
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey]
|
||||
|
||||
# Certificate chain (optional)
|
||||
certificate_chain: List[Certificate] = field(default_factory=list)
|
||||
|
||||
# ALPN protocols
|
||||
alpn_protocols: List[str] = field(default_factory=lambda: ["libp2p"])
|
||||
|
||||
# TLS verification settings
|
||||
verify_mode: Union[bool, ssl.VerifyMode] = False
|
||||
check_hostname: bool = False
|
||||
|
||||
# Optional peer ID for validation
|
||||
peer_id: Optional[ID] = None
|
||||
|
||||
# Configuration metadata
|
||||
is_client_config: bool = False
|
||||
config_name: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
self._validate()
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Validate the TLS configuration."""
|
||||
if self.certificate is None:
|
||||
raise ValueError("Certificate is required")
|
||||
|
||||
if self.private_key is None:
|
||||
raise ValueError("Private key is required")
|
||||
|
||||
if not isinstance(self.certificate, x509.Certificate):
|
||||
raise TypeError(
|
||||
f"Certificate must be x509.Certificate, got {type(self.certificate)}"
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
self.private_key, (ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Private key must be EC or RSA key, got {type(self.private_key)}"
|
||||
)
|
||||
|
||||
if not self.alpn_protocols:
|
||||
raise ValueError("At least one ALPN protocol is required")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convert to dictionary format for compatibility with existing code.
|
||||
|
||||
Returns:
|
||||
Dictionary compatible with the original TSecurityConfig format
|
||||
|
||||
"""
|
||||
return {
|
||||
"certificate": self.certificate,
|
||||
"private_key": self.private_key,
|
||||
"certificate_chain": self.certificate_chain.copy(),
|
||||
"alpn_protocols": self.alpn_protocols.copy(),
|
||||
"verify_mode": self.verify_mode,
|
||||
"check_hostname": self.check_hostname,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict, **kwargs) -> "QUICTLSSecurityConfig":
|
||||
"""
|
||||
Create instance from dictionary format.
|
||||
|
||||
Args:
|
||||
config_dict: Dictionary in TSecurityConfig format
|
||||
**kwargs: Additional parameters for the config
|
||||
|
||||
Returns:
|
||||
QUICTLSSecurityConfig instance
|
||||
|
||||
"""
|
||||
return cls(
|
||||
certificate=config_dict["certificate"],
|
||||
private_key=config_dict["private_key"],
|
||||
certificate_chain=config_dict.get("certificate_chain", []),
|
||||
alpn_protocols=config_dict.get("alpn_protocols", ["libp2p"]),
|
||||
verify_mode=config_dict.get("verify_mode", False),
|
||||
check_hostname=config_dict.get("check_hostname", False),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def validate_certificate_key_match(self) -> bool:
|
||||
"""
|
||||
Validate that the certificate and private key match.
|
||||
|
||||
Returns:
|
||||
True if certificate and private key match
|
||||
|
||||
"""
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
# Get public keys from both certificate and private key
|
||||
cert_public_key = self.certificate.public_key()
|
||||
private_public_key = self.private_key.public_key()
|
||||
|
||||
# Compare their PEM representations
|
||||
cert_pub_pem = cert_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
private_pub_pem = private_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
return cert_pub_pem == private_pub_pem
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def has_libp2p_extension(self) -> bool:
|
||||
"""
|
||||
Check if the certificate has the required libp2p extension.
|
||||
|
||||
Returns:
|
||||
True if libp2p extension is present
|
||||
|
||||
"""
|
||||
try:
|
||||
libp2p_oid = "1.3.6.1.4.1.53594.1.1"
|
||||
for ext in self.certificate.extensions:
|
||||
if str(ext.oid) == libp2p_oid:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def is_certificate_valid(self) -> bool:
|
||||
"""
|
||||
Check if the certificate is currently valid (not expired).
|
||||
|
||||
Returns:
|
||||
True if certificate is valid
|
||||
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.utcnow()
|
||||
return (
|
||||
self.certificate.not_valid_before
|
||||
<= now
|
||||
<= self.certificate.not_valid_after
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_certificate_info(self) -> dict:
|
||||
"""
|
||||
Get certificate information for debugging.
|
||||
|
||||
Returns:
|
||||
Dictionary with certificate details
|
||||
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"subject": str(self.certificate.subject),
|
||||
"issuer": str(self.certificate.issuer),
|
||||
"serial_number": self.certificate.serial_number,
|
||||
"not_valid_before": self.certificate.not_valid_before,
|
||||
"not_valid_after": self.certificate.not_valid_after,
|
||||
"has_libp2p_extension": self.has_libp2p_extension(),
|
||||
"is_valid": self.is_certificate_valid(),
|
||||
"certificate_key_match": self.validate_certificate_key_match(),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def debug_print(self) -> None:
|
||||
"""Print debugging information about this configuration."""
|
||||
print(f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===")
|
||||
print(f"Is client config: {self.is_client_config}")
|
||||
print(f"ALPN protocols: {self.alpn_protocols}")
|
||||
print(f"Verify mode: {self.verify_mode}")
|
||||
print(f"Check hostname: {self.check_hostname}")
|
||||
print(f"Certificate chain length: {len(self.certificate_chain)}")
|
||||
|
||||
cert_info = self.get_certificate_info()
|
||||
for key, value in cert_info.items():
|
||||
print(f"Certificate {key}: {value}")
|
||||
|
||||
print(f"Private key type: {type(self.private_key).__name__}")
|
||||
if hasattr(self.private_key, "key_size"):
|
||||
print(f"Private key size: {self.private_key.key_size}")
|
||||
|
||||
|
||||
def create_server_tls_config(
|
||||
certificate: Certificate,
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||
peer_id: Optional[ID] = None,
|
||||
**kwargs,
|
||||
) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create a server TLS configuration.
|
||||
|
||||
Args:
|
||||
certificate: X.509 certificate
|
||||
private_key: Private key corresponding to certificate
|
||||
peer_id: Optional peer ID for validation
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
Server TLS configuration
|
||||
|
||||
"""
|
||||
return QUICTLSSecurityConfig(
|
||||
certificate=certificate,
|
||||
private_key=private_key,
|
||||
peer_id=peer_id,
|
||||
is_client_config=False,
|
||||
config_name="server",
|
||||
verify_mode=False, # Server doesn't verify client certs in libp2p
|
||||
check_hostname=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create_client_tls_config(
|
||||
certificate: Certificate,
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||
peer_id: Optional[ID] = None,
|
||||
**kwargs,
|
||||
) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create a client TLS configuration.
|
||||
|
||||
Args:
|
||||
certificate: X.509 certificate
|
||||
private_key: Private key corresponding to certificate
|
||||
peer_id: Optional peer ID for validation
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
Client TLS configuration
|
||||
|
||||
"""
|
||||
return QUICTLSSecurityConfig(
|
||||
certificate=certificate,
|
||||
private_key=private_key,
|
||||
peer_id=peer_id,
|
||||
is_client_config=True,
|
||||
config_name="client",
|
||||
verify_mode=False, # Client doesn't verify server certs in libp2p
|
||||
check_hostname=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class QUICTLSConfigManager:
|
||||
"""
|
||||
Manages TLS configuration for QUIC transport with libp2p security.
|
||||
@ -424,44 +684,40 @@ class QUICTLSConfigManager:
|
||||
libp2p_private_key, peer_id
|
||||
)
|
||||
|
||||
def create_server_config(
|
||||
self,
|
||||
) -> TSecurityConfig:
|
||||
def create_server_config(self) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create aioquic server configuration with libp2p TLS settings.
|
||||
Returns cryptography objects instead of DER bytes.
|
||||
Create server configuration using the new class-based approach.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary for aioquic QuicConfiguration
|
||||
QUICTLSSecurityConfig instance for server
|
||||
|
||||
"""
|
||||
config: TSecurityConfig = {
|
||||
"certificate": self.tls_config.certificate,
|
||||
"private_key": self.tls_config.private_key,
|
||||
"certificate_chain": [],
|
||||
"alpn_protocols": ["libp2p"],
|
||||
"verify_mode": False,
|
||||
"check_hostname": False,
|
||||
}
|
||||
config = create_server_tls_config(
|
||||
certificate=self.tls_config.certificate,
|
||||
private_key=self.tls_config.private_key,
|
||||
peer_id=self.peer_id,
|
||||
)
|
||||
|
||||
print("🔧 SECURITY: Created server config")
|
||||
config.debug_print()
|
||||
return config
|
||||
|
||||
def create_client_config(self) -> TSecurityConfig:
|
||||
def create_client_config(self) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create aioquic client configuration with libp2p TLS settings.
|
||||
Returns cryptography objects instead of DER bytes.
|
||||
Create client configuration using the new class-based approach.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary for aioquic QuicConfiguration
|
||||
QUICTLSSecurityConfig instance for client
|
||||
|
||||
"""
|
||||
config: TSecurityConfig = {
|
||||
"certificate": self.tls_config.certificate,
|
||||
"private_key": self.tls_config.private_key,
|
||||
"certificate_chain": [],
|
||||
"alpn_protocols": ["libp2p"],
|
||||
"verify_mode": False,
|
||||
"check_hostname": False,
|
||||
}
|
||||
config = create_client_tls_config(
|
||||
certificate=self.tls_config.certificate,
|
||||
private_key=self.tls_config.private_key,
|
||||
peer_id=self.peer_id,
|
||||
)
|
||||
|
||||
print("🔧 SECURITY: Created client config")
|
||||
config.debug_print()
|
||||
return config
|
||||
|
||||
def verify_peer_identity(
|
||||
|
||||
@ -5,7 +5,6 @@ Based on aioquic library with interface consistency to go-libp2p and js-libp2p.
|
||||
Updated to include Module 5 security integration.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
import copy
|
||||
import logging
|
||||
import sys
|
||||
@ -31,7 +30,7 @@ from libp2p.custom_types import THandler, TProtocol
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.transport.quic.security import TSecurityConfig
|
||||
from libp2p.transport.quic.security import QUICTLSSecurityConfig
|
||||
from libp2p.transport.quic.utils import (
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
@ -192,7 +191,7 @@ class QUICTransport(ITransport):
|
||||
) from e
|
||||
|
||||
def _apply_tls_configuration(
|
||||
self, config: QuicConfiguration, tls_config: TSecurityConfig
|
||||
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
|
||||
) -> None:
|
||||
"""
|
||||
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
||||
@ -203,52 +202,47 @@ class QUICTransport(ITransport):
|
||||
|
||||
"""
|
||||
try:
|
||||
# Set certificate and private key directly on the configuration
|
||||
# aioquic expects cryptography objects, not DER bytes
|
||||
if "certificate" in tls_config and "private_key" in tls_config:
|
||||
# The security manager should return cryptography objects
|
||||
# not DER bytes, but if it returns DER bytes, we need to handle that
|
||||
certificate = tls_config["certificate"]
|
||||
private_key = tls_config["private_key"]
|
||||
|
||||
# Check if we received DER bytes and need
|
||||
# to convert to cryptography objects
|
||||
if isinstance(certificate, bytes):
|
||||
# The security manager should return cryptography objects
|
||||
# not DER bytes, but if it returns DER bytes, we need to handle that
|
||||
certificate = tls_config.certificate
|
||||
private_key = tls_config.private_key
|
||||
|
||||
# Check if we received DER bytes and need
|
||||
# to convert to cryptography objects
|
||||
if isinstance(certificate, bytes):
|
||||
from cryptography import x509
|
||||
|
||||
certificate = x509.load_der_x509_certificate(certificate)
|
||||
|
||||
if isinstance(private_key, bytes):
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
private_key = serialization.load_der_private_key( # type: ignore
|
||||
private_key, password=None
|
||||
)
|
||||
|
||||
# Set directly on the configuration object
|
||||
config.certificate = certificate
|
||||
config.private_key = private_key
|
||||
|
||||
# Handle certificate chain if provided
|
||||
certificate_chain = tls_config.certificate_chain
|
||||
# Convert DER bytes to cryptography objects if needed
|
||||
chain_objects = []
|
||||
for cert in certificate_chain:
|
||||
if isinstance(cert, bytes):
|
||||
from cryptography import x509
|
||||
|
||||
certificate = x509.load_der_x509_certificate(certificate)
|
||||
|
||||
if isinstance(private_key, bytes):
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
private_key = serialization.load_der_private_key( # type: ignore
|
||||
private_key, password=None
|
||||
)
|
||||
|
||||
# Set directly on the configuration object
|
||||
config.certificate = certificate
|
||||
config.private_key = private_key
|
||||
|
||||
# Handle certificate chain if provided
|
||||
certificate_chain = tls_config.get("certificate_chain", [])
|
||||
if certificate_chain and isinstance(certificate_chain, Iterable):
|
||||
# Convert DER bytes to cryptography objects if needed
|
||||
chain_objects = []
|
||||
for cert in certificate_chain:
|
||||
if isinstance(cert, bytes):
|
||||
from cryptography import x509
|
||||
|
||||
cert = x509.load_der_x509_certificate(cert)
|
||||
chain_objects.append(cert)
|
||||
config.certificate_chain = chain_objects
|
||||
cert = x509.load_der_x509_certificate(cert)
|
||||
chain_objects.append(cert)
|
||||
config.certificate_chain = chain_objects
|
||||
|
||||
# Set ALPN protocols
|
||||
if "alpn_protocols" in tls_config:
|
||||
config.alpn_protocols = tls_config["alpn_protocols"] # type: ignore
|
||||
config.alpn_protocols = tls_config.alpn_protocols
|
||||
|
||||
# Set certificate verification mode
|
||||
if "verify_mode" in tls_config:
|
||||
config.verify_mode = tls_config["verify_mode"] # type: ignore
|
||||
config.verify_mode = tls_config.verify_mode
|
||||
|
||||
logger.debug("Successfully applied TLS configuration to QUIC config")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user