fix: add quic utils test and improve connection performance

This commit is contained in:
Akash Mondal
2025-09-04 21:25:13 +00:00
parent 2ee3e0b054
commit 2fe5882013
5 changed files with 525 additions and 455 deletions

View File

@ -3,14 +3,16 @@ QUIC Connection implementation.
Manages bidirectional QUIC connections with integrated stream multiplexing. Manages bidirectional QUIC connections with integrated stream multiplexing.
""" """
from collections import defaultdict
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
import logging import logging
import socket import socket
import time import time
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional, cast
from aioquic.quic import events from aioquic.quic import events
from aioquic.quic.connection import QuicConnection from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import QuicEvent
from cryptography import x509 from cryptography import x509
import multiaddr import multiaddr
import trio import trio
@ -104,12 +106,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._connected_event = trio.Event() self._connected_event = trio.Event()
self._closed_event = trio.Event() self._closed_event = trio.Event()
# Stream management
self._streams: dict[int, QUICStream] = {} self._streams: dict[int, QUICStream] = {}
self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups
self._next_stream_id: int = self._calculate_initial_stream_id() self._next_stream_id: int = self._calculate_initial_stream_id()
self._stream_handler: TQUICStreamHandlerFn | None = None self._stream_handler: TQUICStreamHandlerFn | None = None
self._stream_id_lock = trio.Lock()
self._stream_count_lock = trio.Lock() # Single lock for all stream operations
self._stream_lock = trio.Lock()
# Stream counting and limits # Stream counting and limits
self._outbound_stream_count = 0 self._outbound_stream_count = 0
@ -118,7 +121,6 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Stream acceptance for incoming streams # Stream acceptance for incoming streams
self._stream_accept_queue: list[QUICStream] = [] self._stream_accept_queue: list[QUICStream] = []
self._stream_accept_event = trio.Event() self._stream_accept_event = trio.Event()
self._accept_queue_lock = trio.Lock()
# Connection state # Connection state
self._closed: bool = False self._closed: bool = False
@ -143,9 +145,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._retired_connection_ids: set[bytes] = set() self._retired_connection_ids: set[bytes] = set()
self._connection_id_sequence_numbers: set[int] = set() self._connection_id_sequence_numbers: set[int] = set()
# Event processing control # Event processing control with batching
self._event_processing_active = False self._event_processing_active = False
self._pending_events: list[events.QuicEvent] = [] self._event_batch: list[events.QuicEvent] = []
self._event_batch_size = 10
self._last_event_time = 0.0
# Set quic connection configuration # Set quic connection configuration
self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT
@ -250,6 +254,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""Get the current connection ID.""" """Get the current connection ID."""
return self._current_connection_id return self._current_connection_id
# Fast stream lookup with caching
def _get_stream_fast(self, stream_id: int) -> QUICStream | None:
"""Get stream with caching for performance."""
# Try cache first
stream = self._stream_cache.get(stream_id)
if stream is not None:
return stream
# Fallback to main dict
stream = self._streams.get(stream_id)
if stream is not None:
self._stream_cache[stream_id] = stream
return stream
# Connection lifecycle methods # Connection lifecycle methods
async def start(self) -> None: async def start(self) -> None:
@ -389,8 +408,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
try: try:
while not self._closed: while not self._closed:
# Process QUIC events # Batch process events
await self._process_quic_events() await self._process_quic_events_batched()
# Handle timer events # Handle timer events
await self._handle_timer_events() await self._handle_timer_events()
@ -421,12 +440,25 @@ class QUICConnection(IRawConnection, IMuxedConn):
cid_stats = self.get_connection_id_stats() cid_stats = self.get_connection_id_stats()
logger.debug(f"Connection ID stats: {cid_stats}") logger.debug(f"Connection ID stats: {cid_stats}")
# Clean cache periodically
await self._cleanup_cache()
# Sleep for maintenance interval # Sleep for maintenance interval
await trio.sleep(30.0) # 30 seconds await trio.sleep(30.0) # 30 seconds
except Exception as e: except Exception as e:
logger.error(f"Error in periodic maintenance: {e}") logger.error(f"Error in periodic maintenance: {e}")
async def _cleanup_cache(self) -> None:
"""Clean up stream cache periodically to prevent memory leaks."""
if len(self._stream_cache) > 100: # Arbitrary threshold
# Remove closed streams from cache
closed_stream_ids = [
sid for sid, stream in self._stream_cache.items() if stream.is_closed()
]
for sid in closed_stream_ids:
self._stream_cache.pop(sid, None)
async def _client_packet_receiver(self) -> None: async def _client_packet_receiver(self) -> None:
"""Receive packets for client connections.""" """Receive packets for client connections."""
logger.debug("Starting client packet receiver") logger.debug("Starting client packet receiver")
@ -442,8 +474,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Feed packet to QUIC connection # Feed packet to QUIC connection
self._quic.receive_datagram(data, addr, now=time.time()) self._quic.receive_datagram(data, addr, now=time.time())
# Process any events that result from the packet # Batch process events
await self._process_quic_events() await self._process_quic_events_batched()
# Send any response packets # Send any response packets
await self._transmit() await self._transmit()
@ -675,15 +707,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
if not self._started: if not self._started:
raise QUICConnectionError("Connection not started") raise QUICConnectionError("Connection not started")
# Check stream limits # Use single lock for all stream operations
async with self._stream_count_lock:
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
f"Maximum outbound streams ({self.MAX_OUTGOING_STREAMS}) reached"
)
with trio.move_on_after(timeout): with trio.move_on_after(timeout):
async with self._stream_id_lock: async with self._stream_lock:
# Check stream limits inside lock
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
"Maximum outbound streams "
f"({self.MAX_OUTGOING_STREAMS}) reached"
)
# Generate next stream ID # Generate next stream ID
stream_id = self._next_stream_id stream_id = self._next_stream_id
self._next_stream_id += 4 # Increment by 4 for bidirectional streams self._next_stream_id += 4 # Increment by 4 for bidirectional streams
@ -697,10 +730,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
) )
self._streams[stream_id] = stream self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
async with self._stream_count_lock: self._outbound_stream_count += 1
self._outbound_stream_count += 1 self._stats["streams_opened"] += 1
self._stats["streams_opened"] += 1
logger.debug(f"Opened outbound QUIC stream {stream_id}") logger.debug(f"Opened outbound QUIC stream {stream_id}")
return stream return stream
@ -737,7 +770,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
if self._closed: if self._closed:
raise MuxedConnUnavailable("QUIC connection is closed") raise MuxedConnUnavailable("QUIC connection is closed")
async with self._accept_queue_lock: # Use single lock for stream acceptance
async with self._stream_lock:
if self._stream_accept_queue: if self._stream_accept_queue:
stream = self._stream_accept_queue.pop(0) stream = self._stream_accept_queue.pop(0)
logger.debug(f"Accepted inbound stream {stream.stream_id}") logger.debug(f"Accepted inbound stream {stream.stream_id}")
@ -769,10 +803,12 @@ class QUICConnection(IRawConnection, IMuxedConn):
""" """
if stream_id in self._streams: if stream_id in self._streams:
stream = self._streams.pop(stream_id) stream = self._streams.pop(stream_id)
# Remove from cache too
self._stream_cache.pop(stream_id, None)
# Update stream counts asynchronously # Update stream counts asynchronously
async def update_counts() -> None: async def update_counts() -> None:
async with self._stream_count_lock: async with self._stream_lock:
if stream.direction == StreamDirection.OUTBOUND: if stream.direction == StreamDirection.OUTBOUND:
self._outbound_stream_count = max( self._outbound_stream_count = max(
0, self._outbound_stream_count - 1 0, self._outbound_stream_count - 1
@ -789,29 +825,140 @@ class QUICConnection(IRawConnection, IMuxedConn):
logger.debug(f"Removed stream {stream_id} from connection") logger.debug(f"Removed stream {stream_id} from connection")
async def _process_quic_events(self) -> None: # Batched event processing to reduce overhead
"""Process all pending QUIC events.""" async def _process_quic_events_batched(self) -> None:
"""Process QUIC events in batches for better performance."""
if self._event_processing_active: if self._event_processing_active:
return # Prevent recursion return # Prevent recursion
self._event_processing_active = True self._event_processing_active = True
try: try:
current_time = time.time()
events_processed = 0 events_processed = 0
while True:
# Collect events into batch
while events_processed < self._event_batch_size:
event = self._quic.next_event() event = self._quic.next_event()
if event is None: if event is None:
break break
self._event_batch.append(event)
events_processed += 1 events_processed += 1
await self._handle_quic_event(event)
if events_processed > 0: # Process batch if we have events or timeout
logger.debug(f"Processed {events_processed} QUIC events") if self._event_batch and (
len(self._event_batch) >= self._event_batch_size
or current_time - self._last_event_time > 0.01 # 10ms timeout
):
await self._process_event_batch()
self._event_batch.clear()
self._last_event_time = current_time
finally: finally:
self._event_processing_active = False self._event_processing_active = False
async def _process_event_batch(self) -> None:
"""Process a batch of events efficiently."""
if not self._event_batch:
return
# Group events by type for batch processing where possible
events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list)
for event in self._event_batch:
events_by_type[type(event).__name__].append(event)
# Process events by type
for event_type, event_list in events_by_type.items():
if event_type == type(events.StreamDataReceived).__name__:
await self._handle_stream_data_batch(
cast(list[events.StreamDataReceived], event_list)
)
else:
# Process other events individually
for event in event_list:
await self._handle_quic_event(event)
logger.debug(f"Processed batch of {len(self._event_batch)} events")
async def _handle_stream_data_batch(
self, events_list: list[events.StreamDataReceived]
) -> None:
"""Handle stream data events in batch for better performance."""
# Group by stream ID
events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list)
for event in events_list:
events_by_stream[event.stream_id].append(event)
# Process each stream's events
for stream_id, stream_events in events_by_stream.items():
stream = self._get_stream_fast(stream_id) # Use fast lookup
if not stream:
if self._is_incoming_stream(stream_id):
try:
stream = await self._create_inbound_stream(stream_id)
except QUICStreamLimitError:
# Reset stream if we can't handle it
self._quic.reset_stream(stream_id, error_code=0x04)
await self._transmit()
continue
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
continue
# Process all events for this stream
for received_event in stream_events:
if hasattr(received_event, "data"):
self._stats["bytes_received"] += len(received_event.data) # type: ignore
if hasattr(received_event, "end_stream"):
await stream.handle_data_received(
received_event.data, # type: ignore
received_event.end_stream, # type: ignore
)
async def _create_inbound_stream(self, stream_id: int) -> QUICStream:
"""Create inbound stream with proper limit checking."""
async with self._stream_lock:
# Double-check stream doesn't exist
existing_stream = self._streams.get(stream_id)
if existing_stream:
return existing_stream
# Check limits
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting inbound stream {stream_id}: limit reached")
raise QUICStreamLimitError("Too many inbound streams")
# Create stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Created inbound stream {stream_id}")
return stream
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
# Delegate to batched processing for better performance
await self._process_quic_events_batched()
async def _handle_quic_event(self, event: events.QuicEvent) -> None: async def _handle_quic_event(self, event: events.QuicEvent) -> None:
"""Handle a single QUIC event with COMPLETE event type coverage.""" """Handle a single QUIC event with COMPLETE event type coverage."""
logger.debug(f"Handling QUIC event: {type(event).__name__}") logger.debug(f"Handling QUIC event: {type(event).__name__}")
@ -929,8 +1076,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
f"stream_id={event.stream_id}, error_code={event.error_code}" f"stream_id={event.stream_id}, error_code={event.error_code}"
) )
if event.stream_id in self._streams: # Use fast lookup
stream: QUICStream = self._streams[event.stream_id] stream = self._get_stream_fast(event.stream_id)
if stream:
# Handle stop sending on the stream if method exists # Handle stop sending on the stream if method exists
await stream.handle_stop_sending(event.error_code) await stream.handle_stop_sending(event.error_code)
@ -964,6 +1112,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
await stream.close() await stream.close()
self._streams.clear() self._streams.clear()
self._stream_cache.clear() # Clear cache too
self._closed = True self._closed = True
self._closed_event.set() self._closed_event.set()
@ -978,39 +1127,19 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._stats["bytes_received"] += len(event.data) self._stats["bytes_received"] += len(event.data)
try: try:
if stream_id not in self._streams: # Use fast lookup
stream = self._get_stream_fast(stream_id)
if not stream:
if self._is_incoming_stream(stream_id): if self._is_incoming_stream(stream_id):
logger.debug(f"Creating new incoming stream {stream_id}") logger.debug(f"Creating new incoming stream {stream_id}")
stream = await self._create_inbound_stream(stream_id)
from .stream import QUICStream, StreamDirection
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
# Store the stream
self._streams[stream_id] = stream
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Added stream {stream_id} to accept queue")
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_opened"] += 1
else: else:
logger.error( logger.error(
f"Unexpected outbound stream {stream_id} in data event" f"Unexpected outbound stream {stream_id} in data event"
) )
return return
stream = self._streams[stream_id]
await stream.handle_data_received(event.data, event.end_stream) await stream.handle_data_received(event.data, event.end_stream)
except Exception as e: except Exception as e:
@ -1019,8 +1148,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def _get_or_create_stream(self, stream_id: int) -> QUICStream: async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
"""Get existing stream or create new inbound stream.""" """Get existing stream or create new inbound stream."""
if stream_id in self._streams: # Use fast lookup
return self._streams[stream_id] stream = self._get_stream_fast(stream_id)
if stream:
return stream
# Check if this is an incoming stream # Check if this is an incoming stream
is_incoming = self._is_incoming_stream(stream_id) is_incoming = self._is_incoming_stream(stream_id)
@ -1031,49 +1162,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
f"Received data for unknown outbound stream {stream_id}" f"Received data for unknown outbound stream {stream_id}"
) )
# Check stream limits for incoming streams
async with self._stream_count_lock:
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting incoming stream {stream_id}: limit reached")
# Send reset to reject the stream
self._quic.reset_stream(
stream_id, error_code=0x04
) # STREAM_LIMIT_ERROR
await self._transmit()
raise QUICStreamLimitError("Too many inbound streams")
# Create new inbound stream # Create new inbound stream
stream = QUICStream( return await self._create_inbound_stream(stream_id)
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue and notify handler
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
# Handle directly with stream handler if available
if self._stream_handler:
try:
if self._nursery:
self._nursery.start_soon(self._stream_handler, stream)
else:
await self._stream_handler(stream)
except Exception as e:
logger.error(f"Error in stream handler for stream {stream_id}: {e}")
logger.debug(f"Created inbound stream {stream_id}")
return stream
def _is_incoming_stream(self, stream_id: int) -> bool: def _is_incoming_stream(self, stream_id: int) -> bool:
""" """
@ -1095,9 +1185,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
stream_id = event.stream_id stream_id = event.stream_id
self._stats["streams_reset"] += 1 self._stats["streams_reset"] += 1
if stream_id in self._streams: # Use fast lookup
stream = self._get_stream_fast(stream_id)
if stream:
try: try:
stream = self._streams[stream_id]
await stream.handle_reset(event.error_code) await stream.handle_reset(event.error_code)
logger.debug( logger.debug(
f"Handled reset for stream {stream_id}" f"Handled reset for stream {stream_id}"
@ -1137,12 +1228,20 @@ class QUICConnection(IRawConnection, IMuxedConn):
try: try:
current_time = time.time() current_time = time.time()
datagrams = self._quic.datagrams_to_send(now=current_time) datagrams = self._quic.datagrams_to_send(now=current_time)
# Batch stats updates
packet_count = 0
total_bytes = 0
for data, addr in datagrams: for data, addr in datagrams:
await sock.sendto(data, addr) await sock.sendto(data, addr)
# Update stats if available packet_count += 1
if hasattr(self, "_stats"): total_bytes += len(data)
self._stats["packets_sent"] += 1
self._stats["bytes_sent"] += len(data) # Update stats in batch
if packet_count > 0:
self._stats["packets_sent"] += packet_count
self._stats["bytes_sent"] += total_bytes
except Exception as e: except Exception as e:
logger.error(f"Transmission error: {e}") logger.error(f"Transmission error: {e}")
@ -1217,6 +1316,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._socket = None self._socket = None
self._streams.clear() self._streams.clear()
self._stream_cache.clear() # Clear cache
self._closed_event.set() self._closed_event.set()
logger.debug(f"QUIC connection to {self._remote_peer_id} closed") logger.debug(f"QUIC connection to {self._remote_peer_id} closed")
@ -1328,6 +1428,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
"max_streams": self.MAX_CONCURRENT_STREAMS, "max_streams": self.MAX_CONCURRENT_STREAMS,
"stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS, "stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS,
"stats": self._stats.copy(), "stats": self._stats.copy(),
"cache_size": len(
self._stream_cache
), # Include cache metrics for monitoring
} }
def get_active_streams(self) -> list[QUICStream]: def get_active_streams(self) -> list[QUICStream]:

View File

@ -267,56 +267,37 @@ class QUICListener(IListener):
return value, 8 return value, 8
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
"""Process incoming QUIC packet with fine-grained locking.""" """Process incoming QUIC packet with optimized routing."""
try: try:
self._stats["packets_processed"] += 1 self._stats["packets_processed"] += 1
self._stats["bytes_received"] += len(data) self._stats["bytes_received"] += len(data)
logger.debug(f"Processing packet of {len(data)} bytes from {addr}")
# Parse packet header OUTSIDE the lock
packet_info = self.parse_quic_packet(data) packet_info = self.parse_quic_packet(data)
if packet_info is None: if packet_info is None:
logger.error(f"Failed to parse packet header quic packet from {addr}")
self._stats["invalid_packets"] += 1 self._stats["invalid_packets"] += 1
return return
dest_cid = packet_info.destination_cid dest_cid = packet_info.destination_cid
connection_obj = None
pending_quic_conn = None
# Single lock acquisition with all lookups
async with self._connection_lock: async with self._connection_lock:
if dest_cid in self._connections: connection_obj = self._connections.get(dest_cid)
connection_obj = self._connections[dest_cid] pending_quic_conn = self._pending_connections.get(dest_cid)
logger.debug(f"Routing to established connection {dest_cid.hex()}")
elif dest_cid in self._pending_connections: if not connection_obj and not pending_quic_conn:
pending_quic_conn = self._pending_connections[dest_cid] if packet_info.packet_type == QuicPacketType.INITIAL:
logger.debug(f"Routing to pending connection {dest_cid.hex()}")
else:
# Check if this is a new connection
if packet_info.packet_type.name == "INITIAL":
logger.debug(
f"Received INITIAL Packet Creating new conn for {addr}"
)
# Create new connection INSIDE the lock for safety
pending_quic_conn = await self._handle_new_connection( pending_quic_conn = await self._handle_new_connection(
data, addr, packet_info data, addr, packet_info
) )
else: else:
return return
# CRITICAL: Process packets OUTSIDE the lock to prevent deadlock # Process outside the lock
if connection_obj: if connection_obj:
# Handle established connection
await self._handle_established_connection_packet( await self._handle_established_connection_packet(
connection_obj, data, addr, dest_cid connection_obj, data, addr, dest_cid
) )
elif pending_quic_conn: elif pending_quic_conn:
# Handle pending connection
await self._handle_pending_connection_packet( await self._handle_pending_connection_packet(
pending_quic_conn, data, addr, dest_cid pending_quic_conn, data, addr, dest_cid
) )
@ -431,6 +412,7 @@ class QUICListener(IListener):
f"No configuration found for version 0x{packet_info.version:08x}" f"No configuration found for version 0x{packet_info.version:08x}"
) )
await self._send_version_negotiation(addr, packet_info.source_cid) await self._send_version_negotiation(addr, packet_info.source_cid)
return None
if not quic_config: if not quic_config:
raise QUICListenError("Cannot determine QUIC configuration") raise QUICListenError("Cannot determine QUIC configuration")

View File

@ -108,21 +108,21 @@ def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
# Try to get IPv4 address # Try to get IPv4 address
try: try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except ValueError: except Exception:
pass pass
# Try to get IPv6 address if IPv4 not found # Try to get IPv6 address if IPv4 not found
if host is None: if host is None:
try: try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except ValueError: except Exception:
pass pass
# Get UDP port # Get UDP port
try: try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
port = int(port_str) port = int(port_str)
except ValueError: except Exception:
pass pass
if host is None or port is None: if host is None or port is None:
@ -203,8 +203,7 @@ def create_quic_multiaddr(
if version == "quic-v1" or version == "/quic-v1": if version == "quic-v1" or version == "/quic-v1":
quic_proto = QUIC_V1_PROTOCOL quic_proto = QUIC_V1_PROTOCOL
elif version == "quic" or version == "/quic": elif version == "quic" or version == "/quic":
# This is DRAFT Protocol quic_proto = QUIC_DRAFT29_PROTOCOL
quic_proto = QUIC_V1_PROTOCOL
else: else:
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}") raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")

View File

@ -192,7 +192,7 @@ class TestQUICConnection:
await trio.sleep(10) # Longer than timeout await trio.sleep(10) # Longer than timeout
with patch.object( with patch.object(
quic_connection._stream_id_lock, "acquire", side_effect=slow_acquire quic_connection._stream_lock, "acquire", side_effect=slow_acquire
): ):
with pytest.raises( with pytest.raises(
QUICStreamTimeoutError, match="Stream creation timed out" QUICStreamTimeoutError, match="Stream creation timed out"

View File

@ -3,333 +3,319 @@ Test suite for QUIC multiaddr utilities.
Focused tests covering essential functionality required for QUIC transport. Focused tests covering essential functionality required for QUIC transport.
""" """
# TODO: Enable this test after multiaddr repo supports protocol quic-v1 import pytest
from multiaddr import Multiaddr
# import pytest
# from multiaddr import Multiaddr from libp2p.custom_types import TProtocol
from libp2p.transport.quic.exceptions import (
# from libp2p.custom_types import TProtocol QUICInvalidMultiaddrError,
# from libp2p.transport.quic.exceptions import ( QUICUnsupportedVersionError,
# QUICInvalidMultiaddrError, )
# QUICUnsupportedVersionError, from libp2p.transport.quic.utils import (
# ) create_quic_multiaddr,
# from libp2p.transport.quic.utils import ( get_alpn_protocols,
# create_quic_multiaddr, is_quic_multiaddr,
# get_alpn_protocols, multiaddr_to_quic_version,
# is_quic_multiaddr, normalize_quic_multiaddr,
# multiaddr_to_quic_version, quic_multiaddr_to_endpoint,
# normalize_quic_multiaddr, quic_version_to_wire_format,
# quic_multiaddr_to_endpoint, )
# quic_version_to_wire_format,
# )
class TestIsQuicMultiaddr:
"""Test QUIC multiaddr detection."""
# class TestIsQuicMultiaddr:
# """Test QUIC multiaddr detection.""" def test_valid_quic_v1_multiaddrs(self):
"""Test valid QUIC v1 multiaddrs are detected."""
# def test_valid_quic_v1_multiaddrs(self): valid_addrs = [
# """Test valid QUIC v1 multiaddrs are detected.""" "/ip4/127.0.0.1/udp/4001/quic-v1",
# valid_addrs = [ "/ip4/192.168.1.1/udp/8080/quic-v1",
# "/ip4/127.0.0.1/udp/4001/quic-v1", "/ip6/::1/udp/4001/quic-v1",
# "/ip4/192.168.1.1/udp/8080/quic-v1", "/ip6/2001:db8::1/udp/5000/quic-v1",
# "/ip6/::1/udp/4001/quic-v1", ]
# "/ip6/2001:db8::1/udp/5000/quic-v1",
# ] for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
# for addr_str in valid_addrs: assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
# maddr = Multiaddr(addr_str)
# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" def test_valid_quic_draft29_multiaddrs(self):
"""Test valid QUIC draft-29 multiaddrs are detected."""
# def test_valid_quic_draft29_multiaddrs(self): valid_addrs = [
# """Test valid QUIC draft-29 multiaddrs are detected.""" "/ip4/127.0.0.1/udp/4001/quic",
# valid_addrs = [ "/ip4/10.0.0.1/udp/9000/quic",
# "/ip4/127.0.0.1/udp/4001/quic", "/ip6/::1/udp/4001/quic",
# "/ip4/10.0.0.1/udp/9000/quic", "/ip6/fe80::1/udp/6000/quic",
# "/ip6/::1/udp/4001/quic", ]
# "/ip6/fe80::1/udp/6000/quic",
# ] for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
# for addr_str in valid_addrs: assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
# maddr = Multiaddr(addr_str)
# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" def test_invalid_multiaddrs(self):
"""Test non-QUIC multiaddrs are not detected."""
# def test_invalid_multiaddrs(self): invalid_addrs = [
# """Test non-QUIC multiaddrs are not detected.""" "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
# invalid_addrs = [ "/ip4/127.0.0.1/udp/4001", # UDP without QUIC
# "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC "/ip4/127.0.0.1/udp/4001/ws", # WebSocket
# "/ip4/127.0.0.1/udp/4001", # UDP without QUIC "/ip4/127.0.0.1/quic-v1", # Missing UDP
# "/ip4/127.0.0.1/udp/4001/ws", # WebSocket "/udp/4001/quic-v1", # Missing IP
# "/ip4/127.0.0.1/quic-v1", # Missing UDP "/dns4/example.com/tcp/443/tls", # Completely different
# "/udp/4001/quic-v1", # Missing IP ]
# "/dns4/example.com/tcp/443/tls", # Completely different
# ] for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
# for addr_str in invalid_addrs: assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
# maddr = Multiaddr(addr_str)
# assert not is_quic_multiaddr(maddr),
# f"Should not detect {addr_str} as QUIC" class TestQuicMultiaddrToEndpoint:
"""Test endpoint extraction from QUIC multiaddrs."""
# def test_malformed_multiaddrs(self):
# """Test malformed multiaddrs don't crash.""" def test_ipv4_extraction(self):
# # These should not raise exceptions, just return False """Test IPv4 host/port extraction."""
# malformed = [ test_cases = [
# Multiaddr("/ip4/127.0.0.1"), ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
# Multiaddr("/invalid"), ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
# ] ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
]
# for maddr in malformed:
# assert not is_quic_multiaddr(maddr) for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
result = quic_multiaddr_to_endpoint(maddr)
# class TestQuicMultiaddrToEndpoint: assert result == expected, f"Failed for {addr_str}"
# """Test endpoint extraction from QUIC multiaddrs."""
def test_ipv6_extraction(self):
# def test_ipv4_extraction(self): """Test IPv6 host/port extraction."""
# """Test IPv4 host/port extraction.""" test_cases = [
# test_cases = [ ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
# ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)), ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
# ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)), ]
# ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
# ] for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
# for addr_str, expected in test_cases: result = quic_multiaddr_to_endpoint(maddr)
# maddr = Multiaddr(addr_str) assert result == expected, f"Failed for {addr_str}"
# result = quic_multiaddr_to_endpoint(maddr)
# assert result == expected, f"Failed for {addr_str}" def test_invalid_multiaddr_raises_error(self):
"""Test invalid multiaddrs raise appropriate errors."""
# def test_ipv6_extraction(self): invalid_addrs = [
# """Test IPv6 host/port extraction.""" "/ip4/127.0.0.1/tcp/4001", # Not QUIC
# test_cases = [ "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
# ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)), ]
# ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
# ] for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
# for addr_str, expected in test_cases: with pytest.raises(QUICInvalidMultiaddrError):
# maddr = Multiaddr(addr_str) quic_multiaddr_to_endpoint(maddr)
# result = quic_multiaddr_to_endpoint(maddr)
# assert result == expected, f"Failed for {addr_str}"
class TestMultiaddrToQuicVersion:
# def test_invalid_multiaddr_raises_error(self): """Test QUIC version extraction."""
# """Test invalid multiaddrs raise appropriate errors."""
# invalid_addrs = [ def test_quic_v1_detection(self):
# "/ip4/127.0.0.1/tcp/4001", # Not QUIC """Test QUIC v1 version detection."""
# "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol addrs = [
# ] "/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip6/::1/udp/5000/quic-v1",
# for addr_str in invalid_addrs: ]
# maddr = Multiaddr(addr_str)
# with pytest.raises(QUICInvalidMultiaddrError): for addr_str in addrs:
# quic_multiaddr_to_endpoint(maddr) maddr = Multiaddr(addr_str)
version = multiaddr_to_quic_version(maddr)
assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
# class TestMultiaddrToQuicVersion:
# """Test QUIC version extraction.""" def test_quic_draft29_detection(self):
"""Test QUIC draft-29 version detection."""
# def test_quic_v1_detection(self): addrs = [
# """Test QUIC v1 version detection.""" "/ip4/127.0.0.1/udp/4001/quic",
# addrs = [ "/ip6/::1/udp/5000/quic",
# "/ip4/127.0.0.1/udp/4001/quic-v1", ]
# "/ip6/::1/udp/5000/quic-v1",
# ] for addr_str in addrs:
maddr = Multiaddr(addr_str)
# for addr_str in addrs: version = multiaddr_to_quic_version(maddr)
# maddr = Multiaddr(addr_str) assert version == "quic", f"Should detect quic for {addr_str}"
# version = multiaddr_to_quic_version(maddr)
# assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}" def test_non_quic_raises_error(self):
"""Test non-QUIC multiaddrs raise error."""
# def test_quic_draft29_detection(self): maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
# """Test QUIC draft-29 version detection.""" with pytest.raises(QUICInvalidMultiaddrError):
# addrs = [ multiaddr_to_quic_version(maddr)
# "/ip4/127.0.0.1/udp/4001/quic",
# "/ip6/::1/udp/5000/quic",
# ] class TestCreateQuicMultiaddr:
"""Test QUIC multiaddr creation."""
# for addr_str in addrs:
# maddr = Multiaddr(addr_str) def test_ipv4_creation(self):
# version = multiaddr_to_quic_version(maddr) """Test IPv4 QUIC multiaddr creation."""
# assert version == "quic", f"Should detect quic for {addr_str}" test_cases = [
("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
# def test_non_quic_raises_error(self): ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
# """Test non-QUIC multiaddrs raise error.""" ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") ]
# with pytest.raises(QUICInvalidMultiaddrError):
# multiaddr_to_quic_version(maddr) for host, port, version, expected in test_cases:
result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
# class TestCreateQuicMultiaddr:
# """Test QUIC multiaddr creation.""" def test_ipv6_creation(self):
"""Test IPv6 QUIC multiaddr creation."""
# def test_ipv4_creation(self): test_cases = [
# """Test IPv4 QUIC multiaddr creation.""" ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
# test_cases = [ ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
# ("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"), ]
# ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
# ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"), for host, port, version, expected in test_cases:
# ] result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
# for host, port, version, expected in test_cases:
# result = create_quic_multiaddr(host, port, version) def test_default_version(self):
# assert str(result) == expected """Test default version is quic-v1."""
result = create_quic_multiaddr("127.0.0.1", 4001)
# def test_ipv6_creation(self): expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
# """Test IPv6 QUIC multiaddr creation.""" assert str(result) == expected
# test_cases = [
# ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"), def test_invalid_inputs_raise_errors(self):
# ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"), """Test invalid inputs raise appropriate errors."""
# ] # Invalid IP
with pytest.raises(QUICInvalidMultiaddrError):
# for host, port, version, expected in test_cases: create_quic_multiaddr("invalid-ip", 4001)
# result = create_quic_multiaddr(host, port, version)
# assert str(result) == expected # Invalid port
with pytest.raises(QUICInvalidMultiaddrError):
# def test_default_version(self): create_quic_multiaddr("127.0.0.1", 70000)
# """Test default version is quic-v1."""
# result = create_quic_multiaddr("127.0.0.1", 4001) with pytest.raises(QUICInvalidMultiaddrError):
# expected = "/ip4/127.0.0.1/udp/4001/quic-v1" create_quic_multiaddr("127.0.0.1", -1)
# assert str(result) == expected
# Invalid version
# def test_invalid_inputs_raise_errors(self): with pytest.raises(QUICInvalidMultiaddrError):
# """Test invalid inputs raise appropriate errors.""" create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
# # Invalid IP
# with pytest.raises(QUICInvalidMultiaddrError):
# create_quic_multiaddr("invalid-ip", 4001) class TestQuicVersionToWireFormat:
"""Test QUIC version to wire format conversion."""
# # Invalid port
# with pytest.raises(QUICInvalidMultiaddrError): def test_supported_versions(self):
# create_quic_multiaddr("127.0.0.1", 70000) """Test supported version conversions."""
test_cases = [
# with pytest.raises(QUICInvalidMultiaddrError): ("quic-v1", 0x00000001), # RFC 9000
# create_quic_multiaddr("127.0.0.1", -1) ("quic", 0xFF00001D), # draft-29
]
# # Invalid version
# with pytest.raises(QUICInvalidMultiaddrError): for version, expected_wire in test_cases:
# create_quic_multiaddr("127.0.0.1", 4001, "invalid-version") result = quic_version_to_wire_format(TProtocol(version))
assert result == expected_wire, f"Failed for version {version}"
# class TestQuicVersionToWireFormat: def test_unsupported_version_raises_error(self):
# """Test QUIC version to wire format conversion.""" """Test unsupported versions raise error."""
with pytest.raises(QUICUnsupportedVersionError):
# def test_supported_versions(self): quic_version_to_wire_format(TProtocol("unsupported-version"))
# """Test supported version conversions."""
# test_cases = [
# ("quic-v1", 0x00000001), # RFC 9000 class TestGetAlpnProtocols:
# ("quic", 0xFF00001D), # draft-29 """Test ALPN protocol retrieval."""
# ]
def test_returns_libp2p_protocols(self):
# for version, expected_wire in test_cases: """Test returns expected libp2p ALPN protocols."""
# result = quic_version_to_wire_format(TProtocol(version)) protocols = get_alpn_protocols()
# assert result == expected_wire, f"Failed for version {version}" assert protocols == ["libp2p"]
assert isinstance(protocols, list)
# def test_unsupported_version_raises_error(self):
# """Test unsupported versions raise error.""" def test_returns_copy(self):
# with pytest.raises(QUICUnsupportedVersionError): """Test returns a copy, not the original list."""
# quic_version_to_wire_format(TProtocol("unsupported-version")) protocols1 = get_alpn_protocols()
protocols2 = get_alpn_protocols()
# class TestGetAlpnProtocols: # Modify one list
# """Test ALPN protocol retrieval.""" protocols1.append("test")
# def test_returns_libp2p_protocols(self): # Other list should be unchanged
# """Test returns expected libp2p ALPN protocols.""" assert protocols2 == ["libp2p"]
# protocols = get_alpn_protocols()
# assert protocols == ["libp2p"]
# assert isinstance(protocols, list) class TestNormalizeQuicMultiaddr:
"""Test QUIC multiaddr normalization."""
# def test_returns_copy(self):
# """Test returns a copy, not the original list.""" def test_already_normalized(self):
# protocols1 = get_alpn_protocols() """Test already normalized multiaddrs pass through."""
# protocols2 = get_alpn_protocols() addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
maddr = Multiaddr(addr_str)
# # Modify one list
# protocols1.append("test")
# # Other list should be unchanged
# assert protocols2 == ["libp2p"]
# class TestNormalizeQuicMultiaddr:
# """Test QUIC multiaddr normalization."""
# def test_already_normalized(self):
# """Test already normalized multiaddrs pass through."""
# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
# maddr = Multiaddr(addr_str)
# result = normalize_quic_multiaddr(maddr) result = normalize_quic_multiaddr(maddr)
# assert str(result) == addr_str assert str(result) == addr_str
# def test_normalize_different_versions(self): def test_normalize_different_versions(self):
# """Test normalization works for different QUIC versions.""" """Test normalization works for different QUIC versions."""
# test_cases = [ test_cases = [
# "/ip4/127.0.0.1/udp/4001/quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1",
# "/ip4/127.0.0.1/udp/4001/quic", "/ip4/127.0.0.1/udp/4001/quic",
# "/ip6/::1/udp/5000/quic-v1", "/ip6/::1/udp/5000/quic-v1",
# ] ]
# for addr_str in test_cases: for addr_str in test_cases:
# maddr = Multiaddr(addr_str) maddr = Multiaddr(addr_str)
# result = normalize_quic_multiaddr(maddr) result = normalize_quic_multiaddr(maddr)
# # Should be valid QUIC multiaddr # Should be valid QUIC multiaddr
# assert is_quic_multiaddr(result) assert is_quic_multiaddr(result)
# # Should be parseable # Should be parseable
# host, port = quic_multiaddr_to_endpoint(result) host, port = quic_multiaddr_to_endpoint(result)
# version = multiaddr_to_quic_version(result) version = multiaddr_to_quic_version(result)
# # Should match original # Should match original
# orig_host, orig_port = quic_multiaddr_to_endpoint(maddr) orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
# orig_version = multiaddr_to_quic_version(maddr) orig_version = multiaddr_to_quic_version(maddr)
# assert host == orig_host assert host == orig_host
# assert port == orig_port assert port == orig_port
# assert version == orig_version assert version == orig_version
# def test_non_quic_raises_error(self): def test_non_quic_raises_error(self):
# """Test non-QUIC multiaddrs raise error.""" """Test non-QUIC multiaddrs raise error."""
# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
# with pytest.raises(QUICInvalidMultiaddrError): with pytest.raises(QUICInvalidMultiaddrError):
# normalize_quic_multiaddr(maddr) normalize_quic_multiaddr(maddr)
# class TestIntegration: class TestIntegration:
# """Integration tests for utility functions working together.""" """Integration tests for utility functions working together."""
# def test_round_trip_conversion(self): def test_round_trip_conversion(self):
# """Test creating and parsing multiaddrs works correctly.""" """Test creating and parsing multiaddrs works correctly."""
# test_cases = [ test_cases = [
# ("127.0.0.1", 4001, "quic-v1"), ("127.0.0.1", 4001, "quic-v1"),
# ("::1", 5000, "quic"), ("::1", 5000, "quic"),
# ("192.168.1.100", 8080, "quic-v1"), ("192.168.1.100", 8080, "quic-v1"),
# ] ]
# for host, port, version in test_cases: for host, port, version in test_cases:
# # Create multiaddr # Create multiaddr
# maddr = create_quic_multiaddr(host, port, version) maddr = create_quic_multiaddr(host, port, version)
# # Should be detected as QUIC # Should be detected as QUIC
# assert is_quic_multiaddr(maddr) assert is_quic_multiaddr(maddr)
# # Should extract original values # Should extract original values
# extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr) extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
# extracted_version = multiaddr_to_quic_version(maddr) extracted_version = multiaddr_to_quic_version(maddr)
# assert extracted_host == host assert extracted_host == host
# assert extracted_port == port assert extracted_port == port
# assert extracted_version == version assert extracted_version == version
# # Should normalize to same value # Should normalize to same value
# normalized = normalize_quic_multiaddr(maddr) normalized = normalize_quic_multiaddr(maddr)
# assert str(normalized) == str(maddr) assert str(normalized) == str(maddr)
# def test_wire_format_integration(self): def test_wire_format_integration(self):
# """Test wire format conversion works with version detection.""" """Test wire format conversion works with version detection."""
# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
# maddr = Multiaddr(addr_str) maddr = Multiaddr(addr_str)
# # Extract version and convert to wire format # Extract version and convert to wire format
# version = multiaddr_to_quic_version(maddr) version = multiaddr_to_quic_version(maddr)
# wire_format = quic_version_to_wire_format(version) wire_format = quic_version_to_wire_format(version)
# # Should be QUIC v1 wire format # Should be QUIC v1 wire format
# assert wire_format == 0x00000001 assert wire_format == 0x00000001