mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
chore: cleanup and near v1 quic impl
This commit is contained in:
@ -1,371 +0,0 @@
|
||||
def debug_quic_connection_state(conn, name="Connection"):
|
||||
"""Enhanced debugging function for QUIC connection state."""
|
||||
print(f"\n🔍 === {name} Debug Info ===")
|
||||
|
||||
# Basic connection state
|
||||
print(f"State: {getattr(conn, '_state', 'unknown')}")
|
||||
print(f"Handshake complete: {getattr(conn, '_handshake_complete', False)}")
|
||||
|
||||
# Connection IDs
|
||||
if hasattr(conn, "_host_connection_id"):
|
||||
print(
|
||||
f"Host CID: {conn._host_connection_id.hex() if conn._host_connection_id else 'None'}"
|
||||
)
|
||||
if hasattr(conn, "_peer_connection_id"):
|
||||
print(
|
||||
f"Peer CID: {conn._peer_connection_id.hex() if conn._peer_connection_id else 'None'}"
|
||||
)
|
||||
|
||||
# Check for connection ID sequences
|
||||
if hasattr(conn, "_local_connection_ids"):
|
||||
print(
|
||||
f"Local CID sequence: {[cid.cid.hex() for cid in conn._local_connection_ids]}"
|
||||
)
|
||||
if hasattr(conn, "_remote_connection_ids"):
|
||||
print(
|
||||
f"Remote CID sequence: {[cid.cid.hex() for cid in conn._remote_connection_ids]}"
|
||||
)
|
||||
|
||||
# TLS state
|
||||
if hasattr(conn, "tls") and conn.tls:
|
||||
tls_state = getattr(conn.tls, "state", "unknown")
|
||||
print(f"TLS state: {tls_state}")
|
||||
|
||||
# Check for certificates
|
||||
peer_cert = getattr(conn.tls, "_peer_certificate", None)
|
||||
print(f"Has peer certificate: {peer_cert is not None}")
|
||||
|
||||
# Transport parameters
|
||||
if hasattr(conn, "_remote_transport_parameters"):
|
||||
params = conn._remote_transport_parameters
|
||||
if params:
|
||||
print(f"Remote transport parameters received: {len(params)} params")
|
||||
|
||||
print(f"=== End {name} Debug ===\n")
|
||||
|
||||
|
||||
def debug_firstflight_event(server_conn, name="Server"):
|
||||
"""Debug connection ID changes specifically around FIRSTFLIGHT event."""
|
||||
print(f"\n🎯 === {name} FIRSTFLIGHT Event Debug ===")
|
||||
|
||||
# Connection state
|
||||
state = getattr(server_conn, "_state", "unknown")
|
||||
print(f"Connection State: {state}")
|
||||
|
||||
# Connection IDs
|
||||
peer_cid = getattr(server_conn, "_peer_connection_id", None)
|
||||
host_cid = getattr(server_conn, "_host_connection_id", None)
|
||||
original_dcid = getattr(server_conn, "original_destination_connection_id", None)
|
||||
|
||||
print(f"Peer CID: {peer_cid.hex() if peer_cid else 'None'}")
|
||||
print(f"Host CID: {host_cid.hex() if host_cid else 'None'}")
|
||||
print(f"Original DCID: {original_dcid.hex() if original_dcid else 'None'}")
|
||||
|
||||
print(f"=== End {name} FIRSTFLIGHT Debug ===\n")
|
||||
|
||||
|
||||
def create_minimal_quic_test():
|
||||
"""Simplified test to isolate FIRSTFLIGHT connection ID issues."""
|
||||
print("\n=== MINIMAL QUIC FIRSTFLIGHT CONNECTION ID TEST ===")
|
||||
|
||||
from time import time
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.buffer import Buffer
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
|
||||
# Minimal configs without certificates first
|
||||
client_config = QuicConfiguration(
|
||||
is_client=True, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
|
||||
server_config = QuicConfiguration(
|
||||
is_client=False, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
|
||||
# Create client and connect
|
||||
client_conn = QuicConnection(configuration=client_config)
|
||||
server_addr = ("127.0.0.1", 4321)
|
||||
|
||||
print("🔗 Client calling connect()...")
|
||||
client_conn.connect(server_addr, now=time())
|
||||
|
||||
# Debug client state after connect
|
||||
debug_quic_connection_state(client_conn, "Client After Connect")
|
||||
|
||||
# Get initial client packet
|
||||
initial_packets = client_conn.datagrams_to_send(now=time())
|
||||
if not initial_packets:
|
||||
print("❌ No initial packets from client")
|
||||
return False
|
||||
|
||||
initial_packet = initial_packets[0][0]
|
||||
|
||||
# Parse header to get client's source CID (what server should use as peer CID)
|
||||
header = pull_quic_header(Buffer(data=initial_packet), host_cid_length=8)
|
||||
client_source_cid = header.source_cid
|
||||
client_dest_cid = header.destination_cid
|
||||
|
||||
print(f"📦 Initial packet analysis:")
|
||||
print(
|
||||
f" Client Source CID: {client_source_cid.hex()} (server should use as peer CID)"
|
||||
)
|
||||
print(f" Client Dest CID: {client_dest_cid.hex()}")
|
||||
|
||||
# Create server with proper ODCID
|
||||
print(
|
||||
f"\n🏗️ Creating server with original_destination_connection_id={client_dest_cid.hex()}..."
|
||||
)
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=client_dest_cid,
|
||||
)
|
||||
|
||||
# Debug server state after creation (before FIRSTFLIGHT)
|
||||
debug_firstflight_event(server_conn, "Server After Creation (Pre-FIRSTFLIGHT)")
|
||||
|
||||
# 🎯 CRITICAL: Process initial packet (this triggers FIRSTFLIGHT event)
|
||||
print(f"🚀 Processing initial packet (triggering FIRSTFLIGHT)...")
|
||||
client_addr = ("127.0.0.1", 1234)
|
||||
|
||||
# Before receive_datagram
|
||||
print(f"📊 BEFORE receive_datagram (FIRSTFLIGHT):")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
print(f" Expected peer CID after FIRSTFLIGHT: {client_source_cid.hex()}")
|
||||
|
||||
# This call triggers FIRSTFLIGHT: FIRSTFLIGHT -> CONNECTED
|
||||
server_conn.receive_datagram(initial_packet, client_addr, now=time())
|
||||
|
||||
# After receive_datagram (FIRSTFLIGHT should have happened)
|
||||
print(f"📊 AFTER receive_datagram (Post-FIRSTFLIGHT):")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
|
||||
# Check if FIRSTFLIGHT set peer CID correctly
|
||||
actual_peer_cid = server_conn._peer_cid.cid
|
||||
if actual_peer_cid == client_source_cid:
|
||||
print("✅ FIRSTFLIGHT correctly set peer CID from client source CID")
|
||||
firstflight_success = True
|
||||
else:
|
||||
print("❌ FIRSTFLIGHT BUG: peer CID not set correctly!")
|
||||
print(f" Expected: {client_source_cid.hex()}")
|
||||
print(f" Actual: {actual_peer_cid.hex() if actual_peer_cid else 'None'}")
|
||||
firstflight_success = False
|
||||
|
||||
# Debug both connections after FIRSTFLIGHT
|
||||
debug_firstflight_event(server_conn, "Server After FIRSTFLIGHT")
|
||||
debug_quic_connection_state(client_conn, "Client After Server Processing")
|
||||
|
||||
# Check server response packets
|
||||
print(f"\n📤 Checking server response packets...")
|
||||
server_packets = server_conn.datagrams_to_send(now=time())
|
||||
if server_packets:
|
||||
response_packet = server_packets[0][0]
|
||||
response_header = pull_quic_header(
|
||||
Buffer(data=response_packet), host_cid_length=8
|
||||
)
|
||||
|
||||
print(f"📊 Server response packet:")
|
||||
print(f" Source CID: {response_header.source_cid.hex()}")
|
||||
print(f" Dest CID: {response_header.destination_cid.hex()}")
|
||||
print(f" Expected dest CID: {client_source_cid.hex()}")
|
||||
|
||||
# Final verification
|
||||
if response_header.destination_cid == client_source_cid:
|
||||
print("✅ Server response uses correct destination CID!")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Server response uses WRONG destination CID!")
|
||||
print(f" This proves the FIRSTFLIGHT bug - peer CID not set correctly")
|
||||
print(f" Expected: {client_source_cid.hex()}")
|
||||
print(f" Actual: {response_header.destination_cid.hex()}")
|
||||
return False
|
||||
else:
|
||||
print("❌ Server did not generate response packet")
|
||||
return False
|
||||
|
||||
|
||||
def create_minimal_quic_test_with_config(client_config, server_config):
|
||||
"""Run FIRSTFLIGHT test with provided configurations."""
|
||||
from time import time
|
||||
from aioquic.buffer import Buffer
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
|
||||
print("\n=== FIRSTFLIGHT TEST WITH CERTIFICATES ===")
|
||||
|
||||
# Create client and connect
|
||||
client_conn = QuicConnection(configuration=client_config)
|
||||
server_addr = ("127.0.0.1", 4321)
|
||||
|
||||
print("🔗 Client calling connect() with certificates...")
|
||||
client_conn.connect(server_addr, now=time())
|
||||
|
||||
# Get initial packets and extract client source CID
|
||||
initial_packets = client_conn.datagrams_to_send(now=time())
|
||||
if not initial_packets:
|
||||
print("❌ No initial packets from client")
|
||||
return False
|
||||
|
||||
# Extract client source CID from initial packet
|
||||
initial_packet = initial_packets[0][0]
|
||||
header = pull_quic_header(Buffer(data=initial_packet), host_cid_length=8)
|
||||
client_source_cid = header.source_cid
|
||||
|
||||
print(f"📦 Client source CID (expected server peer CID): {client_source_cid.hex()}")
|
||||
|
||||
# Create server with client's source CID as original destination
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=client_source_cid,
|
||||
)
|
||||
|
||||
# Debug server before FIRSTFLIGHT
|
||||
print(f"\n📊 BEFORE FIRSTFLIGHT (server creation):")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
print(
|
||||
f" Server original DCID: {server_conn.original_destination_connection_id.hex()}"
|
||||
)
|
||||
|
||||
# Process initial packet (triggers FIRSTFLIGHT)
|
||||
client_addr = ("127.0.0.1", 1234)
|
||||
|
||||
print(f"\n🚀 Triggering FIRSTFLIGHT by processing initial packet...")
|
||||
for datagram, _ in initial_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f" Processing packet: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
|
||||
# This triggers FIRSTFLIGHT
|
||||
server_conn.receive_datagram(datagram, client_addr, now=time())
|
||||
|
||||
# Debug immediately after FIRSTFLIGHT
|
||||
print(f"\n📊 AFTER FIRSTFLIGHT:")
|
||||
print(f" Server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
print(
|
||||
f" Server peer CID: {server_conn._peer_cid.cid.hex()}"
|
||||
)
|
||||
print(f" Expected peer CID: {header.source_cid.hex()}")
|
||||
|
||||
# Check if FIRSTFLIGHT worked correctly
|
||||
actual_peer_cid = getattr(server_conn, "_peer_connection_id", None)
|
||||
if actual_peer_cid == header.source_cid:
|
||||
print("✅ FIRSTFLIGHT correctly set peer CID")
|
||||
else:
|
||||
print("❌ FIRSTFLIGHT failed to set peer CID correctly")
|
||||
print(f" This is the root cause of the handshake failure!")
|
||||
|
||||
# Check server response
|
||||
server_packets = server_conn.datagrams_to_send(now=time())
|
||||
if server_packets:
|
||||
response_packet = server_packets[0][0]
|
||||
response_header = pull_quic_header(
|
||||
Buffer(data=response_packet), host_cid_length=8
|
||||
)
|
||||
|
||||
print(f"\n📤 Server response analysis:")
|
||||
print(f" Response dest CID: {response_header.destination_cid.hex()}")
|
||||
print(f" Expected dest CID: {client_source_cid.hex()}")
|
||||
|
||||
if response_header.destination_cid == client_source_cid:
|
||||
print("✅ Server response uses correct destination CID!")
|
||||
return True
|
||||
else:
|
||||
print("❌ FIRSTFLIGHT bug confirmed - wrong destination CID in response!")
|
||||
print(
|
||||
" This proves aioquic doesn't set peer CID correctly during FIRSTFLIGHT"
|
||||
)
|
||||
return False
|
||||
|
||||
print("❌ No server response packets")
|
||||
return False
|
||||
|
||||
|
||||
async def test_with_certificates():
|
||||
"""Test with proper certificate setup and FIRSTFLIGHT debugging."""
|
||||
print("\n=== CERTIFICATE-BASED FIRSTFLIGHT TEST ===")
|
||||
|
||||
# Import your existing certificate creation functions
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
|
||||
# Create security configs
|
||||
client_key_pair = create_new_key_pair()
|
||||
server_key_pair = create_new_key_pair()
|
||||
|
||||
client_security_config = create_quic_security_transport(
|
||||
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
||||
)
|
||||
server_security_config = create_quic_security_transport(
|
||||
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
||||
)
|
||||
|
||||
# Apply the minimal test logic with certificates
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
|
||||
client_config = QuicConfiguration(
|
||||
is_client=True, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
client_config.certificate = client_security_config.tls_config.certificate
|
||||
client_config.private_key = client_security_config.tls_config.private_key
|
||||
client_config.verify_mode = (
|
||||
client_security_config.create_client_config().verify_mode
|
||||
)
|
||||
|
||||
server_config = QuicConfiguration(
|
||||
is_client=False, alpn_protocols=["libp2p"], connection_id_length=8
|
||||
)
|
||||
server_config.certificate = server_security_config.tls_config.certificate
|
||||
server_config.private_key = server_security_config.tls_config.private_key
|
||||
server_config.verify_mode = (
|
||||
server_security_config.create_server_config().verify_mode
|
||||
)
|
||||
|
||||
# Run the FIRSTFLIGHT test with certificates
|
||||
return create_minimal_quic_test_with_config(client_config, server_config)
|
||||
|
||||
|
||||
async def main():
|
||||
print("🎯 Testing FIRSTFLIGHT connection ID behavior...")
|
||||
|
||||
# # First test without certificates
|
||||
# print("\n" + "=" * 60)
|
||||
# print("PHASE 1: Testing FIRSTFLIGHT without certificates")
|
||||
# print("=" * 60)
|
||||
# minimal_success = create_minimal_quic_test()
|
||||
|
||||
# Then test with certificates
|
||||
print("\n" + "=" * 60)
|
||||
print("PHASE 2: Testing FIRSTFLIGHT with certificates")
|
||||
print("=" * 60)
|
||||
cert_success = await test_with_certificates()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("FIRSTFLIGHT TEST SUMMARY")
|
||||
print("=" * 60)
|
||||
# print(f"Minimal test (no certs): {'✅ PASS' if minimal_success else '❌ FAIL'}")
|
||||
print(f"Certificate test: {'✅ PASS' if cert_success else '❌ FAIL'}")
|
||||
|
||||
if not cert_success:
|
||||
print("\n🔥 FIRSTFLIGHT BUG CONFIRMED:")
|
||||
print(" - aioquic fails to set peer CID correctly during FIRSTFLIGHT event")
|
||||
print(" - Server uses wrong destination CID in response packets")
|
||||
print(" - Client drops responses → handshake fails")
|
||||
print(" - Fix: Override _peer_connection_id after receive_datagram()")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import trio
|
||||
|
||||
trio.run(main)
|
||||
@ -1,205 +0,0 @@
|
||||
from aioquic._buffer import Buffer
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from tempfile import NamedTemporaryFile
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from time import time
|
||||
import os
|
||||
import trio
|
||||
|
||||
|
||||
async def test_full_handshake_and_certificate_exchange():
|
||||
"""
|
||||
Test a full handshake to ensure it completes and peer certificates are exchanged.
|
||||
FIXED VERSION: Corrects connection ID management and address handling.
|
||||
"""
|
||||
print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (FIXED) ===")
|
||||
|
||||
# 1. Generate KeyPairs and create libp2p security configs for client and server.
|
||||
client_key_pair = create_new_key_pair()
|
||||
server_key_pair = create_new_key_pair()
|
||||
|
||||
client_security_config = create_quic_security_transport(
|
||||
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
||||
)
|
||||
server_security_config = create_quic_security_transport(
|
||||
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
||||
)
|
||||
print("✅ libp2p security configs created.")
|
||||
|
||||
# 2. Create aioquic configurations with consistent settings
|
||||
client_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-client.log"
|
||||
)
|
||||
client_aioquic_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=client_secrets_log_file,
|
||||
connection_id_length=8, # Set consistent CID length
|
||||
)
|
||||
client_aioquic_config.certificate = client_security_config.tls_config.certificate
|
||||
client_aioquic_config.private_key = client_security_config.tls_config.private_key
|
||||
client_aioquic_config.verify_mode = (
|
||||
client_security_config.create_client_config().verify_mode
|
||||
)
|
||||
|
||||
server_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-server.log"
|
||||
)
|
||||
server_aioquic_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=server_secrets_log_file,
|
||||
connection_id_length=8, # Set consistent CID length
|
||||
)
|
||||
server_aioquic_config.certificate = server_security_config.tls_config.certificate
|
||||
server_aioquic_config.private_key = server_security_config.tls_config.private_key
|
||||
server_aioquic_config.verify_mode = (
|
||||
server_security_config.create_server_config().verify_mode
|
||||
)
|
||||
print("✅ aioquic configurations created and configured.")
|
||||
print(f"🔑 Client secrets will be logged to: {client_secrets_log_file.name}")
|
||||
print(f"🔑 Server secrets will be logged to: {server_secrets_log_file.name}")
|
||||
|
||||
# 3. Use consistent addresses - this is crucial!
|
||||
# The client will connect TO the server address, but packets will come FROM client address
|
||||
client_address = ("127.0.0.1", 1234) # Client binds to this
|
||||
server_address = ("127.0.0.1", 4321) # Server binds to this
|
||||
|
||||
# 4. Create client connection and initiate connection
|
||||
client_conn = QuicConnection(configuration=client_aioquic_config)
|
||||
# Client connects to server address - this sets up the initial packet with proper CIDs
|
||||
client_conn.connect(server_address, now=time())
|
||||
print("✅ Client connection initiated.")
|
||||
|
||||
# 5. Get the initial client packet and extract ODCID properly
|
||||
client_datagrams = client_conn.datagrams_to_send(now=time())
|
||||
if not client_datagrams:
|
||||
raise AssertionError("❌ Client did not generate initial packet")
|
||||
|
||||
client_initial_packet = client_datagrams[0][0]
|
||||
header = pull_quic_header(Buffer(data=client_initial_packet), host_cid_length=8)
|
||||
original_dcid = header.destination_cid
|
||||
client_source_cid = header.source_cid
|
||||
|
||||
print(f"📊 Client ODCID: {original_dcid.hex()}")
|
||||
print(f"📊 Client source CID: {client_source_cid.hex()}")
|
||||
|
||||
# 6. Create server connection with the correct ODCID
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_aioquic_config,
|
||||
original_destination_connection_id=original_dcid,
|
||||
)
|
||||
print("✅ Server connection created with correct ODCID.")
|
||||
|
||||
# 7. Feed the initial client packet to server
|
||||
# IMPORTANT: Use client_address as the source for the packet
|
||||
for datagram, _ in client_datagrams:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
|
||||
# 8. Manual handshake loop with proper packet tracking
|
||||
max_duration_s = 3 # Increased timeout
|
||||
start_time = time()
|
||||
packet_count = 0
|
||||
|
||||
while time() - start_time < max_duration_s:
|
||||
# Process client -> server packets
|
||||
client_packets = list(client_conn.datagrams_to_send(now=time()))
|
||||
for datagram, _ in client_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
packet_count += 1
|
||||
|
||||
# Process server -> client packets
|
||||
server_packets = list(server_conn.datagrams_to_send(now=time()))
|
||||
for datagram, _ in server_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Server -> Client: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
# CRITICAL: Server sends back to client_address, not server_address
|
||||
client_conn.receive_datagram(datagram, server_address, now=time())
|
||||
packet_count += 1
|
||||
|
||||
# Check for completion
|
||||
client_complete = getattr(client_conn, "_handshake_complete", False)
|
||||
server_complete = getattr(server_conn, "_handshake_complete", False)
|
||||
|
||||
print(
|
||||
f"🔄 Handshake status: Client={client_complete}, Server={server_complete}, Packets={packet_count}"
|
||||
)
|
||||
|
||||
if client_complete and server_complete:
|
||||
print("🎉 Handshake completed for both peers!")
|
||||
break
|
||||
|
||||
# If no packets were exchanged in this iteration, wait a bit
|
||||
if not client_packets and not server_packets:
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Safety check - if too many packets, something is wrong
|
||||
if packet_count > 50:
|
||||
print("⚠️ Too many packets exchanged, possible handshake loop")
|
||||
break
|
||||
|
||||
# 9. Enhanced handshake completion checks
|
||||
client_handshake_complete = getattr(client_conn, "_handshake_complete", False)
|
||||
server_handshake_complete = getattr(server_conn, "_handshake_complete", False)
|
||||
|
||||
# Debug additional state information
|
||||
print(f"🔍 Final client state: {getattr(client_conn, '_state', 'unknown')}")
|
||||
print(f"🔍 Final server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
|
||||
if hasattr(client_conn, "tls") and client_conn.tls:
|
||||
print(f"🔍 Client TLS state: {getattr(client_conn.tls, 'state', 'unknown')}")
|
||||
if hasattr(server_conn, "tls") and server_conn.tls:
|
||||
print(f"🔍 Server TLS state: {getattr(server_conn.tls, 'state', 'unknown')}")
|
||||
|
||||
# 10. Cleanup and assertions
|
||||
client_secrets_log_file.close()
|
||||
server_secrets_log_file.close()
|
||||
os.unlink(client_secrets_log_file.name)
|
||||
os.unlink(server_secrets_log_file.name)
|
||||
|
||||
# Final assertions
|
||||
assert client_handshake_complete, (
|
||||
f"❌ Client handshake did not complete. "
|
||||
f"State: {getattr(client_conn, '_state', 'unknown')}, "
|
||||
f"Packets: {packet_count}"
|
||||
)
|
||||
assert server_handshake_complete, (
|
||||
f"❌ Server handshake did not complete. "
|
||||
f"State: {getattr(server_conn, '_state', 'unknown')}, "
|
||||
f"Packets: {packet_count}"
|
||||
)
|
||||
print("✅ Handshake completed for both peers.")
|
||||
|
||||
# Certificate exchange verification
|
||||
client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None)
|
||||
server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None)
|
||||
|
||||
assert client_peer_cert is not None, (
|
||||
"❌ Client FAILED to receive server certificate."
|
||||
)
|
||||
print("✅ Client successfully received server certificate.")
|
||||
|
||||
assert server_peer_cert is not None, (
|
||||
"❌ Server FAILED to receive client certificate."
|
||||
)
|
||||
print("✅ Server successfully received client certificate.")
|
||||
|
||||
print("🎉 Test Passed: Full handshake and certificate exchange successful.")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(test_full_handshake_and_certificate_exchange)
|
||||
@ -1,461 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
"""
|
||||
Fixed QUIC handshake test to debug connection issues.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
from time import time
|
||||
|
||||
from aioquic._buffer import Buffer
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.logger import QuicFileLogger
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import (
|
||||
LIBP2P_TLS_EXTENSION_OID,
|
||||
create_quic_security_transport,
|
||||
)
|
||||
from libp2p.transport.quic.transport import QUICTransport, QUICTransportConfig
|
||||
from libp2p.transport.quic.utils import create_quic_multiaddr
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s", level=logging.DEBUG
|
||||
)
|
||||
|
||||
|
||||
# Adjust this path to your project structure
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
|
||||
|
||||
async def test_certificate_generation():
|
||||
"""Test certificate generation in isolation."""
|
||||
print("\n=== TESTING CERTIFICATE GENERATION ===")
|
||||
|
||||
try:
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
|
||||
# Create key pair
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
print(f"Generated peer ID: {peer_id}")
|
||||
|
||||
# Create security manager
|
||||
security_manager = create_quic_security_transport(private_key, peer_id)
|
||||
print("✅ Security manager created")
|
||||
|
||||
# Test server config
|
||||
server_config = security_manager.create_server_config()
|
||||
print("✅ Server config created")
|
||||
|
||||
# Validate certificate
|
||||
cert = server_config.certificate
|
||||
private_key_obj = server_config.private_key
|
||||
|
||||
print(f"Certificate type: {type(cert)}")
|
||||
print(f"Private key type: {type(private_key_obj)}")
|
||||
print(f"Certificate subject: {cert.subject}")
|
||||
print(f"Certificate issuer: {cert.issuer}")
|
||||
|
||||
# Check for libp2p extension
|
||||
has_libp2p_ext = False
|
||||
for ext in cert.extensions:
|
||||
if ext.oid == LIBP2P_TLS_EXTENSION_OID:
|
||||
has_libp2p_ext = True
|
||||
print(f"✅ Found libp2p extension: {ext.oid}")
|
||||
print(f"Extension critical: {ext.critical}")
|
||||
break
|
||||
|
||||
if not has_libp2p_ext:
|
||||
print("❌ No libp2p extension found!")
|
||||
print("Available extensions:")
|
||||
for ext in cert.extensions:
|
||||
print(f" - {ext.oid} (critical: {ext.critical})")
|
||||
|
||||
# Check certificate/key match
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
cert_public_key = cert.public_key()
|
||||
private_public_key = private_key_obj.public_key()
|
||||
|
||||
cert_pub_bytes = cert_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
private_pub_bytes = private_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
if cert_pub_bytes == private_pub_bytes:
|
||||
print("✅ Certificate and private key match")
|
||||
return has_libp2p_ext
|
||||
else:
|
||||
print("❌ Certificate and private key DO NOT match")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Certificate test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_basic_quic_connection():
|
||||
"""Test basic QUIC connection with proper server setup."""
|
||||
print("\n=== TESTING BASIC QUIC CONNECTION ===")
|
||||
|
||||
try:
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
|
||||
# Create certificates
|
||||
server_key = create_new_key_pair().private_key
|
||||
server_peer_id = ID.from_pubkey(server_key.get_public_key())
|
||||
server_security = create_quic_security_transport(server_key, server_peer_id)
|
||||
|
||||
client_key = create_new_key_pair().private_key
|
||||
client_peer_id = ID.from_pubkey(client_key.get_public_key())
|
||||
client_security = create_quic_security_transport(client_key, client_peer_id)
|
||||
|
||||
# Create server config
|
||||
server_tls_config = server_security.create_server_config()
|
||||
server_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
certificate=server_tls_config.certificate,
|
||||
private_key=server_tls_config.private_key,
|
||||
alpn_protocols=["libp2p"],
|
||||
)
|
||||
|
||||
# Create client config
|
||||
client_tls_config = client_security.create_client_config()
|
||||
client_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
certificate=client_tls_config.certificate,
|
||||
private_key=client_tls_config.private_key,
|
||||
alpn_protocols=["libp2p"],
|
||||
)
|
||||
|
||||
print("✅ QUIC configurations created")
|
||||
|
||||
# Test creating connections with proper parameters
|
||||
# For server, we need to provide original_destination_connection_id
|
||||
original_dcid = secrets.token_bytes(8)
|
||||
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=original_dcid,
|
||||
)
|
||||
|
||||
# For client, no original_destination_connection_id needed
|
||||
client_conn = QuicConnection(configuration=client_config)
|
||||
|
||||
print("✅ QUIC connections created")
|
||||
print(f"Server state: {server_conn._state}")
|
||||
print(f"Client state: {client_conn._state}")
|
||||
|
||||
# Test that certificates are valid
|
||||
print(f"Server has certificate: {server_config.certificate is not None}")
|
||||
print(f"Server has private key: {server_config.private_key is not None}")
|
||||
print(f"Client has certificate: {client_config.certificate is not None}")
|
||||
print(f"Client has private key: {client_config.private_key is not None}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Basic QUIC test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_server_startup():
|
||||
"""Test server startup with timeout."""
|
||||
print("\n=== TESTING SERVER STARTUP ===")
|
||||
|
||||
try:
|
||||
# Create transport
|
||||
private_key = create_new_key_pair().private_key
|
||||
config = QUICTransportConfig(
|
||||
idle_timeout=10.0, # Reduced timeout for testing
|
||||
connection_timeout=10.0,
|
||||
enable_draft29=False,
|
||||
)
|
||||
|
||||
transport = QUICTransport(private_key, config)
|
||||
print("✅ Transport created successfully")
|
||||
|
||||
# Test configuration
|
||||
print(f"Available configs: {list(transport._quic_configs.keys())}")
|
||||
|
||||
config_valid = True
|
||||
for config_key, quic_config in transport._quic_configs.items():
|
||||
print(f"\n--- Testing config: {config_key} ---")
|
||||
print(f"is_client: {quic_config.is_client}")
|
||||
print(f"has_certificate: {quic_config.certificate is not None}")
|
||||
print(f"has_private_key: {quic_config.private_key is not None}")
|
||||
print(f"alpn_protocols: {quic_config.alpn_protocols}")
|
||||
print(f"verify_mode: {quic_config.verify_mode}")
|
||||
|
||||
if quic_config.certificate:
|
||||
cert = quic_config.certificate
|
||||
print(f"Certificate subject: {cert.subject}")
|
||||
|
||||
# Check for libp2p extension
|
||||
has_libp2p_ext = False
|
||||
for ext in cert.extensions:
|
||||
if ext.oid == LIBP2P_TLS_EXTENSION_OID:
|
||||
has_libp2p_ext = True
|
||||
break
|
||||
print(f"Has libp2p extension: {has_libp2p_ext}")
|
||||
|
||||
if not has_libp2p_ext:
|
||||
config_valid = False
|
||||
|
||||
if not config_valid:
|
||||
print("❌ Transport configuration invalid - missing libp2p extensions")
|
||||
return False
|
||||
|
||||
# Create listener
|
||||
async def dummy_handler(connection):
|
||||
print(f"New connection: {connection}")
|
||||
|
||||
listener = transport.create_listener(dummy_handler)
|
||||
print("✅ Listener created successfully")
|
||||
|
||||
# Try to bind with timeout
|
||||
maddr = create_quic_multiaddr("127.0.0.1", 0, "quic-v1")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
result = await listener.listen(maddr, nursery)
|
||||
if result:
|
||||
print("✅ Server bound successfully")
|
||||
addresses = listener.get_addresses()
|
||||
print(f"Listening on: {addresses}")
|
||||
|
||||
# Keep running for a short time
|
||||
with trio.move_on_after(3.0): # 3 second timeout
|
||||
await trio.sleep(5.0)
|
||||
|
||||
print("✅ Server test completed (timed out normally)")
|
||||
nursery.cancel_scope.cancel()
|
||||
return True
|
||||
else:
|
||||
print("❌ Failed to bind server")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Server test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_full_handshake_and_certificate_exchange():
|
||||
"""
|
||||
Test a full handshake to ensure it completes and peer certificates are exchanged.
|
||||
This version is corrected to use the actual APIs available in the codebase.
|
||||
"""
|
||||
print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (CORRECTED) ===")
|
||||
|
||||
# 1. Generate KeyPairs and create libp2p security configs for client and server.
|
||||
# The `create_quic_security_transport` function from `test_quic.py` is the
|
||||
# correct helper to use, and it requires a `KeyPair` argument.
|
||||
client_key_pair = create_new_key_pair()
|
||||
server_key_pair = create_new_key_pair()
|
||||
|
||||
# This is the correct way to get the security configuration objects.
|
||||
client_security_config = create_quic_security_transport(
|
||||
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
||||
)
|
||||
server_security_config = create_quic_security_transport(
|
||||
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
||||
)
|
||||
print("✅ libp2p security configs created.")
|
||||
|
||||
# 2. Create aioquic configurations and manually apply security settings,
|
||||
# mimicking what the `QUICTransport` class does internally.
|
||||
client_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-client.log"
|
||||
)
|
||||
client_aioquic_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=client_secrets_log_file,
|
||||
)
|
||||
client_aioquic_config.certificate = client_security_config.tls_config.certificate
|
||||
client_aioquic_config.private_key = client_security_config.tls_config.private_key
|
||||
client_aioquic_config.verify_mode = (
|
||||
client_security_config.create_client_config().verify_mode
|
||||
)
|
||||
client_aioquic_config.quic_logger = QuicFileLogger(
|
||||
"/home/akmo/GitHub/py-libp2p/examples/echo/logs"
|
||||
)
|
||||
|
||||
server_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-server.log"
|
||||
)
|
||||
|
||||
server_aioquic_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=server_secrets_log_file,
|
||||
)
|
||||
server_aioquic_config.certificate = server_security_config.tls_config.certificate
|
||||
server_aioquic_config.private_key = server_security_config.tls_config.private_key
|
||||
server_aioquic_config.verify_mode = (
|
||||
server_security_config.create_server_config().verify_mode
|
||||
)
|
||||
server_aioquic_config.quic_logger = QuicFileLogger(
|
||||
"/home/akmo/GitHub/py-libp2p/examples/echo/logs"
|
||||
)
|
||||
print("✅ aioquic configurations created and configured.")
|
||||
print(f"🔑 Client secrets will be logged to: {client_secrets_log_file.name}")
|
||||
print(f"🔑 Server secrets will be logged to: {server_secrets_log_file.name}")
|
||||
|
||||
# 3. Instantiate client, initiate its `connect` call, and get the ODCID for the server.
|
||||
client_address = ("127.0.0.1", 1234)
|
||||
server_address = ("127.0.0.1", 4321)
|
||||
|
||||
client_aioquic_config.connection_id_length = 8
|
||||
client_conn = QuicConnection(configuration=client_aioquic_config)
|
||||
client_conn.connect(server_address, now=time())
|
||||
print("✅ aioquic connections instantiated correctly.")
|
||||
|
||||
print("🔧 Client CIDs")
|
||||
print("Local Init CID: ", client_conn._local_initial_source_connection_id.hex())
|
||||
print(
|
||||
"Remote Init CID: ",
|
||||
(client_conn._remote_initial_source_connection_id or b"").hex(),
|
||||
)
|
||||
print(
|
||||
"Original Destination CID: ",
|
||||
client_conn.original_destination_connection_id.hex(),
|
||||
)
|
||||
print(f"Host CID: {client_conn._host_cids[0].cid.hex()}")
|
||||
|
||||
# 4. Instantiate the server with the ODCID from the client.
|
||||
server_aioquic_config.connection_id_length = 8
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_aioquic_config,
|
||||
original_destination_connection_id=client_conn.original_destination_connection_id,
|
||||
)
|
||||
print("✅ aioquic connections instantiated correctly.")
|
||||
|
||||
# 5. Manually drive the handshake process by exchanging datagrams.
|
||||
max_duration_s = 5
|
||||
start_time = time()
|
||||
|
||||
while time() - start_time < max_duration_s:
|
||||
for datagram, _ in client_conn.datagrams_to_send(now=time()):
|
||||
header = pull_quic_header(Buffer(data=datagram), host_cid_length=8)
|
||||
print("Client packet source connection id", header.source_cid.hex())
|
||||
print(
|
||||
"Client packet destination connection id", header.destination_cid.hex()
|
||||
)
|
||||
print("--SERVER INJESTING CLIENT PACKET---")
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
|
||||
print(
|
||||
f"Server remote initial source id: {(server_conn._remote_initial_source_connection_id or b'').hex()}"
|
||||
)
|
||||
for datagram, _ in server_conn.datagrams_to_send(now=time()):
|
||||
header = pull_quic_header(Buffer(data=datagram), host_cid_length=8)
|
||||
print("Server packet source connection id", header.source_cid.hex())
|
||||
print(
|
||||
"Server packet destination connection id", header.destination_cid.hex()
|
||||
)
|
||||
print("--CLIENT INJESTING SERVER PACKET---")
|
||||
client_conn.receive_datagram(datagram, server_address, now=time())
|
||||
|
||||
# Check for completion
|
||||
if client_conn._handshake_complete and server_conn._handshake_complete:
|
||||
break
|
||||
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# 6. Assertions to verify the outcome.
|
||||
assert client_conn._handshake_complete, "❌ Client handshake did not complete."
|
||||
assert server_conn._handshake_complete, "❌ Server handshake did not complete."
|
||||
print("✅ Handshake completed for both peers.")
|
||||
|
||||
# The key assertion: check if the peer certificate was received.
|
||||
client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None)
|
||||
server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None)
|
||||
|
||||
client_secrets_log_file.close()
|
||||
server_secrets_log_file.close()
|
||||
os.unlink(client_secrets_log_file.name)
|
||||
os.unlink(server_secrets_log_file.name)
|
||||
|
||||
assert client_peer_cert is not None, (
|
||||
"❌ Client FAILED to receive server certificate."
|
||||
)
|
||||
print("✅ Client successfully received server certificate.")
|
||||
|
||||
print("🎉 Test Passed: Full handshake and certificate exchange successful.")
|
||||
return True
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests with better error handling."""
|
||||
print("Starting QUIC diagnostic tests...")
|
||||
|
||||
handshake_ok = await test_full_handshake_and_certificate_exchange()
|
||||
if not handshake_ok:
|
||||
print("\n❌ CRITICAL: Handshake failed!")
|
||||
print("Apply the handshake fix and try again.")
|
||||
return
|
||||
|
||||
# Test 1: Certificate generation
|
||||
cert_ok = await test_certificate_generation()
|
||||
if not cert_ok:
|
||||
print("\n❌ CRITICAL: Certificate generation failed!")
|
||||
print("Apply the certificate generation fix and try again.")
|
||||
return
|
||||
|
||||
# Test 2: Basic QUIC connection
|
||||
quic_ok = await test_basic_quic_connection()
|
||||
if not quic_ok:
|
||||
print("\n❌ CRITICAL: Basic QUIC connection test failed!")
|
||||
return
|
||||
|
||||
# Test 3: Server startup
|
||||
server_ok = await test_server_startup()
|
||||
if not server_ok:
|
||||
print("\n❌ Server startup test failed!")
|
||||
return
|
||||
|
||||
print("\n✅ ALL TESTS PASSED!")
|
||||
print("=== DIAGNOSTIC COMPLETE ===")
|
||||
print("Your QUIC implementation should now work correctly.")
|
||||
print("Try running your echo example again.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
@ -183,14 +183,6 @@ class Swarm(Service, INetworkService):
|
||||
"""
|
||||
Try to create a connection to peer_id with addr.
|
||||
"""
|
||||
# QUIC Transport
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
raw_conn = await self.transport.dial(addr, peer_id)
|
||||
print("detected QUIC connection, skipping upgrade steps")
|
||||
swarm_conn = await self.add_conn(raw_conn)
|
||||
print("successfully dialed peer %s via QUIC", peer_id)
|
||||
return swarm_conn
|
||||
|
||||
try:
|
||||
raw_conn = await self.transport.dial(addr)
|
||||
except OpenConnectionError as error:
|
||||
|
||||
@ -179,7 +179,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"connection_id_changes": 0,
|
||||
}
|
||||
|
||||
print(
|
||||
logger.info(
|
||||
f"Created QUIC connection to {remote_peer_id} "
|
||||
f"(initiator: {is_initiator}, addr: {remote_addr}, "
|
||||
"security: {security_manager is not None})"
|
||||
@ -278,7 +278,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
self._started = True
|
||||
self.event_started.set()
|
||||
print(f"Starting QUIC connection to {self._remote_peer_id}")
|
||||
logger.info(f"Starting QUIC connection to {self._remote_peer_id}")
|
||||
|
||||
try:
|
||||
# If this is a client connection, we need to establish the connection
|
||||
@ -289,7 +289,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._established = True
|
||||
self._connected_event.set()
|
||||
|
||||
print(f"QUIC connection to {self._remote_peer_id} started")
|
||||
logger.info(f"QUIC connection to {self._remote_peer_id} started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start connection: {e}")
|
||||
@ -300,7 +300,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
with QUICErrorContext("connection_initiation", "connection"):
|
||||
if not self._socket:
|
||||
print("Creating new socket for outbound connection")
|
||||
logger.info("Creating new socket for outbound connection")
|
||||
self._socket = trio.socket.socket(
|
||||
family=socket.AF_INET, type=socket.SOCK_DGRAM
|
||||
)
|
||||
@ -312,7 +312,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Send initial packet(s)
|
||||
await self._transmit()
|
||||
|
||||
print(f"Initiated QUIC connection to {self._remote_addr}")
|
||||
logger.info(f"Initiated QUIC connection to {self._remote_addr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate connection: {e}")
|
||||
@ -334,16 +334,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
with QUICErrorContext("connection_establishment", "connection"):
|
||||
# Start the connection if not already started
|
||||
print("STARTING TO CONNECT")
|
||||
logger.info("STARTING TO CONNECT")
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
# Start background event processing
|
||||
if not self._background_tasks_started:
|
||||
print("STARTING BACKGROUND TASK")
|
||||
logger.info("STARTING BACKGROUND TASK")
|
||||
await self._start_background_tasks()
|
||||
else:
|
||||
print("BACKGROUND TASK ALREADY STARTED")
|
||||
logger.info("BACKGROUND TASK ALREADY STARTED")
|
||||
|
||||
# Wait for handshake completion with timeout
|
||||
with trio.move_on_after(
|
||||
@ -357,13 +357,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
f"{self.CONNECTION_HANDSHAKE_TIMEOUT}s"
|
||||
)
|
||||
|
||||
print("QUICConnection: Verifying peer identity with security manager")
|
||||
logger.info(
|
||||
"QUICConnection: Verifying peer identity with security manager"
|
||||
)
|
||||
# Verify peer identity using security manager
|
||||
await self._verify_peer_identity_with_security()
|
||||
|
||||
print("QUICConnection: Peer identity verified")
|
||||
logger.info("QUICConnection: Peer identity verified")
|
||||
self._established = True
|
||||
print(f"QUIC connection established with {self._remote_peer_id}")
|
||||
logger.info(f"QUIC connection established with {self._remote_peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to establish connection: {e}")
|
||||
@ -378,22 +380,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._background_tasks_started = True
|
||||
|
||||
if self.__is_initiator:
|
||||
print(f"CLIENT CONNECTION {id(self)}: Starting processing event loop")
|
||||
self._nursery.start_soon(async_fn=self._client_packet_receiver)
|
||||
self._nursery.start_soon(async_fn=self._event_processing_loop)
|
||||
else:
|
||||
print(
|
||||
f"SERVER CONNECTION {id(self)}: Using listener event forwarding, not own loop"
|
||||
)
|
||||
|
||||
# Start periodic tasks
|
||||
self._nursery.start_soon(async_fn=self._event_processing_loop)
|
||||
self._nursery.start_soon(async_fn=self._periodic_maintenance)
|
||||
|
||||
print("Started background tasks for QUIC connection")
|
||||
logger.info("Started background tasks for QUIC connection")
|
||||
|
||||
async def _event_processing_loop(self) -> None:
|
||||
"""Main event processing loop for the connection."""
|
||||
print(
|
||||
logger.info(
|
||||
f"Started QUIC event processing loop for connection id: {id(self)} "
|
||||
f"and local peer id {str(self.local_peer_id())}"
|
||||
)
|
||||
@ -416,7 +412,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
logger.error(f"Error in event processing loop: {e}")
|
||||
await self._handle_connection_error(e)
|
||||
finally:
|
||||
print("QUIC event processing loop finished")
|
||||
logger.info("QUIC event processing loop finished")
|
||||
|
||||
async def _periodic_maintenance(self) -> None:
|
||||
"""Perform periodic connection maintenance."""
|
||||
@ -431,7 +427,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# *** NEW: Log connection ID status periodically ***
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
cid_stats = self.get_connection_id_stats()
|
||||
print(f"Connection ID stats: {cid_stats}")
|
||||
logger.info(f"Connection ID stats: {cid_stats}")
|
||||
|
||||
# Sleep for maintenance interval
|
||||
await trio.sleep(30.0) # 30 seconds
|
||||
@ -441,15 +437,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
async def _client_packet_receiver(self) -> None:
|
||||
"""Receive packets for client connections."""
|
||||
print("Starting client packet receiver")
|
||||
print("Started QUIC client packet receiver")
|
||||
logger.info("Starting client packet receiver")
|
||||
logger.info("Started QUIC client packet receiver")
|
||||
|
||||
try:
|
||||
while not self._closed and self._socket:
|
||||
try:
|
||||
# Receive UDP packets
|
||||
data, addr = await self._socket.recvfrom(65536)
|
||||
print(f"Client received {len(data)} bytes from {addr}")
|
||||
logger.info(f"Client received {len(data)} bytes from {addr}")
|
||||
|
||||
# Feed packet to QUIC connection
|
||||
self._quic.receive_datagram(data, addr, now=time.time())
|
||||
@ -461,7 +457,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await self._transmit()
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
print("Client socket closed")
|
||||
logger.info("Client socket closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving client packet: {e}")
|
||||
@ -471,7 +467,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
logger.info("Client packet receiver cancelled")
|
||||
raise
|
||||
finally:
|
||||
print("Client packet receiver terminated")
|
||||
logger.info("Client packet receiver terminated")
|
||||
|
||||
# Security and identity methods
|
||||
|
||||
@ -483,7 +479,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
QUICPeerVerificationError: If peer verification fails
|
||||
|
||||
"""
|
||||
print("VERIFYING PEER IDENTITY")
|
||||
logger.info("VERIFYING PEER IDENTITY")
|
||||
if not self._security_manager:
|
||||
logger.warning("No security manager available for peer verification")
|
||||
return
|
||||
@ -512,7 +508,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
logger.info(f"Discovered peer ID from certificate: {verified_peer_id}")
|
||||
elif self._remote_peer_id != verified_peer_id:
|
||||
raise QUICPeerVerificationError(
|
||||
f"Peer ID mismatch: expected {self._remote_peer_id}, got {verified_peer_id}"
|
||||
f"Peer ID mismatch: expected {self._remote_peer_id}, "
|
||||
"got {verified_peer_id}"
|
||||
)
|
||||
|
||||
self._peer_verified = True
|
||||
@ -541,14 +538,14 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# aioquic stores the peer certificate as cryptography
|
||||
# x509.Certificate
|
||||
self._peer_certificate = tls_context._peer_certificate
|
||||
print(
|
||||
logger.info(
|
||||
f"Extracted peer certificate: {self._peer_certificate.subject}"
|
||||
)
|
||||
else:
|
||||
print("No peer certificate found in TLS context")
|
||||
logger.info("No peer certificate found in TLS context")
|
||||
|
||||
else:
|
||||
print("No TLS context available for certificate extraction")
|
||||
logger.info("No TLS context available for certificate extraction")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract peer certificate: {e}")
|
||||
@ -556,15 +553,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Try alternative approach - check if certificate is in handshake events
|
||||
try:
|
||||
# Some versions of aioquic might expose certificate differently
|
||||
if hasattr(self._quic, "configuration") and self._quic.configuration:
|
||||
config = self._quic.configuration
|
||||
if hasattr(config, "certificate") and config.certificate:
|
||||
# This would be the local certificate, not peer certificate
|
||||
# but we can use it for debugging
|
||||
print("Found local certificate in configuration")
|
||||
config = self._quic.configuration
|
||||
if hasattr(config, "certificate") and config.certificate:
|
||||
# This would be the local certificate, not peer certificate
|
||||
# but we can use it for debugging
|
||||
logger.debug("Found local certificate in configuration")
|
||||
|
||||
except Exception as inner_e:
|
||||
print(f"Alternative certificate extraction also failed: {inner_e}")
|
||||
logger.error(
|
||||
f"Alternative certificate extraction also failed: {inner_e}"
|
||||
)
|
||||
|
||||
async def get_peer_certificate(self) -> x509.Certificate | None:
|
||||
"""
|
||||
@ -596,7 +594,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
subject = self._peer_certificate.subject
|
||||
serial_number = self._peer_certificate.serial_number
|
||||
|
||||
print(
|
||||
logger.info(
|
||||
f"Certificate validation - Subject: {subject}, Serial: {serial_number}"
|
||||
)
|
||||
return True
|
||||
@ -721,7 +719,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._outbound_stream_count += 1
|
||||
self._stats["streams_opened"] += 1
|
||||
|
||||
print(f"Opened outbound QUIC stream {stream_id}")
|
||||
logger.info(f"Opened outbound QUIC stream {stream_id}")
|
||||
return stream
|
||||
|
||||
raise QUICStreamTimeoutError(f"Stream creation timed out after {timeout}s")
|
||||
@ -754,7 +752,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
async with self._accept_queue_lock:
|
||||
if self._stream_accept_queue:
|
||||
stream = self._stream_accept_queue.pop(0)
|
||||
print(f"Accepted inbound stream {stream.stream_id}")
|
||||
logger.debug(f"Accepted inbound stream {stream.stream_id}")
|
||||
return stream
|
||||
|
||||
if self._closed:
|
||||
@ -765,8 +763,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Wait for new streams
|
||||
await self._stream_accept_event.wait()
|
||||
|
||||
print(
|
||||
f"{id(self)} ACCEPT STREAM TIMEOUT: CONNECTION STATE {self._closed_event.is_set() or self._closed}"
|
||||
logger.error(
|
||||
"Timeout occured while accepting stream for local peer "
|
||||
f"{self._local_peer_id.to_string()} on QUIC connection"
|
||||
)
|
||||
if self._closed_event.is_set() or self._closed:
|
||||
raise MuxedConnUnavailable("QUIC connection closed during timeout")
|
||||
@ -782,7 +781,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
"""
|
||||
self._stream_handler = handler_function
|
||||
print("Set stream handler for incoming streams")
|
||||
logger.info("Set stream handler for incoming streams")
|
||||
|
||||
def _remove_stream(self, stream_id: int) -> None:
|
||||
"""
|
||||
@ -809,7 +808,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(update_counts)
|
||||
|
||||
print(f"Removed stream {stream_id} from connection")
|
||||
logger.info(f"Removed stream {stream_id} from connection")
|
||||
|
||||
# *** UPDATED: Complete QUIC event handling - FIXES THE ORIGINAL ISSUE ***
|
||||
|
||||
@ -831,15 +830,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await self._handle_quic_event(event)
|
||||
|
||||
if events_processed > 0:
|
||||
print(f"Processed {events_processed} QUIC events")
|
||||
logger.info(f"Processed {events_processed} QUIC events")
|
||||
|
||||
finally:
|
||||
self._event_processing_active = False
|
||||
|
||||
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
|
||||
"""Handle a single QUIC event with COMPLETE event type coverage."""
|
||||
print(f"Handling QUIC event: {type(event).__name__}")
|
||||
print(f"QUIC event: {type(event).__name__}")
|
||||
logger.info(f"Handling QUIC event: {type(event).__name__}")
|
||||
logger.info(f"QUIC event: {type(event).__name__}")
|
||||
|
||||
try:
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
@ -865,8 +864,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
elif isinstance(event, events.StopSendingReceived):
|
||||
await self._handle_stop_sending_received(event)
|
||||
else:
|
||||
print(f"Unhandled QUIC event type: {type(event).__name__}")
|
||||
print(f"Unhandled QUIC event: {type(event).__name__}")
|
||||
logger.info(f"Unhandled QUIC event type: {type(event).__name__}")
|
||||
logger.info(f"Unhandled QUIC event: {type(event).__name__}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling QUIC event {type(event).__name__}: {e}")
|
||||
@ -882,7 +881,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
This is the CRITICAL missing functionality that was causing your issue!
|
||||
"""
|
||||
logger.info(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
|
||||
print(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
|
||||
logger.info(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
|
||||
|
||||
# Add to available connection IDs
|
||||
self._available_connection_ids.add(event.connection_id)
|
||||
@ -891,13 +890,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if self._current_connection_id is None:
|
||||
self._current_connection_id = event.connection_id
|
||||
logger.info(f"🆔 Set current connection ID to: {event.connection_id.hex()}")
|
||||
print(f"🆔 Set current connection ID to: {event.connection_id.hex()}")
|
||||
logger.info(f"🆔 Set current connection ID to: {event.connection_id.hex()}")
|
||||
|
||||
# Update statistics
|
||||
self._stats["connection_ids_issued"] += 1
|
||||
|
||||
print(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
print(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
logger.info(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
logger.info(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
|
||||
async def _handle_connection_id_retired(
|
||||
self, event: events.ConnectionIdRetired
|
||||
@ -908,7 +907,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
This handles when the peer tells us to stop using a connection ID.
|
||||
"""
|
||||
logger.info(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
|
||||
print(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
|
||||
logger.info(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
|
||||
|
||||
# Remove from available IDs and add to retired set
|
||||
self._available_connection_ids.discard(event.connection_id)
|
||||
@ -918,17 +917,14 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if self._current_connection_id == event.connection_id:
|
||||
if self._available_connection_ids:
|
||||
self._current_connection_id = next(iter(self._available_connection_ids))
|
||||
logger.info(
|
||||
f"🆔 Switched to new connection ID: {self._current_connection_id.hex()}"
|
||||
)
|
||||
print(
|
||||
f"🆔 Switched to new connection ID: {self._current_connection_id.hex()}"
|
||||
logger.debug(
|
||||
f"Switching new connection ID: {self._current_connection_id.hex()}"
|
||||
)
|
||||
self._stats["connection_id_changes"] += 1
|
||||
else:
|
||||
self._current_connection_id = None
|
||||
logger.warning("⚠️ No available connection IDs after retirement!")
|
||||
print("⚠️ No available connection IDs after retirement!")
|
||||
logger.info("⚠️ No available connection IDs after retirement!")
|
||||
|
||||
# Update statistics
|
||||
self._stats["connection_ids_retired"] += 1
|
||||
@ -937,7 +933,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
async def _handle_ping_acknowledged(self, event: events.PingAcknowledged) -> None:
|
||||
"""Handle ping acknowledgment."""
|
||||
print(f"Ping acknowledged: uid={event.uid}")
|
||||
logger.info(f"Ping acknowledged: uid={event.uid}")
|
||||
|
||||
async def _handle_protocol_negotiated(
|
||||
self, event: events.ProtocolNegotiated
|
||||
@ -949,15 +945,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self, event: events.StopSendingReceived
|
||||
) -> None:
|
||||
"""Handle stop sending request from peer."""
|
||||
print(
|
||||
f"Stop sending received: stream_id={event.stream_id}, error_code={event.error_code}"
|
||||
logger.debug(
|
||||
"Stop sending received: "
|
||||
f"stream_id={event.stream_id}, error_code={event.error_code}"
|
||||
)
|
||||
|
||||
if event.stream_id in self._streams:
|
||||
stream = self._streams[event.stream_id]
|
||||
stream: QUICStream = self._streams[event.stream_id]
|
||||
# Handle stop sending on the stream if method exists
|
||||
if hasattr(stream, "handle_stop_sending"):
|
||||
await stream.handle_stop_sending(event.error_code)
|
||||
await stream.handle_stop_sending(event.error_code)
|
||||
|
||||
# *** EXISTING event handlers (unchanged) ***
|
||||
|
||||
@ -965,7 +961,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self, event: events.HandshakeCompleted
|
||||
) -> None:
|
||||
"""Handle handshake completion with security integration."""
|
||||
print("QUIC handshake completed")
|
||||
logger.info("QUIC handshake completed")
|
||||
self._handshake_completed = True
|
||||
|
||||
# Store handshake event for security verification
|
||||
@ -974,14 +970,14 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Try to extract certificate information after handshake
|
||||
await self._extract_peer_certificate()
|
||||
|
||||
print("✅ Setting connected event")
|
||||
logger.info("✅ Setting connected event")
|
||||
self._connected_event.set()
|
||||
|
||||
async def _handle_connection_terminated(
|
||||
self, event: events.ConnectionTerminated
|
||||
) -> None:
|
||||
"""Handle connection termination."""
|
||||
print(f"QUIC connection terminated: {event.reason_phrase}")
|
||||
logger.info(f"QUIC connection terminated: {event.reason_phrase}")
|
||||
|
||||
# Close all streams
|
||||
for stream in list(self._streams.values()):
|
||||
@ -995,7 +991,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._closed_event.set()
|
||||
|
||||
self._stream_accept_event.set()
|
||||
print(f"✅ TERMINATION: Woke up pending accept_stream() calls, {id(self)}")
|
||||
logger.debug(f"Woke up pending accept_stream() calls, {id(self)}")
|
||||
|
||||
await self._notify_parent_of_termination()
|
||||
|
||||
@ -1005,11 +1001,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._stats["bytes_received"] += len(event.data)
|
||||
|
||||
try:
|
||||
print(f"🔧 STREAM_DATA: Handling data for stream {stream_id}")
|
||||
|
||||
if stream_id not in self._streams:
|
||||
if self._is_incoming_stream(stream_id):
|
||||
print(f"🔧 STREAM_DATA: Creating new incoming stream {stream_id}")
|
||||
logger.info(f"Creating new incoming stream {stream_id}")
|
||||
|
||||
from .stream import QUICStream, StreamDirection
|
||||
|
||||
@ -1027,29 +1021,24 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
async with self._accept_queue_lock:
|
||||
self._stream_accept_queue.append(stream)
|
||||
self._stream_accept_event.set()
|
||||
print(
|
||||
f"✅ STREAM_DATA: Added stream {stream_id} to accept queue"
|
||||
)
|
||||
logger.debug(f"Added stream {stream_id} to accept queue")
|
||||
|
||||
async with self._stream_count_lock:
|
||||
self._inbound_stream_count += 1
|
||||
self._stats["streams_opened"] += 1
|
||||
|
||||
else:
|
||||
print(
|
||||
f"❌ STREAM_DATA: Unexpected outbound stream {stream_id} in data event"
|
||||
logger.error(
|
||||
f"Unexpected outbound stream {stream_id} in data event"
|
||||
)
|
||||
return
|
||||
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_data_received(event.data, event.end_stream)
|
||||
print(
|
||||
f"✅ STREAM_DATA: Forwarded {len(event.data)} bytes to stream {stream_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling stream data for stream {stream_id}: {e}")
|
||||
print(f"❌ STREAM_DATA: Error: {e}")
|
||||
logger.info(f"❌ STREAM_DATA: Error: {e}")
|
||||
|
||||
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
|
||||
"""Get existing stream or create new inbound stream."""
|
||||
@ -1106,7 +1095,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream handler for stream {stream_id}: {e}")
|
||||
|
||||
print(f"Created inbound stream {stream_id}")
|
||||
logger.info(f"Created inbound stream {stream_id}")
|
||||
return stream
|
||||
|
||||
def _is_incoming_stream(self, stream_id: int) -> bool:
|
||||
@ -1133,7 +1122,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_reset(event.error_code)
|
||||
print(
|
||||
logger.info(
|
||||
f"Handled reset for stream {stream_id}"
|
||||
f"with error code {event.error_code}"
|
||||
)
|
||||
@ -1142,13 +1131,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Force remove the stream
|
||||
self._remove_stream(stream_id)
|
||||
else:
|
||||
print(f"Received reset for unknown stream {stream_id}")
|
||||
logger.info(f"Received reset for unknown stream {stream_id}")
|
||||
|
||||
async def _handle_datagram_received(
|
||||
self, event: events.DatagramFrameReceived
|
||||
) -> None:
|
||||
"""Handle datagram frame (if using QUIC datagrams)."""
|
||||
print(f"Datagram frame received: size={len(event.data)}")
|
||||
logger.info(f"Datagram frame received: size={len(event.data)}")
|
||||
# For now, just log. Could be extended for custom datagram handling
|
||||
|
||||
async def _handle_timer_events(self) -> None:
|
||||
@ -1165,7 +1154,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Transmit pending QUIC packets using available socket."""
|
||||
sock = self._socket
|
||||
if not sock:
|
||||
print("No socket to transmit")
|
||||
logger.info("No socket to transmit")
|
||||
return
|
||||
|
||||
try:
|
||||
@ -1183,11 +1172,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await self._handle_connection_error(e)
|
||||
|
||||
# Additional methods for stream data processing
|
||||
async def _process_quic_event(self, event):
|
||||
async def _process_quic_event(self, event: events.QuicEvent) -> None:
|
||||
"""Process a single QUIC event."""
|
||||
await self._handle_quic_event(event)
|
||||
|
||||
async def _transmit_pending_data(self):
|
||||
async def _transmit_pending_data(self) -> None:
|
||||
"""Transmit any pending data."""
|
||||
await self._transmit()
|
||||
|
||||
@ -1211,7 +1200,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
print(f"Closing QUIC connection to {self._remote_peer_id}")
|
||||
logger.info(f"Closing QUIC connection to {self._remote_peer_id}")
|
||||
|
||||
try:
|
||||
# Close all streams gracefully
|
||||
@ -1253,7 +1242,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._streams.clear()
|
||||
self._closed_event.set()
|
||||
|
||||
print(f"QUIC connection to {self._remote_peer_id} closed")
|
||||
logger.info(f"QUIC connection to {self._remote_peer_id} closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during connection close: {e}")
|
||||
@ -1268,13 +1257,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
if self._transport:
|
||||
await self._transport._cleanup_terminated_connection(self)
|
||||
print("Notified transport of connection termination")
|
||||
logger.info("Notified transport of connection termination")
|
||||
return
|
||||
|
||||
for listener in self._transport._listeners:
|
||||
try:
|
||||
await listener._remove_connection_by_object(self)
|
||||
print("Found and notified listener of connection termination")
|
||||
logger.info("Found and notified listener of connection termination")
|
||||
return
|
||||
except Exception:
|
||||
continue
|
||||
@ -1285,7 +1274,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Could not notify parent of connection termination - no parent reference found"
|
||||
"Could not notify parent of connection termination - no"
|
||||
f" parent reference found for conn host {self._quic.host_cid.hex()}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -1298,12 +1288,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
for tracked_cid, tracked_conn in list(listener._connections.items()):
|
||||
if tracked_conn is self:
|
||||
await listener._remove_connection(tracked_cid)
|
||||
print(
|
||||
f"Removed connection {tracked_cid.hex()} by object reference"
|
||||
)
|
||||
logger.info(f"Removed connection {tracked_cid.hex()}")
|
||||
return
|
||||
|
||||
print("Fallback cleanup by connection ID completed")
|
||||
logger.info("Fallback cleanup by connection ID completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback cleanup: {e}")
|
||||
|
||||
@ -1401,6 +1389,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# String representation
|
||||
|
||||
def __repr__(self) -> str:
|
||||
current_cid: str | None = (
|
||||
self._current_connection_id.hex() if self._current_connection_id else None
|
||||
)
|
||||
return (
|
||||
f"QUICConnection(peer={self._remote_peer_id}, "
|
||||
f"addr={self._remote_addr}, "
|
||||
@ -1408,7 +1399,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
f"verified={self._peer_verified}, "
|
||||
f"established={self._established}, "
|
||||
f"streams={len(self._streams)}, "
|
||||
f"current_cid={self._current_connection_id.hex() if self._current_connection_id else None})"
|
||||
f"current_cid={current_cid})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@ -42,7 +42,6 @@ if TYPE_CHECKING:
|
||||
from .transport import QUICTransport
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
@ -277,63 +276,40 @@ class QUICListener(IListener):
|
||||
self._stats["packets_processed"] += 1
|
||||
self._stats["bytes_received"] += len(data)
|
||||
|
||||
print(f"🔧 PACKET: Processing {len(data)} bytes from {addr}")
|
||||
logger.debug(f"Processing packet of {len(data)} bytes from {addr}")
|
||||
|
||||
# Parse packet header OUTSIDE the lock
|
||||
packet_info = self.parse_quic_packet(data)
|
||||
if packet_info is None:
|
||||
print("❌ PACKET: Failed to parse packet header")
|
||||
logger.error(f"Failed to parse packet header quic packet from {addr}")
|
||||
self._stats["invalid_packets"] += 1
|
||||
return
|
||||
|
||||
dest_cid = packet_info.destination_cid
|
||||
print(f"🔧 DEBUG: Packet info: {packet_info is not None}")
|
||||
print(f"🔧 DEBUG: Packet type: {packet_info.packet_type}")
|
||||
print(
|
||||
f"🔧 DEBUG: Is short header: {packet_info.packet_type.name != 'INITIAL'}"
|
||||
)
|
||||
|
||||
# CRITICAL FIX: Reduce lock scope - only protect connection lookups
|
||||
# Get connection references with minimal lock time
|
||||
connection_obj = None
|
||||
pending_quic_conn = None
|
||||
|
||||
async with self._connection_lock:
|
||||
# Quick lookup operations only
|
||||
print(
|
||||
f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}"
|
||||
)
|
||||
print(
|
||||
f"🔧 DEBUG: Established connections: {[cid.hex() for cid in self._connections.keys()]}"
|
||||
)
|
||||
|
||||
if dest_cid in self._connections:
|
||||
connection_obj = self._connections[dest_cid]
|
||||
print(
|
||||
f"✅ PACKET: Routing to established connection {dest_cid.hex()}"
|
||||
)
|
||||
print(f"PACKET: Routing to established connection {dest_cid.hex()}")
|
||||
|
||||
elif dest_cid in self._pending_connections:
|
||||
pending_quic_conn = self._pending_connections[dest_cid]
|
||||
print(f"✅ PACKET: Routing to pending connection {dest_cid.hex()}")
|
||||
print(f"PACKET: Routing to pending connection {dest_cid.hex()}")
|
||||
|
||||
else:
|
||||
# Check if this is a new connection
|
||||
print(
|
||||
f"🔧 PACKET: Parsed packet - version: {packet_info.version:#x}, dest_cid: {dest_cid.hex()}, src_cid: {packet_info.source_cid.hex()}"
|
||||
)
|
||||
|
||||
if packet_info.packet_type.name == "INITIAL":
|
||||
print(f"🔧 PACKET: Creating new connection for {addr}")
|
||||
logger.debug(
|
||||
f"Received INITIAL Packet Creating new conn for {addr}"
|
||||
)
|
||||
|
||||
# Create new connection INSIDE the lock for safety
|
||||
pending_quic_conn = await self._handle_new_connection(
|
||||
data, addr, packet_info
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"❌ PACKET: Unknown connection for non-initial packet {dest_cid.hex()}"
|
||||
)
|
||||
return
|
||||
|
||||
# CRITICAL: Process packets OUTSIDE the lock to prevent deadlock
|
||||
@ -364,7 +340,7 @@ class QUICListener(IListener):
|
||||
) -> None:
|
||||
"""Handle packet for established connection WITHOUT holding connection lock."""
|
||||
try:
|
||||
print(f"🔧 ESTABLISHED: Handling packet for connection {dest_cid.hex()}")
|
||||
print(f" ESTABLISHED: Handling packet for connection {dest_cid.hex()}")
|
||||
|
||||
# Forward packet to connection object
|
||||
# This may trigger event processing and stream creation
|
||||
@ -382,21 +358,19 @@ class QUICListener(IListener):
|
||||
) -> None:
|
||||
"""Handle packet for pending connection WITHOUT holding connection lock."""
|
||||
try:
|
||||
print(
|
||||
f"🔧 PENDING: Handling packet for pending connection {dest_cid.hex()}"
|
||||
)
|
||||
print(f"🔧 PENDING: Packet size: {len(data)} bytes from {addr}")
|
||||
print(f"Handling packet for pending connection {dest_cid.hex()}")
|
||||
print(f"Packet size: {len(data)} bytes from {addr}")
|
||||
|
||||
# Feed data to QUIC connection
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
print("✅ PENDING: Datagram received by QUIC connection")
|
||||
print("PENDING: Datagram received by QUIC connection")
|
||||
|
||||
# Process events - this is crucial for handshake progression
|
||||
print("🔧 PENDING: Processing QUIC events...")
|
||||
print("Processing QUIC events...")
|
||||
await self._process_quic_events(quic_conn, addr, dest_cid)
|
||||
|
||||
# Send any outgoing packets
|
||||
print("🔧 PENDING: Transmitting response...")
|
||||
print("Transmitting response...")
|
||||
await self._transmit_for_connection(quic_conn, addr)
|
||||
|
||||
# Check if handshake completed (with minimal locking)
|
||||
@ -404,10 +378,10 @@ class QUICListener(IListener):
|
||||
hasattr(quic_conn, "_handshake_complete")
|
||||
and quic_conn._handshake_complete
|
||||
):
|
||||
print("✅ PENDING: Handshake completed, promoting connection")
|
||||
print("PENDING: Handshake completed, promoting connection")
|
||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||
else:
|
||||
print("🔧 PENDING: Handshake still in progress")
|
||||
print("Handshake still in progress")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}")
|
||||
@ -455,35 +429,28 @@ class QUICListener(IListener):
|
||||
|
||||
async def _handle_new_connection(
|
||||
self, data: bytes, addr: tuple[str, int], packet_info: QUICPacketInfo
|
||||
) -> None:
|
||||
) -> QuicConnection | None:
|
||||
"""Handle new connection with proper connection ID handling."""
|
||||
try:
|
||||
print(f"🔧 NEW_CONN: Starting handshake for {addr}")
|
||||
logger.debug(f"Starting handshake for {addr}")
|
||||
|
||||
# Find appropriate QUIC configuration
|
||||
quic_config = None
|
||||
config_key = None
|
||||
|
||||
for protocol, config in self._quic_configs.items():
|
||||
wire_versions = custom_quic_version_to_wire_format(protocol)
|
||||
if wire_versions == packet_info.version:
|
||||
quic_config = config
|
||||
config_key = protocol
|
||||
break
|
||||
|
||||
if not quic_config:
|
||||
print(
|
||||
f"❌ NEW_CONN: No configuration found for version 0x{packet_info.version:08x}"
|
||||
)
|
||||
print(
|
||||
f"🔧 NEW_CONN: Available configs: {list(self._quic_configs.keys())}"
|
||||
logger.error(
|
||||
f"No configuration found for version 0x{packet_info.version:08x}"
|
||||
)
|
||||
await self._send_version_negotiation(addr, packet_info.source_cid)
|
||||
return
|
||||
|
||||
print(
|
||||
f"✅ NEW_CONN: Using config {config_key} for version 0x{packet_info.version:08x}"
|
||||
)
|
||||
if not quic_config:
|
||||
raise QUICListenError("Cannot determine QUIC configuration")
|
||||
|
||||
# Create server-side QUIC configuration
|
||||
server_config = create_server_config_from_base(
|
||||
@ -492,19 +459,6 @@ class QUICListener(IListener):
|
||||
transport_config=self._config,
|
||||
)
|
||||
|
||||
# Debug the server configuration
|
||||
print(f"🔧 NEW_CONN: Server config - is_client: {server_config.is_client}")
|
||||
print(
|
||||
f"🔧 NEW_CONN: Server config - has_certificate: {server_config.certificate is not None}"
|
||||
)
|
||||
print(
|
||||
f"🔧 NEW_CONN: Server config - has_private_key: {server_config.private_key is not None}"
|
||||
)
|
||||
print(f"🔧 NEW_CONN: Server config - ALPN: {server_config.alpn_protocols}")
|
||||
print(
|
||||
f"🔧 NEW_CONN: Server config - verify_mode: {server_config.verify_mode}"
|
||||
)
|
||||
|
||||
# Validate certificate has libp2p extension
|
||||
if server_config.certificate:
|
||||
cert = server_config.certificate
|
||||
@ -513,24 +467,15 @@ class QUICListener(IListener):
|
||||
if ext.oid == LIBP2P_TLS_EXTENSION_OID:
|
||||
has_libp2p_ext = True
|
||||
break
|
||||
print(
|
||||
f"🔧 NEW_CONN: Certificate has libp2p extension: {has_libp2p_ext}"
|
||||
)
|
||||
logger.debug(f"Certificate has libp2p extension: {has_libp2p_ext}")
|
||||
|
||||
if not has_libp2p_ext:
|
||||
print("❌ NEW_CONN: Certificate missing libp2p extension!")
|
||||
logger.error("Certificate missing libp2p extension!")
|
||||
|
||||
# Generate a new destination connection ID for this connection
|
||||
import secrets
|
||||
|
||||
destination_cid = secrets.token_bytes(8)
|
||||
|
||||
print(f"🔧 NEW_CONN: Generated new CID: {destination_cid.hex()}")
|
||||
print(
|
||||
f"🔧 NEW_CONN: Original destination CID: {packet_info.destination_cid.hex()}"
|
||||
logger.debug(
|
||||
f"Original destination CID: {packet_info.destination_cid.hex()}"
|
||||
)
|
||||
|
||||
# Create QUIC connection with proper parameters for server
|
||||
quic_conn = QuicConnection(
|
||||
configuration=server_config,
|
||||
original_destination_connection_id=packet_info.destination_cid,
|
||||
@ -540,38 +485,28 @@ class QUICListener(IListener):
|
||||
# Use the first host CID as our routing CID
|
||||
if quic_conn._host_cids:
|
||||
destination_cid = quic_conn._host_cids[0].cid
|
||||
print(
|
||||
f"🔧 NEW_CONN: Using host CID as routing CID: {destination_cid.hex()}"
|
||||
)
|
||||
logger.debug(f"Using host CID as routing CID: {destination_cid.hex()}")
|
||||
else:
|
||||
# Fallback to random if no host CIDs generated
|
||||
import secrets
|
||||
|
||||
destination_cid = secrets.token_bytes(8)
|
||||
print(f"🔧 NEW_CONN: Fallback to random CID: {destination_cid.hex()}")
|
||||
logger.debug(f"Fallback to random CID: {destination_cid.hex()}")
|
||||
|
||||
print(
|
||||
f"🔧 NEW_CONN: Original destination CID: {packet_info.destination_cid.hex()}"
|
||||
logger.debug(f"Generated {len(quic_conn._host_cids)} host CIDs for client")
|
||||
|
||||
logger.debug(
|
||||
f"QUIC connection created for destination CID {destination_cid.hex()}"
|
||||
)
|
||||
|
||||
print(f"🔧 Generated {len(quic_conn._host_cids)} host CIDs for client")
|
||||
|
||||
print("✅ NEW_CONN: QUIC connection created successfully")
|
||||
|
||||
# Store connection mapping using our generated CID
|
||||
self._pending_connections[destination_cid] = quic_conn
|
||||
self._addr_to_cid[addr] = destination_cid
|
||||
self._cid_to_addr[destination_cid] = addr
|
||||
|
||||
print(
|
||||
f"🔧 NEW_CONN: Stored mappings for {addr} <-> {destination_cid.hex()}"
|
||||
)
|
||||
print("Receiving Datagram")
|
||||
|
||||
# Process initial packet
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
|
||||
# Debug connection state after receiving packet
|
||||
await self._debug_quic_connection_state_detailed(quic_conn, destination_cid)
|
||||
|
||||
# Process events and send response
|
||||
await self._process_quic_events(quic_conn, addr, destination_cid)
|
||||
await self._transmit_for_connection(quic_conn, addr)
|
||||
@ -581,109 +516,27 @@ class QUICListener(IListener):
|
||||
f"(version: 0x{packet_info.version:08x}, cid: {destination_cid.hex()})"
|
||||
)
|
||||
|
||||
return quic_conn
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling new connection from {addr}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self._stats["connections_rejected"] += 1
|
||||
|
||||
async def _debug_quic_connection_state_detailed(
|
||||
self, quic_conn: QuicConnection, connection_id: bytes
|
||||
):
|
||||
"""Enhanced connection state debugging."""
|
||||
try:
|
||||
print(f"🔧 QUIC_STATE: Debugging connection {connection_id.hex()}")
|
||||
|
||||
if not quic_conn:
|
||||
print("❌ QUIC_STATE: QUIC CONNECTION NOT FOUND")
|
||||
return
|
||||
|
||||
# Check TLS state
|
||||
if hasattr(quic_conn, "tls") and quic_conn.tls:
|
||||
print("✅ QUIC_STATE: TLS context exists")
|
||||
if hasattr(quic_conn.tls, "state"):
|
||||
print(f"🔧 QUIC_STATE: TLS state: {quic_conn.tls.state}")
|
||||
|
||||
# Check if we have peer certificate
|
||||
if (
|
||||
hasattr(quic_conn.tls, "_peer_certificate")
|
||||
and quic_conn.tls._peer_certificate
|
||||
):
|
||||
print("✅ QUIC_STATE: Peer certificate available")
|
||||
else:
|
||||
print("🔧 QUIC_STATE: No peer certificate yet")
|
||||
|
||||
# Check TLS handshake completion
|
||||
if hasattr(quic_conn.tls, "handshake_complete"):
|
||||
handshake_status = quic_conn._handshake_complete
|
||||
print(f"🔧 QUIC_STATE: TLS handshake complete: {handshake_status}")
|
||||
else:
|
||||
print("❌ QUIC_STATE: No TLS context!")
|
||||
|
||||
# Check connection state
|
||||
if hasattr(quic_conn, "_state"):
|
||||
print(f"🔧 QUIC_STATE: Connection state: {quic_conn._state}")
|
||||
|
||||
# Check if handshake is complete
|
||||
if hasattr(quic_conn, "_handshake_complete"):
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Handshake complete: {quic_conn._handshake_complete}"
|
||||
)
|
||||
|
||||
# Check configuration
|
||||
if hasattr(quic_conn, "configuration"):
|
||||
config = quic_conn.configuration
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Config certificate: {config.certificate is not None}"
|
||||
)
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Config private_key: {config.private_key is not None}"
|
||||
)
|
||||
print(f"🔧 QUIC_STATE: Config is_client: {config.is_client}")
|
||||
print(f"🔧 QUIC_STATE: Config verify_mode: {config.verify_mode}")
|
||||
print(f"🔧 QUIC_STATE: Config ALPN: {config.alpn_protocols}")
|
||||
|
||||
if config.certificate:
|
||||
cert = config.certificate
|
||||
print(f"🔧 QUIC_STATE: Certificate subject: {cert.subject}")
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Certificate valid from: {cert.not_valid_before_utc}"
|
||||
)
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Certificate valid until: {cert.not_valid_after_utc}"
|
||||
)
|
||||
|
||||
# Check for connection errors
|
||||
if hasattr(quic_conn, "_close_event") and quic_conn._close_event:
|
||||
print(
|
||||
f"❌ QUIC_STATE: Connection has close event: {quic_conn._close_event}"
|
||||
)
|
||||
|
||||
# Check for TLS errors
|
||||
if (
|
||||
hasattr(quic_conn, "_handshake_complete")
|
||||
and not quic_conn._handshake_complete
|
||||
):
|
||||
print("⚠️ QUIC_STATE: Handshake not yet complete")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ QUIC_STATE: Error checking state: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def _handle_short_header_packet(
|
||||
self, data: bytes, addr: tuple[str, int]
|
||||
) -> None:
|
||||
"""Handle short header packets for established connections."""
|
||||
try:
|
||||
print(f"🔧 SHORT_HDR: Handling short header packet from {addr}")
|
||||
print(f" SHORT_HDR: Handling short header packet from {addr}")
|
||||
|
||||
# First, try address-based lookup
|
||||
dest_cid = self._addr_to_cid.get(addr)
|
||||
if dest_cid and dest_cid in self._connections:
|
||||
print(f"✅ SHORT_HDR: Routing via address mapping to {dest_cid.hex()}")
|
||||
print(f"SHORT_HDR: Routing via address mapping to {dest_cid.hex()}")
|
||||
connection = self._connections[dest_cid]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
return
|
||||
@ -693,9 +546,7 @@ class QUICListener(IListener):
|
||||
potential_cid = data[1:9]
|
||||
|
||||
if potential_cid in self._connections:
|
||||
print(
|
||||
f"✅ SHORT_HDR: Routing via extracted CID {potential_cid.hex()}"
|
||||
)
|
||||
print(f"SHORT_HDR: Routing via extracted CID {potential_cid.hex()}")
|
||||
connection = self._connections[potential_cid]
|
||||
|
||||
# Update mappings for future packets
|
||||
@ -734,59 +585,26 @@ class QUICListener(IListener):
|
||||
addr: tuple[str, int],
|
||||
dest_cid: bytes,
|
||||
) -> None:
|
||||
"""Handle packet for a pending (handshaking) connection with enhanced debugging."""
|
||||
"""Handle packet for a pending (handshaking) connection."""
|
||||
try:
|
||||
print(
|
||||
f"🔧 PENDING: Handling packet for pending connection {dest_cid.hex()}"
|
||||
)
|
||||
print(f"🔧 PENDING: Packet size: {len(data)} bytes from {addr}")
|
||||
|
||||
# Check connection state before processing
|
||||
if hasattr(quic_conn, "_state"):
|
||||
print(f"🔧 PENDING: Connection state before: {quic_conn._state}")
|
||||
|
||||
if (
|
||||
hasattr(quic_conn, "tls")
|
||||
and quic_conn.tls
|
||||
and hasattr(quic_conn.tls, "state")
|
||||
):
|
||||
print(f"🔧 PENDING: TLS state before: {quic_conn.tls.state}")
|
||||
logger.debug(f"Handling packet for pending connection {dest_cid.hex()}")
|
||||
|
||||
# Feed data to QUIC connection
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
print("✅ PENDING: Datagram received by QUIC connection")
|
||||
|
||||
# Check state after receiving packet
|
||||
if hasattr(quic_conn, "_state"):
|
||||
print(f"🔧 PENDING: Connection state after: {quic_conn._state}")
|
||||
|
||||
if (
|
||||
hasattr(quic_conn, "tls")
|
||||
and quic_conn.tls
|
||||
and hasattr(quic_conn.tls, "state")
|
||||
):
|
||||
print(f"🔧 PENDING: TLS state after: {quic_conn.tls.state}")
|
||||
if quic_conn.tls:
|
||||
print(f"TLS state after: {quic_conn.tls.state}")
|
||||
|
||||
# Process events - this is crucial for handshake progression
|
||||
print("🔧 PENDING: Processing QUIC events...")
|
||||
await self._process_quic_events(quic_conn, addr, dest_cid)
|
||||
|
||||
# Send any outgoing packets - this is where the response should be sent
|
||||
print("🔧 PENDING: Transmitting response...")
|
||||
await self._transmit_for_connection(quic_conn, addr)
|
||||
|
||||
# Check if handshake completed
|
||||
if (
|
||||
hasattr(quic_conn, "_handshake_complete")
|
||||
and quic_conn._handshake_complete
|
||||
):
|
||||
print("✅ PENDING: Handshake completed, promoting connection")
|
||||
if quic_conn._handshake_complete:
|
||||
logger.debug("PENDING: Handshake completed, promoting connection")
|
||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||
else:
|
||||
print("🔧 PENDING: Handshake still in progress")
|
||||
|
||||
# Debug why handshake might be stuck
|
||||
await self._debug_handshake_state(quic_conn, dest_cid)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}")
|
||||
@ -795,7 +613,7 @@ class QUICListener(IListener):
|
||||
traceback.print_exc()
|
||||
|
||||
# Remove problematic pending connection
|
||||
print(f"❌ PENDING: Removing problematic connection {dest_cid.hex()}")
|
||||
logger.error(f"Removing problematic connection {dest_cid.hex()}")
|
||||
await self._remove_pending_connection(dest_cid)
|
||||
|
||||
async def _process_quic_events(
|
||||
@ -810,15 +628,15 @@ class QUICListener(IListener):
|
||||
break
|
||||
|
||||
events_processed += 1
|
||||
print(
|
||||
f"🔧 EVENT: Processing event {events_processed}: {type(event).__name__}"
|
||||
logger.debug(
|
||||
"QUIC EVENT: Processing event "
|
||||
f"{events_processed}: {type(event).__name__}"
|
||||
)
|
||||
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
print(
|
||||
f"❌ EVENT: Connection terminated - code: {event.error_code}, reason: {event.reason_phrase}"
|
||||
)
|
||||
logger.debug(
|
||||
"QUIC EVENT: Connection terminated "
|
||||
f"- code: {event.error_code}, reason: {event.reason_phrase}"
|
||||
f"Connection {dest_cid.hex()} from {addr} "
|
||||
f"terminated: {event.reason_phrase}"
|
||||
)
|
||||
@ -826,47 +644,44 @@ class QUICListener(IListener):
|
||||
break
|
||||
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
print(
|
||||
f"✅ EVENT: Handshake completed for connection {dest_cid.hex()}"
|
||||
logger.debug(
|
||||
"QUIC EVENT: Handshake completed for connection "
|
||||
f"{dest_cid.hex()}"
|
||||
)
|
||||
logger.debug(f"Handshake completed for connection {dest_cid.hex()}")
|
||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
print(f"🔧 EVENT: Stream data received on stream {event.stream_id}")
|
||||
# Forward to established connection if available
|
||||
logger.debug(
|
||||
f"QUIC EVENT: Stream data received on stream {event.stream_id}"
|
||||
)
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
print(
|
||||
f"📨 FORWARDING: Stream data to connection {id(connection)}"
|
||||
)
|
||||
await connection._handle_stream_data(event)
|
||||
|
||||
elif isinstance(event, events.StreamReset):
|
||||
print(f"🔧 EVENT: Stream reset on stream {event.stream_id}")
|
||||
# Forward to established connection if available
|
||||
logger.debug(
|
||||
f"QUIC EVENT: Stream reset on stream {event.stream_id}"
|
||||
)
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
await connection._handle_stream_reset(event)
|
||||
|
||||
elif isinstance(event, events.ConnectionIdIssued):
|
||||
print(
|
||||
f"🔧 EVENT: Connection ID issued: {event.connection_id.hex()}"
|
||||
f"QUIC EVENT: Connection ID issued: {event.connection_id.hex()}"
|
||||
)
|
||||
# ADD: Update mappings using existing data structures
|
||||
# Add new CID to the same address mapping
|
||||
taddr = self._cid_to_addr.get(dest_cid)
|
||||
if taddr:
|
||||
# Don't overwrite, but note that this CID is also valid for this address
|
||||
print(
|
||||
f"🔧 EVENT: New CID {event.connection_id.hex()} available for {taddr}"
|
||||
# Don't overwrite, but this CID is also valid for this address
|
||||
logger.debug(
|
||||
f"QUIC EVENT: New CID {event.connection_id.hex()} "
|
||||
f"available for {taddr}"
|
||||
)
|
||||
|
||||
elif isinstance(event, events.ConnectionIdRetired):
|
||||
print(
|
||||
f"🔧 EVENT: Connection ID retired: {event.connection_id.hex()}"
|
||||
)
|
||||
# ADD: Clean up using existing patterns
|
||||
print(f"EVENT: Connection ID retired: {event.connection_id.hex()}")
|
||||
retired_cid = event.connection_id
|
||||
if retired_cid in self._cid_to_addr:
|
||||
addr = self._cid_to_addr[retired_cid]
|
||||
@ -874,16 +689,13 @@ class QUICListener(IListener):
|
||||
# Only remove addr mapping if this was the active CID
|
||||
if self._addr_to_cid.get(addr) == retired_cid:
|
||||
del self._addr_to_cid[addr]
|
||||
print(
|
||||
f"🔧 EVENT: Cleaned up mapping for retired CID {retired_cid.hex()}"
|
||||
)
|
||||
else:
|
||||
print(f"🔧 EVENT: Unhandled event type: {type(event).__name__}")
|
||||
print(f" EVENT: Unhandled event type: {type(event).__name__}")
|
||||
|
||||
if events_processed == 0:
|
||||
print("🔧 EVENT: No events to process")
|
||||
print(" EVENT: No events to process")
|
||||
else:
|
||||
print(f"🔧 EVENT: Processed {events_processed} events total")
|
||||
print(f" EVENT: Processed {events_processed} events total")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ EVENT: Error processing events: {e}")
|
||||
@ -891,62 +703,18 @@ class QUICListener(IListener):
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _debug_quic_connection_state(
|
||||
self, quic_conn: QuicConnection, connection_id: bytes
|
||||
):
|
||||
"""Debug the internal state of the QUIC connection."""
|
||||
try:
|
||||
print(f"🔧 QUIC_STATE: Debugging connection {connection_id}")
|
||||
|
||||
if not quic_conn:
|
||||
print("🔧 QUIC_STATE: QUIC CONNECTION NOT FOUND")
|
||||
return
|
||||
|
||||
# Check TLS state
|
||||
if hasattr(quic_conn, "tls") and quic_conn.tls:
|
||||
print("🔧 QUIC_STATE: TLS context exists")
|
||||
if hasattr(quic_conn.tls, "state"):
|
||||
print(f"🔧 QUIC_STATE: TLS state: {quic_conn.tls.state}")
|
||||
else:
|
||||
print("❌ QUIC_STATE: No TLS context!")
|
||||
|
||||
# Check connection state
|
||||
if hasattr(quic_conn, "_state"):
|
||||
print(f"🔧 QUIC_STATE: Connection state: {quic_conn._state}")
|
||||
|
||||
# Check if handshake is complete
|
||||
if hasattr(quic_conn, "_handshake_complete"):
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Handshake complete: {quic_conn._handshake_complete}"
|
||||
)
|
||||
|
||||
# Check configuration
|
||||
if hasattr(quic_conn, "configuration"):
|
||||
config = quic_conn.configuration
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Config certificate: {config.certificate is not None}"
|
||||
)
|
||||
print(
|
||||
f"🔧 QUIC_STATE: Config private_key: {config.private_key is not None}"
|
||||
)
|
||||
print(f"🔧 QUIC_STATE: Config is_client: {config.is_client}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ QUIC_STATE: Error checking state: {e}")
|
||||
|
||||
async def _promote_pending_connection(
|
||||
self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes
|
||||
):
|
||||
) -> None:
|
||||
"""Promote pending connection - avoid duplicate creation."""
|
||||
try:
|
||||
# Remove from pending connections
|
||||
self._pending_connections.pop(dest_cid, None)
|
||||
|
||||
# CHECK: Does QUICConnection already exist?
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
print(
|
||||
f"🔄 PROMOTION: Using existing QUICConnection {id(connection)} for {dest_cid.hex()}"
|
||||
logger.debug(
|
||||
f"Using existing QUICConnection {id(connection)} "
|
||||
f"for {dest_cid.hex()}"
|
||||
)
|
||||
|
||||
else:
|
||||
@ -968,22 +736,17 @@ class QUICListener(IListener):
|
||||
listener_socket=self._socket,
|
||||
)
|
||||
|
||||
print(
|
||||
f"🔄 PROMOTION: Created NEW QUICConnection {id(connection)} for {dest_cid.hex()}"
|
||||
)
|
||||
logger.debug(f"🔄 Created NEW QUICConnection for {dest_cid.hex()}")
|
||||
|
||||
# Store the connection
|
||||
self._connections[dest_cid] = connection
|
||||
|
||||
# Update mappings
|
||||
self._addr_to_cid[addr] = dest_cid
|
||||
self._cid_to_addr[dest_cid] = addr
|
||||
|
||||
# Rest of the existing promotion code...
|
||||
if self._nursery:
|
||||
connection._nursery = self._nursery
|
||||
await connection.connect(self._nursery)
|
||||
print("QUICListener: Connection connected succesfully")
|
||||
logger.debug(f"Connection connected succesfully for {dest_cid.hex()}")
|
||||
|
||||
if self._security_manager:
|
||||
try:
|
||||
@ -1001,27 +764,23 @@ class QUICListener(IListener):
|
||||
if self._nursery:
|
||||
connection._nursery = self._nursery
|
||||
await connection._start_background_tasks()
|
||||
print(f"Started background tasks for connection {dest_cid.hex()}")
|
||||
|
||||
if self._transport._swarm:
|
||||
print(f"🔄 PROMOTION: Adding connection {id(connection)} to swarm")
|
||||
await self._transport._swarm.add_conn(connection)
|
||||
print(
|
||||
f"🔄 PROMOTION: Successfully added connection {id(connection)} to swarm"
|
||||
logger.debug(
|
||||
f"Started background tasks for connection {dest_cid.hex()}"
|
||||
)
|
||||
|
||||
if self._handler:
|
||||
try:
|
||||
print(f"Invoking user callback {dest_cid.hex()}")
|
||||
await self._handler(connection)
|
||||
if self._transport._swarm:
|
||||
await self._transport._swarm.add_conn(connection)
|
||||
logger.debug(f"Successfully added connection {dest_cid.hex()} to swarm")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in user callback: {e}")
|
||||
try:
|
||||
print(f"Invoking user callback {dest_cid.hex()}")
|
||||
await self._handler(connection)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in user callback: {e}")
|
||||
|
||||
self._stats["connections_accepted"] += 1
|
||||
logger.info(
|
||||
f"✅ Enhanced connection {dest_cid.hex()} established from {addr}"
|
||||
)
|
||||
logger.info(f"Enhanced connection {dest_cid.hex()} established from {addr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error promoting connection {dest_cid.hex()}: {e}")
|
||||
@ -1062,10 +821,12 @@ class QUICListener(IListener):
|
||||
if dest_cid:
|
||||
await self._remove_connection(dest_cid)
|
||||
|
||||
async def _transmit_for_connection(self, quic_conn, addr):
|
||||
async def _transmit_for_connection(
|
||||
self, quic_conn: QuicConnection, addr: tuple[str, int]
|
||||
) -> None:
|
||||
"""Enhanced transmission diagnostics to analyze datagram content."""
|
||||
try:
|
||||
print(f"🔧 TRANSMIT: Starting transmission to {addr}")
|
||||
print(f" TRANSMIT: Starting transmission to {addr}")
|
||||
|
||||
# Get current timestamp for timing
|
||||
import time
|
||||
@ -1073,56 +834,31 @@ class QUICListener(IListener):
|
||||
now = time.time()
|
||||
|
||||
datagrams = quic_conn.datagrams_to_send(now=now)
|
||||
print(f"🔧 TRANSMIT: Got {len(datagrams)} datagrams to send")
|
||||
print(f" TRANSMIT: Got {len(datagrams)} datagrams to send")
|
||||
|
||||
if not datagrams:
|
||||
print("⚠️ TRANSMIT: No datagrams to send")
|
||||
return
|
||||
|
||||
for i, (datagram, dest_addr) in enumerate(datagrams):
|
||||
print(f"🔧 TRANSMIT: Analyzing datagram {i}")
|
||||
print(f"🔧 TRANSMIT: Datagram size: {len(datagram)} bytes")
|
||||
print(f"🔧 TRANSMIT: Destination: {dest_addr}")
|
||||
print(f"🔧 TRANSMIT: Expected destination: {addr}")
|
||||
print(f" TRANSMIT: Analyzing datagram {i}")
|
||||
print(f" TRANSMIT: Datagram size: {len(datagram)} bytes")
|
||||
print(f" TRANSMIT: Destination: {dest_addr}")
|
||||
print(f" TRANSMIT: Expected destination: {addr}")
|
||||
|
||||
# Analyze datagram content
|
||||
if len(datagram) > 0:
|
||||
# QUIC packet format analysis
|
||||
first_byte = datagram[0]
|
||||
header_form = (first_byte & 0x80) >> 7 # Bit 7
|
||||
fixed_bit = (first_byte & 0x40) >> 6 # Bit 6
|
||||
packet_type = (first_byte & 0x30) >> 4 # Bits 4-5
|
||||
type_specific = first_byte & 0x0F # Bits 0-3
|
||||
|
||||
print(f"🔧 TRANSMIT: First byte: 0x{first_byte:02x}")
|
||||
print(
|
||||
f"🔧 TRANSMIT: Header form: {header_form} ({'Long' if header_form else 'Short'})"
|
||||
)
|
||||
print(
|
||||
f"🔧 TRANSMIT: Fixed bit: {fixed_bit} ({'Valid' if fixed_bit else 'INVALID!'})"
|
||||
)
|
||||
print(f"🔧 TRANSMIT: Packet type: {packet_type}")
|
||||
|
||||
# For long header packets (handshake), analyze further
|
||||
if header_form == 1: # Long header
|
||||
packet_types = {
|
||||
0: "Initial",
|
||||
1: "0-RTT",
|
||||
2: "Handshake",
|
||||
3: "Retry",
|
||||
}
|
||||
type_name = packet_types.get(packet_type, "Unknown")
|
||||
print(f"🔧 TRANSMIT: Long header packet type: {type_name}")
|
||||
|
||||
# Look for CRYPTO frame indicators
|
||||
# CRYPTO frame type is 0x06
|
||||
crypto_frame_found = False
|
||||
for offset in range(len(datagram)):
|
||||
if datagram[offset] == 0x06: # CRYPTO frame type
|
||||
if datagram[offset] == 0x06:
|
||||
crypto_frame_found = True
|
||||
print(
|
||||
f"✅ TRANSMIT: Found CRYPTO frame at offset {offset}"
|
||||
)
|
||||
break
|
||||
|
||||
if not crypto_frame_found:
|
||||
@ -1138,21 +874,11 @@ class QUICListener(IListener):
|
||||
elif frame_type == 0x06: # CRYPTO
|
||||
frame_types_found.add("CRYPTO")
|
||||
|
||||
print(
|
||||
f"🔧 TRANSMIT: Frame types detected: {frame_types_found}"
|
||||
)
|
||||
|
||||
# Show first few bytes for debugging
|
||||
preview_bytes = min(32, len(datagram))
|
||||
hex_preview = " ".join(f"{b:02x}" for b in datagram[:preview_bytes])
|
||||
print(f"🔧 TRANSMIT: First {preview_bytes} bytes: {hex_preview}")
|
||||
|
||||
# Actually send the datagram
|
||||
if self._socket:
|
||||
try:
|
||||
print(f"🔧 TRANSMIT: Sending datagram {i} via socket...")
|
||||
print(f" TRANSMIT: Sending datagram {i} via socket...")
|
||||
await self._socket.sendto(datagram, addr)
|
||||
print(f"✅ TRANSMIT: Successfully sent datagram {i}")
|
||||
print(f"TRANSMIT: Successfully sent datagram {i}")
|
||||
except Exception as send_error:
|
||||
print(f"❌ TRANSMIT: Socket send failed: {send_error}")
|
||||
else:
|
||||
@ -1160,10 +886,9 @@ class QUICListener(IListener):
|
||||
|
||||
# Check if there are more datagrams after sending
|
||||
remaining_datagrams = quic_conn.datagrams_to_send(now=time.time())
|
||||
print(
|
||||
f"🔧 TRANSMIT: After sending, {len(remaining_datagrams)} datagrams remain"
|
||||
logger.debug(
|
||||
f" TRANSMIT: After sending, {len(remaining_datagrams)} datagrams remain"
|
||||
)
|
||||
print("------END OF THIS DATAGRAM LOG-----")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ TRANSMIT: Transmission error: {e}")
|
||||
@ -1184,6 +909,7 @@ class QUICListener(IListener):
|
||||
logger.debug("Using transport background nursery for listener")
|
||||
elif nursery:
|
||||
active_nursery = nursery
|
||||
self._transport._background_nursery = nursery
|
||||
logger.debug("Using provided nursery for listener")
|
||||
else:
|
||||
raise QUICListenError("No nursery available")
|
||||
@ -1299,8 +1025,10 @@ class QUICListener(IListener):
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing listener: {e}")
|
||||
|
||||
async def _remove_connection_by_object(self, connection_obj) -> None:
|
||||
"""Remove a connection by object reference (called when connection terminates)."""
|
||||
async def _remove_connection_by_object(
|
||||
self, connection_obj: QUICConnection
|
||||
) -> None:
|
||||
"""Remove a connection by object reference."""
|
||||
try:
|
||||
# Find the connection ID for this object
|
||||
connection_cid = None
|
||||
@ -1311,19 +1039,12 @@ class QUICListener(IListener):
|
||||
|
||||
if connection_cid:
|
||||
await self._remove_connection(connection_cid)
|
||||
logger.debug(
|
||||
f"✅ TERMINATION: Removed connection {connection_cid.hex()} by object reference"
|
||||
)
|
||||
print(
|
||||
f"✅ TERMINATION: Removed connection {connection_cid.hex()} by object reference"
|
||||
)
|
||||
logger.debug(f"Removed connection {connection_cid.hex()}")
|
||||
else:
|
||||
logger.warning("⚠️ TERMINATION: Connection object not found in tracking")
|
||||
print("⚠️ TERMINATION: Connection object not found in tracking")
|
||||
logger.warning("Connection object not found in tracking")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ TERMINATION: Error removing connection by object: {e}")
|
||||
print(f"❌ TERMINATION: Error removing connection by object: {e}")
|
||||
logger.error(f"Error removing connection by object: {e}")
|
||||
|
||||
def get_addresses(self) -> list[Multiaddr]:
|
||||
"""Get the bound addresses."""
|
||||
@ -1376,63 +1097,3 @@ class QUICListener(IListener):
|
||||
stats["active_connections"] = len(self._connections)
|
||||
stats["pending_connections"] = len(self._pending_connections)
|
||||
return stats
|
||||
|
||||
async def _debug_handshake_state(self, quic_conn: QuicConnection, dest_cid: bytes):
|
||||
"""Debug why handshake might be stuck."""
|
||||
try:
|
||||
print(f"🔧 HANDSHAKE_DEBUG: Analyzing stuck handshake for {dest_cid.hex()}")
|
||||
|
||||
# Check TLS handshake state
|
||||
if hasattr(quic_conn, "tls") and quic_conn.tls:
|
||||
tls = quic_conn.tls
|
||||
print(
|
||||
f"🔧 HANDSHAKE_DEBUG: TLS state: {getattr(tls, 'state', 'Unknown')}"
|
||||
)
|
||||
|
||||
# Check for TLS errors
|
||||
if hasattr(tls, "_error") and tls._error:
|
||||
print(f"❌ HANDSHAKE_DEBUG: TLS error: {tls._error}")
|
||||
|
||||
# Check certificate validation
|
||||
if hasattr(tls, "_peer_certificate"):
|
||||
if tls._peer_certificate:
|
||||
print("✅ HANDSHAKE_DEBUG: Peer certificate received")
|
||||
else:
|
||||
print("❌ HANDSHAKE_DEBUG: No peer certificate")
|
||||
|
||||
# Check ALPN negotiation
|
||||
if hasattr(tls, "_alpn_protocols"):
|
||||
if tls._alpn_protocols:
|
||||
print(
|
||||
f"✅ HANDSHAKE_DEBUG: ALPN negotiated: {tls._alpn_protocols}"
|
||||
)
|
||||
else:
|
||||
print("❌ HANDSHAKE_DEBUG: No ALPN protocol negotiated")
|
||||
|
||||
# Check QUIC connection state
|
||||
if hasattr(quic_conn, "_state"):
|
||||
state = quic_conn._state
|
||||
print(f"🔧 HANDSHAKE_DEBUG: QUIC state: {state}")
|
||||
|
||||
# Check specific states that might indicate problems
|
||||
if "FIRSTFLIGHT" in str(state):
|
||||
print("⚠️ HANDSHAKE_DEBUG: Connection stuck in FIRSTFLIGHT state")
|
||||
elif "CONNECTED" in str(state):
|
||||
print(
|
||||
"⚠️ HANDSHAKE_DEBUG: Connection shows CONNECTED but handshake not complete"
|
||||
)
|
||||
|
||||
# Check for pending crypto data
|
||||
if hasattr(quic_conn, "_cryptos") and quic_conn._cryptos:
|
||||
print(
|
||||
f"🔧 HANDSHAKE_DEBUG: Crypto data present {len(quic_conn._cryptos.keys())}"
|
||||
)
|
||||
|
||||
# Check loss detection state
|
||||
if hasattr(quic_conn, "_loss") and quic_conn._loss:
|
||||
loss_detection = quic_conn._loss
|
||||
if hasattr(loss_detection, "_pto_count"):
|
||||
print(f"🔧 HANDSHAKE_DEBUG: PTO count: {loss_detection._pto_count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ HANDSHAKE_DEBUG: Error during debug: {e}")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
"""
|
||||
QUIC Security implementation for py-libp2p Module 5.
|
||||
Implements libp2p TLS specification for QUIC transport with peer identity integration.
|
||||
@ -8,7 +7,7 @@ Based on go-libp2p and js-libp2p security patterns.
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import ssl
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
@ -130,14 +129,16 @@ class LibP2PExtensionHandler:
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def parse_signed_key_extension(extension: Extension) -> tuple[PublicKey, bytes]:
|
||||
def parse_signed_key_extension(
|
||||
extension: Extension[Any],
|
||||
) -> tuple[PublicKey, bytes]:
|
||||
"""
|
||||
Parse the libp2p Public Key Extension with enhanced debugging.
|
||||
"""
|
||||
try:
|
||||
print(f"🔍 Extension type: {type(extension)}")
|
||||
print(f"🔍 Extension.value type: {type(extension.value)}")
|
||||
|
||||
|
||||
# Extract the raw bytes from the extension
|
||||
if isinstance(extension.value, UnrecognizedExtension):
|
||||
# Use the .value property to get the bytes
|
||||
@ -147,10 +148,10 @@ class LibP2PExtensionHandler:
|
||||
# Fallback if it's already bytes somehow
|
||||
raw_bytes = extension.value
|
||||
print("🔍 Extension.value is already bytes")
|
||||
|
||||
|
||||
print(f"🔍 Total extension length: {len(raw_bytes)} bytes")
|
||||
print(f"🔍 Extension hex (first 50 bytes): {raw_bytes[:50].hex()}")
|
||||
|
||||
|
||||
if not isinstance(raw_bytes, bytes):
|
||||
raise QUICCertificateError(f"Expected bytes, got {type(raw_bytes)}")
|
||||
|
||||
@ -191,28 +192,37 @@ class LibP2PExtensionHandler:
|
||||
signature = raw_bytes[offset : offset + signature_length]
|
||||
print(f"🔍 Extracted signature length: {len(signature)} bytes")
|
||||
print(f"🔍 Signature hex (first 20 bytes): {signature[:20].hex()}")
|
||||
print(f"🔍 Signature starts with DER header: {signature[:2].hex() == '3045'}")
|
||||
|
||||
print(
|
||||
f"🔍 Signature starts with DER header: {signature[:2].hex() == '3045'}"
|
||||
)
|
||||
|
||||
# Detailed signature analysis
|
||||
if len(signature) >= 2:
|
||||
if signature[0] == 0x30:
|
||||
der_length = signature[1]
|
||||
print(f"🔍 DER sequence length field: {der_length}")
|
||||
print(f"🔍 Expected DER total: {der_length + 2}")
|
||||
print(f"🔍 Actual signature length: {len(signature)}")
|
||||
|
||||
logger.debug(
|
||||
f"🔍 Expected DER total: {der_length + 2}"
|
||||
f"🔍 Actual signature length: {len(signature)}"
|
||||
)
|
||||
|
||||
if len(signature) != der_length + 2:
|
||||
print(f"⚠️ DER length mismatch! Expected {der_length + 2}, got {len(signature)}")
|
||||
logger.debug(
|
||||
"⚠️ DER length mismatch! "
|
||||
f"Expected {der_length + 2}, got {len(signature)}"
|
||||
)
|
||||
# Try truncating to correct DER length
|
||||
if der_length + 2 < len(signature):
|
||||
print(f"🔧 Truncating signature to correct DER length: {der_length + 2}")
|
||||
signature = signature[:der_length + 2]
|
||||
|
||||
logger.debug(
|
||||
"🔧 Truncating signature to correct DER length: "
|
||||
f"{der_length + 2}"
|
||||
)
|
||||
signature = signature[: der_length + 2]
|
||||
|
||||
# Check if we have extra data
|
||||
expected_total = 4 + public_key_length + 4 + signature_length
|
||||
print(f"🔍 Expected total length: {expected_total}")
|
||||
print(f"🔍 Actual total length: {len(raw_bytes)}")
|
||||
|
||||
|
||||
if len(raw_bytes) > expected_total:
|
||||
extra_bytes = len(raw_bytes) - expected_total
|
||||
print(f"⚠️ Extra {extra_bytes} bytes detected!")
|
||||
@ -221,7 +231,7 @@ class LibP2PExtensionHandler:
|
||||
# Deserialize the public key
|
||||
public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes)
|
||||
print(f"🔍 Successfully deserialized public key: {type(public_key)}")
|
||||
|
||||
|
||||
print(f"🔍 Final signature to return: {len(signature)} bytes")
|
||||
|
||||
return public_key, signature
|
||||
@ -229,6 +239,7 @@ class LibP2PExtensionHandler:
|
||||
except Exception as e:
|
||||
print(f"❌ Extension parsing failed: {e}")
|
||||
import traceback
|
||||
|
||||
print(f"❌ Traceback: {traceback.format_exc()}")
|
||||
raise QUICCertificateError(
|
||||
f"Failed to parse signed key extension: {e}"
|
||||
@ -470,26 +481,26 @@ class QUICTLSSecurityConfig:
|
||||
|
||||
# Core TLS components (required)
|
||||
certificate: Certificate
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey]
|
||||
private_key: EllipticCurvePrivateKey | RSAPrivateKey
|
||||
|
||||
# Certificate chain (optional)
|
||||
certificate_chain: List[Certificate] = field(default_factory=list)
|
||||
certificate_chain: list[Certificate] = field(default_factory=list)
|
||||
|
||||
# ALPN protocols
|
||||
alpn_protocols: List[str] = field(default_factory=lambda: ["libp2p"])
|
||||
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
|
||||
|
||||
# TLS verification settings
|
||||
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
|
||||
check_hostname: bool = False
|
||||
|
||||
# Optional peer ID for validation
|
||||
peer_id: Optional[ID] = None
|
||||
peer_id: ID | None = None
|
||||
|
||||
# Configuration metadata
|
||||
is_client_config: bool = False
|
||||
config_name: Optional[str] = None
|
||||
config_name: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration after initialization."""
|
||||
self._validate()
|
||||
|
||||
@ -516,46 +527,6 @@ class QUICTLSSecurityConfig:
|
||||
if not self.alpn_protocols:
|
||||
raise ValueError("At least one ALPN protocol is required")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convert to dictionary format for compatibility with existing code.
|
||||
|
||||
Returns:
|
||||
Dictionary compatible with the original TSecurityConfig format
|
||||
|
||||
"""
|
||||
return {
|
||||
"certificate": self.certificate,
|
||||
"private_key": self.private_key,
|
||||
"certificate_chain": self.certificate_chain.copy(),
|
||||
"alpn_protocols": self.alpn_protocols.copy(),
|
||||
"verify_mode": self.verify_mode,
|
||||
"check_hostname": self.check_hostname,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict, **kwargs) -> "QUICTLSSecurityConfig":
|
||||
"""
|
||||
Create instance from dictionary format.
|
||||
|
||||
Args:
|
||||
config_dict: Dictionary in TSecurityConfig format
|
||||
**kwargs: Additional parameters for the config
|
||||
|
||||
Returns:
|
||||
QUICTLSSecurityConfig instance
|
||||
|
||||
"""
|
||||
return cls(
|
||||
certificate=config_dict["certificate"],
|
||||
private_key=config_dict["private_key"],
|
||||
certificate_chain=config_dict.get("certificate_chain", []),
|
||||
alpn_protocols=config_dict.get("alpn_protocols", ["libp2p"]),
|
||||
verify_mode=config_dict.get("verify_mode", False),
|
||||
check_hostname=config_dict.get("check_hostname", False),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def validate_certificate_key_match(self) -> bool:
|
||||
"""
|
||||
Validate that the certificate and private key match.
|
||||
@ -621,7 +592,7 @@ class QUICTLSSecurityConfig:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_certificate_info(self) -> dict:
|
||||
def get_certificate_info(self) -> dict[Any, Any]:
|
||||
"""
|
||||
Get certificate information for debugging.
|
||||
|
||||
@ -652,7 +623,7 @@ class QUICTLSSecurityConfig:
|
||||
print(f"Check hostname: {self.check_hostname}")
|
||||
print(f"Certificate chain length: {len(self.certificate_chain)}")
|
||||
|
||||
cert_info = self.get_certificate_info()
|
||||
cert_info: dict[Any, Any] = self.get_certificate_info()
|
||||
for key, value in cert_info.items():
|
||||
print(f"Certificate {key}: {value}")
|
||||
|
||||
@ -663,9 +634,9 @@ class QUICTLSSecurityConfig:
|
||||
|
||||
def create_server_tls_config(
|
||||
certificate: Certificate,
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||
peer_id: Optional[ID] = None,
|
||||
**kwargs,
|
||||
private_key: EllipticCurvePrivateKey | RSAPrivateKey,
|
||||
peer_id: ID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create a server TLS configuration.
|
||||
@ -694,9 +665,9 @@ def create_server_tls_config(
|
||||
|
||||
def create_client_tls_config(
|
||||
certificate: Certificate,
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||
peer_id: Optional[ID] = None,
|
||||
**kwargs,
|
||||
private_key: EllipticCurvePrivateKey | RSAPrivateKey,
|
||||
peer_id: ID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create a client TLS configuration.
|
||||
@ -729,7 +700,7 @@ class QUICTLSConfigManager:
|
||||
Integrates with aioquic's TLS configuration system.
|
||||
"""
|
||||
|
||||
def __init__(self, libp2p_private_key: PrivateKey, peer_id: ID):
|
||||
def __init__(self, libp2p_private_key: PrivateKey, peer_id: ID) -> None:
|
||||
self.libp2p_private_key = libp2p_private_key
|
||||
self.peer_id = peer_id
|
||||
self.certificate_generator = CertificateGenerator()
|
||||
|
||||
@ -472,6 +472,45 @@ class QUICStream(IMuxedStream):
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received FIN")
|
||||
|
||||
async def handle_stop_sending(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle STOP_SENDING frame from remote peer.
|
||||
|
||||
When a STOP_SENDING frame is received, the peer is requesting that we
|
||||
stop sending data on this stream. We respond by resetting the stream.
|
||||
|
||||
Args:
|
||||
error_code: Error code from the STOP_SENDING frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
|
||||
)
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
# Wake up any pending write operations
|
||||
self._backpressure_event.set()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.direction == StreamDirection.OUTBOUND:
|
||||
self._state = StreamState.CLOSED
|
||||
elif self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
# Only write side closed - add WRITE_CLOSED state if needed
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
# Send RESET_STREAM in response (QUIC protocol requirement)
|
||||
try:
|
||||
self._connection._quic.reset_stream(int(self.stream_id), error_code)
|
||||
await self._connection._transmit()
|
||||
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def handle_reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle stream reset from remote peer.
|
||||
|
||||
@ -128,7 +128,7 @@ class QUICTransport(ITransport):
|
||||
self._background_nursery = nursery
|
||||
print("Transport background nursery set")
|
||||
|
||||
def set_swarm(self, swarm) -> None:
|
||||
def set_swarm(self, swarm: Swarm) -> None:
|
||||
"""Set the swarm for adding incoming connections."""
|
||||
self._swarm = swarm
|
||||
|
||||
@ -232,12 +232,9 @@ class QUICTransport(ITransport):
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
|
||||
|
||||
# type: ignore
|
||||
async def dial(
|
||||
self,
|
||||
maddr: multiaddr.Multiaddr,
|
||||
peer_id: ID,
|
||||
nursery: trio.Nursery | None = None,
|
||||
) -> QUICConnection:
|
||||
"""
|
||||
Dial a remote peer using QUIC transport with security verification.
|
||||
@ -261,9 +258,6 @@ class QUICTransport(ITransport):
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
|
||||
|
||||
if not peer_id:
|
||||
raise QUICDialError("Peer id cannot be null")
|
||||
|
||||
try:
|
||||
# Extract connection details from multiaddr
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
@ -288,7 +282,7 @@ class QUICTransport(ITransport):
|
||||
connection = QUICConnection(
|
||||
quic_connection=native_quic_connection,
|
||||
remote_addr=(host, port),
|
||||
remote_peer_id=peer_id,
|
||||
remote_peer_id=None,
|
||||
local_peer_id=self._peer_id,
|
||||
is_initiator=True,
|
||||
maddr=maddr,
|
||||
@ -297,25 +291,19 @@ class QUICTransport(ITransport):
|
||||
)
|
||||
print("QUIC Connection Created")
|
||||
|
||||
active_nursery = nursery or self._background_nursery
|
||||
|
||||
if active_nursery is None:
|
||||
if self._background_nursery is None:
|
||||
logger.error("No nursery set to execute background tasks")
|
||||
raise QUICDialError("No nursery found to execute tasks")
|
||||
|
||||
await connection.connect(active_nursery)
|
||||
await connection.connect(self._background_nursery)
|
||||
|
||||
print("Starting to verify peer identity")
|
||||
# Verify peer identity after TLS handshake
|
||||
if peer_id:
|
||||
await self._verify_peer_identity(connection, peer_id)
|
||||
|
||||
print("Identity verification done")
|
||||
# Store connection for management
|
||||
conn_id = f"{host}:{port}:{peer_id}"
|
||||
conn_id = f"{host}:{port}"
|
||||
self._connections[conn_id] = connection
|
||||
|
||||
print(f"Successfully dialed secure QUIC connection to {peer_id}")
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
@ -456,7 +444,7 @@ class QUICTransport(ITransport):
|
||||
|
||||
print("QUIC transport closed")
|
||||
|
||||
async def _cleanup_terminated_connection(self, connection) -> None:
|
||||
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
|
||||
"""Clean up a terminated connection from all listeners."""
|
||||
try:
|
||||
for listener in self._listeners:
|
||||
|
||||
@ -1,415 +0,0 @@
|
||||
"""
|
||||
Basic QUIC Echo Test
|
||||
|
||||
Simple test to verify the basic QUIC flow:
|
||||
1. Client connects to server
|
||||
2. Client sends data
|
||||
3. Server receives data and echoes back
|
||||
4. Client receives the echo
|
||||
|
||||
This test focuses on identifying where the accept_stream issue occurs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.quic.utils import create_quic_multiaddr
|
||||
|
||||
# Set up logging to see what's happening
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestBasicQUICFlow:
|
||||
"""Test basic QUIC client-server communication flow."""
|
||||
|
||||
@pytest.fixture
|
||||
def server_key(self):
|
||||
"""Generate server key pair."""
|
||||
return create_new_key_pair()
|
||||
|
||||
@pytest.fixture
|
||||
def client_key(self):
|
||||
"""Generate client key pair."""
|
||||
return create_new_key_pair()
|
||||
|
||||
@pytest.fixture
|
||||
def server_config(self):
|
||||
"""Simple server configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=10,
|
||||
max_connections=5,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def client_config(self):
|
||||
"""Simple client configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=5,
|
||||
)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_basic_echo_flow(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test basic client-server echo flow with detailed logging."""
|
||||
print("\n=== BASIC QUIC ECHO TEST ===")
|
||||
|
||||
# Create server components
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
server_peer_id = ID.from_pubkey(server_key.public_key)
|
||||
|
||||
# Track test state
|
||||
server_received_data = None
|
||||
server_connection_established = False
|
||||
echo_sent = False
|
||||
|
||||
async def echo_server_handler(connection: QUICConnection) -> None:
|
||||
"""Simple echo server handler with detailed logging."""
|
||||
nonlocal server_received_data, server_connection_established, echo_sent
|
||||
|
||||
print("🔗 SERVER: Connection handler called")
|
||||
server_connection_established = True
|
||||
|
||||
try:
|
||||
print("📡 SERVER: Waiting for incoming stream...")
|
||||
|
||||
# Accept stream with timeout and detailed logging
|
||||
print("📡 SERVER: Calling accept_stream...")
|
||||
stream = await connection.accept_stream(timeout=5.0)
|
||||
|
||||
if stream is None:
|
||||
print("❌ SERVER: accept_stream returned None")
|
||||
return
|
||||
|
||||
print(f"✅ SERVER: Stream accepted! Stream ID: {stream.stream_id}")
|
||||
|
||||
# Read data from the stream
|
||||
print("📖 SERVER: Reading data from stream...")
|
||||
server_data = await stream.read(1024)
|
||||
|
||||
if not server_data:
|
||||
print("❌ SERVER: No data received from stream")
|
||||
return
|
||||
|
||||
server_received_data = server_data.decode("utf-8", errors="ignore")
|
||||
print(f"📨 SERVER: Received data: '{server_received_data}'")
|
||||
|
||||
# Echo the data back
|
||||
echo_message = f"ECHO: {server_received_data}"
|
||||
print(f"📤 SERVER: Sending echo: '{echo_message}'")
|
||||
|
||||
await stream.write(echo_message.encode())
|
||||
echo_sent = True
|
||||
print("✅ SERVER: Echo sent successfully")
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
print("🔒 SERVER: Stream closed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ SERVER: Error in handler: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Create listener
|
||||
listener = server_transport.create_listener(echo_server_handler)
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
# Variables to track client state
|
||||
client_connected = False
|
||||
client_sent_data = False
|
||||
client_received_echo = None
|
||||
|
||||
try:
|
||||
print("🚀 Starting server...")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start server listener
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success, "Failed to start server listener"
|
||||
|
||||
# Get server address
|
||||
server_addrs = listener.get_addrs()
|
||||
server_addr = server_addrs[0]
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Give server a moment to be ready
|
||||
await trio.sleep(0.1)
|
||||
|
||||
print("🚀 Starting client...")
|
||||
|
||||
# Create client transport
|
||||
client_transport = QUICTransport(client_key.private_key, client_config)
|
||||
|
||||
try:
|
||||
# Connect to server
|
||||
print(f"📞 CLIENT: Connecting to {server_addr}")
|
||||
connection = await client_transport.dial(
|
||||
server_addr, peer_id=server_peer_id, nursery=nursery
|
||||
)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected to server")
|
||||
|
||||
# Open a stream
|
||||
print("📤 CLIENT: Opening stream...")
|
||||
stream = await connection.open_stream()
|
||||
print(f"✅ CLIENT: Stream opened with ID: {stream.stream_id}")
|
||||
|
||||
# Send test data
|
||||
test_message = "Hello QUIC Server!"
|
||||
print(f"📨 CLIENT: Sending message: '{test_message}'")
|
||||
await stream.write(test_message.encode())
|
||||
client_sent_data = True
|
||||
print("✅ CLIENT: Message sent")
|
||||
|
||||
# Read echo response
|
||||
print("📖 CLIENT: Waiting for echo response...")
|
||||
response_data = await stream.read(1024)
|
||||
|
||||
if response_data:
|
||||
client_received_echo = response_data.decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
print(f"📬 CLIENT: Received echo: '{client_received_echo}'")
|
||||
else:
|
||||
print("❌ CLIENT: No echo response received")
|
||||
|
||||
print("🔒 CLIENT: Closing connection")
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
|
||||
print("🔒 CLIENT: Closing transport")
|
||||
await client_transport.close()
|
||||
print("🔒 CLIENT: Transport closed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ CLIENT: Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
print("🔒 CLIENT: Transport closed")
|
||||
|
||||
# Give everything time to complete
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# Cancel nursery to stop server
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if not listener._closed:
|
||||
await listener.close()
|
||||
await server_transport.close()
|
||||
|
||||
# Verify the flow worked
|
||||
print("\n📊 TEST RESULTS:")
|
||||
print(f" Server connection established: {server_connection_established}")
|
||||
print(f" Client connected: {client_connected}")
|
||||
print(f" Client sent data: {client_sent_data}")
|
||||
print(f" Server received data: '{server_received_data}'")
|
||||
print(f" Echo sent by server: {echo_sent}")
|
||||
print(f" Client received echo: '{client_received_echo}'")
|
||||
|
||||
# Test assertions
|
||||
assert server_connection_established, "Server connection handler was not called"
|
||||
assert client_connected, "Client failed to connect"
|
||||
assert client_sent_data, "Client failed to send data"
|
||||
assert server_received_data == "Hello QUIC Server!", (
|
||||
f"Server received wrong data: '{server_received_data}'"
|
||||
)
|
||||
assert echo_sent, "Server failed to send echo"
|
||||
assert client_received_echo == "ECHO: Hello QUIC Server!", (
|
||||
f"Client received wrong echo: '{client_received_echo}'"
|
||||
)
|
||||
|
||||
print("✅ BASIC ECHO TEST PASSED!")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_server_accept_stream_timeout(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test what happens when server accept_stream times out."""
|
||||
print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===")
|
||||
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
server_peer_id = ID.from_pubkey(server_key.public_key)
|
||||
|
||||
accept_stream_called = False
|
||||
accept_stream_timeout = False
|
||||
|
||||
async def timeout_test_handler(connection: QUICConnection) -> None:
|
||||
"""Handler that tests accept_stream timeout."""
|
||||
nonlocal accept_stream_called, accept_stream_timeout
|
||||
|
||||
print("🔗 SERVER: Connection established, testing accept_stream timeout")
|
||||
accept_stream_called = True
|
||||
|
||||
try:
|
||||
print("📡 SERVER: Calling accept_stream with 2 second timeout...")
|
||||
stream = await connection.accept_stream(timeout=2.0)
|
||||
print(f"✅ SERVER: accept_stream returned: {stream}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⏰ SERVER: accept_stream timed out or failed: {e}")
|
||||
accept_stream_timeout = True
|
||||
|
||||
listener = server_transport.create_listener(timeout_test_handler)
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
client_connected = False
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start server
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
|
||||
server_addr = listener.get_addrs()[0]
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Create client but DON'T open a stream
|
||||
client_transport = QUICTransport(client_key.private_key, client_config)
|
||||
|
||||
try:
|
||||
print("📞 CLIENT: Connecting (but NOT opening stream)...")
|
||||
connection = await client_transport.dial(
|
||||
server_addr, peer_id=server_peer_id, nursery=nursery
|
||||
)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected (no stream opened)")
|
||||
|
||||
# Wait for server timeout
|
||||
await trio.sleep(3.0)
|
||||
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
finally:
|
||||
await listener.close()
|
||||
await server_transport.close()
|
||||
|
||||
print("\n📊 TIMEOUT TEST RESULTS:")
|
||||
print(f" Client connected: {client_connected}")
|
||||
print(f" accept_stream called: {accept_stream_called}")
|
||||
print(f" accept_stream timeout: {accept_stream_timeout}")
|
||||
|
||||
assert client_connected, "Client should have connected"
|
||||
assert accept_stream_called, "accept_stream should have been called"
|
||||
assert accept_stream_timeout, (
|
||||
"accept_stream should have timed out when no stream was opened"
|
||||
)
|
||||
|
||||
print("✅ TIMEOUT TEST PASSED!")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_debug_accept_stream_hanging(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Debug test to see exactly where accept_stream might be hanging."""
|
||||
print("\n=== DEBUGGING ACCEPT_STREAM HANGING ===")
|
||||
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
server_peer_id = ID.from_pubkey(server_key.public_key)
|
||||
|
||||
async def debug_handler(connection: QUICConnection) -> None:
|
||||
"""Handler with extensive debugging."""
|
||||
print(f"🔗 SERVER: Handler called for connection {id(connection)} ")
|
||||
print(f" Connection closed: {connection.is_closed}")
|
||||
print(f" Connection started: {connection._started}")
|
||||
print(f" Connection established: {connection._established}")
|
||||
|
||||
try:
|
||||
print("📡 SERVER: About to call accept_stream...")
|
||||
print(f" Accept queue length: {len(connection._stream_accept_queue)}")
|
||||
print(
|
||||
f" Accept event set: {connection._stream_accept_event.is_set()}"
|
||||
)
|
||||
|
||||
# Use a short timeout to avoid hanging the test
|
||||
with trio.move_on_after(3.0) as cancel_scope:
|
||||
stream = await connection.accept_stream()
|
||||
if stream:
|
||||
print(f"✅ SERVER: Got stream {stream.stream_id}")
|
||||
else:
|
||||
print("❌ SERVER: accept_stream returned None")
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
print("⏰ SERVER: accept_stream cancelled due to timeout")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ SERVER: Exception in accept_stream: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
listener = server_transport.create_listener(debug_handler)
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
|
||||
server_addr = listener.get_addrs()[0]
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Create client and connect
|
||||
client_transport = QUICTransport(client_key.private_key, client_config)
|
||||
|
||||
try:
|
||||
print("📞 CLIENT: Connecting...")
|
||||
connection = await client_transport.dial(
|
||||
server_addr, peer_id=server_peer_id, nursery=nursery
|
||||
)
|
||||
print("✅ CLIENT: Connected")
|
||||
|
||||
# Open stream after a short delay
|
||||
await trio.sleep(0.1)
|
||||
print("📤 CLIENT: Opening stream...")
|
||||
stream = await connection.open_stream()
|
||||
print(f"📤 CLIENT: Stream {stream.stream_id} opened")
|
||||
|
||||
# Send some data
|
||||
await stream.write(b"test data")
|
||||
print("📨 CLIENT: Data sent")
|
||||
|
||||
# Give server time to process
|
||||
await trio.sleep(1.0)
|
||||
|
||||
# Cleanup
|
||||
await stream.close()
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Cleaned up")
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
|
||||
await trio.sleep(0.5)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
finally:
|
||||
await listener.close()
|
||||
await server_transport.close()
|
||||
|
||||
print("✅ DEBUG TEST COMPLETED!")
|
||||
|
||||
@ -16,7 +16,6 @@ import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
@ -68,7 +67,6 @@ class TestBasicQUICFlow:
|
||||
|
||||
# Create server components
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
server_peer_id = ID.from_pubkey(server_key.public_key)
|
||||
|
||||
# Track test state
|
||||
server_received_data = None
|
||||
@ -153,13 +151,12 @@ class TestBasicQUICFlow:
|
||||
|
||||
# Create client transport
|
||||
client_transport = QUICTransport(client_key.private_key, client_config)
|
||||
client_transport.set_background_nursery(nursery)
|
||||
|
||||
try:
|
||||
# Connect to server
|
||||
print(f"📞 CLIENT: Connecting to {server_addr}")
|
||||
connection = await client_transport.dial(
|
||||
server_addr, peer_id=server_peer_id, nursery=nursery
|
||||
)
|
||||
connection = await client_transport.dial(server_addr)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected to server")
|
||||
|
||||
@ -248,7 +245,6 @@ class TestBasicQUICFlow:
|
||||
print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===")
|
||||
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
server_peer_id = ID.from_pubkey(server_key.public_key)
|
||||
|
||||
accept_stream_called = False
|
||||
accept_stream_timeout = False
|
||||
@ -277,6 +273,7 @@ class TestBasicQUICFlow:
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start server
|
||||
server_transport.set_background_nursery(nursery)
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
|
||||
@ -284,24 +281,26 @@ class TestBasicQUICFlow:
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Create client but DON'T open a stream
|
||||
client_transport = QUICTransport(client_key.private_key, client_config)
|
||||
|
||||
try:
|
||||
print("📞 CLIENT: Connecting (but NOT opening stream)...")
|
||||
connection = await client_transport.dial(
|
||||
server_addr, peer_id=server_peer_id, nursery=nursery
|
||||
async with trio.open_nursery() as client_nursery:
|
||||
client_transport = QUICTransport(
|
||||
client_key.private_key, client_config
|
||||
)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected (no stream opened)")
|
||||
client_transport.set_background_nursery(client_nursery)
|
||||
|
||||
# Wait for server timeout
|
||||
await trio.sleep(3.0)
|
||||
try:
|
||||
print("📞 CLIENT: Connecting (but NOT opening stream)...")
|
||||
connection = await client_transport.dial(server_addr)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected (no stream opened)")
|
||||
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
# Wait for server timeout
|
||||
await trio.sleep(3.0)
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
@ -105,7 +104,7 @@ class TestQUICTransport:
|
||||
await transport.close()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dial_closed_transport(self, transport):
|
||||
async def test_dial_closed_transport(self, transport: QUICTransport) -> None:
|
||||
"""Test dialing with closed transport raises error."""
|
||||
import multiaddr
|
||||
|
||||
@ -114,10 +113,9 @@ class TestQUICTransport:
|
||||
with pytest.raises(QUICDialError, match="Transport is closed"):
|
||||
await transport.dial(
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
ID.from_pubkey(create_new_key_pair().public_key),
|
||||
)
|
||||
|
||||
def test_create_listener_closed_transport(self, transport):
|
||||
def test_create_listener_closed_transport(self, transport: QUICTransport) -> None:
|
||||
"""Test creating listener with closed transport raises error."""
|
||||
transport._closed = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user