fix: impl quic listener

This commit is contained in:
Akash Mondal
2025-06-10 21:40:21 +00:00
committed by lla-dane
parent 446a22b0f0
commit 54b3055eaa
13 changed files with 1687 additions and 150 deletions

View File

@ -5,17 +5,15 @@ from collections.abc import (
)
from typing import TYPE_CHECKING, NewType, Union, cast
from libp2p.transport.quic.stream import QUICStream
if TYPE_CHECKING:
from libp2p.abc import (
IMuxedConn,
INetStream,
ISecureTransport,
)
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
else:
IMuxedConn = cast(type, object)
INetStream = cast(type, object)
ISecureTransport = cast(type, object)
IMuxedStream = cast(type, object)
from libp2p.io.abc import (
ReadWriteCloser,
@ -37,3 +35,4 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]

View File

@ -8,6 +8,8 @@ from dataclasses import (
)
import ssl
from libp2p.custom_types import TProtocol
@dataclass
class QUICTransportConfig:
@ -39,6 +41,12 @@ class QUICTransportConfig:
max_connections: int = 1000 # Maximum number of connections
connection_timeout: float = 10.0 # Connection establishment timeout
# Protocol identifiers matching go-libp2p
# TODO: UNTIL MUITIADDR REPO IS UPDATED
# PROTOCOL_QUIC_V1: TProtocol = TProtocol("/quic-v1") # RFC 9000
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic") # RFC 9000
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
def __post_init__(self):
"""Validate configuration after initialization."""
if not (self.enable_draft29 or self.enable_v1):

View File

