mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
205 lines
8.6 KiB
Python
205 lines
8.6 KiB
Python
from aioquic._buffer import Buffer
|
|
from aioquic.quic.packet import pull_quic_header
|
|
from aioquic.quic.connection import QuicConnection
|
|
from aioquic.quic.configuration import QuicConfiguration
|
|
from tempfile import NamedTemporaryFile
|
|
from libp2p.peer.id import ID
|
|
from libp2p.transport.quic.security import create_quic_security_transport
|
|
from libp2p.crypto.ed25519 import create_new_key_pair
|
|
from time import time
|
|
import os
|
|
import trio
|
|
|
|
|
|
async def test_full_handshake_and_certificate_exchange():
|
|
"""
|
|
Test a full handshake to ensure it completes and peer certificates are exchanged.
|
|
FIXED VERSION: Corrects connection ID management and address handling.
|
|
"""
|
|
print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (FIXED) ===")
|
|
|
|
# 1. Generate KeyPairs and create libp2p security configs for client and server.
|
|
client_key_pair = create_new_key_pair()
|
|
server_key_pair = create_new_key_pair()
|
|
|
|
client_security_config = create_quic_security_transport(
|
|
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
|
)
|
|
server_security_config = create_quic_security_transport(
|
|
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
|
)
|
|
print("✅ libp2p security configs created.")
|
|
|
|
# 2. Create aioquic configurations with consistent settings
|
|
client_secrets_log_file = NamedTemporaryFile(
|
|
mode="w", delete=False, suffix="-client.log"
|
|
)
|
|
client_aioquic_config = QuicConfiguration(
|
|
is_client=True,
|
|
alpn_protocols=["libp2p"],
|
|
secrets_log_file=client_secrets_log_file,
|
|
connection_id_length=8, # Set consistent CID length
|
|
)
|
|
client_aioquic_config.certificate = client_security_config.tls_config.certificate
|
|
client_aioquic_config.private_key = client_security_config.tls_config.private_key
|
|
client_aioquic_config.verify_mode = (
|
|
client_security_config.create_client_config().verify_mode
|
|
)
|
|
|
|
server_secrets_log_file = NamedTemporaryFile(
|
|
mode="w", delete=False, suffix="-server.log"
|
|
)
|
|
server_aioquic_config = QuicConfiguration(
|
|
is_client=False,
|
|
alpn_protocols=["libp2p"],
|
|
secrets_log_file=server_secrets_log_file,
|
|
connection_id_length=8, # Set consistent CID length
|
|
)
|
|
server_aioquic_config.certificate = server_security_config.tls_config.certificate
|
|
server_aioquic_config.private_key = server_security_config.tls_config.private_key
|
|
server_aioquic_config.verify_mode = (
|
|
server_security_config.create_server_config().verify_mode
|
|
)
|
|
print("✅ aioquic configurations created and configured.")
|
|
print(f"🔑 Client secrets will be logged to: {client_secrets_log_file.name}")
|
|
print(f"🔑 Server secrets will be logged to: {server_secrets_log_file.name}")
|
|
|
|
# 3. Use consistent addresses - this is crucial!
|
|
# The client will connect TO the server address, but packets will come FROM client address
|
|
client_address = ("127.0.0.1", 1234) # Client binds to this
|
|
server_address = ("127.0.0.1", 4321) # Server binds to this
|
|
|
|
# 4. Create client connection and initiate connection
|
|
client_conn = QuicConnection(configuration=client_aioquic_config)
|
|
# Client connects to server address - this sets up the initial packet with proper CIDs
|
|
client_conn.connect(server_address, now=time())
|
|
print("✅ Client connection initiated.")
|
|
|
|
# 5. Get the initial client packet and extract ODCID properly
|
|
client_datagrams = client_conn.datagrams_to_send(now=time())
|
|
if not client_datagrams:
|
|
raise AssertionError("❌ Client did not generate initial packet")
|
|
|
|
client_initial_packet = client_datagrams[0][0]
|
|
header = pull_quic_header(Buffer(data=client_initial_packet), host_cid_length=8)
|
|
original_dcid = header.destination_cid
|
|
client_source_cid = header.source_cid
|
|
|
|
print(f"📊 Client ODCID: {original_dcid.hex()}")
|
|
print(f"📊 Client source CID: {client_source_cid.hex()}")
|
|
|
|
# 6. Create server connection with the correct ODCID
|
|
server_conn = QuicConnection(
|
|
configuration=server_aioquic_config,
|
|
original_destination_connection_id=original_dcid,
|
|
)
|
|
print("✅ Server connection created with correct ODCID.")
|
|
|
|
# 7. Feed the initial client packet to server
|
|
# IMPORTANT: Use client_address as the source for the packet
|
|
for datagram, _ in client_datagrams:
|
|
header = pull_quic_header(Buffer(data=datagram))
|
|
print(
|
|
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
|
)
|
|
server_conn.receive_datagram(datagram, client_address, now=time())
|
|
|
|
# 8. Manual handshake loop with proper packet tracking
|
|
max_duration_s = 3 # Increased timeout
|
|
start_time = time()
|
|
packet_count = 0
|
|
|
|
while time() - start_time < max_duration_s:
|
|
# Process client -> server packets
|
|
client_packets = list(client_conn.datagrams_to_send(now=time()))
|
|
for datagram, _ in client_packets:
|
|
header = pull_quic_header(Buffer(data=datagram))
|
|
print(
|
|
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
|
)
|
|
server_conn.receive_datagram(datagram, client_address, now=time())
|
|
packet_count += 1
|
|
|
|
# Process server -> client packets
|
|
server_packets = list(server_conn.datagrams_to_send(now=time()))
|
|
for datagram, _ in server_packets:
|
|
header = pull_quic_header(Buffer(data=datagram))
|
|
print(
|
|
f"📤 Server -> Client: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
|
)
|
|
# CRITICAL: Server sends back to client_address, not server_address
|
|
client_conn.receive_datagram(datagram, server_address, now=time())
|
|
packet_count += 1
|
|
|
|
# Check for completion
|
|
client_complete = getattr(client_conn, "_handshake_complete", False)
|
|
server_complete = getattr(server_conn, "_handshake_complete", False)
|
|
|
|
print(
|
|
f"🔄 Handshake status: Client={client_complete}, Server={server_complete}, Packets={packet_count}"
|
|
)
|
|
|
|
if client_complete and server_complete:
|
|
print("🎉 Handshake completed for both peers!")
|
|
break
|
|
|
|
# If no packets were exchanged in this iteration, wait a bit
|
|
if not client_packets and not server_packets:
|
|
await trio.sleep(0.01)
|
|
|
|
# Safety check - if too many packets, something is wrong
|
|
if packet_count > 50:
|
|
print("⚠️ Too many packets exchanged, possible handshake loop")
|
|
break
|
|
|
|
# 9. Enhanced handshake completion checks
|
|
client_handshake_complete = getattr(client_conn, "_handshake_complete", False)
|
|
server_handshake_complete = getattr(server_conn, "_handshake_complete", False)
|
|
|
|
# Debug additional state information
|
|
print(f"🔍 Final client state: {getattr(client_conn, '_state', 'unknown')}")
|
|
print(f"🔍 Final server state: {getattr(server_conn, '_state', 'unknown')}")
|
|
|
|
if hasattr(client_conn, "tls") and client_conn.tls:
|
|
print(f"🔍 Client TLS state: {getattr(client_conn.tls, 'state', 'unknown')}")
|
|
if hasattr(server_conn, "tls") and server_conn.tls:
|
|
print(f"🔍 Server TLS state: {getattr(server_conn.tls, 'state', 'unknown')}")
|
|
|
|
# 10. Cleanup and assertions
|
|
client_secrets_log_file.close()
|
|
server_secrets_log_file.close()
|
|
os.unlink(client_secrets_log_file.name)
|
|
os.unlink(server_secrets_log_file.name)
|
|
|
|
# Final assertions
|
|
assert client_handshake_complete, (
|
|
f"❌ Client handshake did not complete. "
|
|
f"State: {getattr(client_conn, '_state', 'unknown')}, "
|
|
f"Packets: {packet_count}"
|
|
)
|
|
assert server_handshake_complete, (
|
|
f"❌ Server handshake did not complete. "
|
|
f"State: {getattr(server_conn, '_state', 'unknown')}, "
|
|
f"Packets: {packet_count}"
|
|
)
|
|
print("✅ Handshake completed for both peers.")
|
|
|
|
# Certificate exchange verification
|
|
client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None)
|
|
server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None)
|
|
|
|
assert client_peer_cert is not None, (
|
|
"❌ Client FAILED to receive server certificate."
|
|
)
|
|
print("✅ Client successfully received server certificate.")
|
|
|
|
assert server_peer_cert is not None, (
|
|
"❌ Server FAILED to receive client certificate."
|
|
)
|
|
print("✅ Server successfully received client certificate.")
|
|
|
|
print("🎉 Test Passed: Full handshake and certificate exchange successful.")
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
trio.run(test_full_handshake_and_certificate_exchange) |