diff --git a/tests/core/transport/quic/test_connection.py b/tests/core/transport/quic/test_connection.py index 687e4ec0..06e304a9 100644 --- a/tests/core/transport/quic/test_connection.py +++ b/tests/core/transport/quic/test_connection.py @@ -17,9 +17,11 @@ from libp2p.transport.quic.exceptions import ( QUICConnectionClosedError, QUICConnectionError, QUICConnectionTimeoutError, + QUICPeerVerificationError, QUICStreamLimitError, QUICStreamTimeoutError, ) +from libp2p.transport.quic.security import QUICTLSConfigManager from libp2p.transport.quic.stream import QUICStream, StreamDirection @@ -499,3 +501,43 @@ class TestQUICConnection: mock_resource_scope.release_memory(2000) # Should not go negative assert mock_resource_scope.memory_reserved == 0 + + +@pytest.mark.trio +async def test_invalid_certificate_verification(): + key_pair1 = create_new_key_pair() + key_pair2 = create_new_key_pair() + + peer_id1 = ID.from_pubkey(key_pair1.public_key) + peer_id2 = ID.from_pubkey(key_pair2.public_key) + + manager = QUICTLSConfigManager( + libp2p_private_key=key_pair1.private_key, peer_id=peer_id1 + ) + + # Match the certificate against a different peer_id + with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"): + manager.verify_peer_identity(manager.tls_config.certificate, peer_id2) + + from cryptography.hazmat.primitives.serialization import Encoding + + # --- Corrupt the certificate by tampering the DER bytes --- + cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER) + corrupted_bytes = bytearray(cert_bytes) + + # Flip some random bytes in the middle of the certificate + corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF + + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + + # This will still parse (structurally valid), but the signature + # or fingerprint will break + corrupted_cert = x509.load_der_x509_certificate( + bytes(corrupted_bytes), backend=default_backend() + ) + + with pytest.raises( + QUICPeerVerificationError, match="Certificate verification failed" + ): + manager.verify_peer_identity(corrupted_cert, peer_id1) diff --git a/tests/core/transport/quic/test_integration.py b/tests/core/transport/quic/test_integration.py index dfa28565..4edddf07 100644 --- a/tests/core/transport/quic/test_integration.py +++ b/tests/core/transport/quic/test_integration.py @@ -13,9 +13,14 @@ This test focuses on identifying where the accept_stream issue occurs. import logging import pytest +import multiaddr import trio +from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID +from libp2p import new_host +from libp2p.abc import INetStream from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.transport.quic.config import QUICTransportConfig from libp2p.transport.quic.connection import QUICConnection from libp2p.transport.quic.transport import QUICTransport @@ -320,3 +325,87 @@ class TestBasicQUICFlow: ) print("āœ… TIMEOUT TEST PASSED!") + + +@pytest.mark.trio +async def test_yamux_stress_ping(): + STREAM_COUNT = 100 + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + latencies = [] + failures = [] + + # === Server Setup === + server_host = new_host(listen_addrs=[listen_addr]) + + async def handle_ping(stream: INetStream) -> None: + try: + while True: + payload = await stream.read(PING_LENGTH) + if not payload: + break + await stream.write(payload) + except Exception: + await stream.reset() + + server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping) + + async with server_host.run(listen_addrs=[listen_addr]): + # Give server time to start + await trio.sleep(0.1) + + # === Client Setup === + destination = str(server_host.get_addrs()[0]) + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + + client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + client_host = new_host(listen_addrs=[client_listen_addr]) + + async with client_host.run(listen_addrs=[client_listen_addr]): + await client_host.connect(info) + + async def ping_stream(i: int): + try: + start = trio.current_time() + stream = await client_host.new_stream( + info.peer_id, [PING_PROTOCOL_ID] + ) + + await stream.write(b"\x01" * PING_LENGTH) + + with trio.fail_after(5): + response = await stream.read(PING_LENGTH) + + if response == b"\x01" * PING_LENGTH: + latency_ms = int((trio.current_time() - start) * 1000) + latencies.append(latency_ms) + print(f"[Ping #{i}] Latency: {latency_ms} ms") + await stream.close() + except Exception as e: + print(f"[Ping #{i}] Failed: {e}") + failures.append(i) + await stream.reset() + + async with trio.open_nursery() as nursery: + for i in range(STREAM_COUNT): + nursery.start_soon(ping_stream, i) + + # === Result Summary === + print("\nšŸ“Š Ping Stress Test Summary") + print(f"Total Streams Launched: {STREAM_COUNT}") + print(f"Successful Pings: {len(latencies)}") + print(f"Failed Pings: {len(failures)}") + if failures: + print(f"āŒ Failed stream indices: {failures}") + + # === Assertions === + assert len(latencies) == STREAM_COUNT, ( + f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}" + ) + assert all(isinstance(x, int) and x >= 0 for x in latencies), ( + "Invalid latencies" + ) + + avg_latency = sum(latencies) / len(latencies) + print(f"āœ… Average Latency: {avg_latency:.2f} ms") + assert avg_latency < 1000