Files
py-libp2p/libp2p/transport/quic/connection.py
2025-08-31 16:07:41 +00:00

1392 lines
52 KiB
Python

"""
QUIC Connection implementation.
Manages bidirectional QUIC connections with integrated stream multiplexing.
"""
from collections.abc import Awaitable, Callable
import logging
import socket
import time
from typing import TYPE_CHECKING, Any, Optional
from aioquic.quic import events
from aioquic.quic.connection import QuicConnection
from cryptography import x509
import multiaddr
import trio
from libp2p.abc import IMuxedConn, IRawConnection
from libp2p.custom_types import TQUICStreamHandlerFn
from libp2p.peer.id import ID
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
from .exceptions import (
QUICConnectionClosedError,
QUICConnectionError,
QUICConnectionTimeoutError,
QUICErrorContext,
QUICPeerVerificationError,
QUICStreamError,
QUICStreamLimitError,
QUICStreamTimeoutError,
)
from .stream import QUICStream, StreamDirection
if TYPE_CHECKING:
from .security import QUICTLSConfigManager
from .transport import QUICTransport
logger = logging.getLogger(__name__)
class QUICConnection(IRawConnection, IMuxedConn):
"""
QUIC connection implementing both raw connection and muxed connection interfaces.
Uses aioquic's sans-IO core with trio for native async support.
QUIC natively provides stream multiplexing, so this connection acts as both
a raw connection (for transport layer) and muxed connection (for upper layers).
Features:
- Native QUIC stream multiplexing
- Integrated libp2p TLS security with peer identity verification
- Resource-aware stream management
- Comprehensive error handling
- Flow control integration
- Connection migration support
- Performance monitoring
- COMPLETE connection ID management (fixes the original issue)
"""
MAX_CONCURRENT_STREAMS = 256
MAX_INCOMING_STREAMS = 1000
MAX_OUTGOING_STREAMS = 1000
STREAM_ACCEPT_TIMEOUT = 60.0
CONNECTION_HANDSHAKE_TIMEOUT = 60.0
CONNECTION_CLOSE_TIMEOUT = 10.0
def __init__(
self,
quic_connection: QuicConnection,
remote_addr: tuple[str, int],
remote_peer_id: ID | None,
local_peer_id: ID,
is_initiator: bool,
maddr: multiaddr.Multiaddr,
transport: "QUICTransport",
security_manager: Optional["QUICTLSConfigManager"] = None,
resource_scope: Any | None = None,
listener_socket: trio.socket.SocketType | None = None,
):
"""
Initialize QUIC connection with security integration.
Args:
quic_connection: aioquic QuicConnection instance
remote_addr: Remote peer address
remote_peer_id: Remote peer ID (may be None initially)
local_peer_id: Local peer ID
is_initiator: Whether this is the connection initiator
maddr: Multiaddr for this connection
transport: Parent QUIC transport
security_manager: Security manager for TLS/certificate handling
resource_scope: Resource manager scope for tracking
listener_socket: Socket of listener to transmit data
"""
self._quic = quic_connection
self._remote_addr = remote_addr
self._remote_peer_id = remote_peer_id
self._local_peer_id = local_peer_id
self.peer_id = remote_peer_id or local_peer_id
self._is_initiator = is_initiator
self._maddr = maddr
self._transport = transport
self._security_manager = security_manager
self._resource_scope = resource_scope
# Trio networking - socket may be provided by listener
self._socket = listener_socket if listener_socket else None
self._owns_socket = listener_socket is None
self._connected_event = trio.Event()
self._closed_event = trio.Event()
# Stream management
self._streams: dict[int, QUICStream] = {}
self._next_stream_id: int = self._calculate_initial_stream_id()
self._stream_handler: TQUICStreamHandlerFn | None = None
self._stream_id_lock = trio.Lock()
self._stream_count_lock = trio.Lock()
# Stream counting and limits
self._outbound_stream_count = 0
self._inbound_stream_count = 0
# Stream acceptance for incoming streams
self._stream_accept_queue: list[QUICStream] = []
self._stream_accept_event = trio.Event()
self._accept_queue_lock = trio.Lock()
# Connection state
self._closed: bool = False
self._established = False
self._started = False
self._handshake_completed = False
self._peer_verified = False
# Security state
self._peer_certificate: x509.Certificate | None = None
self._handshake_events: list[events.HandshakeCompleted] = []
# Background task management
self._background_tasks_started = False
self._nursery: trio.Nursery | None = None
self._event_processing_task: Any | None = None
self.on_close: Callable[[], Awaitable[None]] | None = None
self.event_started = trio.Event()
# *** NEW: Connection ID tracking - CRITICAL for fixing the original issue ***
self._available_connection_ids: set[bytes] = set()
self._current_connection_id: bytes | None = None
self._retired_connection_ids: set[bytes] = set()
self._connection_id_sequence_numbers: set[int] = set()
# Event processing control
self._event_processing_active = False
self._pending_events: list[events.QuicEvent] = []
# Performance and monitoring
self._connection_start_time = time.time()
self._stats = {
"streams_opened": 0,
"streams_accepted": 0,
"streams_closed": 0,
"streams_reset": 0,
"bytes_sent": 0,
"bytes_received": 0,
"packets_sent": 0,
"packets_received": 0,
# *** NEW: Connection ID statistics ***
"connection_ids_issued": 0,
"connection_ids_retired": 0,
"connection_id_changes": 0,
}
logger.debug(
f"Created QUIC connection to {remote_peer_id} "
f"(initiator: {is_initiator}, addr: {remote_addr}, "
"security: {security_manager is not None})"
)
def _calculate_initial_stream_id(self) -> int:
"""
Calculate the initial stream ID based on QUIC specification.
QUIC stream IDs:
- Client-initiated bidirectional: 0, 4, 8, 12, ...
- Server-initiated bidirectional: 1, 5, 9, 13, ...
- Client-initiated unidirectional: 2, 6, 10, 14, ...
- Server-initiated unidirectional: 3, 7, 11, 15, ...
For libp2p, we primarily use bidirectional streams.
"""
if self._is_initiator:
return 0 # Client starts with 0, then 4, 8, 12...
else:
return 1 # Server starts with 1, then 5, 9, 13...
# Properties
@property
def is_initiator(self) -> bool: # type: ignore
"""Check if this connection is the initiator."""
return self._is_initiator
@property
def is_closed(self) -> bool:
"""Check if connection is closed."""
return self._closed
@property
def is_established(self) -> bool:
"""Check if connection is established (handshake completed)."""
return self._established and self._handshake_completed
@property
def is_started(self) -> bool:
"""Check if connection has been started."""
return self._started
@property
def is_peer_verified(self) -> bool:
"""Check if peer identity has been verified."""
return self._peer_verified
def multiaddr(self) -> multiaddr.Multiaddr:
"""Get the multiaddr for this connection."""
return self._maddr
def local_peer_id(self) -> ID:
"""Get the local peer ID."""
return self._local_peer_id
def remote_peer_id(self) -> ID | None:
"""Get the remote peer ID."""
return self._remote_peer_id
# *** NEW: Connection ID management methods ***
def get_connection_id_stats(self) -> dict[str, Any]:
"""Get connection ID statistics and current state."""
return {
"available_connection_ids": len(self._available_connection_ids),
"current_connection_id": self._current_connection_id.hex()
if self._current_connection_id
else None,
"retired_connection_ids": len(self._retired_connection_ids),
"connection_ids_issued": self._stats["connection_ids_issued"],
"connection_ids_retired": self._stats["connection_ids_retired"],
"connection_id_changes": self._stats["connection_id_changes"],
"available_cid_list": [cid.hex() for cid in self._available_connection_ids],
}
def get_current_connection_id(self) -> bytes | None:
"""Get the current connection ID."""
return self._current_connection_id
# Connection lifecycle methods
async def start(self) -> None:
"""
Start the connection and its background tasks.
This method implements the IMuxedConn.start() interface.
It should be called to begin processing connection events.
"""
if self._started:
logger.warning("Connection already started")
return
if self._closed:
raise QUICConnectionError("Cannot start a closed connection")
self._started = True
self.event_started.set()
logger.debug(f"Starting QUIC connection to {self._remote_peer_id}")
try:
# If this is a client connection, we need to establish the connection
if self._is_initiator:
await self._initiate_connection()
else:
# For server connections, we're already connected via the listener
self._established = True
self._connected_event.set()
logger.debug(f"QUIC connection to {self._remote_peer_id} started")
except Exception as e:
logger.error(f"Failed to start connection: {e}")
raise QUICConnectionError(f"Connection start failed: {e}") from e
async def _initiate_connection(self) -> None:
"""Initiate client-side connection, reusing listener socket if available."""
try:
with QUICErrorContext("connection_initiation", "connection"):
if not self._socket:
logger.debug("Creating new socket for outbound connection")
self._socket = trio.socket.socket(
family=socket.AF_INET, type=socket.SOCK_DGRAM
)
await self._socket.bind(("0.0.0.0", 0))
self._quic.connect(self._remote_addr, now=time.time())
# Send initial packet(s)
await self._transmit()
logger.debug(f"Initiated QUIC connection to {self._remote_addr}")
except Exception as e:
logger.error(f"Failed to initiate connection: {e}")
raise QUICConnectionError(f"Connection initiation failed: {e}") from e
async def connect(self, nursery: trio.Nursery) -> None:
"""
Establish the QUIC connection using trio nursery for background tasks.
Args:
nursery: Trio nursery for managing connection background tasks
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
self._nursery = nursery
try:
with QUICErrorContext("connection_establishment", "connection"):
# Start the connection if not already started
logger.debug("STARTING TO CONNECT")
if not self._started:
await self.start()
# Start background event processing
if not self._background_tasks_started:
logger.debug("STARTING BACKGROUND TASK")
await self._start_background_tasks()
else:
logger.debug("BACKGROUND TASK ALREADY STARTED")
# Wait for handshake completion with timeout
with trio.move_on_after(
self.CONNECTION_HANDSHAKE_TIMEOUT
) as cancel_scope:
await self._connected_event.wait()
if cancel_scope.cancelled_caught:
raise QUICConnectionTimeoutError(
"Connection handshake timed out after"
f"{self.CONNECTION_HANDSHAKE_TIMEOUT}s"
)
logger.debug(
"QUICConnection: Verifying peer identity with security manager"
)
# Verify peer identity using security manager
peer_id = await self._verify_peer_identity_with_security()
if peer_id:
self.peer_id = peer_id
logger.debug(f"QUICConnection {id(self)}: Peer identity verified")
self._established = True
logger.debug(f"QUIC connection established with {self._remote_peer_id}")
except Exception as e:
logger.error(f"Failed to establish connection: {e}")
await self.close()
raise
async def _start_background_tasks(self) -> None:
"""Start background tasks for connection management."""
if self._background_tasks_started or not self._nursery:
return
self._background_tasks_started = True
if self._is_initiator:
self._nursery.start_soon(async_fn=self._client_packet_receiver)
self._nursery.start_soon(async_fn=self._event_processing_loop)
self._nursery.start_soon(async_fn=self._periodic_maintenance)
logger.debug("Started background tasks for QUIC connection")
async def _event_processing_loop(self) -> None:
"""Main event processing loop for the connection."""
logger.debug(
f"Started QUIC event processing loop for connection id: {id(self)} "
f"and local peer id {str(self.local_peer_id())}"
)
try:
while not self._closed:
# Process QUIC events
await self._process_quic_events()
# Handle timer events
await self._handle_timer_events()
# Transmit any pending data
await self._transmit()
# Short sleep to prevent busy waiting
await trio.sleep(0.01)
except Exception as e:
logger.error(f"Error in event processing loop: {e}")
await self._handle_connection_error(e)
finally:
logger.debug("QUIC event processing loop finished")
async def _periodic_maintenance(self) -> None:
"""Perform periodic connection maintenance."""
try:
while not self._closed:
# Update connection statistics
self._update_stats()
# Check for idle streams that can be cleaned up
await self._cleanup_idle_streams()
# *** NEW: Log connection ID status periodically ***
if logger.isEnabledFor(logging.DEBUG):
cid_stats = self.get_connection_id_stats()
logger.debug(f"Connection ID stats: {cid_stats}")
# Sleep for maintenance interval
await trio.sleep(30.0) # 30 seconds
except Exception as e:
logger.error(f"Error in periodic maintenance: {e}")
async def _client_packet_receiver(self) -> None:
"""Receive packets for client connections."""
logger.debug("Starting client packet receiver")
logger.debug("Started QUIC client packet receiver")
try:
while not self._closed and self._socket:
try:
# Receive UDP packets
data, addr = await self._socket.recvfrom(65536)
logger.debug(f"Client received {len(data)} bytes from {addr}")
# Feed packet to QUIC connection
self._quic.receive_datagram(data, addr, now=time.time())
# Process any events that result from the packet
await self._process_quic_events()
# Send any response packets
await self._transmit()
except trio.ClosedResourceError:
logger.debug("Client socket closed")
break
except Exception as e:
logger.error(f"Error receiving client packet: {e}")
await trio.sleep(0.01)
except trio.Cancelled:
logger.debug("Client packet receiver cancelled")
raise
finally:
logger.debug("Client packet receiver terminated")
# Security and identity methods
async def _verify_peer_identity_with_security(self) -> ID | None:
"""
Verify peer identity using integrated security manager.
Raises:
QUICPeerVerificationError: If peer verification fails
"""
logger.debug("VERIFYING PEER IDENTITY")
if not self._security_manager:
logger.debug("No security manager available for peer verification")
return None
try:
# Extract peer certificate from TLS handshake
await self._extract_peer_certificate()
if not self._peer_certificate:
logger.debug("No peer certificate available for verification")
return None
# Validate certificate format and accessibility
if not self._validate_peer_certificate():
logger.debug("Validation Failed for peer cerificate")
raise QUICPeerVerificationError("Peer certificate validation failed")
# Verify peer identity using security manager
verified_peer_id = self._security_manager.verify_peer_identity(
self._peer_certificate,
self._remote_peer_id, # Expected peer ID for outbound connections
)
# Update peer ID if it wasn't known (inbound connections)
if not self._remote_peer_id:
self._remote_peer_id = verified_peer_id
logger.debug(f"Discovered peer ID from certificate: {verified_peer_id}")
elif self._remote_peer_id != verified_peer_id:
raise QUICPeerVerificationError(
f"Peer ID mismatch: expected {self._remote_peer_id}, "
"got {verified_peer_id}"
)
self._peer_verified = True
logger.debug(f"Peer identity verified successfully: {verified_peer_id}")
return verified_peer_id
except QUICPeerVerificationError:
# Re-raise verification errors as-is
raise
except Exception as e:
# Wrap other errors in verification error
raise QUICPeerVerificationError(f"Peer verification failed: {e}") from e
async def _extract_peer_certificate(self) -> None:
"""Extract peer certificate from completed TLS handshake."""
try:
# Get peer certificate from aioquic TLS context
if self._quic.tls:
tls_context = self._quic.tls
if tls_context._peer_certificate:
# aioquic stores the peer certificate as cryptography
# x509.Certificate
self._peer_certificate = tls_context._peer_certificate
logger.debug(
f"Extracted peer certificate: {self._peer_certificate.subject}"
)
else:
logger.debug("No peer certificate found in TLS context")
else:
logger.debug("No TLS context available for certificate extraction")
except Exception as e:
logger.warning(f"Failed to extract peer certificate: {e}")
# Try alternative approach - check if certificate is in handshake events
try:
# Some versions of aioquic might expose certificate differently
config = self._quic.configuration
if hasattr(config, "certificate") and config.certificate:
# This would be the local certificate, not peer certificate
# but we can use it for debugging
logger.debug("Found local certificate in configuration")
except Exception as inner_e:
logger.error(
f"Alternative certificate extraction also failed: {inner_e}"
)
async def get_peer_certificate(self) -> x509.Certificate | None:
"""
Get the peer's TLS certificate.
Returns:
The peer's X.509 certificate, or None if not available
"""
# If we don't have a certificate yet, try to extract it
if not self._peer_certificate and self._handshake_completed:
await self._extract_peer_certificate()
return self._peer_certificate
def _validate_peer_certificate(self) -> bool:
"""
Validate that the peer certificate is properly formatted and accessible.
Returns:
True if certificate is valid and accessible, False otherwise
"""
if not self._peer_certificate:
return False
try:
# Basic validation - try to access certificate properties
subject = self._peer_certificate.subject
serial_number = self._peer_certificate.serial_number
logger.debug(
f"Certificate validation - Subject: {subject}, Serial: {serial_number}"
)
return True
except Exception as e:
logger.error(f"Certificate validation failed: {e}")
return False
def get_security_manager(self) -> Optional["QUICTLSConfigManager"]:
"""Get the security manager for this connection."""
return self._security_manager
def get_security_info(self) -> dict[str, Any]:
"""Get security-related information about the connection."""
info: dict[str, bool | Any | None] = {
"peer_verified": self._peer_verified,
"handshake_complete": self._handshake_completed,
"peer_id": str(self._remote_peer_id) if self._remote_peer_id else None,
"local_peer_id": str(self._local_peer_id),
"is_initiator": self._is_initiator,
"has_certificate": self._peer_certificate is not None,
"security_manager_available": self._security_manager is not None,
}
# Add certificate details if available
if self._peer_certificate:
try:
info.update(
{
"certificate_subject": str(self._peer_certificate.subject),
"certificate_issuer": str(self._peer_certificate.issuer),
"certificate_serial": str(self._peer_certificate.serial_number),
"certificate_not_before": (
self._peer_certificate.not_valid_before.isoformat()
),
"certificate_not_after": (
self._peer_certificate.not_valid_after.isoformat()
),
}
)
except Exception as e:
info["certificate_error"] = str(e)
# Add TLS context debug info
try:
if hasattr(self._quic, "tls") and self._quic.tls:
tls_info = {
"tls_context_available": True,
"tls_state": getattr(self._quic.tls, "state", None),
}
# Check for peer certificate in TLS context
if hasattr(self._quic.tls, "_peer_certificate"):
tls_info["tls_peer_certificate_available"] = (
self._quic.tls._peer_certificate is not None
)
info["tls_debug"] = tls_info
else:
info["tls_debug"] = {"tls_context_available": False}
except Exception as e:
info["tls_debug"] = {"error": str(e)}
return info
# Stream management methods (IMuxedConn interface)
async def open_stream(self, timeout: float = 5.0) -> QUICStream:
"""
Open a new outbound stream
Args:
timeout: Timeout for stream creation
Returns:
New QUIC stream
Raises:
QUICStreamLimitError: Too many concurrent streams
QUICConnectionClosedError: Connection is closed
QUICStreamTimeoutError: Stream creation timed out
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
if not self._started:
raise QUICConnectionError("Connection not started")
# Check stream limits
async with self._stream_count_lock:
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
f"Maximum outbound streams ({self.MAX_OUTGOING_STREAMS}) reached"
)
with trio.move_on_after(timeout):
async with self._stream_id_lock:
# Generate next stream ID
stream_id = self._next_stream_id
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.OUTBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
async with self._stream_count_lock:
self._outbound_stream_count += 1
self._stats["streams_opened"] += 1
logger.debug(f"Opened outbound QUIC stream {stream_id}")
return stream
raise QUICStreamTimeoutError(f"Stream creation timed out after {timeout}s")
async def accept_stream(self, timeout: float | None = None) -> QUICStream:
"""
Accept incoming stream.
Args:
timeout: Optional timeout. If None, waits indefinitely.
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
if timeout is not None:
with trio.move_on_after(timeout):
return await self._accept_stream_impl()
# Timeout occurred
if self._closed_event.is_set() or self._closed:
raise MuxedConnUnavailable("QUIC connection closed during timeout")
else:
raise QUICStreamTimeoutError(
f"Stream accept timed out after {timeout}s"
)
else:
# No timeout - wait indefinitely
return await self._accept_stream_impl()
async def _accept_stream_impl(self) -> QUICStream:
while True:
if self._closed:
raise MuxedConnUnavailable("QUIC connection is closed")
async with self._accept_queue_lock:
if self._stream_accept_queue:
stream = self._stream_accept_queue.pop(0)
logger.debug(f"Accepted inbound stream {stream.stream_id}")
return stream
if self._closed:
raise MuxedConnUnavailable("Connection closed while accepting stream")
# Wait for new streams indefinitely
await self._stream_accept_event.wait()
raise QUICConnectionError("Error occurred while waiting to accept stream")
def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None:
"""
Set handler for incoming streams.
Args:
handler_function: Function to handle new incoming streams
"""
self._stream_handler = handler_function
logger.debug("Set stream handler for incoming streams")
def _remove_stream(self, stream_id: int) -> None:
"""
Remove stream from connection registry.
Called by stream cleanup process.
"""
if stream_id in self._streams:
stream = self._streams.pop(stream_id)
# Update stream counts asynchronously
async def update_counts() -> None:
async with self._stream_count_lock:
if stream.direction == StreamDirection.OUTBOUND:
self._outbound_stream_count = max(
0, self._outbound_stream_count - 1
)
else:
self._inbound_stream_count = max(
0, self._inbound_stream_count - 1
)
self._stats["streams_closed"] += 1
# Schedule count update if we're in a trio context
if self._nursery:
self._nursery.start_soon(update_counts)
logger.debug(f"Removed stream {stream_id} from connection")
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
if self._event_processing_active:
return # Prevent recursion
self._event_processing_active = True
try:
events_processed = 0
while True:
event = self._quic.next_event()
if event is None:
break
events_processed += 1
await self._handle_quic_event(event)
if events_processed > 0:
logger.debug(f"Processed {events_processed} QUIC events")
finally:
self._event_processing_active = False
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
"""Handle a single QUIC event with COMPLETE event type coverage."""
logger.debug(f"Handling QUIC event: {type(event).__name__}")
logger.debug(f"QUIC event: {type(event).__name__}")
try:
if isinstance(event, events.ConnectionTerminated):
await self._handle_connection_terminated(event)
elif isinstance(event, events.HandshakeCompleted):
await self._handle_handshake_completed(event)
elif isinstance(event, events.StreamDataReceived):
await self._handle_stream_data(event)
elif isinstance(event, events.StreamReset):
await self._handle_stream_reset(event)
elif isinstance(event, events.DatagramFrameReceived):
await self._handle_datagram_received(event)
# *** NEW: Connection ID event handlers - CRITICAL FIX ***
elif isinstance(event, events.ConnectionIdIssued):
await self._handle_connection_id_issued(event)
elif isinstance(event, events.ConnectionIdRetired):
await self._handle_connection_id_retired(event)
# *** NEW: Additional event handlers for completeness ***
elif isinstance(event, events.PingAcknowledged):
await self._handle_ping_acknowledged(event)
elif isinstance(event, events.ProtocolNegotiated):
await self._handle_protocol_negotiated(event)
elif isinstance(event, events.StopSendingReceived):
await self._handle_stop_sending_received(event)
else:
logger.debug(f"Unhandled QUIC event type: {type(event).__name__}")
logger.debug(f"Unhandled QUIC event: {type(event).__name__}")
except Exception as e:
logger.error(f"Error handling QUIC event {type(event).__name__}: {e}")
async def _handle_connection_id_issued(
self, event: events.ConnectionIdIssued
) -> None:
"""
Handle new connection ID issued by peer.
This is the CRITICAL missing functionality that was causing your issue!
"""
logger.debug(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
logger.debug(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
# Add to available connection IDs
self._available_connection_ids.add(event.connection_id)
# If we don't have a current connection ID, use this one
if self._current_connection_id is None:
self._current_connection_id = event.connection_id
logger.debug(
f"🆔 Set current connection ID to: {event.connection_id.hex()}"
)
logger.debug(
f"🆔 Set current connection ID to: {event.connection_id.hex()}"
)
# Update statistics
self._stats["connection_ids_issued"] += 1
logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}")
logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}")
async def _handle_connection_id_retired(
self, event: events.ConnectionIdRetired
) -> None:
"""
Handle connection ID retirement.
This handles when the peer tells us to stop using a connection ID.
"""
logger.debug(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
logger.debug(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
# Remove from available IDs and add to retired set
self._available_connection_ids.discard(event.connection_id)
self._retired_connection_ids.add(event.connection_id)
# If this was our current connection ID, switch to another
if self._current_connection_id == event.connection_id:
if self._available_connection_ids:
self._current_connection_id = next(iter(self._available_connection_ids))
if self._current_connection_id:
logger.debug(
"Switching to new connection ID: "
f"{self._current_connection_id.hex()}"
)
self._stats["connection_id_changes"] += 1
else:
logger.warning("⚠️ No available connection IDs after retirement!")
logger.debug("⚠️ No available connection IDs after retirement!")
else:
self._current_connection_id = None
logger.warning("⚠️ No available connection IDs after retirement!")
logger.debug("⚠️ No available connection IDs after retirement!")
# Update statistics
self._stats["connection_ids_retired"] += 1
async def _handle_ping_acknowledged(self, event: events.PingAcknowledged) -> None:
"""Handle ping acknowledgment."""
logger.debug(f"Ping acknowledged: uid={event.uid}")
async def _handle_protocol_negotiated(
self, event: events.ProtocolNegotiated
) -> None:
"""Handle protocol negotiation completion."""
logger.debug(f"Protocol negotiated: {event.alpn_protocol}")
async def _handle_stop_sending_received(
self, event: events.StopSendingReceived
) -> None:
"""Handle stop sending request from peer."""
logger.debug(
"Stop sending received: "
f"stream_id={event.stream_id}, error_code={event.error_code}"
)
if event.stream_id in self._streams:
stream: QUICStream = self._streams[event.stream_id]
# Handle stop sending on the stream if method exists
await stream.handle_stop_sending(event.error_code)
async def _handle_handshake_completed(
self, event: events.HandshakeCompleted
) -> None:
"""Handle handshake completion with security integration."""
logger.debug("QUIC handshake completed")
self._handshake_completed = True
# Store handshake event for security verification
self._handshake_events.append(event)
# Try to extract certificate information after handshake
await self._extract_peer_certificate()
logger.debug("✅ Setting connected event")
self._connected_event.set()
async def _handle_connection_terminated(
self, event: events.ConnectionTerminated
) -> None:
"""Handle connection termination."""
logger.debug(f"QUIC connection terminated: {event.reason_phrase}")
# Close all streams
for stream in list(self._streams.values()):
if event.error_code:
await stream.handle_reset(event.error_code)
else:
await stream.close()
self._streams.clear()
self._closed = True
self._closed_event.set()
self._stream_accept_event.set()
logger.debug(f"Woke up pending accept_stream() calls, {id(self)}")
await self._notify_parent_of_termination()
async def _handle_stream_data(self, event: events.StreamDataReceived) -> None:
"""Handle stream data events - create streams and add to accept queue."""
stream_id = event.stream_id
self._stats["bytes_received"] += len(event.data)
try:
if stream_id not in self._streams:
if self._is_incoming_stream(stream_id):
logger.debug(f"Creating new incoming stream {stream_id}")
from .stream import QUICStream, StreamDirection
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
# Store the stream
self._streams[stream_id] = stream
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Added stream {stream_id} to accept queue")
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_opened"] += 1
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
return
stream = self._streams[stream_id]
await stream.handle_data_received(event.data, event.end_stream)
except Exception as e:
logger.error(f"Error handling stream data for stream {stream_id}: {e}")
logger.debug(f"❌ STREAM_DATA: Error: {e}")
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
"""Get existing stream or create new inbound stream."""
if stream_id in self._streams:
return self._streams[stream_id]
# Check if this is an incoming stream
is_incoming = self._is_incoming_stream(stream_id)
if not is_incoming:
# This shouldn't happen - outbound streams should be created by open_stream
raise QUICStreamError(
f"Received data for unknown outbound stream {stream_id}"
)
# Check stream limits for incoming streams
async with self._stream_count_lock:
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting incoming stream {stream_id}: limit reached")
# Send reset to reject the stream
self._quic.reset_stream(
stream_id, error_code=0x04
) # STREAM_LIMIT_ERROR
await self._transmit()
raise QUICStreamLimitError("Too many inbound streams")
# Create new inbound stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue and notify handler
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
# Handle directly with stream handler if available
if self._stream_handler:
try:
if self._nursery:
self._nursery.start_soon(self._stream_handler, stream)
else:
await self._stream_handler(stream)
except Exception as e:
logger.error(f"Error in stream handler for stream {stream_id}: {e}")
logger.debug(f"Created inbound stream {stream_id}")
return stream
def _is_incoming_stream(self, stream_id: int) -> bool:
"""
Determine if a stream ID represents an incoming stream.
For bidirectional streams:
- Even IDs are client-initiated
- Odd IDs are server-initiated
"""
if self._is_initiator:
# We're the client, so odd stream IDs are incoming
return stream_id % 2 == 1
else:
# We're the server, so even stream IDs are incoming
return stream_id % 2 == 0
async def _handle_stream_reset(self, event: events.StreamReset) -> None:
"""Stream reset handling."""
stream_id = event.stream_id
self._stats["streams_reset"] += 1
if stream_id in self._streams:
try:
stream = self._streams[stream_id]
await stream.handle_reset(event.error_code)
logger.debug(
f"Handled reset for stream {stream_id}"
f"with error code {event.error_code}"
)
except Exception as e:
logger.error(f"Error handling stream reset for {stream_id}: {e}")
# Force remove the stream
self._remove_stream(stream_id)
else:
logger.debug(f"Received reset for unknown stream {stream_id}")
async def _handle_datagram_received(
self, event: events.DatagramFrameReceived
) -> None:
"""Handle datagram frame (if using QUIC datagrams)."""
logger.debug(f"Datagram frame received: size={len(event.data)}")
# For now, just log. Could be extended for custom datagram handling
async def _handle_timer_events(self) -> None:
"""Handle QUIC timer events."""
timer = self._quic.get_timer()
if timer is not None:
now = time.time()
if timer <= now:
self._quic.handle_timer(now=now)
# Network transmission
async def _transmit(self) -> None:
"""Transmit pending QUIC packets using available socket."""
sock = self._socket
if not sock:
logger.debug("No socket to transmit")
return
try:
current_time = time.time()
datagrams = self._quic.datagrams_to_send(now=current_time)
for data, addr in datagrams:
await sock.sendto(data, addr)
# Update stats if available
if hasattr(self, "_stats"):
self._stats["packets_sent"] += 1
self._stats["bytes_sent"] += len(data)
except Exception as e:
logger.error(f"Transmission error: {e}")
await self._handle_connection_error(e)
# Additional methods for stream data processing
async def _process_quic_event(self, event: events.QuicEvent) -> None:
"""Process a single QUIC event."""
await self._handle_quic_event(event)
async def _transmit_pending_data(self) -> None:
"""Transmit any pending data."""
await self._transmit()
# Error handling
async def _handle_connection_error(self, error: Exception) -> None:
"""Handle connection-level errors."""
logger.error(f"Connection error: {error}")
if not self._closed:
try:
await self.close()
except Exception as close_error:
logger.error(f"Error during connection close: {close_error}")
# Connection close
async def close(self) -> None:
"""Connection close with proper stream cleanup."""
if self._closed:
return
self._closed = True
logger.debug(f"Closing QUIC connection to {self._remote_peer_id}")
try:
# Close all streams gracefully
stream_close_tasks = []
for stream in list(self._streams.values()):
if stream.can_write() or stream.can_read():
stream_close_tasks.append(stream.close)
if stream_close_tasks and self._nursery:
try:
# Close streams concurrently with timeout
with trio.move_on_after(self.CONNECTION_CLOSE_TIMEOUT):
async with trio.open_nursery() as close_nursery:
for task in stream_close_tasks:
close_nursery.start_soon(task)
except Exception as e:
logger.warning(f"Error during graceful stream close: {e}")
# Force reset remaining streams
for stream in self._streams.values():
try:
await stream.reset(error_code=0)
except Exception:
pass
if self.on_close:
await self.on_close()
# Close QUIC connection
self._quic.close()
if self._socket:
await self._transmit() # Send close frames
# Close socket
if self._socket and self._owns_socket:
self._socket.close()
self._socket = None
self._streams.clear()
self._closed_event.set()
logger.debug(f"QUIC connection to {self._remote_peer_id} closed")
except Exception as e:
logger.error(f"Error during connection close: {e}")
async def _notify_parent_of_termination(self) -> None:
"""
Notify the parent listener/transport to remove this connection from tracking.
This ensures that terminated connections are cleaned up from the
'established connections' list.
"""
try:
if self._transport:
await self._transport._cleanup_terminated_connection(self)
logger.debug("Notified transport of connection termination")
return
for listener in self._transport._listeners:
try:
await listener._remove_connection_by_object(self)
logger.debug(
"Found and notified listener of connection termination"
)
return
except Exception:
continue
# Method 4: Use connection ID if we have one (most reliable)
if self._current_connection_id:
await self._cleanup_by_connection_id(self._current_connection_id)
return
logger.warning(
"Could not notify parent of connection termination - no"
f" parent reference found for conn host {self._quic.host_cid.hex()}"
)
except Exception as e:
logger.error(f"Error notifying parent of connection termination: {e}")
async def _cleanup_by_connection_id(self, connection_id: bytes) -> None:
"""Cleanup using connection ID as a fallback method."""
try:
for listener in self._transport._listeners:
for tracked_cid, tracked_conn in list(listener._connections.items()):
if tracked_conn is self:
await listener._remove_connection(tracked_cid)
logger.debug(f"Removed connection {tracked_cid.hex()}")
return
logger.debug("Fallback cleanup by connection ID completed")
except Exception as e:
logger.error(f"Error in fallback cleanup: {e}")
# IRawConnection interface (for compatibility)
def get_remote_address(self) -> tuple[str, int]:
return self._remote_addr
async def write(self, data: bytes) -> None:
"""
Write data to the connection.
For QUIC, this creates a new stream for each write operation.
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
stream = await self.open_stream()
try:
await stream.write(data)
await stream.close_write()
except Exception:
await stream.reset()
raise
async def read(self, n: int | None = -1) -> bytes:
"""
Read data from the stream.
Args:
n: Maximum number of bytes to read. -1 means read all available.
Returns:
Data bytes read from the stream.
Raises:
QUICStreamClosedError: If stream is closed for reading.
QUICStreamResetError: If stream was reset.
QUICStreamTimeoutError: If read timeout occurs.
"""
# It's here for interface compatibility but should not be used
raise NotImplementedError(
"Use streams for reading data from QUIC connections. "
"Call accept_stream() or open_stream() instead."
)
# Utility and monitoring methods
def get_stream_stats(self) -> dict[str, Any]:
"""Get stream statistics for monitoring."""
return {
"total_streams": len(self._streams),
"outbound_streams": self._outbound_stream_count,
"inbound_streams": self._inbound_stream_count,
"max_streams": self.MAX_CONCURRENT_STREAMS,
"stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS,
"stats": self._stats.copy(),
}
def get_active_streams(self) -> list[QUICStream]:
"""Get list of active streams."""
return [stream for stream in self._streams.values() if not stream.is_closed()]
def get_streams_by_protocol(self, protocol: str) -> list[QUICStream]:
"""Get streams filtered by protocol."""
return [
stream
for stream in self._streams.values()
if hasattr(stream, "protocol")
and stream.protocol == protocol
and not stream.is_closed()
]
def _update_stats(self) -> None:
"""Update connection statistics."""
# Add any periodic stats updates here
pass
async def _cleanup_idle_streams(self) -> None:
"""Clean up idle streams that are no longer needed."""
current_time = time.time()
streams_to_cleanup = []
for stream in self._streams.values():
if stream.is_closed():
# Check if stream has been closed for a while
if hasattr(stream, "_timeline") and stream._timeline.closed_at:
if current_time - stream._timeline.closed_at > 60: # 1 minute
streams_to_cleanup.append(stream.stream_id)
for stream_id in streams_to_cleanup:
self._remove_stream(int(stream_id))
# String representation
def __repr__(self) -> str:
current_cid: str | None = (
self._current_connection_id.hex() if self._current_connection_id else None
)
return (
f"QUICConnection(peer={self._remote_peer_id}, "
f"addr={self._remote_addr}, "
f"initiator={self._is_initiator}, "
f"verified={self._peer_verified}, "
f"established={self._established}, "
f"streams={len(self._streams)}, "
f"current_cid={current_cid})"
)
def __str__(self) -> str:
return f"QUICConnection({self._remote_peer_id})"