@ -6,6 +6,7 @@ Uses aioquic's sans-IO core with trio for async operations.
import logging
import socket
import time
from typing import TYPE_CHECKING
from aioquic.quic import (
events,
@ -21,9 +22,7 @@ from libp2p.abc import (
IMuxedStream,
IRawConnection,
)
from libp2p.custom_types import (
StreamHandlerFn,
)
from libp2p.custom_types import TQUICStreamHandlerFn
from libp2p.peer.id import (
ID,
)
@ -35,9 +34,11 @@ from .exceptions import (
from .stream import (
QUICStream,
)
from .transport import (
QUICTransport,
)
if TYPE_CHECKING:
from .transport import (
QUICTransport,
)
logger = logging.getLogger(__name__)
@ -49,76 +50,177 @@ class QUICConnection(IRawConnection, IMuxedConn):
Uses aioquic's sans-IO core with trio for native async support.
QUIC natively provides stream multiplexing, so this connection acts as both
a raw connection (for transport layer) and muxed connection (for upper layers).
Updated to work properly with the QUIC listener for server-side connections.
"""
def __init__(
self,
quic_connection: QuicConnection,
remote_addr: tuple[str, int],
peer_id: ID,
peer_id: ID | None,
local_peer_id: ID,
initiator: bool,
is_initiator: bool,
maddr: multiaddr.Multiaddr,
transport: QUICTransport,
transport: "QUICTransport",
):
self._quic = quic_connection
self._remote_addr = remote_addr
self._peer_id = peer_id
self._local_peer_id = local_peer_id
self.__is_initiator = initiator
self.__is_initiator = is_initiator
self._maddr = maddr
self._transport = transport
# Trio networking
# Trio networking - socket may be provided by listener
self._socket: trio.socket.SocketType | None = None
self._connected_event = trio.Event()
self._closed_event = trio.Event()
# Stream management
self._streams: dict[int, QUICStream] = {}
self._next_stream_id: int = (
0 if initiator else 1
) # Even for initiator, odd for responder
self._stream_handler: StreamHandlerFn | None = None
self._next_stream_id: int = self._calculate_initial_stream_id()
self._stream_handler: TQUICStreamHandlerFn | None = None
self._stream_id_lock = trio.Lock()
# Connection state
self._closed = False
self._timer_task = None
self._established = False
self._started = False
logger.debug(f"Created QUIC connection to {peer_id}")
# Background task management
self._background_tasks_started = False
self._nursery: trio.Nursery | None = None
logger.debug(f"Created QUIC connection to {peer_id} (initiator: {is_initiator})")
def _calculate_initial_stream_id(self) -> int:
"""
Calculate the initial stream ID based on QUIC specification.
QUIC stream IDs:
- Client-initiated bidirectional: 0, 4, 8, 12, ...
- Server-initiated bidirectional: 1, 5, 9, 13, ...
- Client-initiated unidirectional: 2, 6, 10, 14, ...
- Server-initiated unidirectional: 3, 7, 11, 15, ...
For libp2p, we primarily use bidirectional streams.
"""
if self.__is_initiator:
return 0 # Client starts with 0, then 4, 8, 12...
else:
return 1 # Server starts with 1, then 5, 9, 13...
@property
def is_initiator(self) -> bool: # type: ignore
return self.__is_initiator
async def connect(self) -> None:
"""Establish the QUIC connection using trio."""
async def start(self) -> None:
"""
Start the connection and its background tasks.
This method implements the IMuxedConn.start() interface.
It should be called to begin processing connection events.
"""
if self._started:
logger.warning("Connection already started")
return
if self._closed:
raise QUICConnectionError("Cannot start a closed connection")
self._started = True
logger.debug(f"Starting QUIC connection to {self._peer_id}")
# If this is a client connection, we need to establish the connection
if self.__is_initiator:
await self._initiate_connection()
else:
# For server connections, we're already connected via the listener
self._established = True
self._connected_event.set()
logger.debug(f"QUIC connection to {self._peer_id} started")
async def _initiate_connection(self) -> None:
"""Initiate client-side connection establishment."""
try:
# Create UDP socket using trio
self._socket = trio.socket.socket(
family=socket.AF_INET, type=socket.SOCK_DGRAM
)
# Connect the socket to the remote address
await self._socket.connect(self._remote_addr)
# Start the connection establishment
self._quic.connect(self._remote_addr, now=time.time())
# Send initial packet(s)
await self._transmit()
# Start background tasks using trio nursery
async with trio.open_nursery() as nursery:
nursery.start_soon(
self._handle_incoming_data, None, "QUIC INCOMING DATA"
)
nursery.start_soon(self._handle_timer, None, "QUIC TIMER HANDLER")
# For client connections, we need to manage our own background tasks
# In a real implementation, this would be managed by the transport
# For now, we'll start them here
if not self._background_tasks_started:
# We would need a nursery to start background tasks
# This is a limitation of the current design
logger.warning("Background tasks need nursery - connection may not work properly")
# Wait for connection to be established
await self._connected_event.wait()
except Exception as e:
logger.error(f"Failed to initiate connection: {e}")
raise QUICConnectionError(f"Connection initiation failed: {e}") from e
async def connect(self, nursery: trio.Nursery) -> None:
"""
Establish the QUIC connection using trio.
Args:
nursery: Trio nursery for background tasks
"""
if not self.__is_initiator:
raise QUICConnectionError("connect() should only be called by client connections")
try:
# Store nursery for background tasks
self._nursery = nursery
# Create UDP socket using trio
self._socket = trio.socket.socket(
family=socket.AF_INET, type=socket.SOCK_DGRAM
)
# Connect the socket to the remote address
await self._socket.connect(self._remote_addr)
# Start the connection establishment
self._quic.connect(self._remote_addr, now=time.time())
# Send initial packet(s)
await self._transmit()
# Start background tasks
await self._start_background_tasks(nursery)
# Wait for connection to be established
await self._connected_event.wait()
except Exception as e:
logger.error(f"Failed to connect: {e}")
raise QUICConnectionError(f"Connection failed: {e}") from e
async def _start_background_tasks(self, nursery: trio.Nursery) -> None:
"""Start background tasks for connection management."""
if self._background_tasks_started:
return
self._background_tasks_started = True
# Start background tasks
nursery.start_soon(self._handle_incoming_data)
nursery.start_soon(self._handle_timer)
async def _handle_incoming_data(self) -> None:
"""Handle incoming UDP datagrams in trio."""
while not self._closed:
@ -128,6 +230,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._quic.receive_datagram(data, addr, now=time.time())
await self._process_events()
await self._transmit()
# Small delay to prevent busy waiting
await trio.sleep(0.001)
except trio.ClosedResourceError:
break
except Exception as e:
@ -137,18 +243,26 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def _handle_timer(self) -> None:
"""Handle QUIC timer events in trio."""
while not self._closed:
timer_at = self._quic.get_timer()
if timer_at is None:
await trio.sleep(1.0) # No timer set, check again later
continue
try:
timer_at = self._quic.get_timer()
if timer_at is None:
await trio.sleep(0.1) # No timer set, check again later
continue
now = time.time()
if timer_at <= now:
self._quic.handle_timer(now=now)
await self._process_events()
await self._transmit()
else:
await trio.sleep(timer_at - now)
now = time.time()
if timer_at <= now:
self._quic.handle_timer(now=now)
await self._process_events()
await self._transmit()
await trio.sleep(0.001) # Small delay
else:
# Sleep until timer fires, but check periodically
sleep_time = min(timer_at - now, 0.1)
await trio.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in timer handler: {e}")
await trio.sleep(0.1)
async def _process_events(self) -> None:
"""Process QUIC events from aioquic core."""
@ -165,6 +279,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
elif isinstance(event, events.HandshakeCompleted):
logger.debug("QUIC handshake completed")
self._established = True
self._connected_event.set()
elif isinstance(event, events.StreamDataReceived):
@ -177,25 +292,47 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""Handle incoming stream data."""
stream_id = event.stream_id
# Get or create stream
if stream_id not in self._streams:
# Create new stream for incoming data
# Determine if this is an incoming stream
is_incoming = self._is_incoming_stream(stream_id)
stream = QUICStream(
connection=self,
stream_id=stream_id,
is_initiator=False, # pyrefly: ignore
is_initiator=not is_incoming,
)
self._streams[stream_id] = stream
# Notify stream handler if available
if self._stream_handler:
# Use trio nursery to start stream handler
async with trio.open_nursery() as nursery:
nursery.start_soon(self._stream_handler, stream)
# Notify stream handler for incoming streams
if is_incoming and self._stream_handler:
# Start stream handler in background
# In a real implementation, you might want to use the nursery
# passed to the connection, but for now we'll handle it directly
try:
await self._stream_handler(stream)
except Exception as e:
logger.error(f"Error in stream handler: {e}")
# Forward data to stream
stream = self._streams[stream_id]
await stream.handle_data_received(event.data, event.end_stream)
def _is_incoming_stream(self, stream_id: int) -> bool:
"""
Determine if a stream ID represents an incoming stream.
For bidirectional streams:
- Even IDs are client-initiated
- Odd IDs are server-initiated
"""
if self.__is_initiator:
# We're the client, so odd stream IDs are incoming
return stream_id % 2 == 1
else:
# We're the server, so even stream IDs are incoming
return stream_id % 2 == 0
async def _handle_stream_reset(self, event: events.StreamReset) -> None:
"""Handle stream reset."""
stream_id = event.stream_id
@ -210,15 +347,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
if socket is None:
return
for data, addr in self._quic.datagrams_to_send(now=time.time()):
try:
try:
for data, addr in self._quic.datagrams_to_send(now=time.time()):
await socket.sendto(data, addr)
except Exception as e:
logger.error(f"Failed to send datagram: {e}")
except Exception as e:
logger.error(f"Failed to send datagram: {e}")
# IRawConnection interface
async def write(self, data: bytes):
async def write(self, data: bytes) -> None:
"""
Write data to the connection.
For QUIC, this creates a new stream for each write operation.
@ -230,7 +367,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
await stream.write(data)
await stream.close()
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int | None = -1) -> bytes:
"""
Read data from the connection.
For QUIC, this reads from the next available stream.
@ -252,14 +389,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._closed = True
logger.debug(f"Closing QUIC connection to {self._peer_id}")
# Close all streams using trio nursery
async with trio.open_nursery() as nursery:
for stream in self._streams.values():
nursery.start_soon(stream.close)
# Close all streams
stream_close_tasks = []
for stream in list(self._streams.values()):
stream_close_tasks.append(stream.close())
if stream_close_tasks:
# Close streams concurrently
async with trio.open_nursery() as nursery:
for task in stream_close_tasks:
nursery.start_soon(lambda t=task: t)
# Close QUIC connection
self._quic.close()
await self._transmit() # Send close frames
if self._socket:
await self._transmit() # Send close frames
# Close socket
if self._socket:
@ -275,6 +419,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""Check if connection is closed."""
return self._closed
@property
def is_established(self) -> bool:
"""Check if connection is established (handshake completed)."""
return self._established
@property
def is_started(self) -> bool:
"""Check if connection has been started."""
return self._started
def multiaddr(self) -> multiaddr.Multiaddr:
"""Get the multiaddr for this connection."""
return self._maddr
@ -283,6 +437,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""Get the local peer ID."""
return self._local_peer_id
def remote_peer_id(self) -> ID | None:
"""Get the remote peer ID."""
return self._peer_id
# IMuxedConn interface
async def open_stream(self) -> IMuxedStream:
@ -296,23 +454,27 @@ class QUICConnection(IRawConnection, IMuxedConn):
if self._closed:
raise QUICStreamError("Connection is closed")
# Generate next stream ID
stream_id = self._next_stream_id
self._next_stream_id += (
2 # Increment by 2 to maintain initiator/responder distinction
)
if not self._started:
raise QUICStreamError("Connection not started")
# Create stream
stream = QUICStream(
connection=self, stream_id=stream_id, is_initiator=True
) # pyrefly: ignore
async with self._stream_id_lock:
# Generate next stream ID
stream_id = self._next_stream_id
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
self._streams[stream_id] = stream
# Create stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
is_initiator=True
)
self._streams[stream_id] = stream
logger.debug(f"Opened QUIC stream {stream_id}")
return stream
def set_stream_handler(self, handler_function: StreamHandlerFn) -> None:
def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None:
"""
Set handler for incoming streams.
@ -341,17 +503,22 @@ class QUICConnection(IRawConnection, IMuxedConn):
"""
# Extract peer ID from TLS certificate
# This should match the expected peer ID
cert_peer_id = self._extract_peer_id_from_cert()
try:
cert_peer_id = self._extract_peer_id_from_cert()
if self._peer_id and cert_peer_id != self._peer_id:
raise QUICConnectionError(
f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}"
)
if self._peer_id and cert_peer_id != self._peer_id:
raise QUICConnectionError(
f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}"
)
if not self._peer_id:
self._peer_id = cert_peer_id
if not self._peer_id:
self._peer_id = cert_peer_id
logger.debug(f"Verified peer identity: {self._peer_id}")
logger.debug(f"Verified peer identity: {self._peer_id}")
except NotImplementedError:
logger.warning("Peer identity verification not implemented - skipping")
# For now, we'll skip verification during development
def _extract_peer_id_from_cert(self) -> ID:
"""Extract peer ID from TLS certificate."""
@ -363,6 +530,22 @@ class QUICConnection(IRawConnection, IMuxedConn):
# The certificate should contain the peer ID in a specific extension
raise NotImplementedError("Certificate peer ID extraction not implemented")
def get_stats(self) -> dict:
"""Get connection statistics."""
return {
"peer_id": str(self._peer_id),
"remote_addr": self._remote_addr,
"is_initiator": self.__is_initiator,
"is_established": self._established,
"is_closed": self._closed,
"is_started": self._started,
"active_streams": len(self._streams),
"next_stream_id": self._next_stream_id,
}
def get_remote_address(self):
return self._remote_addr
def __str__(self) -> str:
"""String representation of the connection."""
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)})"
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)}, established={self._established}, started={self._started})"

