fix: add quic utils test and improve connection performance

This commit is contained in:
Akash Mondal
2025-09-04 21:25:13 +00:00
parent 2ee3e0b054
commit 2fe5882013
5 changed files with 525 additions and 455 deletions

View File

@ -3,14 +3,16 @@ QUIC Connection implementation.
Manages bidirectional QUIC connections with integrated stream multiplexing.
"""
from collections import defaultdict
from collections.abc import Awaitable, Callable
import logging
import socket
import time
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, cast
from aioquic.quic import events
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import QuicEvent
from cryptography import x509
import multiaddr
import trio
@ -104,12 +106,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._connected_event = trio.Event()
self._closed_event = trio.Event()
# Stream management
self._streams: dict[int, QUICStream] = {}
self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups
self._next_stream_id: int = self._calculate_initial_stream_id()
self._stream_handler: TQUICStreamHandlerFn | None = None
self._stream_id_lock = trio.Lock()
self._stream_count_lock = trio.Lock()
# Single lock for all stream operations
self._stream_lock = trio.Lock()
# Stream counting and limits
self._outbound_stream_count = 0
@ -118,7 +121,6 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Stream acceptance for incoming streams
self._stream_accept_queue: list[QUICStream] = []
self._stream_accept_event = trio.Event()
self._accept_queue_lock = trio.Lock()
# Connection state
self._closed: bool = False
@ -143,9 +145,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._retired_connection_ids: set[bytes] = set()
self._connection_id_sequence_numbers: set[int] = set()
# Event processing control
# Event processing control with batching
self._event_processing_active = False
self._pending_events: list[events.QuicEvent] = []
self._event_batch: list[events.QuicEvent] = []
self._event_batch_size = 10
self._last_event_time = 0.0
# Set quic connection configuration
self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT
@ -250,6 +254,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""Get the current connection ID."""
return self._current_connection_id
# Fast stream lookup with caching
def _get_stream_fast(self, stream_id: int) -> QUICStream | None:
"""Get stream with caching for performance."""
# Try cache first
stream = self._stream_cache.get(stream_id)
if stream is not None:
return stream
# Fallback to main dict
stream = self._streams.get(stream_id)
if stream is not None:
self._stream_cache[stream_id] = stream
return stream
# Connection lifecycle methods
async def start(self) -> None:
@ -389,8 +408,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
try:
while not self._closed:
# Process QUIC events
await self._process_quic_events()
# Batch process events
await self._process_quic_events_batched()
# Handle timer events
await self._handle_timer_events()
@ -421,12 +440,25 @@ class QUICConnection(IRawConnection, IMuxedConn):
cid_stats = self.get_connection_id_stats()
logger.debug(f"Connection ID stats: {cid_stats}")
# Clean cache periodically
await self._cleanup_cache()
# Sleep for maintenance interval
await trio.sleep(30.0) # 30 seconds
except Exception as e:
logger.error(f"Error in periodic maintenance: {e}")
async def _cleanup_cache(self) -> None:
"""Clean up stream cache periodically to prevent memory leaks."""
if len(self._stream_cache) > 100: # Arbitrary threshold
# Remove closed streams from cache
closed_stream_ids = [
sid for sid, stream in self._stream_cache.items() if stream.is_closed()
]
for sid in closed_stream_ids:
self._stream_cache.pop(sid, None)
async def _client_packet_receiver(self) -> None:
"""Receive packets for client connections."""
logger.debug("Starting client packet receiver")
@ -442,8 +474,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Feed packet to QUIC connection
self._quic.receive_datagram(data, addr, now=time.time())
# Process any events that result from the packet
await self._process_quic_events()
# Batch process events
await self._process_quic_events_batched()
# Send any response packets
await self._transmit()
@ -675,15 +707,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
if not self._started:
raise QUICConnectionError("Connection not started")
# Check stream limits
async with self._stream_count_lock:
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
f"Maximum outbound streams ({self.MAX_OUTGOING_STREAMS}) reached"
)
# Use single lock for all stream operations
with trio.move_on_after(timeout):
async with self._stream_id_lock:
async with self._stream_lock:
# Check stream limits inside lock
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
"Maximum outbound streams "
f"({self.MAX_OUTGOING_STREAMS}) reached"
)
# Generate next stream ID
stream_id = self._next_stream_id
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
@ -697,10 +730,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
)
self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
async with self._stream_count_lock:
self._outbound_stream_count += 1
self._stats["streams_opened"] += 1
self._outbound_stream_count += 1
self._stats["streams_opened"] += 1
logger.debug(f"Opened outbound QUIC stream {stream_id}")
return stream
@ -737,7 +770,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
if self._closed:
raise MuxedConnUnavailable("QUIC connection is closed")
async with self._accept_queue_lock:
# Use single lock for stream acceptance
async with self._stream_lock:
if self._stream_accept_queue:
stream = self._stream_accept_queue.pop(0)
logger.debug(f"Accepted inbound stream {stream.stream_id}")
@ -769,10 +803,12 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""
if stream_id in self._streams:
stream = self._streams.pop(stream_id)
# Remove from cache too
self._stream_cache.pop(stream_id, None)
# Update stream counts asynchronously
async def update_counts() -> None:
async with self._stream_count_lock:
async with self._stream_lock:
if stream.direction == StreamDirection.OUTBOUND:
self._outbound_stream_count = max(
0, self._outbound_stream_count - 1
@ -789,29 +825,140 @@ class QUICConnection(IRawConnection, IMuxedConn):
logger.debug(f"Removed stream {stream_id} from connection")
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
# Batched event processing to reduce overhead
async def _process_quic_events_batched(self) -> None:
"""Process QUIC events in batches for better performance."""
if self._event_processing_active:
return # Prevent recursion
self._event_processing_active = True
try:
current_time = time.time()
events_processed = 0
while True:
# Collect events into batch
while events_processed < self._event_batch_size:
event = self._quic.next_event()
if event is None:
break
self._event_batch.append(event)
events_processed += 1
await self._handle_quic_event(event)
if events_processed > 0:
logger.debug(f"Processed {events_processed} QUIC events")
# Process batch if we have events or timeout
if self._event_batch and (
len(self._event_batch) >= self._event_batch_size
or current_time - self._last_event_time > 0.01 # 10ms timeout
):
await self._process_event_batch()
self._event_batch.clear()
self._last_event_time = current_time
finally:
self._event_processing_active = False
async def _process_event_batch(self) -> None:
"""Process a batch of events efficiently."""
if not self._event_batch:
return
# Group events by type for batch processing where possible
events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list)
for event in self._event_batch:
events_by_type[type(event).__name__].append(event)
# Process events by type
for event_type, event_list in events_by_type.items():
if event_type == type(events.StreamDataReceived).__name__:
await self._handle_stream_data_batch(
cast(list[events.StreamDataReceived], event_list)
)
else:
# Process other events individually
for event in event_list:
await self._handle_quic_event(event)
logger.debug(f"Processed batch of {len(self._event_batch)} events")
async def _handle_stream_data_batch(
self, events_list: list[events.StreamDataReceived]
) -> None:
"""Handle stream data events in batch for better performance."""
# Group by stream ID
events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list)
for event in events_list:
events_by_stream[event.stream_id].append(event)
# Process each stream's events
for stream_id, stream_events in events_by_stream.items():
stream = self._get_stream_fast(stream_id) # Use fast lookup
if not stream:
if self._is_incoming_stream(stream_id):
try:
stream = await self._create_inbound_stream(stream_id)
except QUICStreamLimitError:
# Reset stream if we can't handle it
self._quic.reset_stream(stream_id, error_code=0x04)
await self._transmit()
continue
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
continue
# Process all events for this stream
for received_event in stream_events:
if hasattr(received_event, "data"):
self._stats["bytes_received"] += len(received_event.data) # type: ignore
if hasattr(received_event, "end_stream"):
await stream.handle_data_received(
received_event.data, # type: ignore
received_event.end_stream, # type: ignore
)
async def _create_inbound_stream(self, stream_id: int) -> QUICStream:
"""Create inbound stream with proper limit checking."""
async with self._stream_lock:
# Double-check stream doesn't exist
existing_stream = self._streams.get(stream_id)
if existing_stream:
return existing_stream
# Check limits
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting inbound stream {stream_id}: limit reached")
raise QUICStreamLimitError("Too many inbound streams")
# Create stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Created inbound stream {stream_id}")
return stream
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
# Delegate to batched processing for better performance
await self._process_quic_events_batched()
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
"""Handle a single QUIC event with COMPLETE event type coverage."""
logger.debug(f"Handling QUIC event: {type(event).__name__}")
@ -929,8 +1076,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
f"stream_id={event.stream_id}, error_code={event.error_code}"
)
if event.stream_id in self._streams:
stream: QUICStream = self._streams[event.stream_id]
# Use fast lookup
stream = self._get_stream_fast(event.stream_id)
if stream:
# Handle stop sending on the stream if method exists
await stream.handle_stop_sending(event.error_code)
@ -964,6 +1112,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
await stream.close()
self._streams.clear()
self._stream_cache.clear() # Clear cache too
self._closed = True
self._closed_event.set()
@ -978,39 +1127,19 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._stats["bytes_received"] += len(event.data)
try:
if stream_id not in self._streams:
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if not stream:
if self._is_incoming_stream(stream_id):
logger.debug(f"Creating new incoming stream {stream_id}")
from .stream import QUICStream, StreamDirection
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
# Store the stream
self._streams[stream_id] = stream
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Added stream {stream_id} to accept queue")
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_opened"] += 1
stream = await self._create_inbound_stream(stream_id)
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
return
stream = self._streams[stream_id]
await stream.handle_data_received(event.data, event.end_stream)
except Exception as e:
@ -1019,8 +1148,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
"""Get existing stream or create new inbound stream."""
if stream_id in self._streams:
return self._streams[stream_id]
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if stream:
return stream
# Check if this is an incoming stream
is_incoming = self._is_incoming_stream(stream_id)
@ -1031,49 +1162,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
f"Received data for unknown outbound stream {stream_id}"
)
# Check stream limits for incoming streams
async with self._stream_count_lock:
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting incoming stream {stream_id}: limit reached")
# Send reset to reject the stream
self._quic.reset_stream(
stream_id, error_code=0x04
) # STREAM_LIMIT_ERROR
await self._transmit()
raise QUICStreamLimitError("Too many inbound streams")
# Create new inbound stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue and notify handler
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
# Handle directly with stream handler if available
if self._stream_handler:
try:
if self._nursery:
self._nursery.start_soon(self._stream_handler, stream)
else:
await self._stream_handler(stream)
except Exception as e:
logger.error(f"Error in stream handler for stream {stream_id}: {e}")
logger.debug(f"Created inbound stream {stream_id}")
return stream
return await self._create_inbound_stream(stream_id)
def _is_incoming_stream(self, stream_id: int) -> bool:
"""
@ -1095,9 +1185,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
stream_id = event.stream_id
self._stats["streams_reset"] += 1
if stream_id in self._streams:
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if stream:
try:
stream = self._streams[stream_id]
await stream.handle_reset(event.error_code)
logger.debug(
f"Handled reset for stream {stream_id}"
@ -1137,12 +1228,20 @@ class QUICConnection(IRawConnection, IMuxedConn):
try:
current_time = time.time()
datagrams = self._quic.datagrams_to_send(now=current_time)
# Batch stats updates
packet_count = 0
total_bytes = 0
for data, addr in datagrams:
await sock.sendto(data, addr)
# Update stats if available
if hasattr(self, "_stats"):
self._stats["packets_sent"] += 1
self._stats["bytes_sent"] += len(data)
packet_count += 1
total_bytes += len(data)
# Update stats in batch
if packet_count > 0:
self._stats["packets_sent"] += packet_count
self._stats["bytes_sent"] += total_bytes
except Exception as e:
logger.error(f"Transmission error: {e}")
@ -1217,6 +1316,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._socket = None
self._streams.clear()
self._stream_cache.clear() # Clear cache
self._closed_event.set()
logger.debug(f"QUIC connection to {self._remote_peer_id} closed")
@ -1328,6 +1428,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
"max_streams": self.MAX_CONCURRENT_STREAMS,
"stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS,
"stats": self._stats.copy(),
"cache_size": len(
self._stream_cache
), # Include cache metrics for monitoring
}
def get_active_streams(self) -> list[QUICStream]:

View File

@ -267,56 +267,37 @@ class QUICListener(IListener):
return value, 8
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
"""Process incoming QUIC packet with fine-grained locking."""
"""Process incoming QUIC packet with optimized routing."""
try:
self._stats["packets_processed"] += 1
self._stats["bytes_received"] += len(data)
logger.debug(f"Processing packet of {len(data)} bytes from {addr}")
# Parse packet header OUTSIDE the lock
packet_info = self.parse_quic_packet(data)
if packet_info is None:
logger.error(f"Failed to parse packet header quic packet from {addr}")
self._stats["invalid_packets"] += 1
return
dest_cid = packet_info.destination_cid
connection_obj = None
pending_quic_conn = None
# Single lock acquisition with all lookups
async with self._connection_lock:
if dest_cid in self._connections:
connection_obj = self._connections[dest_cid]
logger.debug(f"Routing to established connection {dest_cid.hex()}")
connection_obj = self._connections.get(dest_cid)
pending_quic_conn = self._pending_connections.get(dest_cid)
elif dest_cid in self._pending_connections:
pending_quic_conn = self._pending_connections[dest_cid]
logger.debug(f"Routing to pending connection {dest_cid.hex()}")
else:
# Check if this is a new connection
if packet_info.packet_type.name == "INITIAL":
logger.debug(
f"Received INITIAL Packet Creating new conn for {addr}"
)
# Create new connection INSIDE the lock for safety
if not connection_obj and not pending_quic_conn:
if packet_info.packet_type == QuicPacketType.INITIAL:
pending_quic_conn = await self._handle_new_connection(
data, addr, packet_info
)
else:
return
# CRITICAL: Process packets OUTSIDE the lock to prevent deadlock
# Process outside the lock
if connection_obj:
# Handle established connection
await self._handle_established_connection_packet(
connection_obj, data, addr, dest_cid
)
elif pending_quic_conn:
# Handle pending connection
await self._handle_pending_connection_packet(
pending_quic_conn, data, addr, dest_cid
)
@ -431,6 +412,7 @@ class QUICListener(IListener):
f"No configuration found for version 0x{packet_info.version:08x}"
)
await self._send_version_negotiation(addr, packet_info.source_cid)
return None
if not quic_config:
raise QUICListenError("Cannot determine QUIC configuration")

View File

@ -108,21 +108,21 @@ def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
# Try to get IPv4 address
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except ValueError:
except Exception:
pass
# Try to get IPv6 address if IPv4 not found
if host is None:
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except ValueError:
except Exception:
pass
# Get UDP port
try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
port = int(port_str)
except ValueError:
except Exception:
pass
if host is None or port is None:
@ -203,8 +203,7 @@ def create_quic_multiaddr(
if version == "quic-v1" or version == "/quic-v1":
quic_proto = QUIC_V1_PROTOCOL
elif version == "quic" or version == "/quic":
# This is DRAFT Protocol
quic_proto = QUIC_V1_PROTOCOL
quic_proto = QUIC_DRAFT29_PROTOCOL
else:
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")

View File

@ -192,7 +192,7 @@ class TestQUICConnection:
await trio.sleep(10) # Longer than timeout
with patch.object(
quic_connection._stream_id_lock, "acquire", side_effect=slow_acquire
quic_connection._stream_lock, "acquire", side_effect=slow_acquire
):
with pytest.raises(
QUICStreamTimeoutError, match="Stream creation timed out"

View File

@ -3,333 +3,319 @@ Test suite for QUIC multiaddr utilities.
Focused tests covering essential functionality required for QUIC transport.
"""
# TODO: Enable this test after multiaddr repo supports protocol quic-v1
# import pytest
# from multiaddr import Multiaddr
# from libp2p.custom_types import TProtocol
# from libp2p.transport.quic.exceptions import (
# QUICInvalidMultiaddrError,
# QUICUnsupportedVersionError,
# )
# from libp2p.transport.quic.utils import (
# create_quic_multiaddr,
# get_alpn_protocols,
# is_quic_multiaddr,
# multiaddr_to_quic_version,
# normalize_quic_multiaddr,
# quic_multiaddr_to_endpoint,
# quic_version_to_wire_format,
# )
# class TestIsQuicMultiaddr:
# """Test QUIC multiaddr detection."""
# def test_valid_quic_v1_multiaddrs(self):
# """Test valid QUIC v1 multiaddrs are detected."""
# valid_addrs = [
# "/ip4/127.0.0.1/udp/4001/quic-v1",
# "/ip4/192.168.1.1/udp/8080/quic-v1",
# "/ip6/::1/udp/4001/quic-v1",
# "/ip6/2001:db8::1/udp/5000/quic-v1",
# ]
# for addr_str in valid_addrs:
# maddr = Multiaddr(addr_str)
# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
# def test_valid_quic_draft29_multiaddrs(self):
# """Test valid QUIC draft-29 multiaddrs are detected."""
# valid_addrs = [
# "/ip4/127.0.0.1/udp/4001/quic",
# "/ip4/10.0.0.1/udp/9000/quic",
# "/ip6/::1/udp/4001/quic",
# "/ip6/fe80::1/udp/6000/quic",
# ]
# for addr_str in valid_addrs:
# maddr = Multiaddr(addr_str)
# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
# def test_invalid_multiaddrs(self):
# """Test non-QUIC multiaddrs are not detected."""
# invalid_addrs = [
# "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
# "/ip4/127.0.0.1/udp/4001", # UDP without QUIC
# "/ip4/127.0.0.1/udp/4001/ws", # WebSocket
# "/ip4/127.0.0.1/quic-v1", # Missing UDP
# "/udp/4001/quic-v1", # Missing IP
# "/dns4/example.com/tcp/443/tls", # Completely different
# ]
# for addr_str in invalid_addrs:
# maddr = Multiaddr(addr_str)
# assert not is_quic_multiaddr(maddr),
# f"Should not detect {addr_str} as QUIC"
# def test_malformed_multiaddrs(self):
# """Test malformed multiaddrs don't crash."""
# # These should not raise exceptions, just return False
# malformed = [
# Multiaddr("/ip4/127.0.0.1"),
# Multiaddr("/invalid"),
# ]
# for maddr in malformed:
# assert not is_quic_multiaddr(maddr)
# class TestQuicMultiaddrToEndpoint:
# """Test endpoint extraction from QUIC multiaddrs."""
# def test_ipv4_extraction(self):
# """Test IPv4 host/port extraction."""
# test_cases = [
# ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
# ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
# ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
# ]
# for addr_str, expected in test_cases:
# maddr = Multiaddr(addr_str)
# result = quic_multiaddr_to_endpoint(maddr)
# assert result == expected, f"Failed for {addr_str}"
# def test_ipv6_extraction(self):
# """Test IPv6 host/port extraction."""
# test_cases = [
# ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
# ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
# ]
# for addr_str, expected in test_cases:
# maddr = Multiaddr(addr_str)
# result = quic_multiaddr_to_endpoint(maddr)
# assert result == expected, f"Failed for {addr_str}"
# def test_invalid_multiaddr_raises_error(self):
# """Test invalid multiaddrs raise appropriate errors."""
# invalid_addrs = [
# "/ip4/127.0.0.1/tcp/4001", # Not QUIC
# "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
# ]
# for addr_str in invalid_addrs:
# maddr = Multiaddr(addr_str)
# with pytest.raises(QUICInvalidMultiaddrError):
# quic_multiaddr_to_endpoint(maddr)
# class TestMultiaddrToQuicVersion:
# """Test QUIC version extraction."""
# def test_quic_v1_detection(self):
# """Test QUIC v1 version detection."""
# addrs = [
# "/ip4/127.0.0.1/udp/4001/quic-v1",
# "/ip6/::1/udp/5000/quic-v1",
# ]
# for addr_str in addrs:
# maddr = Multiaddr(addr_str)
# version = multiaddr_to_quic_version(maddr)
# assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
# def test_quic_draft29_detection(self):
# """Test QUIC draft-29 version detection."""
# addrs = [
# "/ip4/127.0.0.1/udp/4001/quic",
# "/ip6/::1/udp/5000/quic",
# ]
# for addr_str in addrs:
# maddr = Multiaddr(addr_str)
# version = multiaddr_to_quic_version(maddr)
# assert version == "quic", f"Should detect quic for {addr_str}"
# def test_non_quic_raises_error(self):
# """Test non-QUIC multiaddrs raise error."""
# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
# with pytest.raises(QUICInvalidMultiaddrError):
# multiaddr_to_quic_version(maddr)
# class TestCreateQuicMultiaddr:
# """Test QUIC multiaddr creation."""
# def test_ipv4_creation(self):
# """Test IPv4 QUIC multiaddr creation."""
# test_cases = [
# ("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
# ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
# ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
# ]
# for host, port, version, expected in test_cases:
# result = create_quic_multiaddr(host, port, version)
# assert str(result) == expected
# def test_ipv6_creation(self):
# """Test IPv6 QUIC multiaddr creation."""
# test_cases = [
# ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
# ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
# ]
# for host, port, version, expected in test_cases:
# result = create_quic_multiaddr(host, port, version)
# assert str(result) == expected
# def test_default_version(self):
# """Test default version is quic-v1."""
# result = create_quic_multiaddr("127.0.0.1", 4001)
# expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
# assert str(result) == expected
# def test_invalid_inputs_raise_errors(self):
# """Test invalid inputs raise appropriate errors."""
# # Invalid IP
# with pytest.raises(QUICInvalidMultiaddrError):
# create_quic_multiaddr("invalid-ip", 4001)
# # Invalid port
# with pytest.raises(QUICInvalidMultiaddrError):
# create_quic_multiaddr("127.0.0.1", 70000)
# with pytest.raises(QUICInvalidMultiaddrError):
# create_quic_multiaddr("127.0.0.1", -1)
# # Invalid version
# with pytest.raises(QUICInvalidMultiaddrError):
# create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
# class TestQuicVersionToWireFormat:
# """Test QUIC version to wire format conversion."""
# def test_supported_versions(self):
# """Test supported version conversions."""
# test_cases = [
# ("quic-v1", 0x00000001), # RFC 9000
# ("quic", 0xFF00001D), # draft-29
# ]
# for version, expected_wire in test_cases:
# result = quic_version_to_wire_format(TProtocol(version))
# assert result == expected_wire, f"Failed for version {version}"
# def test_unsupported_version_raises_error(self):
# """Test unsupported versions raise error."""
# with pytest.raises(QUICUnsupportedVersionError):
# quic_version_to_wire_format(TProtocol("unsupported-version"))
# class TestGetAlpnProtocols:
# """Test ALPN protocol retrieval."""
# def test_returns_libp2p_protocols(self):
# """Test returns expected libp2p ALPN protocols."""
# protocols = get_alpn_protocols()
# assert protocols == ["libp2p"]
# assert isinstance(protocols, list)
# def test_returns_copy(self):
# """Test returns a copy, not the original list."""
# protocols1 = get_alpn_protocols()
# protocols2 = get_alpn_protocols()
# # Modify one list
# protocols1.append("test")
# # Other list should be unchanged
# assert protocols2 == ["libp2p"]
# class TestNormalizeQuicMultiaddr:
# """Test QUIC multiaddr normalization."""
# def test_already_normalized(self):
# """Test already normalized multiaddrs pass through."""
# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
# maddr = Multiaddr(addr_str)
import pytest
from multiaddr import Multiaddr
from libp2p.custom_types import TProtocol
from libp2p.transport.quic.exceptions import (
QUICInvalidMultiaddrError,
QUICUnsupportedVersionError,
)
from libp2p.transport.quic.utils import (
create_quic_multiaddr,
get_alpn_protocols,
is_quic_multiaddr,
multiaddr_to_quic_version,
normalize_quic_multiaddr,
quic_multiaddr_to_endpoint,
quic_version_to_wire_format,
)
class TestIsQuicMultiaddr:
"""Test QUIC multiaddr detection."""
def test_valid_quic_v1_multiaddrs(self):
"""Test valid QUIC v1 multiaddrs are detected."""
valid_addrs = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip4/192.168.1.1/udp/8080/quic-v1",
"/ip6/::1/udp/4001/quic-v1",
"/ip6/2001:db8::1/udp/5000/quic-v1",
]
for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
def test_valid_quic_draft29_multiaddrs(self):
"""Test valid QUIC draft-29 multiaddrs are detected."""
valid_addrs = [
"/ip4/127.0.0.1/udp/4001/quic",
"/ip4/10.0.0.1/udp/9000/quic",
"/ip6/::1/udp/4001/quic",
"/ip6/fe80::1/udp/6000/quic",
]
for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
def test_invalid_multiaddrs(self):
"""Test non-QUIC multiaddrs are not detected."""
invalid_addrs = [
"/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
"/ip4/127.0.0.1/udp/4001", # UDP without QUIC
"/ip4/127.0.0.1/udp/4001/ws", # WebSocket
"/ip4/127.0.0.1/quic-v1", # Missing UDP
"/udp/4001/quic-v1", # Missing IP
"/dns4/example.com/tcp/443/tls", # Completely different
]
for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
class TestQuicMultiaddrToEndpoint:
"""Test endpoint extraction from QUIC multiaddrs."""
def test_ipv4_extraction(self):
"""Test IPv4 host/port extraction."""
test_cases = [
("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
]
for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
result = quic_multiaddr_to_endpoint(maddr)
assert result == expected, f"Failed for {addr_str}"
def test_ipv6_extraction(self):
"""Test IPv6 host/port extraction."""
test_cases = [
("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
]
for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
result = quic_multiaddr_to_endpoint(maddr)
assert result == expected, f"Failed for {addr_str}"
def test_invalid_multiaddr_raises_error(self):
"""Test invalid multiaddrs raise appropriate errors."""
invalid_addrs = [
"/ip4/127.0.0.1/tcp/4001", # Not QUIC
"/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
]
for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
with pytest.raises(QUICInvalidMultiaddrError):
quic_multiaddr_to_endpoint(maddr)
class TestMultiaddrToQuicVersion:
"""Test QUIC version extraction."""
def test_quic_v1_detection(self):
"""Test QUIC v1 version detection."""
addrs = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip6/::1/udp/5000/quic-v1",
]
for addr_str in addrs:
maddr = Multiaddr(addr_str)
version = multiaddr_to_quic_version(maddr)
assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
def test_quic_draft29_detection(self):
"""Test QUIC draft-29 version detection."""
addrs = [
"/ip4/127.0.0.1/udp/4001/quic",
"/ip6/::1/udp/5000/quic",
]
for addr_str in addrs:
maddr = Multiaddr(addr_str)
version = multiaddr_to_quic_version(maddr)
assert version == "quic", f"Should detect quic for {addr_str}"
def test_non_quic_raises_error(self):
"""Test non-QUIC multiaddrs raise error."""
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICInvalidMultiaddrError):
multiaddr_to_quic_version(maddr)
class TestCreateQuicMultiaddr:
"""Test QUIC multiaddr creation."""
def test_ipv4_creation(self):
"""Test IPv4 QUIC multiaddr creation."""
test_cases = [
("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
]
for host, port, version, expected in test_cases:
result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
def test_ipv6_creation(self):
"""Test IPv6 QUIC multiaddr creation."""
test_cases = [
("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
]
for host, port, version, expected in test_cases:
result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
def test_default_version(self):
"""Test default version is quic-v1."""
result = create_quic_multiaddr("127.0.0.1", 4001)
expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
assert str(result) == expected
def test_invalid_inputs_raise_errors(self):
"""Test invalid inputs raise appropriate errors."""
# Invalid IP
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("invalid-ip", 4001)
# Invalid port
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", 70000)
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", -1)
# Invalid version
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
class TestQuicVersionToWireFormat:
"""Test QUIC version to wire format conversion."""
def test_supported_versions(self):
"""Test supported version conversions."""
test_cases = [
("quic-v1", 0x00000001), # RFC 9000
("quic", 0xFF00001D), # draft-29
]
for version, expected_wire in test_cases:
result = quic_version_to_wire_format(TProtocol(version))
assert result == expected_wire, f"Failed for version {version}"
def test_unsupported_version_raises_error(self):
"""Test unsupported versions raise error."""
with pytest.raises(QUICUnsupportedVersionError):
quic_version_to_wire_format(TProtocol("unsupported-version"))
class TestGetAlpnProtocols:
"""Test ALPN protocol retrieval."""
def test_returns_libp2p_protocols(self):
"""Test returns expected libp2p ALPN protocols."""
protocols = get_alpn_protocols()
assert protocols == ["libp2p"]
assert isinstance(protocols, list)
def test_returns_copy(self):
"""Test returns a copy, not the original list."""
protocols1 = get_alpn_protocols()
protocols2 = get_alpn_protocols()
# Modify one list
protocols1.append("test")
# Other list should be unchanged
assert protocols2 == ["libp2p"]
class TestNormalizeQuicMultiaddr:
"""Test QUIC multiaddr normalization."""
def test_already_normalized(self):
"""Test already normalized multiaddrs pass through."""
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
maddr = Multiaddr(addr_str)
# result = normalize_quic_multiaddr(maddr)
# assert str(result) == addr_str
# def test_normalize_different_versions(self):
# """Test normalization works for different QUIC versions."""
# test_cases = [
# "/ip4/127.0.0.1/udp/4001/quic-v1",
# "/ip4/127.0.0.1/udp/4001/quic",
# "/ip6/::1/udp/5000/quic-v1",
# ]
# for addr_str in test_cases:
# maddr = Multiaddr(addr_str)
# result = normalize_quic_multiaddr(maddr)
# # Should be valid QUIC multiaddr
# assert is_quic_multiaddr(result)
# # Should be parseable
# host, port = quic_multiaddr_to_endpoint(result)
# version = multiaddr_to_quic_version(result)
result = normalize_quic_multiaddr(maddr)
assert str(result) == addr_str
def test_normalize_different_versions(self):
"""Test normalization works for different QUIC versions."""
test_cases = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip4/127.0.0.1/udp/4001/quic",
"/ip6/::1/udp/5000/quic-v1",
]
for addr_str in test_cases:
maddr = Multiaddr(addr_str)
result = normalize_quic_multiaddr(maddr)
# Should be valid QUIC multiaddr
assert is_quic_multiaddr(result)
# Should be parseable
host, port = quic_multiaddr_to_endpoint(result)
version = multiaddr_to_quic_version(result)
# # Should match original
# orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
# orig_version = multiaddr_to_quic_version(maddr)
# Should match original
orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
orig_version = multiaddr_to_quic_version(maddr)
# assert host == orig_host
# assert port == orig_port
# assert version == orig_version
assert host == orig_host
assert port == orig_port
assert version == orig_version
# def test_non_quic_raises_error(self):
# """Test non-QUIC multiaddrs raise error."""
# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
# with pytest.raises(QUICInvalidMultiaddrError):
# normalize_quic_multiaddr(maddr)
def test_non_quic_raises_error(self):
"""Test non-QUIC multiaddrs raise error."""
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICInvalidMultiaddrError):
normalize_quic_multiaddr(maddr)
# class TestIntegration:
# """Integration tests for utility functions working together."""
class TestIntegration:
"""Integration tests for utility functions working together."""
# def test_round_trip_conversion(self):
# """Test creating and parsing multiaddrs works correctly."""
# test_cases = [
# ("127.0.0.1", 4001, "quic-v1"),
# ("::1", 5000, "quic"),
# ("192.168.1.100", 8080, "quic-v1"),
# ]
def test_round_trip_conversion(self):
"""Test creating and parsing multiaddrs works correctly."""
test_cases = [
("127.0.0.1", 4001, "quic-v1"),
("::1", 5000, "quic"),
("192.168.1.100", 8080, "quic-v1"),
]
# for host, port, version in test_cases:
# # Create multiaddr
# maddr = create_quic_multiaddr(host, port, version)
for host, port, version in test_cases:
# Create multiaddr
maddr = create_quic_multiaddr(host, port, version)
# # Should be detected as QUIC
# assert is_quic_multiaddr(maddr)
# # Should extract original values
# extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
# extracted_version = multiaddr_to_quic_version(maddr)
# Should be detected as QUIC
assert is_quic_multiaddr(maddr)
# Should extract original values
extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
extracted_version = multiaddr_to_quic_version(maddr)
# assert extracted_host == host
# assert extracted_port == port
# assert extracted_version == version
assert extracted_host == host
assert extracted_port == port
assert extracted_version == version
# # Should normalize to same value
# normalized = normalize_quic_multiaddr(maddr)
# assert str(normalized) == str(maddr)
# Should normalize to same value
normalized = normalize_quic_multiaddr(maddr)
assert str(normalized) == str(maddr)
# def test_wire_format_integration(self):
# """Test wire format conversion works with version detection."""
# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
# maddr = Multiaddr(addr_str)
def test_wire_format_integration(self):
"""Test wire format conversion works with version detection."""
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
maddr = Multiaddr(addr_str)
# # Extract version and convert to wire format
# version = multiaddr_to_quic_version(maddr)
# wire_format = quic_version_to_wire_format(version)
# Extract version and convert to wire format
version = multiaddr_to_quic_version(maddr)
wire_format = quic_version_to_wire_format(version)
# # Should be QUIC v1 wire format
# assert wire_format == 0x00000001
# Should be QUIC v1 wire format
assert wire_format == 0x00000001