Files
py-libp2p/libp2p/transport/quic/connection.py
2025-08-30 14:07:31 +05:30

552 lines
18 KiB
Python

"""
QUIC Connection implementation for py-libp2p.
Uses aioquic's sans-IO core with trio for async operations.
"""
import logging
import socket
import time
from typing import TYPE_CHECKING
from aioquic.quic import (
events,
)
from aioquic.quic.connection import (
QuicConnection,
)
import multiaddr
import trio
from libp2p.abc import (
IMuxedConn,
IMuxedStream,
IRawConnection,
)
from libp2p.custom_types import TQUICStreamHandlerFn
from libp2p.peer.id import (
ID,
)
from .exceptions import (
QUICConnectionError,
QUICStreamError,
)
from .stream import (
QUICStream,
)
if TYPE_CHECKING:
from .transport import (
QUICTransport,
)
logger = logging.getLogger(__name__)
class QUICConnection(IRawConnection, IMuxedConn):
"""
QUIC connection implementing both raw connection and muxed connection interfaces.
Uses aioquic's sans-IO core with trio for native async support.
QUIC natively provides stream multiplexing, so this connection acts as both
a raw connection (for transport layer) and muxed connection (for upper layers).
Updated to work properly with the QUIC listener for server-side connections.
"""
def __init__(
self,
quic_connection: QuicConnection,
remote_addr: tuple[str, int],
peer_id: ID | None,
local_peer_id: ID,
is_initiator: bool,
maddr: multiaddr.Multiaddr,
transport: "QUICTransport",
):
self._quic = quic_connection
self._remote_addr = remote_addr
self._peer_id = peer_id
self._local_peer_id = local_peer_id
self.__is_initiator = is_initiator
self._maddr = maddr
self._transport = transport
# Trio networking - socket may be provided by listener
self._socket: trio.socket.SocketType | None = None
self._connected_event = trio.Event()
self._closed_event = trio.Event()
# Stream management
self._streams: dict[int, QUICStream] = {}
self._next_stream_id: int = self._calculate_initial_stream_id()
self._stream_handler: TQUICStreamHandlerFn | None = None
self._stream_id_lock = trio.Lock()
# Connection state
self._closed = False
self._established = False
self._started = False
# Background task management
self._background_tasks_started = False
self._nursery: trio.Nursery | None = None
logger.debug(f"Created QUIC connection to {peer_id} (initiator: {is_initiator})")
def _calculate_initial_stream_id(self) -> int:
"""
Calculate the initial stream ID based on QUIC specification.
QUIC stream IDs:
- Client-initiated bidirectional: 0, 4, 8, 12, ...
- Server-initiated bidirectional: 1, 5, 9, 13, ...
- Client-initiated unidirectional: 2, 6, 10, 14, ...
- Server-initiated unidirectional: 3, 7, 11, 15, ...
For libp2p, we primarily use bidirectional streams.
"""
if self.__is_initiator:
return 0 # Client starts with 0, then 4, 8, 12...
else:
return 1 # Server starts with 1, then 5, 9, 13...
@property
def is_initiator(self) -> bool: # type: ignore
return self.__is_initiator
async def start(self) -> None:
"""
Start the connection and its background tasks.
This method implements the IMuxedConn.start() interface.
It should be called to begin processing connection events.
"""
if self._started:
logger.warning("Connection already started")
return
if self._closed:
raise QUICConnectionError("Cannot start a closed connection")
self._started = True
logger.debug(f"Starting QUIC connection to {self._peer_id}")
# If this is a client connection, we need to establish the connection
if self.__is_initiator:
await self._initiate_connection()
else:
# For server connections, we're already connected via the listener
self._established = True
self._connected_event.set()
logger.debug(f"QUIC connection to {self._peer_id} started")
async def _initiate_connection(self) -> None:
"""Initiate client-side connection establishment."""
try:
# Create UDP socket using trio
self._socket = trio.socket.socket(
family=socket.AF_INET, type=socket.SOCK_DGRAM
)
# Connect the socket to the remote address
await self._socket.connect(self._remote_addr)
# Start the connection establishment
self._quic.connect(self._remote_addr, now=time.time())
# Send initial packet(s)
await self._transmit()
# For client connections, we need to manage our own background tasks
# In a real implementation, this would be managed by the transport
# For now, we'll start them here
if not self._background_tasks_started:
# We would need a nursery to start background tasks
# This is a limitation of the current design
logger.warning("Background tasks need nursery - connection may not work properly")
except Exception as e:
logger.error(f"Failed to initiate connection: {e}")
raise QUICConnectionError(f"Connection initiation failed: {e}") from e
async def connect(self, nursery: trio.Nursery) -> None:
"""
Establish the QUIC connection using trio.
Args:
nursery: Trio nursery for background tasks
"""
if not self.__is_initiator:
raise QUICConnectionError("connect() should only be called by client connections")
try:
# Store nursery for background tasks
self._nursery = nursery
# Create UDP socket using trio
self._socket = trio.socket.socket(
family=socket.AF_INET, type=socket.SOCK_DGRAM
)
# Connect the socket to the remote address
await self._socket.connect(self._remote_addr)
# Start the connection establishment
self._quic.connect(self._remote_addr, now=time.time())
# Send initial packet(s)
await self._transmit()
# Start background tasks
await self._start_background_tasks(nursery)
# Wait for connection to be established
await self._connected_event.wait()
except Exception as e:
logger.error(f"Failed to connect: {e}")
raise QUICConnectionError(f"Connection failed: {e}") from e
async def _start_background_tasks(self, nursery: trio.Nursery) -> None:
"""Start background tasks for connection management."""
if self._background_tasks_started:
return
self._background_tasks_started = True
# Start background tasks
nursery.start_soon(self._handle_incoming_data)
nursery.start_soon(self._handle_timer)
async def _handle_incoming_data(self) -> None:
"""Handle incoming UDP datagrams in trio."""
while not self._closed:
try:
if self._socket:
data, addr = await self._socket.recvfrom(65536)
self._quic.receive_datagram(data, addr, now=time.time())
await self._process_events()
await self._transmit()
# Small delay to prevent busy waiting
await trio.sleep(0.001)
except trio.ClosedResourceError:
break
except Exception as e:
logger.error(f"Error handling incoming data: {e}")
break
async def _handle_timer(self) -> None:
"""Handle QUIC timer events in trio."""
while not self._closed:
try:
timer_at = self._quic.get_timer()
if timer_at is None:
await trio.sleep(0.1) # No timer set, check again later
continue
now = time.time()
if timer_at <= now:
self._quic.handle_timer(now=now)
await self._process_events()
await self._transmit()
await trio.sleep(0.001) # Small delay
else:
# Sleep until timer fires, but check periodically
sleep_time = min(timer_at - now, 0.1)
await trio.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in timer handler: {e}")
await trio.sleep(0.1)
async def _process_events(self) -> None:
"""Process QUIC events from aioquic core."""
while True:
event = self._quic.next_event()
if event is None:
break
if isinstance(event, events.ConnectionTerminated):
logger.info(f"QUIC connection terminated: {event.reason_phrase}")
self._closed = True
self._closed_event.set()
break
elif isinstance(event, events.HandshakeCompleted):
logger.debug("QUIC handshake completed")
self._established = True
self._connected_event.set()
elif isinstance(event, events.StreamDataReceived):
await self._handle_stream_data(event)
elif isinstance(event, events.StreamReset):
await self._handle_stream_reset(event)
async def _handle_stream_data(self, event: events.StreamDataReceived) -> None:
"""Handle incoming stream data."""
stream_id = event.stream_id
# Get or create stream
if stream_id not in self._streams:
# Determine if this is an incoming stream
is_incoming = self._is_incoming_stream(stream_id)
stream = QUICStream(
connection=self,
stream_id=stream_id,
is_initiator=not is_incoming,
)
self._streams[stream_id] = stream
# Notify stream handler for incoming streams
if is_incoming and self._stream_handler:
# Start stream handler in background
# In a real implementation, you might want to use the nursery
# passed to the connection, but for now we'll handle it directly
try:
await self._stream_handler(stream)
except Exception as e:
logger.error(f"Error in stream handler: {e}")
# Forward data to stream
stream = self._streams[stream_id]
await stream.handle_data_received(event.data, event.end_stream)
def _is_incoming_stream(self, stream_id: int) -> bool:
"""
Determine if a stream ID represents an incoming stream.
For bidirectional streams:
- Even IDs are client-initiated
- Odd IDs are server-initiated
"""
if self.__is_initiator:
# We're the client, so odd stream IDs are incoming
return stream_id % 2 == 1
else:
# We're the server, so even stream IDs are incoming
return stream_id % 2 == 0
async def _handle_stream_reset(self, event: events.StreamReset) -> None:
"""Handle stream reset."""
stream_id = event.stream_id
if stream_id in self._streams:
stream = self._streams[stream_id]
await stream.handle_reset(event.error_code)
del self._streams[stream_id]
async def _transmit(self) -> None:
"""Send pending datagrams using trio."""
socket = self._socket
if socket is None:
return
try:
for data, addr in self._quic.datagrams_to_send(now=time.time()):
await socket.sendto(data, addr)
except Exception as e:
logger.error(f"Failed to send datagram: {e}")
# IRawConnection interface
async def write(self, data: bytes) -> None:
"""
Write data to the connection.
For QUIC, this creates a new stream for each write operation.
"""
if self._closed:
raise QUICConnectionError("Connection is closed")
stream = await self.open_stream()
await stream.write(data)
await stream.close()
async def read(self, n: int | None = -1) -> bytes:
"""
Read data from the connection.
For QUIC, this reads from the next available stream.
"""
if self._closed:
raise QUICConnectionError("Connection is closed")
# For raw connection interface, we need to handle this differently
# In practice, upper layers will use the muxed connection interface
raise NotImplementedError(
"Use muxed connection interface for stream-based reading"
)
async def close(self) -> None:
"""Close the connection and all streams."""
if self._closed:
return
self._closed = True
logger.debug(f"Closing QUIC connection to {self._peer_id}")
# Close all streams
stream_close_tasks = []
for stream in list(self._streams.values()):
stream_close_tasks.append(stream.close())
if stream_close_tasks:
# Close streams concurrently
async with trio.open_nursery() as nursery:
for task in stream_close_tasks:
nursery.start_soon(lambda t=task: t)
# Close QUIC connection
self._quic.close()
if self._socket:
await self._transmit() # Send close frames
# Close socket
if self._socket:
self._socket.close()
self._streams.clear()
self._closed_event.set()
logger.debug(f"QUIC connection to {self._peer_id} closed")
@property
def is_closed(self) -> bool:
"""Check if connection is closed."""
return self._closed
@property
def is_established(self) -> bool:
"""Check if connection is established (handshake completed)."""
return self._established
@property
def is_started(self) -> bool:
"""Check if connection has been started."""
return self._started
def multiaddr(self) -> multiaddr.Multiaddr:
"""Get the multiaddr for this connection."""
return self._maddr
def local_peer_id(self) -> ID:
"""Get the local peer ID."""
return self._local_peer_id
def remote_peer_id(self) -> ID | None:
"""Get the remote peer ID."""
return self._peer_id
# IMuxedConn interface
async def open_stream(self) -> IMuxedStream:
"""
Open a new stream on this connection.
Returns:
New QUIC stream
"""
if self._closed:
raise QUICStreamError("Connection is closed")
if not self._started:
raise QUICStreamError("Connection not started")
async with self._stream_id_lock:
# Generate next stream ID
stream_id = self._next_stream_id
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
# Create stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
is_initiator=True
)
self._streams[stream_id] = stream
logger.debug(f"Opened QUIC stream {stream_id}")
return stream
def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None:
"""
Set handler for incoming streams.
Args:
handler_function: Function to handle new incoming streams
"""
self._stream_handler = handler_function
async def accept_stream(self) -> IMuxedStream:
"""
Accept an incoming stream.
Returns:
Accepted stream
"""
# This is handled automatically by the event processing
# Upper layers should use set_stream_handler instead
raise NotImplementedError("Use set_stream_handler for incoming streams")
async def verify_peer_identity(self) -> None:
"""
Verify the remote peer's identity using TLS certificate.
This implements the libp2p TLS handshake verification.
"""
# Extract peer ID from TLS certificate
# This should match the expected peer ID
try:
cert_peer_id = self._extract_peer_id_from_cert()
if self._peer_id and cert_peer_id != self._peer_id:
raise QUICConnectionError(
f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}"
)
if not self._peer_id:
self._peer_id = cert_peer_id
logger.debug(f"Verified peer identity: {self._peer_id}")
except NotImplementedError:
logger.warning("Peer identity verification not implemented - skipping")
# For now, we'll skip verification during development
def _extract_peer_id_from_cert(self) -> ID:
"""Extract peer ID from TLS certificate."""
# This should extract the peer ID from the TLS certificate
# following the libp2p TLS specification
# Implementation depends on how the certificate is structured
# Placeholder - implement based on libp2p TLS spec
# The certificate should contain the peer ID in a specific extension
raise NotImplementedError("Certificate peer ID extraction not implemented")
def get_stats(self) -> dict:
"""Get connection statistics."""
return {
"peer_id": str(self._peer_id),
"remote_addr": self._remote_addr,
"is_initiator": self.__is_initiator,
"is_established": self._established,
"is_closed": self._closed,
"is_started": self._started,
"active_streams": len(self._streams),
"next_stream_id": self._next_stream_id,
}
def get_remote_address(self):
return self._remote_addr
def __str__(self) -> str:
"""String representation of the connection."""
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)}, established={self._established}, started={self._started})"