mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: impl quic listener
This commit is contained in:
@ -5,17 +5,15 @@ from collections.abc import (
|
||||
)
|
||||
from typing import TYPE_CHECKING, NewType, Union, cast
|
||||
|
||||
from libp2p.transport.quic.stream import QUICStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import (
|
||||
IMuxedConn,
|
||||
INetStream,
|
||||
ISecureTransport,
|
||||
)
|
||||
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
|
||||
else:
|
||||
IMuxedConn = cast(type, object)
|
||||
INetStream = cast(type, object)
|
||||
ISecureTransport = cast(type, object)
|
||||
|
||||
IMuxedStream = cast(type, object)
|
||||
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
@ -37,3 +35,4 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
|
||||
|
||||
@ -8,6 +8,8 @@ from dataclasses import (
|
||||
)
|
||||
import ssl
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class QUICTransportConfig:
|
||||
@ -39,6 +41,12 @@ class QUICTransportConfig:
|
||||
max_connections: int = 1000 # Maximum number of connections
|
||||
connection_timeout: float = 10.0 # Connection establishment timeout
|
||||
|
||||
# Protocol identifiers matching go-libp2p
|
||||
# TODO: UNTIL MUITIADDR REPO IS UPDATED
|
||||
# PROTOCOL_QUIC_V1: TProtocol = TProtocol("/quic-v1") # RFC 9000
|
||||
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic") # RFC 9000
|
||||
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not (self.enable_draft29 or self.enable_v1):
|
||||
|
||||
@ -6,6 +6,7 @@ Uses aioquic's sans-IO core with trio for async operations.
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from aioquic.quic import (
|
||||
events,
|
||||
@ -21,9 +22,7 @@ from libp2p.abc import (
|
||||
IMuxedStream,
|
||||
IRawConnection,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
StreamHandlerFn,
|
||||
)
|
||||
from libp2p.custom_types import TQUICStreamHandlerFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -35,9 +34,11 @@ from .exceptions import (
|
||||
from .stream import (
|
||||
QUICStream,
|
||||
)
|
||||
from .transport import (
|
||||
QUICTransport,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .transport import (
|
||||
QUICTransport,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -49,76 +50,177 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
Uses aioquic's sans-IO core with trio for native async support.
|
||||
QUIC natively provides stream multiplexing, so this connection acts as both
|
||||
a raw connection (for transport layer) and muxed connection (for upper layers).
|
||||
|
||||
Updated to work properly with the QUIC listener for server-side connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quic_connection: QuicConnection,
|
||||
remote_addr: tuple[str, int],
|
||||
peer_id: ID,
|
||||
peer_id: ID | None,
|
||||
local_peer_id: ID,
|
||||
initiator: bool,
|
||||
is_initiator: bool,
|
||||
maddr: multiaddr.Multiaddr,
|
||||
transport: QUICTransport,
|
||||
transport: "QUICTransport",
|
||||
):
|
||||
self._quic = quic_connection
|
||||
self._remote_addr = remote_addr
|
||||
self._peer_id = peer_id
|
||||
self._local_peer_id = local_peer_id
|
||||
self.__is_initiator = initiator
|
||||
self.__is_initiator = is_initiator
|
||||
self._maddr = maddr
|
||||
self._transport = transport
|
||||
|
||||
# Trio networking
|
||||
# Trio networking - socket may be provided by listener
|
||||
self._socket: trio.socket.SocketType | None = None
|
||||
self._connected_event = trio.Event()
|
||||
self._closed_event = trio.Event()
|
||||
|
||||
# Stream management
|
||||
self._streams: dict[int, QUICStream] = {}
|
||||
self._next_stream_id: int = (
|
||||
0 if initiator else 1
|
||||
) # Even for initiator, odd for responder
|
||||
self._stream_handler: StreamHandlerFn | None = None
|
||||
self._next_stream_id: int = self._calculate_initial_stream_id()
|
||||
self._stream_handler: TQUICStreamHandlerFn | None = None
|
||||
self._stream_id_lock = trio.Lock()
|
||||
|
||||
# Connection state
|
||||
self._closed = False
|
||||
self._timer_task = None
|
||||
self._established = False
|
||||
self._started = False
|
||||
|
||||
logger.debug(f"Created QUIC connection to {peer_id}")
|
||||
# Background task management
|
||||
self._background_tasks_started = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
logger.debug(f"Created QUIC connection to {peer_id} (initiator: {is_initiator})")
|
||||
|
||||
def _calculate_initial_stream_id(self) -> int:
|
||||
"""
|
||||
Calculate the initial stream ID based on QUIC specification.
|
||||
|
||||
QUIC stream IDs:
|
||||
- Client-initiated bidirectional: 0, 4, 8, 12, ...
|
||||
- Server-initiated bidirectional: 1, 5, 9, 13, ...
|
||||
- Client-initiated unidirectional: 2, 6, 10, 14, ...
|
||||
- Server-initiated unidirectional: 3, 7, 11, 15, ...
|
||||
|
||||
For libp2p, we primarily use bidirectional streams.
|
||||
"""
|
||||
if self.__is_initiator:
|
||||
return 0 # Client starts with 0, then 4, 8, 12...
|
||||
else:
|
||||
return 1 # Server starts with 1, then 5, 9, 13...
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool: # type: ignore
|
||||
return self.__is_initiator
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish the QUIC connection using trio."""
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the connection and its background tasks.
|
||||
|
||||
This method implements the IMuxedConn.start() interface.
|
||||
It should be called to begin processing connection events.
|
||||
"""
|
||||
if self._started:
|
||||
logger.warning("Connection already started")
|
||||
return
|
||||
|
||||
if self._closed:
|
||||
raise QUICConnectionError("Cannot start a closed connection")
|
||||
|
||||
self._started = True
|
||||
logger.debug(f"Starting QUIC connection to {self._peer_id}")
|
||||
|
||||
# If this is a client connection, we need to establish the connection
|
||||
if self.__is_initiator:
|
||||
await self._initiate_connection()
|
||||
else:
|
||||
# For server connections, we're already connected via the listener
|
||||
self._established = True
|
||||
self._connected_event.set()
|
||||
|
||||
logger.debug(f"QUIC connection to {self._peer_id} started")
|
||||
|
||||
async def _initiate_connection(self) -> None:
|
||||
"""Initiate client-side connection establishment."""
|
||||
try:
|
||||
# Create UDP socket using trio
|
||||
self._socket = trio.socket.socket(
|
||||
family=socket.AF_INET, type=socket.SOCK_DGRAM
|
||||
)
|
||||
|
||||
# Connect the socket to the remote address
|
||||
await self._socket.connect(self._remote_addr)
|
||||
|
||||
# Start the connection establishment
|
||||
self._quic.connect(self._remote_addr, now=time.time())
|
||||
|
||||
# Send initial packet(s)
|
||||
await self._transmit()
|
||||
|
||||
# Start background tasks using trio nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
self._handle_incoming_data, None, "QUIC INCOMING DATA"
|
||||
)
|
||||
nursery.start_soon(self._handle_timer, None, "QUIC TIMER HANDLER")
|
||||
# For client connections, we need to manage our own background tasks
|
||||
# In a real implementation, this would be managed by the transport
|
||||
# For now, we'll start them here
|
||||
if not self._background_tasks_started:
|
||||
# We would need a nursery to start background tasks
|
||||
# This is a limitation of the current design
|
||||
logger.warning("Background tasks need nursery - connection may not work properly")
|
||||
|
||||
# Wait for connection to be established
|
||||
await self._connected_event.wait()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate connection: {e}")
|
||||
raise QUICConnectionError(f"Connection initiation failed: {e}") from e
|
||||
|
||||
async def connect(self, nursery: trio.Nursery) -> None:
|
||||
"""
|
||||
Establish the QUIC connection using trio.
|
||||
|
||||
Args:
|
||||
nursery: Trio nursery for background tasks
|
||||
|
||||
"""
|
||||
if not self.__is_initiator:
|
||||
raise QUICConnectionError("connect() should only be called by client connections")
|
||||
|
||||
try:
|
||||
# Store nursery for background tasks
|
||||
self._nursery = nursery
|
||||
|
||||
# Create UDP socket using trio
|
||||
self._socket = trio.socket.socket(
|
||||
family=socket.AF_INET, type=socket.SOCK_DGRAM
|
||||
)
|
||||
|
||||
# Connect the socket to the remote address
|
||||
await self._socket.connect(self._remote_addr)
|
||||
|
||||
# Start the connection establishment
|
||||
self._quic.connect(self._remote_addr, now=time.time())
|
||||
|
||||
# Send initial packet(s)
|
||||
await self._transmit()
|
||||
|
||||
# Start background tasks
|
||||
await self._start_background_tasks(nursery)
|
||||
|
||||
# Wait for connection to be established
|
||||
await self._connected_event.wait()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect: {e}")
|
||||
raise QUICConnectionError(f"Connection failed: {e}") from e
|
||||
|
||||
async def _start_background_tasks(self, nursery: trio.Nursery) -> None:
|
||||
"""Start background tasks for connection management."""
|
||||
if self._background_tasks_started:
|
||||
return
|
||||
|
||||
self._background_tasks_started = True
|
||||
|
||||
# Start background tasks
|
||||
nursery.start_soon(self._handle_incoming_data)
|
||||
nursery.start_soon(self._handle_timer)
|
||||
|
||||
async def _handle_incoming_data(self) -> None:
|
||||
"""Handle incoming UDP datagrams in trio."""
|
||||
while not self._closed:
|
||||
@ -128,6 +230,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._quic.receive_datagram(data, addr, now=time.time())
|
||||
await self._process_events()
|
||||
await self._transmit()
|
||||
|
||||
# Small delay to prevent busy waiting
|
||||
await trio.sleep(0.001)
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
break
|
||||
except Exception as e:
|
||||
@ -137,18 +243,26 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
async def _handle_timer(self) -> None:
|
||||
"""Handle QUIC timer events in trio."""
|
||||
while not self._closed:
|
||||
timer_at = self._quic.get_timer()
|
||||
if timer_at is None:
|
||||
await trio.sleep(1.0) # No timer set, check again later
|
||||
continue
|
||||
try:
|
||||
timer_at = self._quic.get_timer()
|
||||
if timer_at is None:
|
||||
await trio.sleep(0.1) # No timer set, check again later
|
||||
continue
|
||||
|
||||
now = time.time()
|
||||
if timer_at <= now:
|
||||
self._quic.handle_timer(now=now)
|
||||
await self._process_events()
|
||||
await self._transmit()
|
||||
else:
|
||||
await trio.sleep(timer_at - now)
|
||||
now = time.time()
|
||||
if timer_at <= now:
|
||||
self._quic.handle_timer(now=now)
|
||||
await self._process_events()
|
||||
await self._transmit()
|
||||
await trio.sleep(0.001) # Small delay
|
||||
else:
|
||||
# Sleep until timer fires, but check periodically
|
||||
sleep_time = min(timer_at - now, 0.1)
|
||||
await trio.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in timer handler: {e}")
|
||||
await trio.sleep(0.1)
|
||||
|
||||
async def _process_events(self) -> None:
|
||||
"""Process QUIC events from aioquic core."""
|
||||
@ -165,6 +279,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
logger.debug("QUIC handshake completed")
|
||||
self._established = True
|
||||
self._connected_event.set()
|
||||
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
@ -177,25 +292,47 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Handle incoming stream data."""
|
||||
stream_id = event.stream_id
|
||||
|
||||
# Get or create stream
|
||||
if stream_id not in self._streams:
|
||||
# Create new stream for incoming data
|
||||
# Determine if this is an incoming stream
|
||||
is_incoming = self._is_incoming_stream(stream_id)
|
||||
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
is_initiator=False, # pyrefly: ignore
|
||||
is_initiator=not is_incoming,
|
||||
)
|
||||
self._streams[stream_id] = stream
|
||||
|
||||
# Notify stream handler if available
|
||||
if self._stream_handler:
|
||||
# Use trio nursery to start stream handler
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._stream_handler, stream)
|
||||
# Notify stream handler for incoming streams
|
||||
if is_incoming and self._stream_handler:
|
||||
# Start stream handler in background
|
||||
# In a real implementation, you might want to use the nursery
|
||||
# passed to the connection, but for now we'll handle it directly
|
||||
try:
|
||||
await self._stream_handler(stream)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream handler: {e}")
|
||||
|
||||
# Forward data to stream
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_data_received(event.data, event.end_stream)
|
||||
|
||||
def _is_incoming_stream(self, stream_id: int) -> bool:
|
||||
"""
|
||||
Determine if a stream ID represents an incoming stream.
|
||||
|
||||
For bidirectional streams:
|
||||
- Even IDs are client-initiated
|
||||
- Odd IDs are server-initiated
|
||||
"""
|
||||
if self.__is_initiator:
|
||||
# We're the client, so odd stream IDs are incoming
|
||||
return stream_id % 2 == 1
|
||||
else:
|
||||
# We're the server, so even stream IDs are incoming
|
||||
return stream_id % 2 == 0
|
||||
|
||||
async def _handle_stream_reset(self, event: events.StreamReset) -> None:
|
||||
"""Handle stream reset."""
|
||||
stream_id = event.stream_id
|
||||
@ -210,15 +347,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if socket is None:
|
||||
return
|
||||
|
||||
for data, addr in self._quic.datagrams_to_send(now=time.time()):
|
||||
try:
|
||||
try:
|
||||
for data, addr in self._quic.datagrams_to_send(now=time.time()):
|
||||
await socket.sendto(data, addr)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send datagram: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send datagram: {e}")
|
||||
|
||||
# IRawConnection interface
|
||||
|
||||
async def write(self, data: bytes):
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the connection.
|
||||
For QUIC, this creates a new stream for each write operation.
|
||||
@ -230,7 +367,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await stream.write(data)
|
||||
await stream.close()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int | None = -1) -> bytes:
|
||||
"""
|
||||
Read data from the connection.
|
||||
For QUIC, this reads from the next available stream.
|
||||
@ -252,14 +389,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._closed = True
|
||||
logger.debug(f"Closing QUIC connection to {self._peer_id}")
|
||||
|
||||
# Close all streams using trio nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
for stream in self._streams.values():
|
||||
nursery.start_soon(stream.close)
|
||||
# Close all streams
|
||||
stream_close_tasks = []
|
||||
for stream in list(self._streams.values()):
|
||||
stream_close_tasks.append(stream.close())
|
||||
|
||||
if stream_close_tasks:
|
||||
# Close streams concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
for task in stream_close_tasks:
|
||||
nursery.start_soon(lambda t=task: t)
|
||||
|
||||
# Close QUIC connection
|
||||
self._quic.close()
|
||||
await self._transmit() # Send close frames
|
||||
if self._socket:
|
||||
await self._transmit() # Send close frames
|
||||
|
||||
# Close socket
|
||||
if self._socket:
|
||||
@ -275,6 +419,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Check if connection is closed."""
|
||||
return self._closed
|
||||
|
||||
@property
|
||||
def is_established(self) -> bool:
|
||||
"""Check if connection is established (handshake completed)."""
|
||||
return self._established
|
||||
|
||||
@property
|
||||
def is_started(self) -> bool:
|
||||
"""Check if connection has been started."""
|
||||
return self._started
|
||||
|
||||
def multiaddr(self) -> multiaddr.Multiaddr:
|
||||
"""Get the multiaddr for this connection."""
|
||||
return self._maddr
|
||||
@ -283,6 +437,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Get the local peer ID."""
|
||||
return self._local_peer_id
|
||||
|
||||
def remote_peer_id(self) -> ID | None:
|
||||
"""Get the remote peer ID."""
|
||||
return self._peer_id
|
||||
|
||||
# IMuxedConn interface
|
||||
|
||||
async def open_stream(self) -> IMuxedStream:
|
||||
@ -296,23 +454,27 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if self._closed:
|
||||
raise QUICStreamError("Connection is closed")
|
||||
|
||||
# Generate next stream ID
|
||||
stream_id = self._next_stream_id
|
||||
self._next_stream_id += (
|
||||
2 # Increment by 2 to maintain initiator/responder distinction
|
||||
)
|
||||
if not self._started:
|
||||
raise QUICStreamError("Connection not started")
|
||||
|
||||
# Create stream
|
||||
stream = QUICStream(
|
||||
connection=self, stream_id=stream_id, is_initiator=True
|
||||
) # pyrefly: ignore
|
||||
async with self._stream_id_lock:
|
||||
# Generate next stream ID
|
||||
stream_id = self._next_stream_id
|
||||
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
|
||||
|
||||
self._streams[stream_id] = stream
|
||||
# Create stream
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
is_initiator=True
|
||||
)
|
||||
|
||||
self._streams[stream_id] = stream
|
||||
|
||||
logger.debug(f"Opened QUIC stream {stream_id}")
|
||||
return stream
|
||||
|
||||
def set_stream_handler(self, handler_function: StreamHandlerFn) -> None:
|
||||
def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None:
|
||||
"""
|
||||
Set handler for incoming streams.
|
||||
|
||||
@ -341,17 +503,22 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""
|
||||
# Extract peer ID from TLS certificate
|
||||
# This should match the expected peer ID
|
||||
cert_peer_id = self._extract_peer_id_from_cert()
|
||||
try:
|
||||
cert_peer_id = self._extract_peer_id_from_cert()
|
||||
|
||||
if self._peer_id and cert_peer_id != self._peer_id:
|
||||
raise QUICConnectionError(
|
||||
f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}"
|
||||
)
|
||||
if self._peer_id and cert_peer_id != self._peer_id:
|
||||
raise QUICConnectionError(
|
||||
f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}"
|
||||
)
|
||||
|
||||
if not self._peer_id:
|
||||
self._peer_id = cert_peer_id
|
||||
if not self._peer_id:
|
||||
self._peer_id = cert_peer_id
|
||||
|
||||
logger.debug(f"Verified peer identity: {self._peer_id}")
|
||||
logger.debug(f"Verified peer identity: {self._peer_id}")
|
||||
|
||||
except NotImplementedError:
|
||||
logger.warning("Peer identity verification not implemented - skipping")
|
||||
# For now, we'll skip verification during development
|
||||
|
||||
def _extract_peer_id_from_cert(self) -> ID:
|
||||
"""Extract peer ID from TLS certificate."""
|
||||
@ -363,6 +530,22 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# The certificate should contain the peer ID in a specific extension
|
||||
raise NotImplementedError("Certificate peer ID extraction not implemented")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get connection statistics."""
|
||||
return {
|
||||
"peer_id": str(self._peer_id),
|
||||
"remote_addr": self._remote_addr,
|
||||
"is_initiator": self.__is_initiator,
|
||||
"is_established": self._established,
|
||||
"is_closed": self._closed,
|
||||
"is_started": self._started,
|
||||
"active_streams": len(self._streams),
|
||||
"next_stream_id": self._next_stream_id,
|
||||
}
|
||||
|
||||
def get_remote_address(self):
|
||||
return self._remote_addr
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the connection."""
|
||||
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)})"
|
||||
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)}, established={self._established}, started={self._started})"
|
||||
|
||||
579
libp2p/transport/quic/listener.py
Normal file
579
libp2p/transport/quic/listener.py
Normal file
@ -0,0 +1,579 @@
|
||||
"""
|
||||
QUIC Listener implementation for py-libp2p.
|
||||
Based on go-libp2p and js-libp2p QUIC listener patterns.
|
||||
Uses aioquic's server-side QUIC implementation with trio.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from aioquic.quic import events
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import IListener
|
||||
from libp2p.custom_types import THandler, TProtocol
|
||||
|
||||
from .config import QUICTransportConfig
|
||||
from .connection import QUICConnection
|
||||
from .exceptions import QUICListenError
|
||||
from .utils import (
|
||||
create_quic_multiaddr,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
quic_multiaddr_to_endpoint,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .transport import QUICTransport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
|
||||
class QUICListener(IListener):
|
||||
"""
|
||||
QUIC Listener implementation following libp2p listener interface.
|
||||
|
||||
Handles incoming QUIC connections, manages server-side handshakes,
|
||||
and integrates with the libp2p connection handler system.
|
||||
Based on go-libp2p and js-libp2p listener patterns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: "QUICTransport",
|
||||
handler_function: THandler,
|
||||
quic_configs: Dict[TProtocol, QuicConfiguration],
|
||||
config: QUICTransportConfig,
|
||||
):
|
||||
"""
|
||||
Initialize QUIC listener.
|
||||
|
||||
Args:
|
||||
transport: Parent QUIC transport
|
||||
handler_function: Function to handle new connections
|
||||
quic_configs: QUIC configurations for different versions
|
||||
config: QUIC transport configuration
|
||||
|
||||
"""
|
||||
self._transport = transport
|
||||
self._handler = handler_function
|
||||
self._quic_configs = quic_configs
|
||||
self._config = config
|
||||
|
||||
# Network components
|
||||
self._socket: trio.socket.SocketType | None = None
|
||||
self._bound_addresses: list[Multiaddr] = []
|
||||
|
||||
# Connection management
|
||||
self._connections: Dict[tuple[str, int], QUICConnection] = {}
|
||||
self._pending_connections: Dict[tuple[str, int], QuicConnection] = {}
|
||||
self._connection_lock = trio.Lock()
|
||||
|
||||
# Listener state
|
||||
self._closed = False
|
||||
self._listening = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
# Performance tracking
|
||||
self._stats = {
|
||||
"connections_accepted": 0,
|
||||
"connections_rejected": 0,
|
||||
"bytes_received": 0,
|
||||
"packets_processed": 0,
|
||||
}
|
||||
|
||||
logger.debug("Initialized QUIC listener")
|
||||
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
"""
|
||||
Start listening on the given multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to listen on
|
||||
nursery: Trio nursery for managing background tasks
|
||||
|
||||
Returns:
|
||||
True if listening started successfully
|
||||
|
||||
Raises:
|
||||
QUICListenError: If failed to start listening
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}")
|
||||
|
||||
if self._listening:
|
||||
raise QUICListenError("Already listening")
|
||||
|
||||
try:
|
||||
# Extract host and port from multiaddr
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
quic_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
# Validate QUIC version support
|
||||
if quic_version not in self._quic_configs:
|
||||
raise QUICListenError(f"Unsupported QUIC version: {quic_version}")
|
||||
|
||||
# Create and bind UDP socket
|
||||
self._socket = await self._create_and_bind_socket(host, port)
|
||||
actual_port = self._socket.getsockname()[1]
|
||||
|
||||
# Update multiaddr with actual bound port
|
||||
actual_maddr = create_quic_multiaddr(host, actual_port, f"/{quic_version}")
|
||||
self._bound_addresses = [actual_maddr]
|
||||
|
||||
# Store nursery reference and set listening state
|
||||
self._nursery = nursery
|
||||
self._listening = True
|
||||
|
||||
# Start background tasks directly in the provided nursery
|
||||
# This ensures proper cancellation when the nursery exits
|
||||
nursery.start_soon(self._handle_incoming_packets)
|
||||
nursery.start_soon(self._manage_connections)
|
||||
|
||||
print(f"QUIC listener started on {actual_maddr}")
|
||||
return True
|
||||
|
||||
except trio.Cancelled:
|
||||
print("CLOSING LISTENER")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start QUIC listener on {maddr}: {e}")
|
||||
await self._cleanup_socket()
|
||||
raise QUICListenError(f"Listen failed: {e}") from e
|
||||
|
||||
async def _create_and_bind_socket(
|
||||
self, host: str, port: int
|
||||
) -> trio.socket.SocketType:
|
||||
"""Create and bind UDP socket for QUIC."""
|
||||
try:
|
||||
# Determine address family
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
ip = ipaddress.ip_address(host)
|
||||
family = socket.AF_INET if ip.version == 4 else socket.AF_INET6
|
||||
except ValueError:
|
||||
# Assume IPv4 for hostnames
|
||||
family = socket.AF_INET
|
||||
|
||||
# Create UDP socket
|
||||
sock = trio.socket.socket(family=family, type=socket.SOCK_DGRAM)
|
||||
|
||||
# Set socket options for better performance
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
if hasattr(socket, "SO_REUSEPORT"):
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
|
||||
# Bind to address
|
||||
await sock.bind((host, port))
|
||||
|
||||
logger.debug(f"Created and bound UDP socket to {host}:{port}")
|
||||
return sock
|
||||
|
||||
except Exception as e:
|
||||
raise QUICListenError(f"Failed to create socket: {e}") from e
|
||||
|
||||
async def _handle_incoming_packets(self) -> None:
|
||||
"""
|
||||
Handle incoming UDP packets and route to appropriate connections.
|
||||
This is the main packet processing loop.
|
||||
"""
|
||||
logger.debug("Started packet handling loop")
|
||||
|
||||
try:
|
||||
while self._listening and self._socket:
|
||||
try:
|
||||
# Receive UDP packet (this blocks until packet arrives or socket closes)
|
||||
data, addr = await self._socket.recvfrom(65536)
|
||||
self._stats["bytes_received"] += len(data)
|
||||
self._stats["packets_processed"] += 1
|
||||
|
||||
# Process packet asynchronously to avoid blocking
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(self._process_packet, data, addr)
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
# Socket was closed, exit gracefully
|
||||
logger.debug("Socket closed, exiting packet handler")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving packet: {e}")
|
||||
# Continue processing other packets
|
||||
await trio.sleep(0.01)
|
||||
except trio.Cancelled:
|
||||
print("PACKET HANDLER CANCELLED - FORCIBLY CLOSING SOCKET")
|
||||
raise
|
||||
finally:
|
||||
print("PACKET HANDLER FINISHED")
|
||||
logger.debug("Packet handling loop terminated")
|
||||
|
||||
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""
|
||||
Process a single incoming packet.
|
||||
Routes to existing connection or creates new connection.
|
||||
|
||||
Args:
|
||||
data: Raw UDP packet data
|
||||
addr: Source address (host, port)
|
||||
|
||||
"""
|
||||
try:
|
||||
async with self._connection_lock:
|
||||
# Check if we have an existing connection for this address
|
||||
if addr in self._connections:
|
||||
connection = self._connections[addr]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
elif addr in self._pending_connections:
|
||||
# Handle packet for pending connection
|
||||
quic_conn = self._pending_connections[addr]
|
||||
await self._handle_pending_connection(quic_conn, data, addr)
|
||||
else:
|
||||
# New connection
|
||||
await self._handle_new_connection(data, addr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing packet from {addr}: {e}")
|
||||
|
||||
async def _route_to_connection(
|
||||
self, connection: QUICConnection, data: bytes, addr: tuple[str, int]
|
||||
) -> None:
|
||||
"""Route packet to existing connection."""
|
||||
try:
|
||||
# Feed data to the connection's QUIC instance
|
||||
connection._quic.receive_datagram(data, addr, now=time.time())
|
||||
|
||||
# Process events and handle responses
|
||||
await connection._process_events()
|
||||
await connection._transmit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error routing packet to connection {addr}: {e}")
|
||||
# Remove problematic connection
|
||||
await self._remove_connection(addr)
|
||||
|
||||
async def _handle_pending_connection(
|
||||
self, quic_conn: QuicConnection, data: bytes, addr: tuple[str, int]
|
||||
) -> None:
|
||||
"""Handle packet for a pending (handshaking) connection."""
|
||||
try:
|
||||
# Feed data to QUIC connection
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
|
||||
# Process events
|
||||
await self._process_quic_events(quic_conn, addr)
|
||||
|
||||
# Send any outgoing packets
|
||||
await self._transmit_for_connection(quic_conn)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling pending connection {addr}: {e}")
|
||||
# Remove from pending connections
|
||||
self._pending_connections.pop(addr, None)
|
||||
|
||||
async def _handle_new_connection(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""
|
||||
Handle a new incoming connection.
|
||||
Creates a new QUIC connection and starts handshake.
|
||||
|
||||
Args:
|
||||
data: Initial packet data
|
||||
addr: Source address
|
||||
|
||||
"""
|
||||
try:
|
||||
# Determine QUIC version from packet
|
||||
# For now, use the first available configuration
|
||||
# TODO: Implement proper version negotiation
|
||||
quic_version = next(iter(self._quic_configs.keys()))
|
||||
config = self._quic_configs[quic_version]
|
||||
|
||||
# Create server-side QUIC configuration
|
||||
server_config = copy.deepcopy(config)
|
||||
server_config.is_client = False
|
||||
|
||||
# Create QUIC connection
|
||||
quic_conn = QuicConnection(configuration=server_config)
|
||||
|
||||
# Store as pending connection
|
||||
self._pending_connections[addr] = quic_conn
|
||||
|
||||
# Process initial packet
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
await self._process_quic_events(quic_conn, addr)
|
||||
await self._transmit_for_connection(quic_conn)
|
||||
|
||||
logger.debug(f"Started handshake for new connection from {addr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling new connection from {addr}: {e}")
|
||||
self._stats["connections_rejected"] += 1
|
||||
|
||||
async def _process_quic_events(
|
||||
self, quic_conn: QuicConnection, addr: tuple[str, int]
|
||||
) -> None:
|
||||
"""Process QUIC events for a connection."""
|
||||
while True:
|
||||
event = quic_conn.next_event()
|
||||
if event is None:
|
||||
break
|
||||
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
logger.debug(
|
||||
f"Connection from {addr} terminated: {event.reason_phrase}"
|
||||
)
|
||||
await self._remove_connection(addr)
|
||||
break
|
||||
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
logger.debug(f"Handshake completed for {addr}")
|
||||
await self._promote_pending_connection(quic_conn, addr)
|
||||
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
# Forward to established connection if available
|
||||
if addr in self._connections:
|
||||
connection = self._connections[addr]
|
||||
await connection._handle_stream_data(event)
|
||||
|
||||
elif isinstance(event, events.StreamReset):
|
||||
# Forward to established connection if available
|
||||
if addr in self._connections:
|
||||
connection = self._connections[addr]
|
||||
await connection._handle_stream_reset(event)
|
||||
|
||||
async def _promote_pending_connection(
|
||||
self, quic_conn: QuicConnection, addr: tuple[str, int]
|
||||
) -> None:
|
||||
"""
|
||||
Promote a pending connection to an established connection.
|
||||
Called after successful handshake completion.
|
||||
|
||||
Args:
|
||||
quic_conn: Established QUIC connection
|
||||
addr: Remote address
|
||||
|
||||
"""
|
||||
try:
|
||||
# Remove from pending connections
|
||||
self._pending_connections.pop(addr, None)
|
||||
|
||||
# Create multiaddr for this connection
|
||||
host, port = addr
|
||||
# Use the first supported QUIC version for now
|
||||
quic_version = next(iter(self._quic_configs.keys()))
|
||||
remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}")
|
||||
|
||||
# Create libp2p connection wrapper
|
||||
connection = QUICConnection(
|
||||
quic_connection=quic_conn,
|
||||
remote_addr=addr,
|
||||
peer_id=None, # Will be determined during identity verification
|
||||
local_peer_id=self._transport._peer_id,
|
||||
is_initiator=False, # We're the server
|
||||
maddr=remote_maddr,
|
||||
transport=self._transport,
|
||||
)
|
||||
|
||||
# Store the connection
|
||||
self._connections[addr] = connection
|
||||
|
||||
# Start connection management tasks
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(connection._handle_incoming_data)
|
||||
self._nursery.start_soon(connection._handle_timer)
|
||||
|
||||
# TODO: Verify peer identity
|
||||
# await connection.verify_peer_identity()
|
||||
|
||||
# Call the connection handler
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(
|
||||
self._handle_new_established_connection, connection
|
||||
)
|
||||
|
||||
self._stats["connections_accepted"] += 1
|
||||
logger.info(f"Accepted new QUIC connection from {addr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error promoting connection from {addr}: {e}")
|
||||
# Clean up
|
||||
await self._remove_connection(addr)
|
||||
self._stats["connections_rejected"] += 1
|
||||
|
||||
async def _handle_new_established_connection(
|
||||
self, connection: QUICConnection
|
||||
) -> None:
|
||||
"""
|
||||
Handle a newly established connection by calling the user handler.
|
||||
|
||||
Args:
|
||||
connection: Established QUIC connection
|
||||
|
||||
"""
|
||||
try:
|
||||
# Call the connection handler provided by the transport
|
||||
await self._handler(connection)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in connection handler: {e}")
|
||||
# Close the problematic connection
|
||||
await connection.close()
|
||||
|
||||
async def _transmit_for_connection(self, quic_conn: QuicConnection) -> None:
|
||||
"""Send pending datagrams for a QUIC connection."""
|
||||
sock = self._socket
|
||||
if not sock:
|
||||
return
|
||||
|
||||
for data, addr in quic_conn.datagrams_to_send(now=time.time()):
|
||||
try:
|
||||
await sock.sendto(data, addr)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send datagram to {addr}: {e}")
|
||||
|
||||
async def _manage_connections(self) -> None:
|
||||
"""
|
||||
Background task to manage connection lifecycle.
|
||||
Handles cleanup of closed/idle connections.
|
||||
"""
|
||||
try:
|
||||
while not self._closed:
|
||||
try:
|
||||
# Sleep for a short interval
|
||||
await trio.sleep(1.0)
|
||||
|
||||
# Clean up closed connections
|
||||
await self._cleanup_closed_connections()
|
||||
|
||||
# Handle connection timeouts
|
||||
await self._handle_connection_timeouts()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in connection management: {e}")
|
||||
except trio.Cancelled:
|
||||
print("CONNECTION MANAGER CANCELLED")
|
||||
raise
|
||||
finally:
|
||||
print("CONNECTION MANAGER FINISHED")
|
||||
|
||||
async def _cleanup_closed_connections(self) -> None:
|
||||
"""Remove closed connections from tracking."""
|
||||
async with self._connection_lock:
|
||||
closed_addrs = []
|
||||
|
||||
for addr, connection in self._connections.items():
|
||||
if connection.is_closed:
|
||||
closed_addrs.append(addr)
|
||||
|
||||
for addr in closed_addrs:
|
||||
self._connections.pop(addr, None)
|
||||
logger.debug(f"Cleaned up closed connection from {addr}")
|
||||
|
||||
async def _handle_connection_timeouts(self) -> None:
|
||||
"""Handle connection timeouts and cleanup."""
|
||||
# TODO: Implement connection timeout handling
|
||||
# Check for idle connections and close them
|
||||
pass
|
||||
|
||||
async def _remove_connection(self, addr: tuple[str, int]) -> None:
|
||||
"""Remove a connection from tracking."""
|
||||
async with self._connection_lock:
|
||||
# Remove from active connections
|
||||
connection = self._connections.pop(addr, None)
|
||||
if connection:
|
||||
await connection.close()
|
||||
|
||||
# Remove from pending connections
|
||||
quic_conn = self._pending_connections.pop(addr, None)
|
||||
if quic_conn:
|
||||
quic_conn.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the listener and cleanup resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
self._listening = False
|
||||
print("Closing QUIC listener")
|
||||
|
||||
# CRITICAL: Close socket FIRST to unblock recvfrom()
|
||||
await self._cleanup_socket()
|
||||
|
||||
print("SOCKET CLEANUP COMPLETE")
|
||||
|
||||
# Close all connections WITHOUT using the lock during shutdown
|
||||
# (avoid deadlock if background tasks are cancelled while holding lock)
|
||||
connections_to_close = list(self._connections.values())
|
||||
pending_to_close = list(self._pending_connections.values())
|
||||
|
||||
print(
|
||||
f"CLOSING {len(connections_to_close)} connections and {len(pending_to_close)} pending"
|
||||
)
|
||||
|
||||
# Close active connections
|
||||
for connection in connections_to_close:
|
||||
try:
|
||||
await connection.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing connection: {e}")
|
||||
|
||||
# Close pending connections
|
||||
for quic_conn in pending_to_close:
|
||||
try:
|
||||
quic_conn.close()
|
||||
except Exception as e:
|
||||
print(f"Error closing pending connection: {e}")
|
||||
|
||||
# Clear the dictionaries without lock (we're shutting down)
|
||||
self._connections.clear()
|
||||
self._pending_connections.clear()
|
||||
if self._nursery:
|
||||
print("TASKS", len(self._nursery.child_tasks))
|
||||
|
||||
print("QUIC listener closed")
|
||||
|
||||
async def _cleanup_socket(self) -> None:
|
||||
"""Clean up the UDP socket."""
|
||||
if self._socket:
|
||||
try:
|
||||
self._socket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing socket: {e}")
|
||||
finally:
|
||||
self._socket = None
|
||||
|
||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||
"""
|
||||
Get the addresses this listener is bound to.
|
||||
|
||||
Returns:
|
||||
Tuple of bound multiaddrs
|
||||
|
||||
"""
|
||||
return tuple(self._bound_addresses)
|
||||
|
||||
def is_listening(self) -> bool:
|
||||
"""Check if the listener is actively listening."""
|
||||
return self._listening and not self._closed
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get listener statistics."""
|
||||
stats = self._stats.copy()
|
||||
stats.update(
|
||||
{
|
||||
"active_connections": len(self._connections),
|
||||
"pending_connections": len(self._pending_connections),
|
||||
"is_listening": self.is_listening(),
|
||||
}
|
||||
)
|
||||
return stats
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the listener."""
|
||||
return f"QUICListener(addrs={self._bound_addresses}, connections={len(self._connections)})"
|
||||
123
libp2p/transport/quic/security.py
Normal file
123
libp2p/transport/quic/security.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""
|
||||
Basic QUIC Security implementation for Module 1.
|
||||
This provides minimal TLS configuration for QUIC transport.
|
||||
Full implementation will be in Module 5.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
from .exceptions import QUICSecurityError
|
||||
|
||||
|
||||
@dataclass
|
||||
class TLSConfig:
|
||||
"""TLS configuration for QUIC transport."""
|
||||
|
||||
cert_file: str
|
||||
key_file: str
|
||||
ca_file: Optional[str] = None
|
||||
|
||||
|
||||
def generate_libp2p_tls_config(private_key: PrivateKey, peer_id: ID) -> TLSConfig:
|
||||
"""
|
||||
Generate TLS configuration with libp2p peer identity.
|
||||
|
||||
This is a basic implementation for Module 1.
|
||||
Full implementation with proper libp2p TLS spec compliance
|
||||
will be provided in Module 5.
|
||||
|
||||
Args:
|
||||
private_key: libp2p private key
|
||||
peer_id: libp2p peer ID
|
||||
|
||||
Returns:
|
||||
TLS configuration
|
||||
|
||||
Raises:
|
||||
QUICSecurityError: If TLS configuration generation fails
|
||||
|
||||
"""
|
||||
try:
|
||||
# TODO: Implement proper libp2p TLS certificate generation
|
||||
# This should follow the libp2p TLS specification:
|
||||
# https://github.com/libp2p/specs/blob/master/tls/tls.md
|
||||
|
||||
# For now, create a basic self-signed certificate
|
||||
# This is a placeholder implementation
|
||||
|
||||
# Create temporary files for cert and key
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".pem", delete=False
|
||||
) as cert_file:
|
||||
cert_path = cert_file.name
|
||||
# Write placeholder certificate
|
||||
cert_file.write(_generate_placeholder_cert(peer_id))
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".key", delete=False
|
||||
) as key_file:
|
||||
key_path = key_file.name
|
||||
# Write placeholder private key
|
||||
key_file.write(_generate_placeholder_key(private_key))
|
||||
|
||||
return TLSConfig(cert_file=cert_path, key_file=key_path)
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Failed to generate TLS config: {e}") from e
|
||||
|
||||
|
||||
def _generate_placeholder_cert(peer_id: ID) -> str:
|
||||
"""
|
||||
Generate a placeholder certificate.
|
||||
|
||||
This is a temporary implementation for Module 1.
|
||||
Real implementation will embed the peer ID in the certificate
|
||||
following the libp2p TLS specification.
|
||||
"""
|
||||
# This is a placeholder - real implementation needed
|
||||
return f"""-----BEGIN CERTIFICATE-----
|
||||
# Placeholder certificate for peer {peer_id}
|
||||
# TODO: Implement proper libp2p TLS certificate generation
|
||||
# This should embed the peer ID in a certificate extension
|
||||
# according to the libp2p TLS specification
|
||||
-----END CERTIFICATE-----"""
|
||||
|
||||
|
||||
def _generate_placeholder_key(private_key: PrivateKey) -> str:
|
||||
"""
|
||||
Generate a placeholder private key.
|
||||
|
||||
This is a temporary implementation for Module 1.
|
||||
Real implementation will use the actual libp2p private key.
|
||||
"""
|
||||
# This is a placeholder - real implementation needed
|
||||
return """-----BEGIN PRIVATE KEY-----
|
||||
# Placeholder private key
|
||||
# TODO: Convert libp2p private key to TLS-compatible format
|
||||
-----END PRIVATE KEY-----"""
|
||||
|
||||
|
||||
def cleanup_tls_config(config: TLSConfig) -> None:
|
||||
"""
|
||||
Clean up temporary TLS files.
|
||||
|
||||
Args:
|
||||
config: TLS configuration to clean up
|
||||
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(config.cert_file):
|
||||
os.unlink(config.cert_file)
|
||||
if os.path.exists(config.key_file):
|
||||
os.unlink(config.key_file)
|
||||
if config.ca_file and os.path.exists(config.ca_file):
|
||||
os.unlink(config.ca_file)
|
||||
except Exception:
|
||||
# Ignore cleanup errors
|
||||
pass
|
||||
@ -5,16 +5,17 @@ QUIC Stream implementation
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IMuxedStream,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import IMuxedStream
|
||||
|
||||
from .connection import QUICConnection
|
||||
else:
|
||||
IMuxedStream = cast(type, object)
|
||||
|
||||
from .connection import (
|
||||
QUICConnection,
|
||||
)
|
||||
from .exceptions import (
|
||||
QUICStreamError,
|
||||
)
|
||||
@ -41,7 +42,7 @@ class QUICStream(IMuxedStream):
|
||||
self._receive_event = trio.Event()
|
||||
self._close_event = trio.Event()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int | None = -1) -> bytes:
|
||||
"""Read data from the stream."""
|
||||
if self._closed:
|
||||
raise QUICStreamError("Stream is closed")
|
||||
|
||||
@ -14,9 +14,6 @@ from aioquic.quic.connection import (
|
||||
QuicConnection,
|
||||
)
|
||||
import multiaddr
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
@ -27,9 +24,15 @@ from libp2p.abc import (
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
)
|
||||
from libp2p.custom_types import THandler, TProtocol
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.transport.quic.utils import (
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
quic_multiaddr_to_endpoint,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
QUICTransportConfig,
|
||||
@ -41,21 +44,16 @@ from .exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
)
|
||||
from .listener import (
|
||||
QUICListener,
|
||||
)
|
||||
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QUICListener(IListener):
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
return False
|
||||
|
||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||
return ()
|
||||
|
||||
|
||||
class QUICTransport(ITransport):
|
||||
"""
|
||||
QUIC Transport implementation following libp2p transport interface.
|
||||
@ -65,10 +63,6 @@ class QUICTransport(ITransport):
|
||||
go-libp2p and js-libp2p implementations.
|
||||
"""
|
||||
|
||||
# Protocol identifiers matching go-libp2p
|
||||
PROTOCOL_QUIC_V1 = "/quic-v1" # RFC 9000
|
||||
PROTOCOL_QUIC_DRAFT29 = "/quic" # draft-29
|
||||
|
||||
def __init__(
|
||||
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
|
||||
):
|
||||
@ -89,7 +83,7 @@ class QUICTransport(ITransport):
|
||||
self._listeners: list[QUICListener] = []
|
||||
|
||||
# QUIC configurations for different versions
|
||||
self._quic_configs: dict[str, QuicConfiguration] = {}
|
||||
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
|
||||
self._setup_quic_configurations()
|
||||
|
||||
# Resource management
|
||||
@ -110,35 +104,36 @@ class QUICTransport(ITransport):
|
||||
)
|
||||
|
||||
# Add TLS certificate generated from libp2p private key
|
||||
self._setup_tls_configuration(base_config)
|
||||
# self._setup_tls_configuration(base_config)
|
||||
|
||||
# QUIC v1 (RFC 9000) configuration
|
||||
quic_v1_config = copy.deepcopy(base_config)
|
||||
quic_v1_config.supported_versions = [0x00000001] # QUIC v1
|
||||
self._quic_configs[self.PROTOCOL_QUIC_V1] = quic_v1_config
|
||||
self._quic_configs[QUIC_V1_PROTOCOL] = quic_v1_config
|
||||
|
||||
# QUIC draft-29 configuration for compatibility
|
||||
if self._config.enable_draft29:
|
||||
draft29_config = copy.deepcopy(base_config)
|
||||
draft29_config.supported_versions = [0xFF00001D] # draft-29
|
||||
self._quic_configs[self.PROTOCOL_QUIC_DRAFT29] = draft29_config
|
||||
self._quic_configs[QUIC_DRAFT29_PROTOCOL] = draft29_config
|
||||
|
||||
def _setup_tls_configuration(self, config: QuicConfiguration) -> None:
|
||||
"""
|
||||
Setup TLS configuration with libp2p identity integration.
|
||||
Similar to go-libp2p's certificate generation approach.
|
||||
"""
|
||||
from .security import (
|
||||
generate_libp2p_tls_config,
|
||||
)
|
||||
# TODO: SETUP TLS LISTENER
|
||||
# def _setup_tls_configuration(self, config: QuicConfiguration) -> None:
|
||||
# """
|
||||
# Setup TLS configuration with libp2p identity integration.
|
||||
# Similar to go-libp2p's certificate generation approach.
|
||||
# """
|
||||
# from .security import (
|
||||
# generate_libp2p_tls_config,
|
||||
# )
|
||||
|
||||
# Generate TLS certificate with embedded libp2p peer ID
|
||||
# This follows the libp2p TLS spec for peer identity verification
|
||||
tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id)
|
||||
# # Generate TLS certificate with embedded libp2p peer ID
|
||||
# # This follows the libp2p TLS spec for peer identity verification
|
||||
# tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id)
|
||||
|
||||
config.load_cert_chain(tls_config.cert_file, tls_config.key_file)
|
||||
if tls_config.ca_file:
|
||||
config.load_verify_locations(tls_config.ca_file)
|
||||
# config.load_cert_chain(certfile=tls_config.cert_file, keyfile=tls_config.key_file)
|
||||
# if tls_config.ca_file:
|
||||
# config.load_verify_locations(tls_config.ca_file)
|
||||
|
||||
async def dial(
|
||||
self, maddr: multiaddr.Multiaddr, peer_id: ID | None = None
|
||||
@ -196,14 +191,17 @@ class QUICTransport(ITransport):
|
||||
)
|
||||
|
||||
# Establish connection using trio
|
||||
await connection.connect()
|
||||
# We need a nursery for this - in real usage, this would be provided
|
||||
# by the caller or we'd use a transport-level nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
await connection.connect(nursery)
|
||||
|
||||
# Store connection for management
|
||||
conn_id = f"{host}:{port}:{peer_id}"
|
||||
self._connections[conn_id] = connection
|
||||
|
||||
# Perform libp2p handshake verification
|
||||
await connection.verify_peer_identity()
|
||||
# await connection.verify_peer_identity()
|
||||
|
||||
logger.info(f"Successfully dialed QUIC connection to {peer_id}")
|
||||
return connection
|
||||
@ -212,9 +210,7 @@ class QUICTransport(ITransport):
|
||||
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
|
||||
raise QUICDialError(f"Dial failed: {e}") from e
|
||||
|
||||
def create_listener(
|
||||
self, handler_function: Callable[[ReadWriteCloser], None]
|
||||
) -> IListener:
|
||||
def create_listener(self, handler_function: THandler) -> IListener:
|
||||
"""
|
||||
Create a QUIC listener.
|
||||
|
||||
@ -224,20 +220,22 @@ class QUICTransport(ITransport):
|
||||
Returns:
|
||||
QUIC listener instance
|
||||
|
||||
Raises:
|
||||
QUICListenError: If transport is closed
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise QUICListenError("Transport is closed")
|
||||
|
||||
# TODO: Create QUIC Listener
|
||||
# listener = QUICListener(
|
||||
# transport=self,
|
||||
# handler_function=handler_function,
|
||||
# quic_configs=self._quic_configs,
|
||||
# config=self._config,
|
||||
# )
|
||||
listener = QUICListener()
|
||||
listener = QUICListener(
|
||||
transport=self,
|
||||
handler_function=handler_function,
|
||||
quic_configs=self._quic_configs,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
self._listeners.append(listener)
|
||||
logger.debug("Created QUIC listener")
|
||||
return listener
|
||||
|
||||
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
|
||||
@ -253,7 +251,7 @@ class QUICTransport(ITransport):
|
||||
"""
|
||||
return is_quic_multiaddr(maddr)
|
||||
|
||||
def protocols(self) -> list[str]:
|
||||
def protocols(self) -> list[TProtocol]:
|
||||
"""
|
||||
Get supported protocol identifiers.
|
||||
|
||||
@ -261,9 +259,9 @@ class QUICTransport(ITransport):
|
||||
List of supported protocol strings
|
||||
|
||||
"""
|
||||
protocols = [self.PROTOCOL_QUIC_V1]
|
||||
protocols = [QUIC_V1_PROTOCOL]
|
||||
if self._config.enable_draft29:
|
||||
protocols.append(self.PROTOCOL_QUIC_DRAFT29)
|
||||
protocols.append(QUIC_DRAFT29_PROTOCOL)
|
||||
return protocols
|
||||
|
||||
def listen_order(self) -> int:
|
||||
@ -300,6 +298,26 @@ class QUICTransport(ITransport):
|
||||
|
||||
logger.info("QUIC transport closed")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get transport statistics."""
|
||||
stats = {
|
||||
"active_connections": len(self._connections),
|
||||
"active_listeners": len(self._listeners),
|
||||
"supported_protocols": self.protocols(),
|
||||
}
|
||||
|
||||
# Aggregate listener stats
|
||||
listener_stats = {}
|
||||
for i, listener in enumerate(self._listeners):
|
||||
listener_stats[f"listener_{i}"] = listener.get_stats()
|
||||
|
||||
if listener_stats:
|
||||
# TODO: Fix type of listener_stats
|
||||
# type: ignore
|
||||
stats["listeners"] = listener_stats
|
||||
|
||||
return stats
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the transport."""
|
||||
return f"QUICTransport(peer_id={self._peer_id}, protocols={self.protocols()})"
|
||||
|
||||
223
libp2p/transport/quic/utils.py
Normal file
223
libp2p/transport/quic/utils.py
Normal file
@ -0,0 +1,223 @@
|
||||
"""
|
||||
Multiaddr utilities for QUIC transport.
|
||||
Handles QUIC-specific multiaddr parsing and validation.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
|
||||
from .config import QUICTransportConfig
|
||||
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
UDP_PROTOCOL = "udp"
|
||||
IP4_PROTOCOL = "ip4"
|
||||
IP6_PROTOCOL = "ip6"
|
||||
|
||||
|
||||
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""
|
||||
Check if a multiaddr represents a QUIC address.
|
||||
|
||||
Valid QUIC multiaddrs:
|
||||
- /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
- /ip4/127.0.0.1/udp/4001/quic
|
||||
- /ip6/::1/udp/4001/quic-v1
|
||||
- /ip6/::1/udp/4001/quic
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to check
|
||||
|
||||
Returns:
|
||||
True if the multiaddr represents a QUIC address
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get protocol names from the multiaddr string
|
||||
addr_str = str(maddr)
|
||||
|
||||
# Check for required components
|
||||
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
|
||||
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
|
||||
has_quic = (
|
||||
addr_str.endswith(f"/{QUIC_V1_PROTOCOL}")
|
||||
or addr_str.endswith(f"/{QUIC_DRAFT29_PROTOCOL}")
|
||||
or addr_str.endswith("/quic")
|
||||
)
|
||||
|
||||
return has_ip and has_udp and has_quic
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> Tuple[str, int]:
|
||||
"""
|
||||
Extract host and port from a QUIC multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Tuple of (host, port)
|
||||
|
||||
Raises:
|
||||
ValueError: If multiaddr is not a valid QUIC address
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise ValueError(f"Not a valid QUIC multiaddr: {maddr}")
|
||||
|
||||
try:
|
||||
# Use multiaddr's value_for_protocol method to extract values
|
||||
host = None
|
||||
port = None
|
||||
|
||||
# Try to get IPv4 address
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try to get IPv6 address if IPv4 not found
|
||||
if host is None:
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Get UDP port
|
||||
try:
|
||||
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP)
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if host is None or port is None:
|
||||
raise ValueError(f"Could not extract host/port from {maddr}")
|
||||
|
||||
return host, port
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse QUIC multiaddr {maddr}: {e}") from e
|
||||
|
||||
|
||||
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
|
||||
"""
|
||||
Determine QUIC version from multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
QUIC version identifier ("/quic-v1" or "/quic")
|
||||
|
||||
Raises:
|
||||
ValueError: If multiaddr doesn't contain QUIC protocol
|
||||
|
||||
"""
|
||||
try:
|
||||
addr_str = str(maddr)
|
||||
|
||||
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
|
||||
return QUIC_V1_PROTOCOL # RFC 9000
|
||||
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
|
||||
return QUIC_DRAFT29_PROTOCOL # draft-29
|
||||
else:
|
||||
raise ValueError(f"No QUIC protocol found in {maddr}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to determine QUIC version from {maddr}: {e}") from e
|
||||
|
||||
|
||||
def create_quic_multiaddr(
|
||||
host: str, port: int, version: str = "/quic-v1"
|
||||
) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Create a QUIC multiaddr from host, port, and version.
|
||||
|
||||
Args:
|
||||
host: IP address (IPv4 or IPv6)
|
||||
port: UDP port number
|
||||
version: QUIC version ("/quic-v1" or "/quic")
|
||||
|
||||
Returns:
|
||||
QUIC multiaddr
|
||||
|
||||
Raises:
|
||||
ValueError: If invalid parameters provided
|
||||
|
||||
"""
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
# Determine IP version
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
ip_proto = IP4_PROTOCOL
|
||||
else:
|
||||
ip_proto = IP6_PROTOCOL
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid IP address: {host}")
|
||||
|
||||
# Validate port
|
||||
if not (0 <= port <= 65535):
|
||||
raise ValueError(f"Invalid port: {port}")
|
||||
|
||||
# Validate QUIC version
|
||||
if version not in ["/quic-v1", "/quic"]:
|
||||
raise ValueError(f"Invalid QUIC version: {version}")
|
||||
|
||||
# Construct multiaddr
|
||||
quic_proto = (
|
||||
QUIC_V1_PROTOCOL if version == "/quic-v1" else QUIC_DRAFT29_PROTOCOL
|
||||
)
|
||||
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
|
||||
|
||||
return multiaddr.Multiaddr(addr_str)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to create QUIC multiaddr: {e}") from e
|
||||
|
||||
|
||||
def is_quic_v1_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""Check if multiaddr uses QUIC v1 (RFC 9000)."""
|
||||
try:
|
||||
return multiaddr_to_quic_version(maddr) == "/quic-v1"
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_quic_draft29_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""Check if multiaddr uses QUIC draft-29."""
|
||||
try:
|
||||
return multiaddr_to_quic_version(maddr) == "/quic"
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Normalize a QUIC multiaddr to canonical form.
|
||||
|
||||
Args:
|
||||
maddr: Input QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Normalized multiaddr
|
||||
|
||||
Raises:
|
||||
ValueError: If not a valid QUIC multiaddr
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise ValueError(f"Not a QUIC multiaddr: {maddr}")
|
||||
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
return create_quic_multiaddr(host, port, version)
|
||||
@ -16,6 +16,7 @@ maintainers = [
|
||||
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
|
||||
]
|
||||
dependencies = [
|
||||
"aioquic>=1.2.0",
|
||||
"base58>=1.0.3",
|
||||
"coincurve>=10.0.0",
|
||||
"exceptiongroup>=1.2.0; python_version < '3.11'",
|
||||
|
||||
119
tests/core/transport/quic/test_connection.py
Normal file
119
tests/core/transport/quic/test_connection.py
Normal file
@ -0,0 +1,119 @@
|
||||
from unittest.mock import (
|
||||
Mock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.exceptions import QUICStreamError
|
||||
|
||||
|
||||
class TestQUICConnection:
|
||||
"""Test suite for QUIC connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_connection(self):
|
||||
"""Create mock aioquic QuicConnection."""
|
||||
mock = Mock()
|
||||
mock.next_event.return_value = None
|
||||
mock.datagrams_to_send.return_value = []
|
||||
mock.get_timer.return_value = None
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(self, mock_quic_connection):
|
||||
"""Create test QUIC connection."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
|
||||
def test_connection_initialization(self, quic_connection):
|
||||
"""Test connection initialization."""
|
||||
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
|
||||
assert quic_connection.is_initiator is True
|
||||
assert not quic_connection.is_closed
|
||||
assert not quic_connection.is_established
|
||||
assert len(quic_connection._streams) == 0
|
||||
|
||||
def test_stream_id_calculation(self):
|
||||
"""Test stream ID calculation for client/server."""
|
||||
# Client connection (initiator)
|
||||
client_conn = QUICConnection(
|
||||
quic_connection=Mock(),
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
peer_id=None,
|
||||
local_peer_id=Mock(),
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
assert client_conn._next_stream_id == 0 # Client starts with 0
|
||||
|
||||
# Server connection (not initiator)
|
||||
server_conn = QUICConnection(
|
||||
quic_connection=Mock(),
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
peer_id=None,
|
||||
local_peer_id=Mock(),
|
||||
is_initiator=False,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
assert server_conn._next_stream_id == 1 # Server starts with 1
|
||||
|
||||
def test_incoming_stream_detection(self, quic_connection):
|
||||
"""Test incoming stream detection logic."""
|
||||
# For client (initiator), odd stream IDs are incoming
|
||||
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
|
||||
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_stats(self, quic_connection):
|
||||
"""Test connection statistics."""
|
||||
stats = quic_connection.get_stats()
|
||||
|
||||
expected_keys = [
|
||||
"peer_id",
|
||||
"remote_addr",
|
||||
"is_initiator",
|
||||
"is_established",
|
||||
"is_closed",
|
||||
"active_streams",
|
||||
"next_stream_id",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in stats
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_close(self, quic_connection):
|
||||
"""Test connection close functionality."""
|
||||
assert not quic_connection.is_closed
|
||||
|
||||
await quic_connection.close()
|
||||
|
||||
assert quic_connection.is_closed
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_operations_on_closed_connection(self, quic_connection):
|
||||
"""Test stream operations on closed connection."""
|
||||
await quic_connection.close()
|
||||
|
||||
with pytest.raises(QUICStreamError, match="Connection is closed"):
|
||||
await quic_connection.open_stream()
|
||||
171
tests/core/transport/quic/test_listener.py
Normal file
171
tests/core/transport/quic/test_listener.py
Normal file
@ -0,0 +1,171 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICListenError,
|
||||
)
|
||||
from libp2p.transport.quic.listener import QUICListener
|
||||
from libp2p.transport.quic.transport import (
|
||||
QUICTransport,
|
||||
QUICTransportConfig,
|
||||
)
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_quic_multiaddr,
|
||||
quic_multiaddr_to_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class TestQUICListener:
|
||||
"""Test suite for QUIC listener functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def private_key(self):
|
||||
"""Generate test private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
"""Generate test transport configuration."""
|
||||
return QUICTransportConfig(idle_timeout=10.0)
|
||||
|
||||
@pytest.fixture
|
||||
def transport(self, private_key, transport_config):
|
||||
"""Create test transport instance."""
|
||||
return QUICTransport(private_key, transport_config)
|
||||
|
||||
@pytest.fixture
|
||||
def connection_handler(self):
|
||||
"""Mock connection handler."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def listener(self, transport, connection_handler):
|
||||
"""Create test listener."""
|
||||
return transport.create_listener(connection_handler)
|
||||
|
||||
def test_listener_creation(self, transport, connection_handler):
|
||||
"""Test listener creation."""
|
||||
listener = transport.create_listener(connection_handler)
|
||||
|
||||
assert isinstance(listener, QUICListener)
|
||||
assert listener._transport == transport
|
||||
assert listener._handler == connection_handler
|
||||
assert not listener._listening
|
||||
assert not listener._closed
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_invalid_multiaddr(self, listener: QUICListener):
|
||||
"""Test listener with invalid multiaddr."""
|
||||
async with trio.open_nursery() as nursery:
|
||||
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
|
||||
with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"):
|
||||
await listener.listen(invalid_addr, nursery)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_basic_lifecycle(self, listener: QUICListener):
|
||||
"""Test basic listener lifecycle."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start listening
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
assert listener.is_listening()
|
||||
|
||||
# Check bound addresses
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) == 1
|
||||
|
||||
# Check stats
|
||||
stats = listener.get_stats()
|
||||
assert stats["is_listening"] is True
|
||||
assert stats["active_connections"] == 0
|
||||
assert stats["pending_connections"] == 0
|
||||
|
||||
# Close listener
|
||||
await listener.close()
|
||||
assert not listener.is_listening()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_double_listen(self, listener: QUICListener):
|
||||
"""Test that double listen raises error."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic")
|
||||
|
||||
# The nursery is the outer context
|
||||
async with trio.open_nursery() as nursery:
|
||||
# The try/finally is now INSIDE the nursery scope
|
||||
try:
|
||||
# The listen method creates the socket and starts background tasks
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
await trio.sleep(0.01)
|
||||
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) > 0
|
||||
print("ADDRS 1: ", len(addrs))
|
||||
print("TEST LOGIC FINISHED")
|
||||
|
||||
async with trio.open_nursery() as nursery2:
|
||||
with pytest.raises(QUICListenError, match="Already listening"):
|
||||
await listener.listen(listen_addr, nursery2)
|
||||
finally:
|
||||
# This block runs BEFORE the 'async with nursery' exits.
|
||||
print("INNER FINALLY: Closing listener to release socket...")
|
||||
|
||||
# This closes the socket and sets self._listening = False,
|
||||
# which helps the background tasks terminate cleanly.
|
||||
await listener.close()
|
||||
print("INNER FINALLY: Listener closed.")
|
||||
|
||||
# By the time we get here, the listener and its tasks have been fully
|
||||
# shut down, allowing the nursery to exit without hanging.
|
||||
print("TEST COMPLETED SUCCESSFULLY.")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_port_binding(self, listener: QUICListener):
|
||||
"""Test listener port binding and cleanup."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
# The nursery is the outer context
|
||||
async with trio.open_nursery() as nursery:
|
||||
# The try/finally is now INSIDE the nursery scope
|
||||
try:
|
||||
# The listen method creates the socket and starts background tasks
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
await trio.sleep(0.5)
|
||||
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) > 0
|
||||
print("TEST LOGIC FINISHED")
|
||||
|
||||
finally:
|
||||
# This block runs BEFORE the 'async with nursery' exits.
|
||||
print("INNER FINALLY: Closing listener to release socket...")
|
||||
|
||||
# This closes the socket and sets self._listening = False,
|
||||
# which helps the background tasks terminate cleanly.
|
||||
await listener.close()
|
||||
print("INNER FINALLY: Listener closed.")
|
||||
|
||||
# By the time we get here, the listener and its tasks have been fully
|
||||
# shut down, allowing the nursery to exit without hanging.
|
||||
print("TEST COMPLETED SUCCESSFULLY.")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_stats_tracking(self, listener):
|
||||
"""Test listener statistics tracking."""
|
||||
initial_stats = listener.get_stats()
|
||||
|
||||
# All counters should start at 0
|
||||
assert initial_stats["connections_accepted"] == 0
|
||||
assert initial_stats["connections_rejected"] == 0
|
||||
assert initial_stats["bytes_received"] == 0
|
||||
assert initial_stats["packets_processed"] == 0
|
||||
@ -7,6 +7,7 @@ import pytest
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
@ -23,7 +24,7 @@ class TestQUICTransport:
|
||||
@pytest.fixture
|
||||
def private_key(self):
|
||||
"""Generate test private key."""
|
||||
return create_new_key_pair()
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
@ -33,7 +34,7 @@ class TestQUICTransport:
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def transport(self, private_key, transport_config):
|
||||
def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig):
|
||||
"""Create test transport instance."""
|
||||
return QUICTransport(private_key, transport_config)
|
||||
|
||||
@ -47,18 +48,35 @@ class TestQUICTransport:
|
||||
def test_supported_protocols(self, transport):
|
||||
"""Test supported protocol identifiers."""
|
||||
protocols = transport.protocols()
|
||||
assert "/quic-v1" in protocols
|
||||
assert "/quic" in protocols # draft-29
|
||||
# TODO: Update when quic-v1 compatible
|
||||
# assert "quic-v1" in protocols
|
||||
assert "quic" in protocols # draft-29
|
||||
|
||||
def test_can_dial_quic_addresses(self, transport):
|
||||
def test_can_dial_quic_addresses(self, transport: QUICTransport):
|
||||
"""Test multiaddr compatibility checking."""
|
||||
import multiaddr
|
||||
|
||||
# Valid QUIC addresses
|
||||
valid_addrs = [
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1"),
|
||||
multiaddr.Multiaddr("/ip4/192.168.1.1/udp/8080/quic"),
|
||||
multiaddr.Multiaddr("/ip6/::1/udp/4001/quic-v1"),
|
||||
# TODO: Update Multiaddr package to accept quic-v1
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
]
|
||||
|
||||
for addr in valid_addrs:
|
||||
@ -93,7 +111,7 @@ class TestQUICTransport:
|
||||
await transport.close()
|
||||
|
||||
with pytest.raises(QUICDialError, match="Transport is closed"):
|
||||
await transport.dial(multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic-v1"))
|
||||
await transport.dial(multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"))
|
||||
|
||||
def test_create_listener_closed_transport(self, transport):
|
||||
"""Test creating listener with closed transport raises error."""
|
||||
|
||||
94
tests/core/transport/quic/test_utils.py
Normal file
94
tests/core/transport/quic/test_utils.py
Normal file
@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_quic_multiaddr,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
quic_multiaddr_to_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class TestQUICUtils:
|
||||
"""Test suite for QUIC utility functions."""
|
||||
|
||||
def test_is_quic_multiaddr(self):
|
||||
"""Test QUIC multiaddr validation."""
|
||||
# Valid QUIC multiaddrs
|
||||
valid = [
|
||||
# TODO: Update Multiaddr package to accept quic-v1
|
||||
Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
]
|
||||
|
||||
for addr in valid:
|
||||
assert is_quic_multiaddr(addr)
|
||||
|
||||
# Invalid multiaddrs
|
||||
invalid = [
|
||||
Multiaddr("/ip4/127.0.0.1/tcp/4001"),
|
||||
Multiaddr("/ip4/127.0.0.1/udp/4001"),
|
||||
Multiaddr("/ip4/127.0.0.1/udp/4001/ws"),
|
||||
]
|
||||
|
||||
for addr in invalid:
|
||||
assert not is_quic_multiaddr(addr)
|
||||
|
||||
def test_quic_multiaddr_to_endpoint(self):
|
||||
"""Test multiaddr to endpoint conversion."""
|
||||
addr = Multiaddr("/ip4/192.168.1.100/udp/4001/quic")
|
||||
host, port = quic_multiaddr_to_endpoint(addr)
|
||||
|
||||
assert host == "192.168.1.100"
|
||||
assert port == 4001
|
||||
|
||||
# Test IPv6
|
||||
# TODO: Update Multiaddr project to handle ip6
|
||||
# addr6 = Multiaddr("/ip6/::1/udp/8080/quic")
|
||||
# host6, port6 = quic_multiaddr_to_endpoint(addr6)
|
||||
|
||||
# assert host6 == "::1"
|
||||
# assert port6 == 8080
|
||||
|
||||
def test_create_quic_multiaddr(self):
|
||||
"""Test QUIC multiaddr creation."""
|
||||
# IPv4
|
||||
addr = create_quic_multiaddr("127.0.0.1", 4001, "/quic")
|
||||
assert str(addr) == "/ip4/127.0.0.1/udp/4001/quic"
|
||||
|
||||
# IPv6
|
||||
addr6 = create_quic_multiaddr("::1", 8080, "/quic")
|
||||
assert str(addr6) == "/ip6/::1/udp/8080/quic"
|
||||
|
||||
def test_multiaddr_to_quic_version(self):
|
||||
"""Test QUIC version extraction."""
|
||||
addr = Multiaddr("/ip4/127.0.0.1/udp/4001/quic")
|
||||
version = multiaddr_to_quic_version(addr)
|
||||
assert version in ["quic", "quic-v1"] # Depending on implementation
|
||||
|
||||
def test_invalid_multiaddr_operations(self):
|
||||
"""Test error handling for invalid multiaddrs."""
|
||||
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
quic_multiaddr_to_endpoint(invalid_addr)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
multiaddr_to_quic_version(invalid_addr)
|
||||
Reference in New Issue
Block a user