mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: peer verification successful
This commit is contained in:
205
examples/echo/test_handshake.py
Normal file
205
examples/echo/test_handshake.py
Normal file
@ -0,0 +1,205 @@
|
||||
from aioquic._buffer import Buffer
|
||||
from aioquic.quic.packet import pull_quic_header
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from tempfile import NamedTemporaryFile
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.security import create_quic_security_transport
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from time import time
|
||||
import os
|
||||
import trio
|
||||
|
||||
|
||||
async def test_full_handshake_and_certificate_exchange():
|
||||
"""
|
||||
Test a full handshake to ensure it completes and peer certificates are exchanged.
|
||||
FIXED VERSION: Corrects connection ID management and address handling.
|
||||
"""
|
||||
print("\n=== TESTING FULL HANDSHAKE AND CERTIFICATE EXCHANGE (FIXED) ===")
|
||||
|
||||
# 1. Generate KeyPairs and create libp2p security configs for client and server.
|
||||
client_key_pair = create_new_key_pair()
|
||||
server_key_pair = create_new_key_pair()
|
||||
|
||||
client_security_config = create_quic_security_transport(
|
||||
client_key_pair.private_key, ID.from_pubkey(client_key_pair.public_key)
|
||||
)
|
||||
server_security_config = create_quic_security_transport(
|
||||
server_key_pair.private_key, ID.from_pubkey(server_key_pair.public_key)
|
||||
)
|
||||
print("✅ libp2p security configs created.")
|
||||
|
||||
# 2. Create aioquic configurations with consistent settings
|
||||
client_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-client.log"
|
||||
)
|
||||
client_aioquic_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=client_secrets_log_file,
|
||||
connection_id_length=8, # Set consistent CID length
|
||||
)
|
||||
client_aioquic_config.certificate = client_security_config.tls_config.certificate
|
||||
client_aioquic_config.private_key = client_security_config.tls_config.private_key
|
||||
client_aioquic_config.verify_mode = (
|
||||
client_security_config.create_client_config().verify_mode
|
||||
)
|
||||
|
||||
server_secrets_log_file = NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix="-server.log"
|
||||
)
|
||||
server_aioquic_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=["libp2p"],
|
||||
secrets_log_file=server_secrets_log_file,
|
||||
connection_id_length=8, # Set consistent CID length
|
||||
)
|
||||
server_aioquic_config.certificate = server_security_config.tls_config.certificate
|
||||
server_aioquic_config.private_key = server_security_config.tls_config.private_key
|
||||
server_aioquic_config.verify_mode = (
|
||||
server_security_config.create_server_config().verify_mode
|
||||
)
|
||||
print("✅ aioquic configurations created and configured.")
|
||||
print(f"🔑 Client secrets will be logged to: {client_secrets_log_file.name}")
|
||||
print(f"🔑 Server secrets will be logged to: {server_secrets_log_file.name}")
|
||||
|
||||
# 3. Use consistent addresses - this is crucial!
|
||||
# The client will connect TO the server address, but packets will come FROM client address
|
||||
client_address = ("127.0.0.1", 1234) # Client binds to this
|
||||
server_address = ("127.0.0.1", 4321) # Server binds to this
|
||||
|
||||
# 4. Create client connection and initiate connection
|
||||
client_conn = QuicConnection(configuration=client_aioquic_config)
|
||||
# Client connects to server address - this sets up the initial packet with proper CIDs
|
||||
client_conn.connect(server_address, now=time())
|
||||
print("✅ Client connection initiated.")
|
||||
|
||||
# 5. Get the initial client packet and extract ODCID properly
|
||||
client_datagrams = client_conn.datagrams_to_send(now=time())
|
||||
if not client_datagrams:
|
||||
raise AssertionError("❌ Client did not generate initial packet")
|
||||
|
||||
client_initial_packet = client_datagrams[0][0]
|
||||
header = pull_quic_header(Buffer(data=client_initial_packet), host_cid_length=8)
|
||||
original_dcid = header.destination_cid
|
||||
client_source_cid = header.source_cid
|
||||
|
||||
print(f"📊 Client ODCID: {original_dcid.hex()}")
|
||||
print(f"📊 Client source CID: {client_source_cid.hex()}")
|
||||
|
||||
# 6. Create server connection with the correct ODCID
|
||||
server_conn = QuicConnection(
|
||||
configuration=server_aioquic_config,
|
||||
original_destination_connection_id=original_dcid,
|
||||
)
|
||||
print("✅ Server connection created with correct ODCID.")
|
||||
|
||||
# 7. Feed the initial client packet to server
|
||||
# IMPORTANT: Use client_address as the source for the packet
|
||||
for datagram, _ in client_datagrams:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
|
||||
# 8. Manual handshake loop with proper packet tracking
|
||||
max_duration_s = 3 # Increased timeout
|
||||
start_time = time()
|
||||
packet_count = 0
|
||||
|
||||
while time() - start_time < max_duration_s:
|
||||
# Process client -> server packets
|
||||
client_packets = list(client_conn.datagrams_to_send(now=time()))
|
||||
for datagram, _ in client_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Client -> Server: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
server_conn.receive_datagram(datagram, client_address, now=time())
|
||||
packet_count += 1
|
||||
|
||||
# Process server -> client packets
|
||||
server_packets = list(server_conn.datagrams_to_send(now=time()))
|
||||
for datagram, _ in server_packets:
|
||||
header = pull_quic_header(Buffer(data=datagram))
|
||||
print(
|
||||
f"📤 Server -> Client: src={header.source_cid.hex()}, dst={header.destination_cid.hex()}"
|
||||
)
|
||||
# CRITICAL: Server sends back to client_address, not server_address
|
||||
client_conn.receive_datagram(datagram, server_address, now=time())
|
||||
packet_count += 1
|
||||
|
||||
# Check for completion
|
||||
client_complete = getattr(client_conn, "_handshake_complete", False)
|
||||
server_complete = getattr(server_conn, "_handshake_complete", False)
|
||||
|
||||
print(
|
||||
f"🔄 Handshake status: Client={client_complete}, Server={server_complete}, Packets={packet_count}"
|
||||
)
|
||||
|
||||
if client_complete and server_complete:
|
||||
print("🎉 Handshake completed for both peers!")
|
||||
break
|
||||
|
||||
# If no packets were exchanged in this iteration, wait a bit
|
||||
if not client_packets and not server_packets:
|
||||
await trio.sleep(0.01)
|
||||
|
||||
# Safety check - if too many packets, something is wrong
|
||||
if packet_count > 50:
|
||||
print("⚠️ Too many packets exchanged, possible handshake loop")
|
||||
break
|
||||
|
||||
# 9. Enhanced handshake completion checks
|
||||
client_handshake_complete = getattr(client_conn, "_handshake_complete", False)
|
||||
server_handshake_complete = getattr(server_conn, "_handshake_complete", False)
|
||||
|
||||
# Debug additional state information
|
||||
print(f"🔍 Final client state: {getattr(client_conn, '_state', 'unknown')}")
|
||||
print(f"🔍 Final server state: {getattr(server_conn, '_state', 'unknown')}")
|
||||
|
||||
if hasattr(client_conn, "tls") and client_conn.tls:
|
||||
print(f"🔍 Client TLS state: {getattr(client_conn.tls, 'state', 'unknown')}")
|
||||
if hasattr(server_conn, "tls") and server_conn.tls:
|
||||
print(f"🔍 Server TLS state: {getattr(server_conn.tls, 'state', 'unknown')}")
|
||||
|
||||
# 10. Cleanup and assertions
|
||||
client_secrets_log_file.close()
|
||||
server_secrets_log_file.close()
|
||||
os.unlink(client_secrets_log_file.name)
|
||||
os.unlink(server_secrets_log_file.name)
|
||||
|
||||
# Final assertions
|
||||
assert client_handshake_complete, (
|
||||
f"❌ Client handshake did not complete. "
|
||||
f"State: {getattr(client_conn, '_state', 'unknown')}, "
|
||||
f"Packets: {packet_count}"
|
||||
)
|
||||
assert server_handshake_complete, (
|
||||
f"❌ Server handshake did not complete. "
|
||||
f"State: {getattr(server_conn, '_state', 'unknown')}, "
|
||||
f"Packets: {packet_count}"
|
||||
)
|
||||
print("✅ Handshake completed for both peers.")
|
||||
|
||||
# Certificate exchange verification
|
||||
client_peer_cert = getattr(client_conn.tls, "_peer_certificate", None)
|
||||
server_peer_cert = getattr(server_conn.tls, "_peer_certificate", None)
|
||||
|
||||
assert client_peer_cert is not None, (
|
||||
"❌ Client FAILED to receive server certificate."
|
||||
)
|
||||
print("✅ Client successfully received server certificate.")
|
||||
|
||||
assert server_peer_cert is not None, (
|
||||
"❌ Server FAILED to receive client certificate."
|
||||
)
|
||||
print("✅ Server successfully received client certificate.")
|
||||
|
||||
print("🎉 Test Passed: Full handshake and certificate exchange successful.")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(test_full_handshake_and_certificate_exchange)
|
||||
Reference in New Issue
Block a user