mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-09 14:40:53 +00:00
multiple streams ping, invalid certificate handling
This commit is contained in:
@ -17,9 +17,11 @@ from libp2p.transport.quic.exceptions import (
|
|||||||
QUICConnectionClosedError,
|
QUICConnectionClosedError,
|
||||||
QUICConnectionError,
|
QUICConnectionError,
|
||||||
QUICConnectionTimeoutError,
|
QUICConnectionTimeoutError,
|
||||||
|
QUICPeerVerificationError,
|
||||||
QUICStreamLimitError,
|
QUICStreamLimitError,
|
||||||
QUICStreamTimeoutError,
|
QUICStreamTimeoutError,
|
||||||
)
|
)
|
||||||
|
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||||
from libp2p.transport.quic.stream import QUICStream, StreamDirection
|
from libp2p.transport.quic.stream import QUICStream, StreamDirection
|
||||||
|
|
||||||
|
|
||||||
@ -499,3 +501,43 @@ class TestQUICConnection:
|
|||||||
|
|
||||||
mock_resource_scope.release_memory(2000) # Should not go negative
|
mock_resource_scope.release_memory(2000) # Should not go negative
|
||||||
assert mock_resource_scope.memory_reserved == 0
|
assert mock_resource_scope.memory_reserved == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_invalid_certificate_verification():
|
||||||
|
key_pair1 = create_new_key_pair()
|
||||||
|
key_pair2 = create_new_key_pair()
|
||||||
|
|
||||||
|
peer_id1 = ID.from_pubkey(key_pair1.public_key)
|
||||||
|
peer_id2 = ID.from_pubkey(key_pair2.public_key)
|
||||||
|
|
||||||
|
manager = QUICTLSConfigManager(
|
||||||
|
libp2p_private_key=key_pair1.private_key, peer_id=peer_id1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match the certificate against a different peer_id
|
||||||
|
with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"):
|
||||||
|
manager.verify_peer_identity(manager.tls_config.certificate, peer_id2)
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.serialization import Encoding
|
||||||
|
|
||||||
|
# --- Corrupt the certificate by tampering the DER bytes ---
|
||||||
|
cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER)
|
||||||
|
corrupted_bytes = bytearray(cert_bytes)
|
||||||
|
|
||||||
|
# Flip some random bytes in the middle of the certificate
|
||||||
|
corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF
|
||||||
|
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
# This will still parse (structurally valid), but the signature
|
||||||
|
# or fingerprint will break
|
||||||
|
corrupted_cert = x509.load_der_x509_certificate(
|
||||||
|
bytes(corrupted_bytes), backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
QUICPeerVerificationError, match="Certificate verification failed"
|
||||||
|
):
|
||||||
|
manager.verify_peer_identity(corrupted_cert, peer_id1)
|
||||||
|
|||||||
@ -13,9 +13,14 @@ This test focuses on identifying where the accept_stream issue occurs.
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import multiaddr
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID
|
||||||
|
from libp2p import new_host
|
||||||
|
from libp2p.abc import INetStream
|
||||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from libp2p.transport.quic.config import QUICTransportConfig
|
from libp2p.transport.quic.config import QUICTransportConfig
|
||||||
from libp2p.transport.quic.connection import QUICConnection
|
from libp2p.transport.quic.connection import QUICConnection
|
||||||
from libp2p.transport.quic.transport import QUICTransport
|
from libp2p.transport.quic.transport import QUICTransport
|
||||||
@ -320,3 +325,87 @@ class TestBasicQUICFlow:
|
|||||||
)
|
)
|
||||||
|
|
||||||
print("✅ TIMEOUT TEST PASSED!")
|
print("✅ TIMEOUT TEST PASSED!")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_yamux_stress_ping():
|
||||||
|
STREAM_COUNT = 100
|
||||||
|
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||||
|
latencies = []
|
||||||
|
failures = []
|
||||||
|
|
||||||
|
# === Server Setup ===
|
||||||
|
server_host = new_host(listen_addrs=[listen_addr])
|
||||||
|
|
||||||
|
async def handle_ping(stream: INetStream) -> None:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
payload = await stream.read(PING_LENGTH)
|
||||||
|
if not payload:
|
||||||
|
break
|
||||||
|
await stream.write(payload)
|
||||||
|
except Exception:
|
||||||
|
await stream.reset()
|
||||||
|
|
||||||
|
server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||||
|
|
||||||
|
async with server_host.run(listen_addrs=[listen_addr]):
|
||||||
|
# Give server time to start
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
# === Client Setup ===
|
||||||
|
destination = str(server_host.get_addrs()[0])
|
||||||
|
maddr = multiaddr.Multiaddr(destination)
|
||||||
|
info = info_from_p2p_addr(maddr)
|
||||||
|
|
||||||
|
client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||||
|
client_host = new_host(listen_addrs=[client_listen_addr])
|
||||||
|
|
||||||
|
async with client_host.run(listen_addrs=[client_listen_addr]):
|
||||||
|
await client_host.connect(info)
|
||||||
|
|
||||||
|
async def ping_stream(i: int):
|
||||||
|
try:
|
||||||
|
start = trio.current_time()
|
||||||
|
stream = await client_host.new_stream(
|
||||||
|
info.peer_id, [PING_PROTOCOL_ID]
|
||||||
|
)
|
||||||
|
|
||||||
|
await stream.write(b"\x01" * PING_LENGTH)
|
||||||
|
|
||||||
|
with trio.fail_after(5):
|
||||||
|
response = await stream.read(PING_LENGTH)
|
||||||
|
|
||||||
|
if response == b"\x01" * PING_LENGTH:
|
||||||
|
latency_ms = int((trio.current_time() - start) * 1000)
|
||||||
|
latencies.append(latency_ms)
|
||||||
|
print(f"[Ping #{i}] Latency: {latency_ms} ms")
|
||||||
|
await stream.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Ping #{i}] Failed: {e}")
|
||||||
|
failures.append(i)
|
||||||
|
await stream.reset()
|
||||||
|
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for i in range(STREAM_COUNT):
|
||||||
|
nursery.start_soon(ping_stream, i)
|
||||||
|
|
||||||
|
# === Result Summary ===
|
||||||
|
print("\n📊 Ping Stress Test Summary")
|
||||||
|
print(f"Total Streams Launched: {STREAM_COUNT}")
|
||||||
|
print(f"Successful Pings: {len(latencies)}")
|
||||||
|
print(f"Failed Pings: {len(failures)}")
|
||||||
|
if failures:
|
||||||
|
print(f"❌ Failed stream indices: {failures}")
|
||||||
|
|
||||||
|
# === Assertions ===
|
||||||
|
assert len(latencies) == STREAM_COUNT, (
|
||||||
|
f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}"
|
||||||
|
)
|
||||||
|
assert all(isinstance(x, int) and x >= 0 for x in latencies), (
|
||||||
|
"Invalid latencies"
|
||||||
|
)
|
||||||
|
|
||||||
|
avg_latency = sum(latencies) / len(latencies)
|
||||||
|
print(f"✅ Average Latency: {avg_latency:.2f} ms")
|
||||||
|
assert avg_latency < 1000
|
||||||
|
|||||||
Reference in New Issue
Block a user