fix: try to fix connection id updation

This commit is contained in:
Akash Mondal
2025-06-20 11:52:51 +00:00
committed by lla-dane
parent 6633eb01d4
commit e2fee14bc5
8 changed files with 1305 additions and 79 deletions

View File

@ -7,7 +7,7 @@ import logging
import socket
from sys import stdout
import time
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Set
from aioquic.quic import events
from aioquic.quic.connection import QuicConnection
@ -60,6 +60,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
- Flow control integration
- Connection migration support
- Performance monitoring
- COMPLETE connection ID management (fixes the original issue)
"""
# Configuration constants based on research
@ -144,6 +145,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._nursery: trio.Nursery | None = None
self._event_processing_task: Any | None = None
# *** NEW: Connection ID tracking - CRITICAL for fixing the original issue ***
self._available_connection_ids: Set[bytes] = set()
self._current_connection_id: Optional[bytes] = None
self._retired_connection_ids: Set[bytes] = set()
self._connection_id_sequence_numbers: Set[int] = set()
# Event processing control
self._event_processing_active = False
self._pending_events: list[events.QuicEvent] = []
# Performance and monitoring
self._connection_start_time = time.time()
self._stats = {
@ -155,6 +166,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
"bytes_received": 0,
"packets_sent": 0,
"packets_received": 0,
# *** NEW: Connection ID statistics ***
"connection_ids_issued": 0,
"connection_ids_retired": 0,
"connection_id_changes": 0,
}
logger.debug(
@ -219,6 +234,25 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""Get the remote peer ID."""
return self._peer_id
# *** NEW: Connection ID management methods ***
def get_connection_id_stats(self) -> dict[str, Any]:
"""Get connection ID statistics and current state."""
return {
"available_connection_ids": len(self._available_connection_ids),
"current_connection_id": self._current_connection_id.hex()
if self._current_connection_id
else None,
"retired_connection_ids": len(self._retired_connection_ids),
"connection_ids_issued": self._stats["connection_ids_issued"],
"connection_ids_retired": self._stats["connection_ids_retired"],
"connection_id_changes": self._stats["connection_id_changes"],
"available_cid_list": [cid.hex() for cid in self._available_connection_ids],
}
def get_current_connection_id(self) -> Optional[bytes]:
"""Get the current connection ID."""
return self._current_connection_id
# Connection lifecycle methods
async def start(self) -> None:
@ -379,6 +413,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Check for idle streams that can be cleaned up
await self._cleanup_idle_streams()
# *** NEW: Log connection ID status periodically ***
if logger.isEnabledFor(logging.DEBUG):
cid_stats = self.get_connection_id_stats()
logger.debug(f"Connection ID stats: {cid_stats}")
# Sleep for maintenance interval
await trio.sleep(30.0) # 30 seconds
@ -752,36 +791,155 @@ class QUICConnection(IRawConnection, IMuxedConn):
logger.debug(f"Removed stream {stream_id} from connection")
# QUIC event handling
# *** UPDATED: Complete QUIC event handling - FIXES THE ORIGINAL ISSUE ***
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
while True:
event = self._quic.next_event()
if event is None:
break
if self._event_processing_active:
return # Prevent recursion
try:
self._event_processing_active = True
try:
events_processed = 0
while True:
event = self._quic.next_event()
if event is None:
break
events_processed += 1
await self._handle_quic_event(event)
except Exception as e:
logger.error(f"Error handling QUIC event {type(event).__name__}: {e}")
if events_processed > 0:
logger.debug(f"Processed {events_processed} QUIC events")
finally:
self._event_processing_active = False
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
"""Handle a single QUIC event."""
"""Handle a single QUIC event with COMPLETE event type coverage."""
logger.debug(f"Handling QUIC event: {type(event).__name__}")
print(f"QUIC event: {type(event).__name__}")
if isinstance(event, events.ConnectionTerminated):
await self._handle_connection_terminated(event)
elif isinstance(event, events.HandshakeCompleted):
await self._handle_handshake_completed(event)
elif isinstance(event, events.StreamDataReceived):
await self._handle_stream_data(event)
elif isinstance(event, events.StreamReset):
await self._handle_stream_reset(event)
elif isinstance(event, events.DatagramFrameReceived):
await self._handle_datagram_received(event)
else:
logger.debug(f"Unhandled QUIC event: {type(event).__name__}")
print(f"Unhandled QUIC event: {type(event).__name__}")
try:
if isinstance(event, events.ConnectionTerminated):
await self._handle_connection_terminated(event)
elif isinstance(event, events.HandshakeCompleted):
await self._handle_handshake_completed(event)
elif isinstance(event, events.StreamDataReceived):
await self._handle_stream_data(event)
elif isinstance(event, events.StreamReset):
await self._handle_stream_reset(event)
elif isinstance(event, events.DatagramFrameReceived):
await self._handle_datagram_received(event)
# *** NEW: Connection ID event handlers - CRITICAL FIX ***
elif isinstance(event, events.ConnectionIdIssued):
await self._handle_connection_id_issued(event)
elif isinstance(event, events.ConnectionIdRetired):
await self._handle_connection_id_retired(event)
# *** NEW: Additional event handlers for completeness ***
elif isinstance(event, events.PingAcknowledged):
await self._handle_ping_acknowledged(event)
elif isinstance(event, events.ProtocolNegotiated):
await self._handle_protocol_negotiated(event)
elif isinstance(event, events.StopSendingReceived):
await self._handle_stop_sending_received(event)
else:
logger.debug(f"Unhandled QUIC event type: {type(event).__name__}")
print(f"Unhandled QUIC event: {type(event).__name__}")
except Exception as e:
logger.error(f"Error handling QUIC event {type(event).__name__}: {e}")
# *** NEW: Connection ID event handlers - THE MAIN FIX ***
async def _handle_connection_id_issued(
self, event: events.ConnectionIdIssued
) -> None:
"""
Handle new connection ID issued by peer.
This is the CRITICAL missing functionality that was causing your issue!
"""
logger.info(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
print(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
# Add to available connection IDs
self._available_connection_ids.add(event.connection_id)
# If we don't have a current connection ID, use this one
if self._current_connection_id is None:
self._current_connection_id = event.connection_id
logger.info(f"🆔 Set current connection ID to: {event.connection_id.hex()}")
print(f"🆔 Set current connection ID to: {event.connection_id.hex()}")
# Update statistics
self._stats["connection_ids_issued"] += 1
logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}")
print(f"Available connection IDs: {len(self._available_connection_ids)}")
async def _handle_connection_id_retired(
self, event: events.ConnectionIdRetired
) -> None:
"""
Handle connection ID retirement.
This handles when the peer tells us to stop using a connection ID.
"""
logger.info(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
print(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
# Remove from available IDs and add to retired set
self._available_connection_ids.discard(event.connection_id)
self._retired_connection_ids.add(event.connection_id)
# If this was our current connection ID, switch to another
if self._current_connection_id == event.connection_id:
if self._available_connection_ids:
self._current_connection_id = next(iter(self._available_connection_ids))
logger.info(
f"🆔 Switched to new connection ID: {self._current_connection_id.hex()}"
)
print(
f"🆔 Switched to new connection ID: {self._current_connection_id.hex()}"
)
self._stats["connection_id_changes"] += 1
else:
self._current_connection_id = None
logger.warning("⚠️ No available connection IDs after retirement!")
print("⚠️ No available connection IDs after retirement!")
# Update statistics
self._stats["connection_ids_retired"] += 1
# *** NEW: Additional event handlers for completeness ***
async def _handle_ping_acknowledged(self, event: events.PingAcknowledged) -> None:
"""Handle ping acknowledgment."""
logger.debug(f"Ping acknowledged: uid={event.uid}")
async def _handle_protocol_negotiated(
self, event: events.ProtocolNegotiated
) -> None:
"""Handle protocol negotiation completion."""
logger.info(f"Protocol negotiated: {event.alpn_protocol}")
async def _handle_stop_sending_received(
self, event: events.StopSendingReceived
) -> None:
"""Handle stop sending request from peer."""
logger.debug(
f"Stop sending received: stream_id={event.stream_id}, error_code={event.error_code}"
)
if event.stream_id in self._streams:
stream = self._streams[event.stream_id]
# Handle stop sending on the stream if method exists
if hasattr(stream, "handle_stop_sending"):
await stream.handle_stop_sending(event.error_code)
# *** EXISTING event handlers (unchanged) ***
async def _handle_handshake_completed(
self, event: events.HandshakeCompleted
@ -930,9 +1088,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def _handle_datagram_received(
self, event: events.DatagramFrameReceived
) -> None:
"""Handle received datagrams."""
# For future datagram support
logger.debug(f"Received datagram: {len(event.data)} bytes")
"""Handle datagram frame (if using QUIC datagrams)."""
logger.debug(f"Datagram frame received: size={len(event.data)}")
# For now, just log. Could be extended for custom datagram handling
async def _handle_timer_events(self) -> None:
"""Handle QUIC timer events."""
@ -961,6 +1119,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
logger.error(f"Failed to send datagram: {e}")
await self._handle_connection_error(e)
# Additional methods for stream data processing
async def _process_quic_event(self, event):
"""Process a single QUIC event."""
await self._handle_quic_event(event)
async def _transmit_pending_data(self):
"""Transmit any pending data."""
await self._transmit()
# Error handling
async def _handle_connection_error(self, error: Exception) -> None:
@ -1046,16 +1213,24 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def read(self, n: int | None = -1) -> bytes:
"""
Read data from the connection.
For QUIC, this reads from the next available stream.
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
Read data from the stream.
# For raw connection interface, we need to handle this differently
# In practice, upper layers will use the muxed connection interface
Args:
n: Maximum number of bytes to read. -1 means read all available.
Returns:
Data bytes read from the stream.
Raises:
QUICStreamClosedError: If stream is closed for reading.
QUICStreamResetError: If stream was reset.
QUICStreamTimeoutError: If read timeout occurs.
"""
# This method doesn't make sense for a muxed connection
# It's here for interface compatibility but should not be used
raise NotImplementedError(
"Use muxed connection interface for stream-based reading"
"Use streams for reading data from QUIC connections. "
"Call accept_stream() or open_stream() instead."
)
# Utility and monitoring methods
@ -1080,7 +1255,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
return [
stream
for stream in self._streams.values()
if stream.protocol == protocol and not stream.is_closed()
if hasattr(stream, "protocol")
and stream.protocol == protocol
and not stream.is_closed()
]
def _update_stats(self) -> None:
@ -1112,7 +1289,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
f"initiator={self.__is_initiator}, "
f"verified={self._peer_verified}, "
f"established={self._established}, "
f"streams={len(self._streams)})"
f"streams={len(self._streams)}, "
f"current_cid={self._current_connection_id.hex() if self._current_connection_id else None})"
)
def __str__(self) -> str: