Files
py-libp2p/libp2p/transport/quic/connection.py

1488 lines
55 KiB
Python

"""
QUIC Connection implementation.
Manages bidirectional QUIC connections with integrated stream multiplexing.
"""
from collections import defaultdict
from collections.abc import Awaitable, Callable
import logging
import socket
import time
from typing import TYPE_CHECKING, Any, Optional, cast
from aioquic.quic import events
from aioquic.quic.connection import QuicConnection
from aioquic.quic.events import QuicEvent
from cryptography import x509
import multiaddr
import trio
from libp2p.abc import IMuxedConn, IRawConnection
from libp2p.custom_types import TQUICStreamHandlerFn
from libp2p.peer.id import ID
from libp2p.stream_muxer.exceptions import MuxedConnUnavailable
from .exceptions import (
QUICConnectionClosedError,
QUICConnectionError,
QUICConnectionTimeoutError,
QUICErrorContext,
QUICPeerVerificationError,
QUICStreamError,
QUICStreamLimitError,
QUICStreamTimeoutError,
)
from .stream import QUICStream, StreamDirection
if TYPE_CHECKING:
from .security import QUICTLSConfigManager
from .transport import QUICTransport
logger = logging.getLogger(__name__)
class QUICConnection(IRawConnection, IMuxedConn):
"""
QUIC connection implementing both raw connection and muxed connection interfaces.
Uses aioquic's sans-IO core with trio for native async support.
QUIC natively provides stream multiplexing, so this connection acts as both
a raw connection (for transport layer) and muxed connection (for upper layers).
Features:
- Native QUIC stream multiplexing
- Integrated libp2p TLS security with peer identity verification
- Resource-aware stream management
- Comprehensive error handling
- Flow control integration
- Connection migration support
- Performance monitoring
- COMPLETE connection ID management (fixes the original issue)
"""
def __init__(
self,
quic_connection: QuicConnection,
remote_addr: tuple[str, int],
remote_peer_id: ID | None,
local_peer_id: ID,
is_initiator: bool,
maddr: multiaddr.Multiaddr,
transport: "QUICTransport",
security_manager: Optional["QUICTLSConfigManager"] = None,
resource_scope: Any | None = None,
listener_socket: trio.socket.SocketType | None = None,
):
"""
Initialize QUIC connection with security integration.
Args:
quic_connection: aioquic QuicConnection instance
remote_addr: Remote peer address
remote_peer_id: Remote peer ID (may be None initially)
local_peer_id: Local peer ID
is_initiator: Whether this is the connection initiator
maddr: Multiaddr for this connection
transport: Parent QUIC transport
security_manager: Security manager for TLS/certificate handling
resource_scope: Resource manager scope for tracking
listener_socket: Socket of listener to transmit data
"""
self._quic = quic_connection
self._remote_addr = remote_addr
self._remote_peer_id = remote_peer_id
self._local_peer_id = local_peer_id
self.peer_id = remote_peer_id or local_peer_id
self._is_initiator = is_initiator
self._maddr = maddr
self._transport = transport
self._security_manager = security_manager
self._resource_scope = resource_scope
# Trio networking - socket may be provided by listener
self._socket = listener_socket if listener_socket else None
self._owns_socket = listener_socket is None
self._connected_event = trio.Event()
self._closed_event = trio.Event()
self._streams: dict[int, QUICStream] = {}
self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups
self._next_stream_id: int = self._calculate_initial_stream_id()
self._stream_handler: TQUICStreamHandlerFn | None = None
# Single lock for all stream operations
self._stream_lock = trio.Lock()
# Stream counting and limits
self._outbound_stream_count = 0
self._inbound_stream_count = 0
# Stream acceptance for incoming streams
self._stream_accept_queue: list[QUICStream] = []
self._stream_accept_event = trio.Event()
# Connection state
self._closed: bool = False
self._established = False
self._started = False
self._handshake_completed = False
self._peer_verified = False
# Security state
self._peer_certificate: x509.Certificate | None = None
self._handshake_events: list[events.HandshakeCompleted] = []
# Background task management
self._background_tasks_started = False
self._nursery: trio.Nursery | None = None
self._event_processing_task: Any | None = None
self.on_close: Callable[[], Awaitable[None]] | None = None
self.event_started = trio.Event()
self._available_connection_ids: set[bytes] = set()
self._current_connection_id: bytes | None = None
self._retired_connection_ids: set[bytes] = set()
self._connection_id_sequence_numbers: set[int] = set()
# Event processing control with batching
self._event_processing_active = False
self._event_batch: list[events.QuicEvent] = []
self._event_batch_size = 10
self._last_event_time = 0.0
# Set quic connection configuration
self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT
self.MAX_INCOMING_STREAMS = transport._config.MAX_INCOMING_STREAMS
self.MAX_OUTGOING_STREAMS = transport._config.MAX_OUTGOING_STREAMS
self.CONNECTION_HANDSHAKE_TIMEOUT = (
transport._config.CONNECTION_HANDSHAKE_TIMEOUT
)
self.MAX_CONCURRENT_STREAMS = transport._config.MAX_CONCURRENT_STREAMS
# Performance and monitoring
self._connection_start_time = time.time()
self._stats = {
"streams_opened": 0,
"streams_accepted": 0,
"streams_closed": 0,
"streams_reset": 0,
"bytes_sent": 0,
"bytes_received": 0,
"packets_sent": 0,
"packets_received": 0,
"connection_ids_issued": 0,
"connection_ids_retired": 0,
"connection_id_changes": 0,
}
logger.debug(
f"Created QUIC connection to {remote_peer_id} "
f"(initiator: {is_initiator}, addr: {remote_addr}, "
"security: {security_manager is not None})"
)
def _calculate_initial_stream_id(self) -> int:
"""
Calculate the initial stream ID based on QUIC specification.
QUIC stream IDs:
- Client-initiated bidirectional: 0, 4, 8, 12, ...
- Server-initiated bidirectional: 1, 5, 9, 13, ...
- Client-initiated unidirectional: 2, 6, 10, 14, ...
- Server-initiated unidirectional: 3, 7, 11, 15, ...
For libp2p, we primarily use bidirectional streams.
"""
if self._is_initiator:
return 0
else:
return 1
@property
def is_initiator(self) -> bool: # type: ignore
"""Check if this connection is the initiator."""
return self._is_initiator
@property
def is_closed(self) -> bool:
"""Check if connection is closed."""
return self._closed
@property
def is_established(self) -> bool:
"""Check if connection is established (handshake completed)."""
return self._established and self._handshake_completed
@property
def is_started(self) -> bool:
"""Check if connection has been started."""
return self._started
@property
def is_peer_verified(self) -> bool:
"""Check if peer identity has been verified."""
return self._peer_verified
def multiaddr(self) -> multiaddr.Multiaddr:
"""Get the multiaddr for this connection."""
return self._maddr
def local_peer_id(self) -> ID:
"""Get the local peer ID."""
return self._local_peer_id
def remote_peer_id(self) -> ID | None:
"""Get the remote peer ID."""
return self._remote_peer_id
def get_connection_id_stats(self) -> dict[str, Any]:
"""Get connection ID statistics and current state."""
return {
"available_connection_ids": len(self._available_connection_ids),
"current_connection_id": self._current_connection_id.hex()
if self._current_connection_id
else None,
"retired_connection_ids": len(self._retired_connection_ids),
"connection_ids_issued": self._stats["connection_ids_issued"],
"connection_ids_retired": self._stats["connection_ids_retired"],
"connection_id_changes": self._stats["connection_id_changes"],
"available_cid_list": [cid.hex() for cid in self._available_connection_ids],
}
def get_current_connection_id(self) -> bytes | None:
"""Get the current connection ID."""
return self._current_connection_id
# Fast stream lookup with caching
def _get_stream_fast(self, stream_id: int) -> QUICStream | None:
"""Get stream with caching for performance."""
# Try cache first
stream = self._stream_cache.get(stream_id)
if stream is not None:
return stream
# Fallback to main dict
stream = self._streams.get(stream_id)
if stream is not None:
self._stream_cache[stream_id] = stream
return stream
# Connection lifecycle methods
async def start(self) -> None:
"""
Start the connection and its background tasks.
This method implements the IMuxedConn.start() interface.
It should be called to begin processing connection events.
"""
if self._started:
logger.warning("Connection already started")
return
if self._closed:
raise QUICConnectionError("Cannot start a closed connection")
self._started = True
self.event_started.set()
logger.debug(f"Starting QUIC connection to {self._remote_peer_id}")
try:
# If this is a client connection, we need to establish the connection
if self._is_initiator:
await self._initiate_connection()
else:
# For server connections, we're already connected via the listener
self._established = True
self._connected_event.set()
logger.debug(f"QUIC connection to {self._remote_peer_id} started")
except Exception as e:
logger.error(f"Failed to start connection: {e}")
raise QUICConnectionError(f"Connection start failed: {e}") from e
async def _initiate_connection(self) -> None:
"""Initiate client-side connection, reusing listener socket if available."""
try:
with QUICErrorContext("connection_initiation", "connection"):
if not self._socket:
logger.debug("Creating new socket for outbound connection")
self._socket = trio.socket.socket(
family=socket.AF_INET, type=socket.SOCK_DGRAM
)
await self._socket.bind(("0.0.0.0", 0))
self._quic.connect(self._remote_addr, now=time.time())
# Send initial packet(s)
await self._transmit()
logger.debug(f"Initiated QUIC connection to {self._remote_addr}")
except Exception as e:
logger.error(f"Failed to initiate connection: {e}")
raise QUICConnectionError(f"Connection initiation failed: {e}") from e
async def connect(self, nursery: trio.Nursery) -> None:
"""
Establish the QUIC connection using trio nursery for background tasks.
Args:
nursery: Trio nursery for managing connection background tasks
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
self._nursery = nursery
try:
with QUICErrorContext("connection_establishment", "connection"):
# Start the connection if not already started
logger.debug("STARTING TO CONNECT")
if not self._started:
await self.start()
# Start background event processing
if not self._background_tasks_started:
logger.debug("STARTING BACKGROUND TASK")
await self._start_background_tasks()
else:
logger.debug("BACKGROUND TASK ALREADY STARTED")
# Wait for handshake completion with timeout
with trio.move_on_after(
self.CONNECTION_HANDSHAKE_TIMEOUT
) as cancel_scope:
await self._connected_event.wait()
if cancel_scope.cancelled_caught:
raise QUICConnectionTimeoutError(
"Connection handshake timed out after"
f"{self.CONNECTION_HANDSHAKE_TIMEOUT}s"
)
logger.debug(
"QUICConnection: Verifying peer identity with security manager"
)
# Verify peer identity using security manager
peer_id = await self._verify_peer_identity_with_security()
if peer_id:
self.peer_id = peer_id
logger.debug(f"QUICConnection {id(self)}: Peer identity verified")
self._established = True
logger.debug(f"QUIC connection established with {self._remote_peer_id}")
except Exception as e:
logger.error(f"Failed to establish connection: {e}")
await self.close()
raise
async def _start_background_tasks(self) -> None:
"""Start background tasks for connection management."""
if self._background_tasks_started or not self._nursery:
return
self._background_tasks_started = True
if self._is_initiator:
self._nursery.start_soon(async_fn=self._client_packet_receiver)
self._nursery.start_soon(async_fn=self._event_processing_loop)
self._nursery.start_soon(async_fn=self._periodic_maintenance)
logger.debug("Started background tasks for QUIC connection")
async def _event_processing_loop(self) -> None:
"""Main event processing loop for the connection."""
logger.debug(
f"Started QUIC event processing loop for connection id: {id(self)} "
f"and local peer id {str(self.local_peer_id())}"
)
try:
while not self._closed:
# Batch process events
await self._process_quic_events_batched()
# Handle timer events
await self._handle_timer_events()
# Transmit any pending data
await self._transmit()
# Short sleep to prevent busy waiting
await trio.sleep(0.01)
except Exception as e:
logger.error(f"Error in event processing loop: {e}")
await self._handle_connection_error(e)
finally:
logger.debug("QUIC event processing loop finished")
async def _periodic_maintenance(self) -> None:
"""Perform periodic connection maintenance."""
try:
while not self._closed:
# Update connection statistics
self._update_stats()
# Check for idle streams that can be cleaned up
await self._cleanup_idle_streams()
if logger.isEnabledFor(logging.DEBUG):
cid_stats = self.get_connection_id_stats()
logger.debug(f"Connection ID stats: {cid_stats}")
# Clean cache periodically
await self._cleanup_cache()
# Sleep for maintenance interval
await trio.sleep(30.0) # 30 seconds
except Exception as e:
logger.error(f"Error in periodic maintenance: {e}")
async def _cleanup_cache(self) -> None:
"""Clean up stream cache periodically to prevent memory leaks."""
if len(self._stream_cache) > 100: # Arbitrary threshold
# Remove closed streams from cache
closed_stream_ids = [
sid for sid, stream in self._stream_cache.items() if stream.is_closed()
]
for sid in closed_stream_ids:
self._stream_cache.pop(sid, None)
async def _client_packet_receiver(self) -> None:
"""Receive packets for client connections."""
logger.debug("Starting client packet receiver")
logger.debug("Started QUIC client packet receiver")
try:
while not self._closed and self._socket:
try:
# Receive UDP packets
data, addr = await self._socket.recvfrom(65536)
logger.debug(f"Client received {len(data)} bytes from {addr}")
# Feed packet to QUIC connection
self._quic.receive_datagram(data, addr, now=time.time())
# Batch process events
await self._process_quic_events_batched()
# Send any response packets
await self._transmit()
except trio.ClosedResourceError:
logger.debug("Client socket closed")
break
except Exception as e:
logger.error(f"Error receiving client packet: {e}")
await trio.sleep(0.01)
except trio.Cancelled:
logger.debug("Client packet receiver cancelled")
raise
finally:
logger.debug("Client packet receiver terminated")
# Security and identity methods
async def _verify_peer_identity_with_security(self) -> ID | None:
"""
Verify peer identity using integrated security manager.
Raises:
QUICPeerVerificationError: If peer verification fails
"""
logger.debug("VERIFYING PEER IDENTITY")
if not self._security_manager:
logger.debug("No security manager available for peer verification")
return None
try:
# Extract peer certificate from TLS handshake
await self._extract_peer_certificate()
if not self._peer_certificate:
logger.debug("No peer certificate available for verification")
return None
# Validate certificate format and accessibility
if not self._validate_peer_certificate():
logger.debug("Validation Failed for peer cerificate")
raise QUICPeerVerificationError("Peer certificate validation failed")
# Verify peer identity using security manager
verified_peer_id = self._security_manager.verify_peer_identity(
self._peer_certificate,
self._remote_peer_id, # Expected peer ID for outbound connections
)
# Update peer ID if it wasn't known (inbound connections)
if not self._remote_peer_id:
self._remote_peer_id = verified_peer_id
logger.debug(f"Discovered peer ID from certificate: {verified_peer_id}")
elif self._remote_peer_id != verified_peer_id:
raise QUICPeerVerificationError(
f"Peer ID mismatch: expected {self._remote_peer_id}, "
"got {verified_peer_id}"
)
self._peer_verified = True
logger.debug(f"Peer identity verified successfully: {verified_peer_id}")
return verified_peer_id
except QUICPeerVerificationError:
# Re-raise verification errors as-is
raise
except Exception as e:
# Wrap other errors in verification error
raise QUICPeerVerificationError(f"Peer verification failed: {e}") from e
async def _extract_peer_certificate(self) -> None:
"""Extract peer certificate from completed TLS handshake."""
try:
# Get peer certificate from aioquic TLS context
if self._quic.tls:
tls_context = self._quic.tls
if tls_context._peer_certificate:
# aioquic stores the peer certificate as cryptography
# x509.Certificate
self._peer_certificate = tls_context._peer_certificate
logger.debug(
f"Extracted peer certificate: {self._peer_certificate.subject}"
)
else:
logger.debug("No peer certificate found in TLS context")
else:
logger.debug("No TLS context available for certificate extraction")
except Exception as e:
logger.warning(f"Failed to extract peer certificate: {e}")
# Try alternative approach - check if certificate is in handshake events
try:
# Some versions of aioquic might expose certificate differently
config = self._quic.configuration
if hasattr(config, "certificate") and config.certificate:
# This would be the local certificate, not peer certificate
# but we can use it for debugging
logger.debug("Found local certificate in configuration")
except Exception as inner_e:
logger.error(
f"Alternative certificate extraction also failed: {inner_e}"
)
async def get_peer_certificate(self) -> x509.Certificate | None:
"""
Get the peer's TLS certificate.
Returns:
The peer's X.509 certificate, or None if not available
"""
# If we don't have a certificate yet, try to extract it
if not self._peer_certificate and self._handshake_completed:
await self._extract_peer_certificate()
return self._peer_certificate
def _validate_peer_certificate(self) -> bool:
"""
Validate that the peer certificate is properly formatted and accessible.
Returns:
True if certificate is valid and accessible, False otherwise
"""
if not self._peer_certificate:
return False
try:
# Basic validation - try to access certificate properties
subject = self._peer_certificate.subject
serial_number = self._peer_certificate.serial_number
logger.debug(
f"Certificate validation - Subject: {subject}, Serial: {serial_number}"
)
return True
except Exception as e:
logger.error(f"Certificate validation failed: {e}")
return False
def get_security_manager(self) -> Optional["QUICTLSConfigManager"]:
"""Get the security manager for this connection."""
return self._security_manager
def get_security_info(self) -> dict[str, Any]:
"""Get security-related information about the connection."""
info: dict[str, bool | Any | None] = {
"peer_verified": self._peer_verified,
"handshake_complete": self._handshake_completed,
"peer_id": str(self._remote_peer_id) if self._remote_peer_id else None,
"local_peer_id": str(self._local_peer_id),
"is_initiator": self._is_initiator,
"has_certificate": self._peer_certificate is not None,
"security_manager_available": self._security_manager is not None,
}
# Add certificate details if available
if self._peer_certificate:
try:
info.update(
{
"certificate_subject": str(self._peer_certificate.subject),
"certificate_issuer": str(self._peer_certificate.issuer),
"certificate_serial": str(self._peer_certificate.serial_number),
"certificate_not_before": (
self._peer_certificate.not_valid_before.isoformat()
),
"certificate_not_after": (
self._peer_certificate.not_valid_after.isoformat()
),
}
)
except Exception as e:
info["certificate_error"] = str(e)
# Add TLS context debug info
try:
if hasattr(self._quic, "tls") and self._quic.tls:
tls_info = {
"tls_context_available": True,
"tls_state": getattr(self._quic.tls, "state", None),
}
# Check for peer certificate in TLS context
if hasattr(self._quic.tls, "_peer_certificate"):
tls_info["tls_peer_certificate_available"] = (
self._quic.tls._peer_certificate is not None
)
info["tls_debug"] = tls_info
else:
info["tls_debug"] = {"tls_context_available": False}
except Exception as e:
info["tls_debug"] = {"error": str(e)}
return info
# Stream management methods (IMuxedConn interface)
async def open_stream(self, timeout: float = 5.0) -> QUICStream:
"""
Open a new outbound stream
Args:
timeout: Timeout for stream creation
Returns:
New QUIC stream
Raises:
QUICStreamLimitError: Too many concurrent streams
QUICConnectionClosedError: Connection is closed
QUICStreamTimeoutError: Stream creation timed out
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
if not self._started:
raise QUICConnectionError("Connection not started")
# Use single lock for all stream operations
with trio.move_on_after(timeout):
async with self._stream_lock:
# Check stream limits inside lock
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
raise QUICStreamLimitError(
"Maximum outbound streams "
f"({self.MAX_OUTGOING_STREAMS}) reached"
)
# Generate next stream ID
stream_id = self._next_stream_id
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.OUTBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
self._outbound_stream_count += 1
self._stats["streams_opened"] += 1
logger.debug(f"Opened outbound QUIC stream {stream_id}")
return stream
raise QUICStreamTimeoutError(f"Stream creation timed out after {timeout}s")
async def accept_stream(self, timeout: float | None = None) -> QUICStream:
"""
Accept incoming stream.
Args:
timeout: Optional timeout. If None, waits indefinitely.
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
if timeout is not None:
with trio.move_on_after(timeout):
return await self._accept_stream_impl()
# Timeout occurred
if self._closed_event.is_set() or self._closed:
raise MuxedConnUnavailable("QUIC connection closed during timeout")
else:
raise QUICStreamTimeoutError(
f"Stream accept timed out after {timeout}s"
)
else:
# No timeout - wait indefinitely
return await self._accept_stream_impl()
async def _accept_stream_impl(self) -> QUICStream:
while True:
if self._closed:
raise MuxedConnUnavailable("QUIC connection is closed")
# Use single lock for stream acceptance
async with self._stream_lock:
if self._stream_accept_queue:
stream = self._stream_accept_queue.pop(0)
logger.debug(f"Accepted inbound stream {stream.stream_id}")
return stream
if self._closed:
raise MuxedConnUnavailable("Connection closed while accepting stream")
# Wait for new streams indefinitely
await self._stream_accept_event.wait()
raise QUICConnectionError("Error occurred while waiting to accept stream")
def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None:
"""
Set handler for incoming streams.
Args:
handler_function: Function to handle new incoming streams
"""
self._stream_handler = handler_function
logger.debug("Set stream handler for incoming streams")
def _remove_stream(self, stream_id: int) -> None:
"""
Remove stream from connection registry.
Called by stream cleanup process.
"""
if stream_id in self._streams:
stream = self._streams.pop(stream_id)
# Remove from cache too
self._stream_cache.pop(stream_id, None)
# Update stream counts asynchronously
async def update_counts() -> None:
async with self._stream_lock:
if stream.direction == StreamDirection.OUTBOUND:
self._outbound_stream_count = max(
0, self._outbound_stream_count - 1
)
else:
self._inbound_stream_count = max(
0, self._inbound_stream_count - 1
)
self._stats["streams_closed"] += 1
# Schedule count update if we're in a trio context
if self._nursery:
self._nursery.start_soon(update_counts)
logger.debug(f"Removed stream {stream_id} from connection")
# Batched event processing to reduce overhead
async def _process_quic_events_batched(self) -> None:
"""Process QUIC events in batches for better performance."""
if self._event_processing_active:
return # Prevent recursion
self._event_processing_active = True
try:
current_time = time.time()
events_processed = 0
# Collect events into batch
while events_processed < self._event_batch_size:
event = self._quic.next_event()
if event is None:
break
self._event_batch.append(event)
events_processed += 1
# Process batch if we have events or timeout
if self._event_batch and (
len(self._event_batch) >= self._event_batch_size
or current_time - self._last_event_time > 0.01 # 10ms timeout
):
await self._process_event_batch()
self._event_batch.clear()
self._last_event_time = current_time
finally:
self._event_processing_active = False
async def _process_event_batch(self) -> None:
"""Process a batch of events efficiently."""
if not self._event_batch:
return
# Group events by type for batch processing where possible
events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list)
for event in self._event_batch:
events_by_type[type(event).__name__].append(event)
# Process events by type
for event_type, event_list in events_by_type.items():
if event_type == type(events.StreamDataReceived).__name__:
await self._handle_stream_data_batch(
cast(list[events.StreamDataReceived], event_list)
)
else:
# Process other events individually
for event in event_list:
await self._handle_quic_event(event)
logger.debug(f"Processed batch of {len(self._event_batch)} events")
async def _handle_stream_data_batch(
self, events_list: list[events.StreamDataReceived]
) -> None:
"""Handle stream data events in batch for better performance."""
# Group by stream ID
events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list)
for event in events_list:
events_by_stream[event.stream_id].append(event)
# Process each stream's events
for stream_id, stream_events in events_by_stream.items():
stream = self._get_stream_fast(stream_id) # Use fast lookup
if not stream:
if self._is_incoming_stream(stream_id):
try:
stream = await self._create_inbound_stream(stream_id)
except QUICStreamLimitError:
# Reset stream if we can't handle it
self._quic.reset_stream(stream_id, error_code=0x04)
await self._transmit()
continue
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
continue
# Process all events for this stream
for received_event in stream_events:
if hasattr(received_event, "data"):
self._stats["bytes_received"] += len(received_event.data) # type: ignore
if hasattr(received_event, "end_stream"):
await stream.handle_data_received(
received_event.data, # type: ignore
received_event.end_stream, # type: ignore
)
async def _create_inbound_stream(self, stream_id: int) -> QUICStream:
"""Create inbound stream with proper limit checking."""
async with self._stream_lock:
# Double-check stream doesn't exist
existing_stream = self._streams.get(stream_id)
if existing_stream:
return existing_stream
# Check limits
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
logger.warning(f"Rejecting inbound stream {stream_id}: limit reached")
raise QUICStreamLimitError("Too many inbound streams")
# Create stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
self._streams[stream_id] = stream
self._stream_cache[stream_id] = stream # Add to cache
self._inbound_stream_count += 1
self._stats["streams_accepted"] += 1
# Add to accept queue
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
logger.debug(f"Created inbound stream {stream_id}")
return stream
async def _process_quic_events(self) -> None:
"""Process all pending QUIC events."""
# Delegate to batched processing for better performance
await self._process_quic_events_batched()
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
"""Handle a single QUIC event with COMPLETE event type coverage."""
logger.debug(f"Handling QUIC event: {type(event).__name__}")
logger.debug(f"QUIC event: {type(event).__name__}")
try:
if isinstance(event, events.ConnectionTerminated):
await self._handle_connection_terminated(event)
elif isinstance(event, events.HandshakeCompleted):
await self._handle_handshake_completed(event)
elif isinstance(event, events.StreamDataReceived):
await self._handle_stream_data(event)
elif isinstance(event, events.StreamReset):
await self._handle_stream_reset(event)
elif isinstance(event, events.DatagramFrameReceived):
await self._handle_datagram_received(event)
# *** NEW: Connection ID event handlers - CRITICAL FIX ***
elif isinstance(event, events.ConnectionIdIssued):
await self._handle_connection_id_issued(event)
elif isinstance(event, events.ConnectionIdRetired):
await self._handle_connection_id_retired(event)
# *** NEW: Additional event handlers for completeness ***
elif isinstance(event, events.PingAcknowledged):
await self._handle_ping_acknowledged(event)
elif isinstance(event, events.ProtocolNegotiated):
await self._handle_protocol_negotiated(event)
elif isinstance(event, events.StopSendingReceived):
await self._handle_stop_sending_received(event)
else:
logger.debug(f"Unhandled QUIC event type: {type(event).__name__}")
logger.debug(f"Unhandled QUIC event: {type(event).__name__}")
except Exception as e:
logger.error(f"Error handling QUIC event {type(event).__name__}: {e}")
async def _handle_connection_id_issued(
self, event: events.ConnectionIdIssued
) -> None:
"""
Handle new connection ID issued by peer.
This is the CRITICAL missing functionality that was causing your issue!
"""
logger.debug(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
logger.debug(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
# Add to available connection IDs
self._available_connection_ids.add(event.connection_id)
# If we don't have a current connection ID, use this one
if self._current_connection_id is None:
self._current_connection_id = event.connection_id
logger.debug(
f"🆔 Set current connection ID to: {event.connection_id.hex()}"
)
logger.debug(
f"🆔 Set current connection ID to: {event.connection_id.hex()}"
)
# Update statistics
self._stats["connection_ids_issued"] += 1
logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}")
logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}")
async def _handle_connection_id_retired(
self, event: events.ConnectionIdRetired
) -> None:
"""
Handle connection ID retirement.
This handles when the peer tells us to stop using a connection ID.
"""
logger.debug(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
# Remove from available IDs and add to retired set
self._available_connection_ids.discard(event.connection_id)
self._retired_connection_ids.add(event.connection_id)
# If this was our current connection ID, switch to another
if self._current_connection_id == event.connection_id:
if self._available_connection_ids:
self._current_connection_id = next(iter(self._available_connection_ids))
if self._current_connection_id:
logger.debug(
"Switching to new connection ID: "
f"{self._current_connection_id.hex()}"
)
self._stats["connection_id_changes"] += 1
else:
logger.warning("⚠️ No available connection IDs after retirement!")
else:
self._current_connection_id = None
logger.warning("⚠️ No available connection IDs after retirement!")
# Update statistics
self._stats["connection_ids_retired"] += 1
async def _handle_ping_acknowledged(self, event: events.PingAcknowledged) -> None:
"""Handle ping acknowledgment."""
logger.debug(f"Ping acknowledged: uid={event.uid}")
async def _handle_protocol_negotiated(
self, event: events.ProtocolNegotiated
) -> None:
"""Handle protocol negotiation completion."""
logger.debug(f"Protocol negotiated: {event.alpn_protocol}")
async def _handle_stop_sending_received(
self, event: events.StopSendingReceived
) -> None:
"""Handle stop sending request from peer."""
logger.debug(
"Stop sending received: "
f"stream_id={event.stream_id}, error_code={event.error_code}"
)
# Use fast lookup
stream = self._get_stream_fast(event.stream_id)
if stream:
# Handle stop sending on the stream if method exists
await stream.handle_stop_sending(event.error_code)
async def _handle_handshake_completed(
self, event: events.HandshakeCompleted
) -> None:
"""Handle handshake completion with security integration."""
logger.debug("QUIC handshake completed")
self._handshake_completed = True
# Store handshake event for security verification
self._handshake_events.append(event)
# Try to extract certificate information after handshake
await self._extract_peer_certificate()
logger.debug("✅ Setting connected event")
self._connected_event.set()
async def _handle_connection_terminated(
self, event: events.ConnectionTerminated
) -> None:
"""Handle connection termination."""
logger.debug(f"QUIC connection terminated: {event.reason_phrase}")
# Close all streams
for stream in list(self._streams.values()):
if event.error_code:
await stream.handle_reset(event.error_code)
else:
await stream.close()
self._streams.clear()
self._stream_cache.clear() # Clear cache too
self._closed = True
self._closed_event.set()
self._stream_accept_event.set()
logger.debug(f"Woke up pending accept_stream() calls, {id(self)}")
await self._notify_parent_of_termination()
async def _handle_stream_data(self, event: events.StreamDataReceived) -> None:
"""Handle stream data events - create streams and add to accept queue."""
stream_id = event.stream_id
self._stats["bytes_received"] += len(event.data)
try:
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if not stream:
if self._is_incoming_stream(stream_id):
logger.debug(f"Creating new incoming stream {stream_id}")
stream = await self._create_inbound_stream(stream_id)
else:
logger.error(
f"Unexpected outbound stream {stream_id} in data event"
)
return
await stream.handle_data_received(event.data, event.end_stream)
except Exception as e:
logger.error(f"Error handling stream data for stream {stream_id}: {e}")
logger.debug(f"❌ STREAM_DATA: Error: {e}")
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
"""Get existing stream or create new inbound stream."""
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if stream:
return stream
# Check if this is an incoming stream
is_incoming = self._is_incoming_stream(stream_id)
if not is_incoming:
# This shouldn't happen - outbound streams should be created by open_stream
raise QUICStreamError(
f"Received data for unknown outbound stream {stream_id}"
)
# Create new inbound stream
return await self._create_inbound_stream(stream_id)
def _is_incoming_stream(self, stream_id: int) -> bool:
"""
Determine if a stream ID represents an incoming stream.
For bidirectional streams:
- Even IDs are client-initiated
- Odd IDs are server-initiated
"""
if self._is_initiator:
# We're the client, so odd stream IDs are incoming
return stream_id % 2 == 1
else:
# We're the server, so even stream IDs are incoming
return stream_id % 2 == 0
async def _handle_stream_reset(self, event: events.StreamReset) -> None:
"""Stream reset handling."""
stream_id = event.stream_id
self._stats["streams_reset"] += 1
# Use fast lookup
stream = self._get_stream_fast(stream_id)
if stream:
try:
await stream.handle_reset(event.error_code)
logger.debug(
f"Handled reset for stream {stream_id}"
f"with error code {event.error_code}"
)
except Exception as e:
logger.error(f"Error handling stream reset for {stream_id}: {e}")
# Force remove the stream
self._remove_stream(stream_id)
else:
logger.debug(f"Received reset for unknown stream {stream_id}")
async def _handle_datagram_received(
self, event: events.DatagramFrameReceived
) -> None:
"""Handle datagram frame (if using QUIC datagrams)."""
logger.debug(f"Datagram frame received: size={len(event.data)}")
# For now, just log. Could be extended for custom datagram handling
async def _handle_timer_events(self) -> None:
"""Handle QUIC timer events."""
timer = self._quic.get_timer()
if timer is not None:
now = time.time()
if timer <= now:
self._quic.handle_timer(now=now)
# Network transmission
async def _transmit(self) -> None:
"""Transmit pending QUIC packets using available socket."""
sock = self._socket
if not sock:
logger.debug("No socket to transmit")
return
try:
current_time = time.time()
datagrams = self._quic.datagrams_to_send(now=current_time)
# Batch stats updates
packet_count = 0
total_bytes = 0
for data, addr in datagrams:
await sock.sendto(data, addr)
packet_count += 1
total_bytes += len(data)
# Update stats in batch
if packet_count > 0:
self._stats["packets_sent"] += packet_count
self._stats["bytes_sent"] += total_bytes
except Exception as e:
logger.error(f"Transmission error: {e}")
await self._handle_connection_error(e)
# Additional methods for stream data processing
async def _process_quic_event(self, event: events.QuicEvent) -> None:
"""Process a single QUIC event."""
await self._handle_quic_event(event)
async def _transmit_pending_data(self) -> None:
"""Transmit any pending data."""
await self._transmit()
# Error handling
async def _handle_connection_error(self, error: Exception) -> None:
"""Handle connection-level errors."""
logger.error(f"Connection error: {error}")
if not self._closed:
try:
await self.close()
except Exception as close_error:
logger.error(f"Error during connection close: {close_error}")
# Connection close
async def close(self) -> None:
"""Connection close with proper stream cleanup."""
if self._closed:
return
self._closed = True
logger.debug(f"Closing QUIC connection to {self._remote_peer_id}")
try:
# Close all streams gracefully
stream_close_tasks = []
for stream in list(self._streams.values()):
if stream.can_write() or stream.can_read():
stream_close_tasks.append(stream.close)
if stream_close_tasks and self._nursery:
try:
# Close streams concurrently with timeout
with trio.move_on_after(self.CONNECTION_CLOSE_TIMEOUT):
async with trio.open_nursery() as close_nursery:
for task in stream_close_tasks:
close_nursery.start_soon(task)
except Exception as e:
logger.warning(f"Error during graceful stream close: {e}")
# Force reset remaining streams
for stream in self._streams.values():
try:
await stream.reset(error_code=0)
except Exception:
pass
if self.on_close:
await self.on_close()
# Close QUIC connection
self._quic.close()
if self._socket:
await self._transmit() # Send close frames
# Close socket
if self._socket and self._owns_socket:
self._socket.close()
self._socket = None
self._streams.clear()
self._stream_cache.clear() # Clear cache
self._closed_event.set()
logger.debug(f"QUIC connection to {self._remote_peer_id} closed")
except Exception as e:
logger.error(f"Error during connection close: {e}")
async def _notify_parent_of_termination(self) -> None:
"""
Notify the parent listener/transport to remove this connection from tracking.
This ensures that terminated connections are cleaned up from the
'established connections' list.
"""
try:
if self._transport:
await self._transport._cleanup_terminated_connection(self)
logger.debug("Notified transport of connection termination")
return
for listener in self._transport._listeners:
try:
await listener._remove_connection_by_object(self)
logger.debug(
"Found and notified listener of connection termination"
)
return
except Exception:
continue
# Method 4: Use connection ID if we have one (most reliable)
if self._current_connection_id:
await self._cleanup_by_connection_id(self._current_connection_id)
return
logger.warning(
"Could not notify parent of connection termination - no"
f" parent reference found for conn host {self._quic.host_cid.hex()}"
)
except Exception as e:
logger.error(f"Error notifying parent of connection termination: {e}")
async def _cleanup_by_connection_id(self, connection_id: bytes) -> None:
"""Cleanup using connection ID as a fallback method."""
try:
for listener in self._transport._listeners:
for tracked_cid, tracked_conn in list(listener._connections.items()):
if tracked_conn is self:
await listener._remove_connection(tracked_cid)
logger.debug(f"Removed connection {tracked_cid.hex()}")
return
logger.debug("Fallback cleanup by connection ID completed")
except Exception as e:
logger.error(f"Error in fallback cleanup: {e}")
# IRawConnection interface (for compatibility)
def get_remote_address(self) -> tuple[str, int]:
return self._remote_addr
async def write(self, data: bytes) -> None:
"""
Write data to the connection.
For QUIC, this creates a new stream for each write operation.
"""
if self._closed:
raise QUICConnectionClosedError("Connection is closed")
stream = await self.open_stream()
try:
await stream.write(data)
await stream.close_write()
except Exception:
await stream.reset()
raise
async def read(self, n: int | None = -1) -> bytes:
"""
Read data from the stream.
Args:
n: Maximum number of bytes to read. -1 means read all available.
Returns:
Data bytes read from the stream.
Raises:
QUICStreamClosedError: If stream is closed for reading.
QUICStreamResetError: If stream was reset.
QUICStreamTimeoutError: If read timeout occurs.
"""
# It's here for interface compatibility but should not be used
raise NotImplementedError(
"Use streams for reading data from QUIC connections. "
"Call accept_stream() or open_stream() instead."
)
# Utility and monitoring methods
def get_stream_stats(self) -> dict[str, Any]:
"""Get stream statistics for monitoring."""
return {
"total_streams": len(self._streams),
"outbound_streams": self._outbound_stream_count,
"inbound_streams": self._inbound_stream_count,
"max_streams": self.MAX_CONCURRENT_STREAMS,
"stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS,
"stats": self._stats.copy(),
"cache_size": len(
self._stream_cache
), # Include cache metrics for monitoring
}
def get_active_streams(self) -> list[QUICStream]:
"""Get list of active streams."""
return [stream for stream in self._streams.values() if not stream.is_closed()]
def get_streams_by_protocol(self, protocol: str) -> list[QUICStream]:
"""Get streams filtered by protocol."""
return [
stream
for stream in self._streams.values()
if hasattr(stream, "protocol")
and stream.protocol == protocol
and not stream.is_closed()
]
def _update_stats(self) -> None:
"""Update connection statistics."""
# Add any periodic stats updates here
pass
async def _cleanup_idle_streams(self) -> None:
"""Clean up idle streams that are no longer needed."""
current_time = time.time()
streams_to_cleanup = []
for stream in self._streams.values():
if stream.is_closed():
# Check if stream has been closed for a while
if hasattr(stream, "_timeline") and stream._timeline.closed_at:
if current_time - stream._timeline.closed_at > 60: # 1 minute
streams_to_cleanup.append(stream.stream_id)
for stream_id in streams_to_cleanup:
self._remove_stream(int(stream_id))
# String representation
def __repr__(self) -> str:
current_cid: str | None = (
self._current_connection_id.hex() if self._current_connection_id else None
)
return (
f"QUICConnection(peer={self._remote_peer_id}, "
f"addr={self._remote_addr}, "
f"initiator={self._is_initiator}, "
f"verified={self._peer_verified}, "
f"established={self._established}, "
f"streams={len(self._streams)}, "
f"current_cid={current_cid})"
)
def __str__(self) -> str:
return f"QUICConnection({self._remote_peer_id})"