View File

@ -0,0 +1,579 @@
"""
QUIC Listener implementation for py-libp2p.
Based on go-libp2p and js-libp2p QUIC listener patterns.
Uses aioquic's server-side QUIC implementation with trio.
"""
import copy
import logging
import socket
import time
from typing import TYPE_CHECKING, Dict
from aioquic.quic import events
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from multiaddr import Multiaddr
import trio
from libp2p.abc import IListener
from libp2p.custom_types import THandler, TProtocol
from .config import QUICTransportConfig
from .connection import QUICConnection
from .exceptions import QUICListenError
from .utils import (
create_quic_multiaddr,
is_quic_multiaddr,
multiaddr_to_quic_version,
quic_multiaddr_to_endpoint,
)
if TYPE_CHECKING:
from .transport import QUICTransport
logger = logging.getLogger(__name__)
logger.setLevel("DEBUG")
class QUICListener(IListener):
"""
QUIC Listener implementation following libp2p listener interface.
Handles incoming QUIC connections, manages server-side handshakes,
and integrates with the libp2p connection handler system.
Based on go-libp2p and js-libp2p listener patterns.
"""
def __init__(
self,
transport: "QUICTransport",
handler_function: THandler,
quic_configs: Dict[TProtocol, QuicConfiguration],
config: QUICTransportConfig,
):
"""
Initialize QUIC listener.
Args:
transport: Parent QUIC transport
handler_function: Function to handle new connections
quic_configs: QUIC configurations for different versions
config: QUIC transport configuration
"""
self._transport = transport
self._handler = handler_function
self._quic_configs = quic_configs
self._config = config
# Network components
self._socket: trio.socket.SocketType | None = None
self._bound_addresses: list[Multiaddr] = []
# Connection management
self._connections: Dict[tuple[str, int], QUICConnection] = {}
self._pending_connections: Dict[tuple[str, int], QuicConnection] = {}
self._connection_lock = trio.Lock()
# Listener state
self._closed = False
self._listening = False
self._nursery: trio.Nursery | None = None
# Performance tracking
self._stats = {
"connections_accepted": 0,
"connections_rejected": 0,
"bytes_received": 0,
"packets_processed": 0,
}
logger.debug("Initialized QUIC listener")
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
"""
Start listening on the given multiaddr.
Args:
maddr: Multiaddr to listen on
nursery: Trio nursery for managing background tasks
Returns:
True if listening started successfully
Raises:
QUICListenError: If failed to start listening
"""
if not is_quic_multiaddr(maddr):
raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}")
if self._listening:
raise QUICListenError("Already listening")
try:
# Extract host and port from multiaddr
host, port = quic_multiaddr_to_endpoint(maddr)
quic_version = multiaddr_to_quic_version(maddr)
# Validate QUIC version support
if quic_version not in self._quic_configs:
raise QUICListenError(f"Unsupported QUIC version: {quic_version}")
# Create and bind UDP socket
self._socket = await self._create_and_bind_socket(host, port)
actual_port = self._socket.getsockname()[1]
# Update multiaddr with actual bound port
actual_maddr = create_quic_multiaddr(host, actual_port, f"/{quic_version}")
self._bound_addresses = [actual_maddr]
# Store nursery reference and set listening state
self._nursery = nursery
self._listening = True
# Start background tasks directly in the provided nursery
# This ensures proper cancellation when the nursery exits
nursery.start_soon(self._handle_incoming_packets)
nursery.start_soon(self._manage_connections)
print(f"QUIC listener started on {actual_maddr}")
return True
except trio.Cancelled:
print("CLOSING LISTENER")
raise
except Exception as e:
logger.error(f"Failed to start QUIC listener on {maddr}: {e}")
await self._cleanup_socket()
raise QUICListenError(f"Listen failed: {e}") from e
async def _create_and_bind_socket(
self, host: str, port: int
) -> trio.socket.SocketType:
"""Create and bind UDP socket for QUIC."""
try:
# Determine address family
try:
import ipaddress
ip = ipaddress.ip_address(host)
family = socket.AF_INET if ip.version == 4 else socket.AF_INET6
except ValueError:
# Assume IPv4 for hostnames
family = socket.AF_INET
# Create UDP socket
sock = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
# Set socket options for better performance
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, "SO_REUSEPORT"):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
# Bind to address
await sock.bind((host, port))
logger.debug(f"Created and bound UDP socket to {host}:{port}")
return sock
except Exception as e:
raise QUICListenError(f"Failed to create socket: {e}") from e
async def _handle_incoming_packets(self) -> None:
"""
Handle incoming UDP packets and route to appropriate connections.
This is the main packet processing loop.
"""
logger.debug("Started packet handling loop")
try:
while self._listening and self._socket:
try:
# Receive UDP packet (this blocks until packet arrives or socket closes)
data, addr = await self._socket.recvfrom(65536)
self._stats["bytes_received"] += len(data)
self._stats["packets_processed"] += 1
# Process packet asynchronously to avoid blocking
if self._nursery:
self._nursery.start_soon(self._process_packet, data, addr)
except trio.ClosedResourceError:
# Socket was closed, exit gracefully
logger.debug("Socket closed, exiting packet handler")
break
except Exception as e:
logger.error(f"Error receiving packet: {e}")
# Continue processing other packets
await trio.sleep(0.01)
except trio.Cancelled:
print("PACKET HANDLER CANCELLED - FORCIBLY CLOSING SOCKET")
raise
finally:
print("PACKET HANDLER FINISHED")
logger.debug("Packet handling loop terminated")
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
"""
Process a single incoming packet.
Routes to existing connection or creates new connection.
Args:
data: Raw UDP packet data
addr: Source address (host, port)
"""
try:
async with self._connection_lock:
# Check if we have an existing connection for this address
if addr in self._connections:
connection = self._connections[addr]
await self._route_to_connection(connection, data, addr)
elif addr in self._pending_connections:
# Handle packet for pending connection
quic_conn = self._pending_connections[addr]
await self._handle_pending_connection(quic_conn, data, addr)
else:
# New connection
await self._handle_new_connection(data, addr)
except Exception as e:
logger.error(f"Error processing packet from {addr}: {e}")
async def _route_to_connection(
self, connection: QUICConnection, data: bytes, addr: tuple[str, int]
) -> None:
"""Route packet to existing connection."""
try:
# Feed data to the connection's QUIC instance
connection._quic.receive_datagram(data, addr, now=time.time())
# Process events and handle responses
await connection._process_events()
await connection._transmit()
except Exception as e:
logger.error(f"Error routing packet to connection {addr}: {e}")
# Remove problematic connection
await self._remove_connection(addr)
async def _handle_pending_connection(
self, quic_conn: QuicConnection, data: bytes, addr: tuple[str, int]
) -> None:
"""Handle packet for a pending (handshaking) connection."""
try:
# Feed data to QUIC connection
quic_conn.receive_datagram(data, addr, now=time.time())
# Process events
await self._process_quic_events(quic_conn, addr)
# Send any outgoing packets
await self._transmit_for_connection(quic_conn)
except Exception as e:
logger.error(f"Error handling pending connection {addr}: {e}")
# Remove from pending connections
self._pending_connections.pop(addr, None)
async def _handle_new_connection(self, data: bytes, addr: tuple[str, int]) -> None:
"""
Handle a new incoming connection.
Creates a new QUIC connection and starts handshake.
Args:
data: Initial packet data
addr: Source address
"""
try:
# Determine QUIC version from packet
# For now, use the first available configuration
# TODO: Implement proper version negotiation
quic_version = next(iter(self._quic_configs.keys()))
config = self._quic_configs[quic_version]
# Create server-side QUIC configuration
server_config = copy.deepcopy(config)
server_config.is_client = False
# Create QUIC connection
quic_conn = QuicConnection(configuration=server_config)
# Store as pending connection
self._pending_connections[addr] = quic_conn
# Process initial packet
quic_conn.receive_datagram(data, addr, now=time.time())
await self._process_quic_events(quic_conn, addr)
await self._transmit_for_connection(quic_conn)
logger.debug(f"Started handshake for new connection from {addr}")
except Exception as e:
logger.error(f"Error handling new connection from {addr}: {e}")
self._stats["connections_rejected"] += 1
async def _process_quic_events(
self, quic_conn: QuicConnection, addr: tuple[str, int]
) -> None:
"""Process QUIC events for a connection."""
while True:
event = quic_conn.next_event()
if event is None:
break
if isinstance(event, events.ConnectionTerminated):
logger.debug(
f"Connection from {addr} terminated: {event.reason_phrase}"
)
await self._remove_connection(addr)
break
elif isinstance(event, events.HandshakeCompleted):
logger.debug(f"Handshake completed for {addr}")
await self._promote_pending_connection(quic_conn, addr)
elif isinstance(event, events.StreamDataReceived):
# Forward to established connection if available
if addr in self._connections:
connection = self._connections[addr]
await connection._handle_stream_data(event)
elif isinstance(event, events.StreamReset):
# Forward to established connection if available
if addr in self._connections:
connection = self._connections[addr]
await connection._handle_stream_reset(event)
async def _promote_pending_connection(
self, quic_conn: QuicConnection, addr: tuple[str, int]
) -> None:
"""
Promote a pending connection to an established connection.
Called after successful handshake completion.
Args:
quic_conn: Established QUIC connection
addr: Remote address
"""
try:
# Remove from pending connections
self._pending_connections.pop(addr, None)
# Create multiaddr for this connection
host, port = addr
# Use the first supported QUIC version for now
quic_version = next(iter(self._quic_configs.keys()))
remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}")
# Create libp2p connection wrapper
connection = QUICConnection(
quic_connection=quic_conn,
remote_addr=addr,
peer_id=None, # Will be determined during identity verification
local_peer_id=self._transport._peer_id,
is_initiator=False, # We're the server
maddr=remote_maddr,
transport=self._transport,
)
# Store the connection
self._connections[addr] = connection
# Start connection management tasks
if self._nursery:
self._nursery.start_soon(connection._handle_incoming_data)
self._nursery.start_soon(connection._handle_timer)
# TODO: Verify peer identity
# await connection.verify_peer_identity()
# Call the connection handler
if self._nursery:
self._nursery.start_soon(
self._handle_new_established_connection, connection
)
self._stats["connections_accepted"] += 1
logger.info(f"Accepted new QUIC connection from {addr}")
except Exception as e:
logger.error(f"Error promoting connection from {addr}: {e}")
# Clean up
await self._remove_connection(addr)
self._stats["connections_rejected"] += 1
async def _handle_new_established_connection(
self, connection: QUICConnection
) -> None:
"""
Handle a newly established connection by calling the user handler.
Args:
connection: Established QUIC connection
"""
try:
# Call the connection handler provided by the transport
await self._handler(connection)
except Exception as e:
logger.error(f"Error in connection handler: {e}")
# Close the problematic connection
await connection.close()
async def _transmit_for_connection(self, quic_conn: QuicConnection) -> None:
"""Send pending datagrams for a QUIC connection."""
sock = self._socket
if not sock:
return
for data, addr in quic_conn.datagrams_to_send(now=time.time()):
try:
await sock.sendto(data, addr)
except Exception as e:
logger.error(f"Failed to send datagram to {addr}: {e}")
async def _manage_connections(self) -> None:
"""
Background task to manage connection lifecycle.
Handles cleanup of closed/idle connections.
"""
try:
while not self._closed:
try:
# Sleep for a short interval
await trio.sleep(1.0)
# Clean up closed connections
await self._cleanup_closed_connections()
# Handle connection timeouts
await self._handle_connection_timeouts()
except Exception as e:
logger.error(f"Error in connection management: {e}")
except trio.Cancelled:
print("CONNECTION MANAGER CANCELLED")
raise
finally:
print("CONNECTION MANAGER FINISHED")
async def _cleanup_closed_connections(self) -> None:
"""Remove closed connections from tracking."""
async with self._connection_lock:
closed_addrs = []
for addr, connection in self._connections.items():
if connection.is_closed:
closed_addrs.append(addr)
for addr in closed_addrs:
self._connections.pop(addr, None)
logger.debug(f"Cleaned up closed connection from {addr}")
async def _handle_connection_timeouts(self) -> None:
"""Handle connection timeouts and cleanup."""
# TODO: Implement connection timeout handling
# Check for idle connections and close them
pass
async def _remove_connection(self, addr: tuple[str, int]) -> None:
"""Remove a connection from tracking."""
async with self._connection_lock:
# Remove from active connections
connection = self._connections.pop(addr, None)
if connection:
await connection.close()
# Remove from pending connections
quic_conn = self._pending_connections.pop(addr, None)
if quic_conn:
quic_conn.close()
async def close(self) -> None:
"""Close the listener and cleanup resources."""
if self._closed:
return
self._closed = True
self._listening = False
print("Closing QUIC listener")
# CRITICAL: Close socket FIRST to unblock recvfrom()
await self._cleanup_socket()
print("SOCKET CLEANUP COMPLETE")
# Close all connections WITHOUT using the lock during shutdown
# (avoid deadlock if background tasks are cancelled while holding lock)
connections_to_close = list(self._connections.values())
pending_to_close = list(self._pending_connections.values())
print(
f"CLOSING {len(connections_to_close)} connections and {len(pending_to_close)} pending"
)
# Close active connections
for connection in connections_to_close:
try:
await connection.close()
except Exception as e:
print(f"Error closing connection: {e}")
# Close pending connections
for quic_conn in pending_to_close:
try:
quic_conn.close()
except Exception as e:
print(f"Error closing pending connection: {e}")
# Clear the dictionaries without lock (we're shutting down)
self._connections.clear()
self._pending_connections.clear()
if self._nursery:
print("TASKS", len(self._nursery.child_tasks))
print("QUIC listener closed")
async def _cleanup_socket(self) -> None:
"""Clean up the UDP socket."""
if self._socket:
try:
self._socket.close()
except Exception as e:
logger.error(f"Error closing socket: {e}")
finally:
self._socket = None
def get_addrs(self) -> tuple[Multiaddr, ...]:
"""
Get the addresses this listener is bound to.
Returns:
Tuple of bound multiaddrs
"""
return tuple(self._bound_addresses)
def is_listening(self) -> bool:
"""Check if the listener is actively listening."""
return self._listening and not self._closed
def get_stats(self) -> dict:
"""Get listener statistics."""
stats = self._stats.copy()
stats.update(
{
"active_connections": len(self._connections),
"pending_connections": len(self._pending_connections),
"is_listening": self.is_listening(),
}
)
return stats
def __str__(self) -> str:
"""String representation of the listener."""
return f"QUICListener(addrs={self._bound_addresses}, connections={len(self._connections)})"

