Files
py-libp2p/examples/echo/test_handshake.py
2025-08-30 14:08:06 +05:30

205 lines
8.6 KiB
Python

from aioquic._buffer import Buffer
from aioquic.quic.packet import pull_quic_header
from aioquic.quic.connection import QuicConnection
from aioquic.quic.configuration import QuicConfiguration
from tempfile import NamedTemporaryFile
from libp2p.peer.id import ID
from libp2p.transport.quic.security import create_quic_security_transport
from libp2p.crypto.ed25519 import create_new_key_pair
from time import time
import os
import trio
async def test_full_handshake_and_certificate_exchange():
"""
Test a full handshake to ensure it completes and peer certificates are exchanged.
FIXED VERSION: Corrects connection ID management and address handling.
"""
print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (FIXED) ===")
# 1. Generate KeyPairs and create libp2p security configs for client and server.
client_key_pair = create_new_key_pair()
server_key_pair = create_new_key_pair()
client_security_config = create_quic_security_transport(
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
)
server_security_config = create_quic_security_transport(
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
)
print("✅ libp2p security configs created.")
# 2. Create aioquic configurations with consistent settings
client_secrets_log_file = NamedTemporaryFile(
mode="w", delete=False, suffix="-client.log"
)
client_aioquic_config = QuicConfiguration(
is_client=True,
alpn_protocols=["libp2p"],
secrets_log_file=client_secrets_log_file,
connection_id_length=8, # Set consistent CID length
)
client_aioquic_config.certificate = client_security_config.tls_config.certificate
client_aioquic_config.private_key = client_security_config.tls_config.private_key
client_aioquic_config.verify_mode = (
client_security_config.create_client_config().verify_mode
)
server_secrets_log_file = NamedTemporaryFile(
mode="w", delete=False, suffix="-server.log"
)
server_aioquic_config = QuicConfiguration(
is_client=False,
alpn_protocols=["libp2p"],
secrets_log_file=server_secrets_log_file,
connection_id_length=8, # Set consistent CID length
)
server_aioquic_config.certificate = server_security_config.tls_config.certificate
server_aioquic_config.private_key = server_security_config.tls_config.private_key
server_aioquic_config.verify_mode = (
server_security_config.create_server_config().verify_mode
)
print("✅ aioquic configurations created and configured.")
print(f"🔑 Client secrets will be logged to: {client_secrets_log_file.name}")
print(f"🔑 Server secrets will be logged to: {server_secrets_log_file.name}")
# 3. Use consistent addresses - this is crucial!
# The client will connect TO the server address, but packets will come FROM client address
client_address = ("127.0.0.1", 1234) # Client binds to this
server_address = ("127.0.0.1", 4321) # Server binds to this
# 4. Create client connection and initiate connection
client_conn = QuicConnection(configuration=client_aioquic_config)
# Client connects to server address - this sets up the initial packet with proper CIDs
client_conn.connect(server_address, now=time())
print("✅ Client connection initiated.")
# 5. Get the initial client packet and extract ODCID properly
client_datagrams = client_conn.datagrams_to_send(now=time())
if not client_datagrams:
raise AssertionError("❌ Client did not generate initial packet")
client_initial_packet = client_datagrams[0][0]
header = pull_quic_header(Buffer(data=client_initial_packet), host_cid_length=8)
original_dcid = header.destination_cid
client_source_cid = header.source_cid
print(f"📊 Client ODCID: {original_dcid.hex()}")
print(f"📊 Client source CID: {client_source_cid.hex()}")
# 6. Create server connection with the correct ODCID
server_conn = QuicConnection(
configuration=server_aioquic_config,
original_destination_connection_id=original_dcid,
)
print("✅ Server connection created with correct ODCID.")
# 7. Feed the initial client packet to server
# IMPORTANT: Use client_address as the source for the packet
for datagram, _ in client_datagrams:
header = pull_quic_header(Buffer(data=datagram))
print(
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
)
server_conn.receive_datagram(datagram, client_address, now=time())
# 8. Manual handshake loop with proper packet tracking
max_duration_s = 3 # Increased timeout
start_time = time()
packet_count = 0
while time() - start_time < max_duration_s:
# Process client -> server packets
client_packets = list(client_conn.datagrams_to_send(now=time()))
for datagram, _ in client_packets:
header = pull_quic_header(Buffer(data=datagram))
print(
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
)
server_conn.receive_datagram(datagram, client_address, now=time())
packet_count += 1
# Process server -> client packets
server_packets = list(server_conn.datagrams_to_send(now=time()))
for datagram, _ in server_packets:
header = pull_quic_header(Buffer(data=datagram))
print(
f"📤 Server -> Client: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
)
# CRITICAL: Server sends back to client_address, not server_address
client_conn.receive_datagram(datagram, server_address, now=time())
packet_count += 1
# Check for completion
client_complete = getattr(client_conn, "_handshake_complete", False)
server_complete = getattr(server_conn, "_handshake_complete", False)
print(
f"🔄 Handshake status: Client={client_complete}, Server={server_complete}, Packets={packet_count}"
)
if client_complete and server_complete:
print("🎉 Handshake completed for both peers!")
break
# If no packets were exchanged in this iteration, wait a bit
if not client_packets and not server_packets:
await trio.sleep(0.01)
# Safety check - if too many packets, something is wrong
if packet_count > 50:
print("⚠️ Too many packets exchanged, possible handshake loop")
break
# 9. Enhanced handshake completion checks
client_handshake_complete = getattr(client_conn, "_handshake_complete", False)
server_handshake_complete = getattr(server_conn, "_handshake_complete", False)
# Debug additional state information
print(f"🔍 Final client state: {getattr(client_conn, '_state', 'unknown')}")
print(f"🔍 Final server state: {getattr(server_conn, '_state', 'unknown')}")
if hasattr(client_conn, "tls") and client_conn.tls:
print(f"🔍 Client TLS state: {getattr(client_conn.tls, 'state', 'unknown')}")
if hasattr(server_conn, "tls") and server_conn.tls:
print(f"🔍 Server TLS state: {getattr(server_conn.tls, 'state', 'unknown')}")
# 10. Cleanup and assertions
client_secrets_log_file.close()
server_secrets_log_file.close()
os.unlink(client_secrets_log_file.name)
os.unlink(server_secrets_log_file.name)
# Final assertions
assert client_handshake_complete, (
f"❌ Client handshake did not complete. "
f"State: {getattr(client_conn, '_state', 'unknown')}, "
f"Packets: {packet_count}"
)
assert server_handshake_complete, (
f"❌ Server handshake did not complete. "
f"State: {getattr(server_conn, '_state', 'unknown')}, "
f"Packets: {packet_count}"
)
print("✅ Handshake completed for both peers.")
# Certificate exchange verification
client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None)
server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None)
assert client_peer_cert is not None, (
"❌ Client FAILED to receive server certificate."
)
print("✅ Client successfully received server certificate.")
assert server_peer_cert is not None, (
"❌ Server FAILED to receive client certificate."
)
print("✅ Server successfully received client certificate.")
print("🎉 Test Passed: Full handshake and certificate exchange successful.")
return True
if __name__ == "__main__":
trio.run(test_full_handshake_and_certificate_exchange)