fix: add quic utils test and improve connection performance

This commit is contained in:
Akash Mondal
2025-09-04 21:25:13 +00:00
parent 2ee3e0b054
commit 2fe5882013
5 changed files with 525 additions and 455 deletions

View File

@ -3,14 +3,16 @@ QUIC Connection implementation.
Manages bidirectional QUIC connections with integrated stream multiplexing.
"""
from collections import defaultdict
from collections.abc import Awaitable, Callable
import logging
import socket
import time
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, cast
from aioquic.quic import events
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import QuicEvent
from cryptography import x509
import multiaddr
import trio
@ -104,12 +106,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._connected_event = trio.Event()
self._closed_event = trio.Event()
# Stream management
self._streams: dict[int, QUICStream] = {}
self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups
self._next_stream_id: int = self._calculate_initial_stream_id()
self._stream_handler: TQUICStreamHandlerFn | None = None
self._stream_id_lock = trio.Lock()
self._stream_count_lock = trio.Lock()
# Single lock for all stream operations
self._stream_lock = trio.Lock()
# Stream counting and limits
self._outbound_stream_count = 0
@ -118,7 +121,6 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Stream acceptance for incoming streams
self._stream_accept_queue: list[QUICStream] = []
self._stream_accept_event = trio.Event()
self._accept_queue_lock = trio.Lock()
# Connection state
self._closed: bool = False
@ -143,9 +145,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._retired_connection_ids: set[bytes] = set()
self._connection_id_sequence_numbers: set[int] = set()
# Event processing control
# Event processing control with batching
self._event_processing_active = False
self._pending_events: list[events.QuicEvent] = []
self._event_batch: list[events.QuicEvent] = []
self._event_batch_size = 10
self._last_event_time = 0.0
# Set quic connection configuration
self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT
@ -250,6 +254,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""Get the current connection ID."""
return self._current_connection_id
# Fast stream lookup with caching
def _get_stream_fast(self, stream_id: int) -> QUICStream | None:
"""Get stream with caching for performance."""
# Try cache first
stream = self._stream_cache.get(stream_id)
if stream is not None:
return stream
# Fallback to main dict
stream = self._streams.get(stream_id)
if stream is not None:
self._stream_cache[stream_id] = stream
return stream
# Connection lifecycle methods
async def start(self) -> None:
@ -389,8 +408,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
try:
while not self._closed:
# Process QUIC events
await self._process_quic_events()
# Batch process events
await self._process_quic_events_batched()
# Handle timer events
await self._handle_timer_events()
@ -421,12 +440,25 @@ class QUICConnection(IRawConnection, IMuxedConn):
cid_stats = self.get_connection_id_stats()
logger.debug(f"Connection ID stats: {cid_stats}")
# Clean cache periodically
await self._cleanup_cache()
# Sleep for maintenance interval
await trio.sleep(30.0) # 30 seconds
except Exception as e:
logger.error(f"Error in periodic maintenance: {e}")
async def _cleanup_cache(self) -> None:
"""Clean up stream cache periodically to prevent memory leaks."""
if len(self._stream_cache) > 100: # Arbitrary threshold
# Remove closed streams from cache
closed_stream_ids = [
sid for sid, stream in self._stream_cache.items() if stream.is_closed()
]
for sid in closed_stream_ids:
self._stream_cache.pop(sid, None)
async def _client_packet_receiver(self) -> None:
"""Receive packets for client connections."""
logger.debug("Starting client packet receiver")
@ -442,8 +474,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Feed packet to QUIC connection
self._quic.receive_datagram(data, addr, now=time.time())
# Process any events that result from the packet
await self._process_quic_events()
# Batch process events
await self._process_quic_events_batched()
# Send any response packets
await self._transmit()
@ -675,15 +707,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
if not self._started:
raise QUICConnectionError("Connection not started")
# Check stream limits
async with self._stream_count_lock:
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
f"Maximum outbound streams ({self.MAX_OUTGOING_STREAMS}) reached"
)
# Use single lock for all stream operations
with trio.move_on_after(timeout):
async with self._stream_id_lock:
async with self._stream_lock:
# Check stream limits inside lock
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
"Maximum outbound streams "
f"({self.MAX_OUTGOING_STREAMS}) reached"
)
# Generate next stream ID
stream_id = self._next_stream_id
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
@ -697,10 +730,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
)
self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
async with self._stream_count_lock:
self._outbound_stream_count += 1
self._stats["streams_opened"] += 1
self._outbound_stream_count += 1
self._stats["streams_opened"] += 1
logger.debug(f"Opened outbound QUIC stream {stream_id}")
return stream
@ -737,7 +770,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
if self._closed:
raise MuxedConnUnavailable("QUIC connection is closed")
async with self._accept_queue_lock:
# Use single lock for stream acceptance
async with self._stream_lock:
if self._stream_accept_queue:
stream = self._stream_accept_queue.pop(0)
logger.debug(f"Accepted inbound stream {stream.stream_id}")
@ -769,10 +803,12 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""
if stream_id in self._streams:
stream = self._streams.pop(stream_id)
# Remove from cache too
self._stream_cache.pop(stream_id, None)
# Update stream counts asynchronously
async def update_counts() -> None:
async with self._stream_count_lock:
async with self._stream_lock:
if stream.direction == StreamDirection.OUTBOUND:
self._outbound_stream_count = max(
0, self._outbound_stream_count - 1
@ -789,29 +825,140 @@ class QUICConnection(IRawConnection, IMuxedConn):
logger.debug(f"Removed stream {stream_id} from connection")
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
# Batched event processing to reduce overhead
async def _process_quic_events_batched(self) -> None:
"""Process QUIC events in batches for better performance."""
if self._event_processing_active:
return # Prevent recursion
self._event_processing_active = True
try:
current_time = time.time()
events_processed = 0
while True:
# Collect events into batch
while events_processed < self._event_batch_size:
event = self._quic.next_event()
if event is None:
break
self._event_batch.append(event)
events_processed += 1
await self._handle_quic_event(event)
if events_processed > 0:
logger.debug(f"Processed {events_processed} QUIC events")
# Process batch if we have events or timeout
if self._event_batch and (
len(self._event_batch) >= self._event_batch_size
or current_time - self._last_event_time > 0.01 # 10ms timeout
):
await self._process_event_batch()
self._event_batch.clear()
self._last_event_time = current_time
finally:
self._event_processing_active = False
async def _process_event_batch(self) -> None:
"""Process a batch of events efficiently."""
if not self._event_batch:
return
# Group events by type for batch processing where possible
events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list)
for event in self._event_batch:
events_by_type[type(event).__name__].append(event)
# Process events by type
for event_type, event_list in events_by_type.items():
if event_type == type(events.StreamDataReceived).__name__:
await self._handle_stream_data_batch(
cast(list[events.StreamDataReceived], event_list)
)
else:
# Process other events individually
for event in event_list:
await self._handle_quic_event(event)
logger.debug(f"Processed batch of {len(self._event_batch)} events")
async def _handle_stream_data_batch(
self, events_list: list[events.StreamDataReceived]
) -> None:
"""Handle stream data events in batch for better performance."""
# Group by stream ID
events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list)
for event in events_list:
events_by_stream[event.stream_id].append(event)
# Process each stream's events
for stream_id, stream_events in events_by_stream.items():
stream = self._get_stream_fast(stream_id) # Use fast lookup
if not stream:
if self._is_incoming_stream(stream_id):
try:
stream = await self._create_inbound_stream(stream_id)
except QUICStreamLimitError:
# Reset stream if we can't handle it
self._quic.reset_stream(stream_id, error_code=0x04)
await self._transmit()
continue
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
continue
# Process all events for this stream
for received_event in stream_events:
if hasattr(received_event, "data"):
self._stats["bytes_received"] += len(received_event.data) # type: ignore
if hasattr(received_event, "end_stream"):
await stream.handle_data_received(
received_event.data, # type: ignore
received_event.end_stream, # type: ignore
)
async def _create_inbound_stream(self, stream_id: int) -> QUICStream:
"""Create inbound stream with proper limit checking."""
async with self._stream_lock:
# Double-check stream doesn't exist
existing_stream = self._streams.get(stream_id)
if existing_stream:
return existing_stream
# Check limits
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting inbound stream {stream_id}: limit reached")
raise QUICStreamLimitError("Too many inbound streams")
# Create stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Created inbound stream {stream_id}")
return stream
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
# Delegate to batched processing for better performance
await self._process_quic_events_batched()
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
"""Handle a single QUIC event with COMPLETE event type coverage."""
logger.debug(f"Handling QUIC event: {type(event).__name__}")
@ -929,8 +1076,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
f"stream_id={event.stream_id}, error_code={event.error_code}"
)
if event.stream_id in self._streams:
stream: QUICStream = self._streams[event.stream_id]
# Use fast lookup
stream = self._get_stream_fast(event.stream_id)
if stream:
# Handle stop sending on the stream if method exists
await stream.handle_stop_sending(event.error_code)
@ -964,6 +1112,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
await stream.close()
self._streams.clear()
self._stream_cache.clear() # Clear cache too
self._closed = True
self._closed_event.set()
@ -978,39 +1127,19 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._stats["bytes_received"] += len(event.data)
try:
if stream_id not in self._streams:
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if not stream:
if self._is_incoming_stream(stream_id):
logger.debug(f"Creating new incoming stream {stream_id}")
from .stream import QUICStream, StreamDirection
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
# Store the stream
self._streams[stream_id] = stream
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Added stream {stream_id} to accept queue")
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_opened"] += 1
stream = await self._create_inbound_stream(stream_id)
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
return
stream = self._streams[stream_id]
await stream.handle_data_received(event.data, event.end_stream)
except Exception as e:
@ -1019,8 +1148,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
"""Get existing stream or create new inbound stream."""
if stream_id in self._streams:
return self._streams[stream_id]
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if stream:
return stream
# Check if this is an incoming stream
is_incoming = self._is_incoming_stream(stream_id)
@ -1031,49 +1162,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
f"Received data for unknown outbound stream {stream_id}"
)
# Check stream limits for incoming streams
async with self._stream_count_lock:
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting incoming stream {stream_id}: limit reached")
# Send reset to reject the stream
self._quic.reset_stream(
stream_id, error_code=0x04
) # STREAM_LIMIT_ERROR
await self._transmit()
raise QUICStreamLimitError("Too many inbound streams")
# Create new inbound stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue and notify handler
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
# Handle directly with stream handler if available
if self._stream_handler:
try:
if self._nursery:
self._nursery.start_soon(self._stream_handler, stream)
else:
await self._stream_handler(stream)
except Exception as e:
logger.error(f"Error in stream handler for stream {stream_id}: {e}")
logger.debug(f"Created inbound stream {stream_id}")
return stream
return await self._create_inbound_stream(stream_id)
def _is_incoming_stream(self, stream_id: int) -> bool:
"""
@ -1095,9 +1185,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
stream_id = event.stream_id
self._stats["streams_reset"] += 1
if stream_id in self._streams:
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if stream:
try:
stream = self._streams[stream_id]
await stream.handle_reset(event.error_code)
logger.debug(
f"Handled reset for stream {stream_id}"
@ -1137,12 +1228,20 @@ class QUICConnection(IRawConnection, IMuxedConn):
try:
current_time = time.time()
datagrams = self._quic.datagrams_to_send(now=current_time)
# Batch stats updates
packet_count = 0
total_bytes = 0
for data, addr in datagrams:
await sock.sendto(data, addr)
# Update stats if available
if hasattr(self, "_stats"):
self._stats["packets_sent"] += 1
self._stats["bytes_sent"] += len(data)
packet_count += 1
total_bytes += len(data)
# Update stats in batch
if packet_count > 0:
self._stats["packets_sent"] += packet_count
self._stats["bytes_sent"] += total_bytes
except Exception as e:
logger.error(f"Transmission error: {e}")
@ -1217,6 +1316,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._socket = None
self._streams.clear()
self._stream_cache.clear() # Clear cache
self._closed_event.set()
logger.debug(f"QUIC connection to {self._remote_peer_id} closed")
@ -1328,6 +1428,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
"max_streams": self.MAX_CONCURRENT_STREAMS,
"stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS,
"stats": self._stats.copy(),
"cache_size": len(
self._stream_cache
), # Include cache metrics for monitoring
}
def get_active_streams(self) -> list[QUICStream]:

View File

@ -267,56 +267,37 @@ class QUICListener(IListener):
return value, 8
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
"""Process incoming QUIC packet with fine-grained locking."""
"""Process incoming QUIC packet with optimized routing."""
try:
self._stats["packets_processed"] += 1
self._stats["bytes_received"] += len(data)
logger.debug(f"Processing packet of {len(data)} bytes from {addr}")
# Parse packet header OUTSIDE the lock
packet_info = self.parse_quic_packet(data)
if packet_info is None:
logger.error(f"Failed to parse packet header quic packet from {addr}")
self._stats["invalid_packets"] += 1
return
dest_cid = packet_info.destination_cid
connection_obj = None
pending_quic_conn = None
# Single lock acquisition with all lookups
async with self._connection_lock:
if dest_cid in self._connections:
connection_obj = self._connections[dest_cid]
logger.debug(f"Routing to established connection {dest_cid.hex()}")
connection_obj = self._connections.get(dest_cid)
pending_quic_conn = self._pending_connections.get(dest_cid)
elif dest_cid in self._pending_connections:
pending_quic_conn = self._pending_connections[dest_cid]
logger.debug(f"Routing to pending connection {dest_cid.hex()}")
else:
# Check if this is a new connection
if packet_info.packet_type.name == "INITIAL":
logger.debug(
f"Received INITIAL Packet Creating new conn for {addr}"
)
# Create new connection INSIDE the lock for safety
if not connection_obj and not pending_quic_conn:
if packet_info.packet_type == QuicPacketType.INITIAL:
pending_quic_conn = await self._handle_new_connection(
data, addr, packet_info
)
else:
return
# CRITICAL: Process packets OUTSIDE the lock to prevent deadlock
# Process outside the lock
if connection_obj:
# Handle established connection
await self._handle_established_connection_packet(
connection_obj, data, addr, dest_cid
)
elif pending_quic_conn:
# Handle pending connection
await self._handle_pending_connection_packet(
pending_quic_conn, data, addr, dest_cid
)
@ -431,6 +412,7 @@ class QUICListener(IListener):
f"No configuration found for version 0x{packet_info.version:08x}"
)
await self._send_version_negotiation(addr, packet_info.source_cid)
return None
if not quic_config:
raise QUICListenError("Cannot determine QUIC configuration")

View File

@ -108,21 +108,21 @@ def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
# Try to get IPv4 address
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except ValueError:
except Exception:
pass
# Try to get IPv6 address if IPv4 not found
if host is None:
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except ValueError:
except Exception:
pass
# Get UDP port
try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
port = int(port_str)
except ValueError:
except Exception:
pass
if host is None or port is None:
@ -203,8 +203,7 @@ def create_quic_multiaddr(
if version == "quic-v1" or version == "/quic-v1":
quic_proto = QUIC_V1_PROTOCOL
elif version == "quic" or version == "/quic":
# This is DRAFT Protocol
quic_proto = QUIC_V1_PROTOCOL
quic_proto = QUIC_DRAFT29_PROTOCOL
else:
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")