View File

@ -0,0 +1,123 @@
"""
Basic QUIC Security implementation for Module 1.
This provides minimal TLS configuration for QUIC transport.
Full implementation will be in Module 5.
"""
from dataclasses import dataclass
import os
import tempfile
from typing import Optional
from libp2p.crypto.keys import PrivateKey
from libp2p.peer.id import ID
from .exceptions import QUICSecurityError
@dataclass
class TLSConfig:
"""TLS configuration for QUIC transport."""
cert_file: str
key_file: str
ca_file: Optional[str] = None
def generate_libp2p_tls_config(private_key: PrivateKey, peer_id: ID) -> TLSConfig:
"""
Generate TLS configuration with libp2p peer identity.
This is a basic implementation for Module 1.
Full implementation with proper libp2p TLS spec compliance
will be provided in Module 5.
Args:
private_key: libp2p private key
peer_id: libp2p peer ID
Returns:
TLS configuration
Raises:
QUICSecurityError: If TLS configuration generation fails
"""
try:
# TODO: Implement proper libp2p TLS certificate generation
# This should follow the libp2p TLS specification:
# https://github.com/libp2p/specs/blob/master/tls/tls.md
# For now, create a basic self-signed certificate
# This is a placeholder implementation
# Create temporary files for cert and key
with tempfile.NamedTemporaryFile(
mode="w", suffix=".pem", delete=False
) as cert_file:
cert_path = cert_file.name
# Write placeholder certificate
cert_file.write(_generate_placeholder_cert(peer_id))
with tempfile.NamedTemporaryFile(
mode="w", suffix=".key", delete=False
) as key_file:
key_path = key_file.name
# Write placeholder private key
key_file.write(_generate_placeholder_key(private_key))
return TLSConfig(cert_file=cert_path, key_file=key_path)
except Exception as e:
raise QUICSecurityError(f"Failed to generate TLS config: {e}") from e
def _generate_placeholder_cert(peer_id: ID) -> str:
"""
Generate a placeholder certificate.
This is a temporary implementation for Module 1.
Real implementation will embed the peer ID in the certificate
following the libp2p TLS specification.
"""
# This is a placeholder - real implementation needed
return f"""-----BEGIN CERTIFICATE-----
# Placeholder certificate for peer {peer_id}
# TODO: Implement proper libp2p TLS certificate generation
# This should embed the peer ID in a certificate extension
# according to the libp2p TLS specification
-----END CERTIFICATE-----"""
def _generate_placeholder_key(private_key: PrivateKey) -> str:
"""
Generate a placeholder private key.
This is a temporary implementation for Module 1.
Real implementation will use the actual libp2p private key.
"""
# This is a placeholder - real implementation needed
return """-----BEGIN PRIVATE KEY-----
# Placeholder private key
# TODO: Convert libp2p private key to TLS-compatible format
-----END PRIVATE KEY-----"""
def cleanup_tls_config(config: TLSConfig) -> None:
"""
Clean up temporary TLS files.
Args:
config: TLS configuration to clean up
"""
try:
if os.path.exists(config.cert_file):
os.unlink(config.cert_file)
if os.path.exists(config.key_file):
os.unlink(config.key_file)
if config.ca_file and os.path.exists(config.ca_file):
os.unlink(config.ca_file)
except Exception:
# Ignore cleanup errors
pass

View File

@ -5,16 +5,17 @@ QUIC Stream implementation
from types import (
TracebackType,
)
from typing import TYPE_CHECKING, cast
import trio
from libp2p.abc import (
IMuxedStream,
)
if TYPE_CHECKING:
from libp2p.abc import IMuxedStream
from .connection import QUICConnection
else:
IMuxedStream = cast(type, object)
from .connection import (
QUICConnection,
)
from .exceptions import (
QUICStreamError,
)
@ -41,7 +42,7 @@ class QUICStream(IMuxedStream):
self._receive_event = trio.Event()
self._close_event = trio.Event()
async def read(self, n: int = -1) -> bytes:
async def read(self, n: int | None = -1) -> bytes:
"""Read data from the stream."""
if self._closed:
raise QUICStreamError("Stream is closed")

View File

@ -14,9 +14,6 @@ from aioquic.quic.connection import (
QuicConnection,
)
import multiaddr
from multiaddr import (
Multiaddr,
)
import trio
from libp2p.abc import (
@ -27,9 +24,15 @@ from libp2p.abc import (
from libp2p.crypto.keys import (
PrivateKey,
)
from libp2p.custom_types import THandler, TProtocol
from libp2p.peer.id import (
ID,
)
from libp2p.transport.quic.utils import (
is_quic_multiaddr,
multiaddr_to_quic_version,
quic_multiaddr_to_endpoint,
)
from .config import (
QUICTransportConfig,
@ -41,21 +44,16 @@ from .exceptions import (
QUICDialError,
QUICListenError,
)
from .listener import (
QUICListener,
)
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
logger = logging.getLogger(__name__)
class QUICListener(IListener):
async def close(self):
pass
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
return False
def get_addrs(self) -> tuple[Multiaddr, ...]:
return ()
class QUICTransport(ITransport):
"""
QUIC Transport implementation following libp2p transport interface.
@ -65,10 +63,6 @@ class QUICTransport(ITransport):
go-libp2p and js-libp2p implementations.
"""
# Protocol identifiers matching go-libp2p
PROTOCOL_QUIC_V1 = "/quic-v1" # RFC 9000
PROTOCOL_QUIC_DRAFT29 = "/quic" # draft-29
def __init__(
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
):
@ -89,7 +83,7 @@ class QUICTransport(ITransport):
self._listeners: list[QUICListener] = []
# QUIC configurations for different versions
self._quic_configs: dict[str, QuicConfiguration] = {}
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
self._setup_quic_configurations()
# Resource management
@ -110,35 +104,36 @@ class QUICTransport(ITransport):
)
# Add TLS certificate generated from libp2p private key
self._setup_tls_configuration(base_config)
# self._setup_tls_configuration(base_config)
# QUIC v1 (RFC 9000) configuration
quic_v1_config = copy.deepcopy(base_config)
quic_v1_config.supported_versions = [0x00000001] # QUIC v1
self._quic_configs[self.PROTOCOL_QUIC_V1] = quic_v1_config
self._quic_configs[QUIC_V1_PROTOCOL] = quic_v1_config
# QUIC draft-29 configuration for compatibility
if self._config.enable_draft29:
draft29_config = copy.deepcopy(base_config)
draft29_config.supported_versions = [0xFF00001D] # draft-29
self._quic_configs[self.PROTOCOL_QUIC_DRAFT29] = draft29_config
self._quic_configs[QUIC_DRAFT29_PROTOCOL] = draft29_config
def _setup_tls_configuration(self, config: QuicConfiguration) -> None:
"""
Setup TLS configuration with libp2p identity integration.
Similar to go-libp2p's certificate generation approach.
"""
from .security import (
generate_libp2p_tls_config,
)
# TODO: SETUP TLS LISTENER
# def _setup_tls_configuration(self, config: QuicConfiguration) -> None:
# """
# Setup TLS configuration with libp2p identity integration.
# Similar to go-libp2p's certificate generation approach.
# """
# from .security import (
# generate_libp2p_tls_config,
# )
# Generate TLS certificate with embedded libp2p peer ID
# This follows the libp2p TLS spec for peer identity verification
tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id)
# # Generate TLS certificate with embedded libp2p peer ID
# # This follows the libp2p TLS spec for peer identity verification
# tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id)
config.load_cert_chain(tls_config.cert_file, tls_config.key_file)
if tls_config.ca_file:
config.load_verify_locations(tls_config.ca_file)
# config.load_cert_chain(certfile=tls_config.cert_file, keyfile=tls_config.key_file)
# if tls_config.ca_file:
# config.load_verify_locations(tls_config.ca_file)
async def dial(
self, maddr: multiaddr.Multiaddr, peer_id: ID | None = None
@ -196,14 +191,17 @@ class QUICTransport(ITransport):
)
# Establish connection using trio
await connection.connect()
# We need a nursery for this - in real usage, this would be provided
# by the caller or we'd use a transport-level nursery
async with trio.open_nursery() as nursery:
await connection.connect(nursery)
# Store connection for management
conn_id = f"{host}:{port}:{peer_id}"
self._connections[conn_id] = connection
# Perform libp2p handshake verification
await connection.verify_peer_identity()
# await connection.verify_peer_identity()
logger.info(f"Successfully dialed QUIC connection to {peer_id}")
return connection
@ -212,9 +210,7 @@ class QUICTransport(ITransport):
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
raise QUICDialError(f"Dial failed: {e}") from e
def create_listener(
self, handler_function: Callable[[ReadWriteCloser], None]
) -> IListener:
def create_listener(self, handler_function: THandler) -> IListener:
"""
Create a QUIC listener.
@ -224,20 +220,22 @@ class QUICTransport(ITransport):
Returns:
QUIC listener instance
Raises:
QUICListenError: If transport is closed
"""
if self._closed:
raise QUICListenError("Transport is closed")
# TODO: Create QUIC Listener
# listener = QUICListener(
# transport=self,
# handler_function=handler_function,
# quic_configs=self._quic_configs,
# config=self._config,
# )
listener = QUICListener()
listener = QUICListener(
transport=self,
handler_function=handler_function,
quic_configs=self._quic_configs,
config=self._config,
)
self._listeners.append(listener)
logger.debug("Created QUIC listener")
return listener
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
@ -253,7 +251,7 @@ class QUICTransport(ITransport):
"""
return is_quic_multiaddr(maddr)
def protocols(self) -> list[str]:
def protocols(self) -> list[TProtocol]:
"""
Get supported protocol identifiers.
@ -261,9 +259,9 @@ class QUICTransport(ITransport):
List of supported protocol strings
"""
protocols = [self.PROTOCOL_QUIC_V1]
protocols = [QUIC_V1_PROTOCOL]
if self._config.enable_draft29:
protocols.append(self.PROTOCOL_QUIC_DRAFT29)
protocols.append(QUIC_DRAFT29_PROTOCOL)
return protocols
def listen_order(self) -> int:
@ -300,6 +298,26 @@ class QUICTransport(ITransport):
logger.info("QUIC transport closed")
def get_stats(self) -> dict:
"""Get transport statistics."""
stats = {
"active_connections": len(self._connections),
"active_listeners": len(self._listeners),
"supported_protocols": self.protocols(),
}
# Aggregate listener stats
listener_stats = {}
for i, listener in enumerate(self._listeners):
listener_stats[f"listener_{i}"] = listener.get_stats()
if listener_stats:
# TODO: Fix type of listener_stats
# type: ignore
stats["listeners"] = listener_stats
return stats
def __str__(self) -> str:
"""String representation of the transport."""
return f"QUICTransport(peer_id={self._peer_id}, protocols={self.protocols()})"

View File

@ -0,0 +1,223 @@
"""
Multiaddr utilities for QUIC transport.
Handles QUIC-specific multiaddr parsing and validation.
"""
from typing import Tuple
import multiaddr
from libp2p.custom_types import TProtocol
from .config import QUICTransportConfig
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
UDP_PROTOCOL = "udp"
IP4_PROTOCOL = "ip4"
IP6_PROTOCOL = "ip6"
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
"""
Check if a multiaddr represents a QUIC address.
Valid QUIC multiaddrs:
- /ip4/127.0.0.1/udp/4001/quic-v1
- /ip4/127.0.0.1/udp/4001/quic
- /ip6/::1/udp/4001/quic-v1
- /ip6/::1/udp/4001/quic
Args:
maddr: Multiaddr to check
Returns:
True if the multiaddr represents a QUIC address
"""
try:
# Get protocol names from the multiaddr string
addr_str = str(maddr)
# Check for required components
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
has_quic = (
addr_str.endswith(f"/{QUIC_V1_PROTOCOL}")
or addr_str.endswith(f"/{QUIC_DRAFT29_PROTOCOL}")
or addr_str.endswith("/quic")
)
return has_ip and has_udp and has_quic
except Exception:
return False
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> Tuple[str, int]:
"""
Extract host and port from a QUIC multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
Tuple of (host, port)
Raises:
ValueError: If multiaddr is not a valid QUIC address
"""
if not is_quic_multiaddr(maddr):
raise ValueError(f"Not a valid QUIC multiaddr: {maddr}")
try:
# Use multiaddr's value_for_protocol method to extract values
host = None
port = None
# Try to get IPv4 address
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except ValueError:
pass
# Try to get IPv6 address if IPv4 not found
if host is None:
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except ValueError:
pass
# Get UDP port
try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP)
port = int(port_str)
except ValueError:
pass
if host is None or port is None:
raise ValueError(f"Could not extract host/port from {maddr}")
return host, port
except Exception as e:
raise ValueError(f"Failed to parse QUIC multiaddr {maddr}: {e}") from e
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
"""
Determine QUIC version from multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
QUIC version identifier ("/quic-v1" or "/quic")
Raises:
ValueError: If multiaddr doesn't contain QUIC protocol
"""
try:
addr_str = str(maddr)
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
return QUIC_V1_PROTOCOL # RFC 9000
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
return QUIC_DRAFT29_PROTOCOL # draft-29
else:
raise ValueError(f"No QUIC protocol found in {maddr}")
except Exception as e:
raise ValueError(f"Failed to determine QUIC version from {maddr}: {e}") from e
def create_quic_multiaddr(
host: str, port: int, version: str = "/quic-v1"
) -> multiaddr.Multiaddr:
"""
Create a QUIC multiaddr from host, port, and version.
Args:
host: IP address (IPv4 or IPv6)
port: UDP port number
version: QUIC version ("/quic-v1" or "/quic")
Returns:
QUIC multiaddr
Raises:
ValueError: If invalid parameters provided
"""
try:
import ipaddress
# Determine IP version
try:
ip = ipaddress.ip_address(host)
if isinstance(ip, ipaddress.IPv4Address):
ip_proto = IP4_PROTOCOL
else:
ip_proto = IP6_PROTOCOL
except ValueError:
raise ValueError(f"Invalid IP address: {host}")
# Validate port
if not (0 <= port <= 65535):
raise ValueError(f"Invalid port: {port}")
# Validate QUIC version
if version not in ["/quic-v1", "/quic"]:
raise ValueError(f"Invalid QUIC version: {version}")
# Construct multiaddr
quic_proto = (
QUIC_V1_PROTOCOL if version == "/quic-v1" else QUIC_DRAFT29_PROTOCOL
)
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
return multiaddr.Multiaddr(addr_str)
except Exception as e:
raise ValueError(f"Failed to create QUIC multiaddr: {e}") from e
def is_quic_v1_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
"""Check if multiaddr uses QUIC v1 (RFC 9000)."""
try:
return multiaddr_to_quic_version(maddr) == "/quic-v1"
except ValueError:
return False
def is_quic_draft29_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
"""Check if multiaddr uses QUIC draft-29."""
try:
return multiaddr_to_quic_version(maddr) == "/quic"
except ValueError:
return False
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
"""
Normalize a QUIC multiaddr to canonical form.
Args:
maddr: Input QUIC multiaddr
Returns:
Normalized multiaddr
Raises:
ValueError: If not a valid QUIC multiaddr
"""
if not is_quic_multiaddr(maddr):
raise ValueError(f"Not a QUIC multiaddr: {maddr}")
host, port = quic_multiaddr_to_endpoint(maddr)
version = multiaddr_to_quic_version(maddr)
return create_quic_multiaddr(host, port, version)

View File

@ -16,6 +16,7 @@ maintainers = [
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
]
dependencies = [
"aioquic>=1.2.0",
"base58>=1.0.3",
"coincurve>=10.0.0",
"exceptiongroup>=1.2.0; python_version < '3.11'",

View File

@ -0,0 +1,119 @@
from unittest.mock import (
Mock,
)
import pytest
from multiaddr.multiaddr import Multiaddr
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.peer.id import ID
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.exceptions import QUICStreamError
class TestQUICConnection:
"""Test suite for QUIC connection functionality."""
@pytest.fixture
def mock_quic_connection(self):
"""Create mock aioquic QuicConnection."""
mock = Mock()
mock.next_event.return_value = None
mock.datagrams_to_send.return_value = []
mock.get_timer.return_value = None
return mock
@pytest.fixture
def quic_connection(self, mock_quic_connection):
"""Create test QUIC connection."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
return QUICConnection(
quic_connection=mock_quic_connection,
remote_addr=("127.0.0.1", 4001),
peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
def test_connection_initialization(self, quic_connection):
"""Test connection initialization."""
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
assert quic_connection.is_initiator is True
assert not quic_connection.is_closed
assert not quic_connection.is_established
assert len(quic_connection._streams) == 0
def test_stream_id_calculation(self):
"""Test stream ID calculation for client/server."""
# Client connection (initiator)
client_conn = QUICConnection(
quic_connection=Mock(),
remote_addr=("127.0.0.1", 4001),
peer_id=None,
local_peer_id=Mock(),
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
assert client_conn._next_stream_id == 0 # Client starts with 0
# Server connection (not initiator)
server_conn = QUICConnection(
quic_connection=Mock(),
remote_addr=("127.0.0.1", 4001),
peer_id=None,
local_peer_id=Mock(),
is_initiator=False,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
assert server_conn._next_stream_id == 1 # Server starts with 1
def test_incoming_stream_detection(self, quic_connection):
"""Test incoming stream detection logic."""
# For client (initiator), odd stream IDs are incoming
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
@pytest.mark.trio
async def test_connection_stats(self, quic_connection):
"""Test connection statistics."""
stats = quic_connection.get_stats()
expected_keys = [
"peer_id",
"remote_addr",
"is_initiator",
"is_established",
"is_closed",
"active_streams",
"next_stream_id",
]
for key in expected_keys:
assert key in stats
@pytest.mark.trio
async def test_connection_close(self, quic_connection):
"""Test connection close functionality."""
assert not quic_connection.is_closed
await quic_connection.close()
assert quic_connection.is_closed
@pytest.mark.trio
async def test_stream_operations_on_closed_connection(self, quic_connection):
"""Test stream operations on closed connection."""
await quic_connection.close()
with pytest.raises(QUICStreamError, match="Connection is closed"):
await quic_connection.open_stream()

View File

@ -0,0 +1,171 @@
from unittest.mock import AsyncMock
import pytest
from multiaddr.multiaddr import Multiaddr
import trio
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.transport.quic.exceptions import (
QUICListenError,
)
from libp2p.transport.quic.listener import QUICListener
from libp2p.transport.quic.transport import (
QUICTransport,
QUICTransportConfig,
)
from libp2p.transport.quic.utils import (
create_quic_multiaddr,
quic_multiaddr_to_endpoint,
)
class TestQUICListener:
"""Test suite for QUIC listener functionality."""
@pytest.fixture
def private_key(self):
"""Generate test private key."""
return create_new_key_pair().private_key
@pytest.fixture
def transport_config(self):
"""Generate test transport configuration."""
return QUICTransportConfig(idle_timeout=10.0)
@pytest.fixture
def transport(self, private_key, transport_config):
"""Create test transport instance."""
return QUICTransport(private_key, transport_config)
@pytest.fixture
def connection_handler(self):
"""Mock connection handler."""
return AsyncMock()
@pytest.fixture
def listener(self, transport, connection_handler):
"""Create test listener."""
return transport.create_listener(connection_handler)
def test_listener_creation(self, transport, connection_handler):
"""Test listener creation."""
listener = transport.create_listener(connection_handler)
assert isinstance(listener, QUICListener)
assert listener._transport == transport
assert listener._handler == connection_handler
assert not listener._listening
assert not listener._closed
@pytest.mark.trio
async def test_listener_invalid_multiaddr(self, listener: QUICListener):
"""Test listener with invalid multiaddr."""
async with trio.open_nursery() as nursery:
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"):
await listener.listen(invalid_addr, nursery)
@pytest.mark.trio
async def test_listener_basic_lifecycle(self, listener: QUICListener):
"""Test basic listener lifecycle."""
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random
async with trio.open_nursery() as nursery:
# Start listening
success = await listener.listen(listen_addr, nursery)
assert success
assert listener.is_listening()
# Check bound addresses
addrs = listener.get_addrs()
assert len(addrs) == 1
# Check stats
stats = listener.get_stats()
assert stats["is_listening"] is True
assert stats["active_connections"] == 0
assert stats["pending_connections"] == 0
# Close listener
await listener.close()
assert not listener.is_listening()
@pytest.mark.trio
async def test_listener_double_listen(self, listener: QUICListener):
"""Test that double listen raises error."""
listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic")
# The nursery is the outer context
async with trio.open_nursery() as nursery:
# The try/finally is now INSIDE the nursery scope
try:
# The listen method creates the socket and starts background tasks
success = await listener.listen(listen_addr, nursery)
assert success
await trio.sleep(0.01)
addrs = listener.get_addrs()
assert len(addrs) > 0
print("ADDRS 1: ", len(addrs))
print("TEST LOGIC FINISHED")
async with trio.open_nursery() as nursery2:
with pytest.raises(QUICListenError, match="Already listening"):
await listener.listen(listen_addr, nursery2)
finally:
# This block runs BEFORE the 'async with nursery' exits.
print("INNER FINALLY: Closing listener to release socket...")
# This closes the socket and sets self._listening = False,
# which helps the background tasks terminate cleanly.
await listener.close()
print("INNER FINALLY: Listener closed.")
# By the time we get here, the listener and its tasks have been fully
# shut down, allowing the nursery to exit without hanging.
print("TEST COMPLETED SUCCESSFULLY.")
@pytest.mark.trio
async def test_listener_port_binding(self, listener: QUICListener):
"""Test listener port binding and cleanup."""
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
# The nursery is the outer context
async with trio.open_nursery() as nursery:
# The try/finally is now INSIDE the nursery scope
try:
# The listen method creates the socket and starts background tasks
success = await listener.listen(listen_addr, nursery)
assert success
await trio.sleep(0.5)
addrs = listener.get_addrs()
assert len(addrs) > 0
print("TEST LOGIC FINISHED")
finally:
# This block runs BEFORE the 'async with nursery' exits.
print("INNER FINALLY: Closing listener to release socket...")
# This closes the socket and sets self._listening = False,
# which helps the background tasks terminate cleanly.
await listener.close()
print("INNER FINALLY: Listener closed.")
# By the time we get here, the listener and its tasks have been fully
# shut down, allowing the nursery to exit without hanging.
print("TEST COMPLETED SUCCESSFULLY.")
@pytest.mark.trio
async def test_listener_stats_tracking(self, listener):
"""Test listener statistics tracking."""
initial_stats = listener.get_stats()
# All counters should start at 0
assert initial_stats["connections_accepted"] == 0
assert initial_stats["connections_rejected"] == 0
assert initial_stats["bytes_received"] == 0
assert initial_stats["packets_processed"] == 0

View File

@ -7,6 +7,7 @@ import pytest
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.crypto.keys import PrivateKey
from libp2p.transport.quic.exceptions import (
QUICDialError,
QUICListenError,
@ -23,7 +24,7 @@ class TestQUICTransport:
@pytest.fixture
def private_key(self):
"""Generate test private key."""
return create_new_key_pair()
return create_new_key_pair().private_key
@pytest.fixture
def transport_config(self):
@ -33,7 +34,7 @@ class TestQUICTransport:
)
@pytest.fixture
def transport(self, private_key, transport_config):
def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig):
"""Create test transport instance."""
return QUICTransport(private_key, transport_config)
@ -47,18 +48,35 @@ class TestQUICTransport:
def test_supported_protocols(self, transport):
"""Test supported protocol identifiers."""
protocols = transport.protocols()
assert "/quic-v1" in protocols
assert "/quic" in protocols # draft-29
# TODO: Update when quic-v1 compatible
# assert "quic-v1" in protocols
assert "quic" in protocols # draft-29
def test_can_dial_quic_addresses(self, transport):
def test_can_dial_quic_addresses(self, transport: QUICTransport):
"""Test multiaddr compatibility checking."""
import multiaddr
# Valid QUIC addresses
valid_addrs = [
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1"),
multiaddr.Multiaddr("/ip4/192.168.1.1/udp/8080/quic"),
multiaddr.Multiaddr("/ip6/::1/udp/4001/quic-v1"),
# TODO: Update Multiaddr package to accept quic-v1
multiaddr.Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
multiaddr.Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
multiaddr.Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
]
for addr in valid_addrs:
@ -93,7 +111,7 @@ class TestQUICTransport:
await transport.close()
with pytest.raises(QUICDialError, match="Transport is closed"):
await transport.dial(multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1"))
await transport.dial(multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"))
def test_create_listener_closed_transport(self, transport):
"""Test creating listener with closed transport raises error."""

View File

@ -0,0 +1,94 @@
import pytest
from multiaddr.multiaddr import Multiaddr
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.utils import (
create_quic_multiaddr,
is_quic_multiaddr,
multiaddr_to_quic_version,
quic_multiaddr_to_endpoint,
)
class TestQUICUtils:
"""Test suite for QUIC utility functions."""
def test_is_quic_multiaddr(self):
"""Test QUIC multiaddr validation."""
# Valid QUIC multiaddrs
valid = [
# TODO: Update Multiaddr package to accept quic-v1
Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
]
for addr in valid:
assert is_quic_multiaddr(addr)
# Invalid multiaddrs
invalid = [
Multiaddr("/ip4/127.0.0.1/tcp/4001"),
Multiaddr("/ip4/127.0.0.1/udp/4001"),
Multiaddr("/ip4/127.0.0.1/udp/4001/ws"),
]
for addr in invalid:
assert not is_quic_multiaddr(addr)
def test_quic_multiaddr_to_endpoint(self):
"""Test multiaddr to endpoint conversion."""
addr = Multiaddr("/ip4/192.168.1.100/udp/4001/quic")
host, port = quic_multiaddr_to_endpoint(addr)
assert host == "192.168.1.100"
assert port == 4001
# Test IPv6
# TODO: Update Multiaddr project to handle ip6
# addr6 = Multiaddr("/ip6/::1/udp/8080/quic")
# host6, port6 = quic_multiaddr_to_endpoint(addr6)
# assert host6 == "::1"
# assert port6 == 8080
def test_create_quic_multiaddr(self):
"""Test QUIC multiaddr creation."""
# IPv4
addr = create_quic_multiaddr("127.0.0.1", 4001, "/quic")
assert str(addr) == "/ip4/127.0.0.1/udp/4001/quic"
# IPv6
addr6 = create_quic_multiaddr("::1", 8080, "/quic")
assert str(addr6) == "/ip6/::1/udp/8080/quic"
def test_multiaddr_to_quic_version(self):
"""Test QUIC version extraction."""
addr = Multiaddr("/ip4/127.0.0.1/udp/4001/quic")
version = multiaddr_to_quic_version(addr)
assert version in ["quic", "quic-v1"] # Depending on implementation
def test_invalid_multiaddr_operations(self):
"""Test error handling for invalid multiaddrs."""
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(ValueError):
quic_multiaddr_to_endpoint(invalid_addr)
with pytest.raises(ValueError):
multiaddr_to_quic_version(invalid_addr)