fix: add basic tests for listener

This commit is contained in:
Akash Mondal
2025-06-12 10:03:08 +00:00
committed by lla-dane
parent 54b3055eaa
commit a3231af714
10 changed files with 892 additions and 100 deletions

View File

@ -7,10 +7,45 @@ from dataclasses import (
field,
)
import ssl
from typing import TypedDict
from libp2p.custom_types import TProtocol
class QUICTransportKwargs(TypedDict, total=False):
"""Type definition for kwargs accepted by new_transport function."""
# Connection settings
idle_timeout: float
max_datagram_size: int
local_port: int | None
# Protocol version support
enable_draft29: bool
enable_v1: bool
# TLS settings
verify_mode: ssl.VerifyMode
alpn_protocols: list[str]
# Performance settings
max_concurrent_streams: int
connection_window: int
stream_window: int
# Logging and debugging
enable_qlog: bool
qlog_dir: str | None
# Connection management
max_connections: int
connection_timeout: float
# Protocol identifiers
PROTOCOL_QUIC_V1: TProtocol
PROTOCOL_QUIC_DRAFT29: TProtocol
@dataclass
class QUICTransportConfig:
"""Configuration for QUIC transport."""
@ -47,7 +82,7 @@ class QUICTransportConfig:
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic") # RFC 9000
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
def __post_init__(self):
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not (self.enable_draft29 or self.enable_v1):
raise ValueError("At least one QUIC version must be enabled")

View File

@ -50,7 +50,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
Uses aioquic's sans-IO core with trio for native async support.
QUIC natively provides stream multiplexing, so this connection acts as both
a raw connection (for transport layer) and muxed connection (for upper layers).
Updated to work properly with the QUIC listener for server-side connections.
"""
@ -92,18 +92,20 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._background_tasks_started = False
self._nursery: trio.Nursery | None = None
logger.debug(f"Created QUIC connection to {peer_id} (initiator: {is_initiator})")
logger.debug(
f"Created QUIC connection to {peer_id} (initiator: {is_initiator})"
)
def _calculate_initial_stream_id(self) -> int:
"""
Calculate the initial stream ID based on QUIC specification.
QUIC stream IDs:
- Client-initiated bidirectional: 0, 4, 8, 12, ...
- Server-initiated bidirectional: 1, 5, 9, 13, ...
- Client-initiated unidirectional: 2, 6, 10, 14, ...
- Server-initiated unidirectional: 3, 7, 11, 15, ...
For libp2p, we primarily use bidirectional streams.
"""
if self.__is_initiator:
@ -118,7 +120,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def start(self) -> None:
"""
Start the connection and its background tasks.
This method implements the IMuxedConn.start() interface.
It should be called to begin processing connection events.
"""
@ -165,7 +167,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
if not self._background_tasks_started:
# We would need a nursery to start background tasks
# This is a limitation of the current design
logger.warning("Background tasks need nursery - connection may not work properly")
logger.warning(
"Background tasks need nursery - connection may not work properly"
)
except Exception as e:
logger.error(f"Failed to initiate connection: {e}")
@ -174,13 +178,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
async def connect(self, nursery: trio.Nursery) -> None:
"""
Establish the QUIC connection using trio.
Args:
nursery: Trio nursery for background tasks
"""
if not self.__is_initiator:
raise QUICConnectionError("connect() should only be called by client connections")
raise QUICConnectionError(
"connect() should only be called by client connections"
)
try:
# Store nursery for background tasks
@ -321,7 +327,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
def _is_incoming_stream(self, stream_id: int) -> bool:
"""
Determine if a stream ID represents an incoming stream.
For bidirectional streams:
- Even IDs are client-initiated
- Odd IDs are server-initiated
@ -463,11 +469,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
# Create stream
stream = QUICStream(
connection=self,
stream_id=stream_id,
is_initiator=True
)
stream = QUICStream(connection=self, stream_id=stream_id, is_initiator=True)
self._streams[stream_id] = stream
@ -530,9 +532,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
# The certificate should contain the peer ID in a specific extension
raise NotImplementedError("Certificate peer ID extraction not implemented")
def get_stats(self) -> dict:
# TODO: Define type for stats
def get_stats(self) -> dict[str, object]:
"""Get connection statistics."""
return {
stats: dict[str, object] = {
"peer_id": str(self._peer_id),
"remote_addr": self._remote_addr,
"is_initiator": self.__is_initiator,
@ -542,10 +545,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
"active_streams": len(self._streams),
"next_stream_id": self._next_stream_id,
}
return stats
def get_remote_address(self):
def get_remote_address(self) -> tuple[str, int]:
return self._remote_addr
def __str__(self) -> str:
"""String representation of the connection."""
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)}, established={self._established}, started={self._started})"
id = self._peer_id
estb = self._established
stream_len = len(self._streams)
return f"QUICConnection(peer={id}, streams={stream_len}".__add__(
f"established={estb}, started={self._started})"
)

View File

@ -8,7 +8,7 @@ import copy
import logging
import socket
import time
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from aioquic.quic import events
from aioquic.quic.configuration import QuicConfiguration
@ -49,7 +49,7 @@ class QUICListener(IListener):
self,
transport: "QUICTransport",
handler_function: THandler,
quic_configs: Dict[TProtocol, QuicConfiguration],
quic_configs: dict[TProtocol, QuicConfiguration],
config: QUICTransportConfig,
):
"""
@ -72,8 +72,8 @@ class QUICListener(IListener):
self._bound_addresses: list[Multiaddr] = []
# Connection management
self._connections: Dict[tuple[str, int], QUICConnection] = {}
self._pending_connections: Dict[tuple[str, int], QuicConnection] = {}
self._connections: dict[tuple[str, int], QUICConnection] = {}
self._pending_connections: dict[tuple[str, int], QuicConnection] = {}
self._connection_lock = trio.Lock()
# Listener state
@ -104,6 +104,7 @@ class QUICListener(IListener):
Raises:
QUICListenError: If failed to start listening
"""
if not is_quic_multiaddr(maddr):
raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}")
@ -133,11 +134,11 @@ class QUICListener(IListener):
self._listening = True
# Start background tasks directly in the provided nursery
# This ensures proper cancellation when the nursery exits
# This e per cancellation when the nursery exits
nursery.start_soon(self._handle_incoming_packets)
nursery.start_soon(self._manage_connections)
print(f"QUIC listener started on {actual_maddr}")
logger.info(f"QUIC listener started on {actual_maddr}")
return True
except trio.Cancelled:
@ -190,7 +191,8 @@ class QUICListener(IListener):
try:
while self._listening and self._socket:
try:
# Receive UDP packet (this blocks until packet arrives or socket closes)
# Receive UDP packet
# (this blocks until packet arrives or socket closes)
data, addr = await self._socket.recvfrom(65536)
self._stats["bytes_received"] += len(data)
self._stats["packets_processed"] += 1
@ -208,10 +210,9 @@ class QUICListener(IListener):
# Continue processing other packets
await trio.sleep(0.01)
except trio.Cancelled:
print("PACKET HANDLER CANCELLED - FORCIBLY CLOSING SOCKET")
logger.info("Received Cancel, stopping handling incoming packets")
raise
finally:
print("PACKET HANDLER FINISHED")
logger.debug("Packet handling loop terminated")
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
@ -456,10 +457,7 @@ class QUICListener(IListener):
except Exception as e:
logger.error(f"Error in connection management: {e}")
except trio.Cancelled:
print("CONNECTION MANAGER CANCELLED")
raise
finally:
print("CONNECTION MANAGER FINISHED")
async def _cleanup_closed_connections(self) -> None:
"""Remove closed connections from tracking."""
@ -500,20 +498,20 @@ class QUICListener(IListener):
self._closed = True
self._listening = False
print("Closing QUIC listener")
logger.debug("Closing QUIC listener")
# CRITICAL: Close socket FIRST to unblock recvfrom()
await self._cleanup_socket()
print("SOCKET CLEANUP COMPLETE")
logger.debug("SOCKET CLEANUP COMPLETE")
# Close all connections WITHOUT using the lock during shutdown
# (avoid deadlock if background tasks are cancelled while holding lock)
connections_to_close = list(self._connections.values())
pending_to_close = list(self._pending_connections.values())
print(
f"CLOSING {len(connections_to_close)} connections and {len(pending_to_close)} pending"
logger.debug(
f"CLOSING {connections_to_close} connections and {pending_to_close} pending"
)
# Close active connections
@ -533,10 +531,7 @@ class QUICListener(IListener):
# Clear the dictionaries without lock (we're shutting down)
self._connections.clear()
self._pending_connections.clear()
if self._nursery:
print("TASKS", len(self._nursery.child_tasks))
print("QUIC listener closed")
logger.debug("QUIC listener closed")
async def _cleanup_socket(self) -> None:
"""Clean up the UDP socket."""
@ -562,7 +557,7 @@ class QUICListener(IListener):
"""Check if the listener is actively listening."""
return self._listening and not self._closed
def get_stats(self) -> dict:
def get_stats(self) -> dict[str, int]:
"""Get listener statistics."""
stats = self._stats.copy()
stats.update(
@ -576,4 +571,6 @@ class QUICListener(IListener):
def __str__(self) -> str:
"""String representation of the listener."""
return f"QUICListener(addrs={self._bound_addresses}, connections={len(self._connections)})"
addr = self._bound_addresses
conn_count = len(self._connections)
return f"QUICListener(addrs={addr}, connections={conn_count})"

View File

@ -7,7 +7,6 @@ Full implementation will be in Module 5.
from dataclasses import dataclass
import os
import tempfile
from typing import Optional
from libp2p.crypto.keys import PrivateKey
from libp2p.peer.id import ID
@ -21,7 +20,7 @@ class TLSConfig:
cert_file: str
key_file: str
ca_file: Optional[str] = None
ca_file: str | None = None
def generate_libp2p_tls_config(private_key: PrivateKey, peer_id: ID) -> TLSConfig:

View File

@ -116,7 +116,8 @@ class QUICStream(IMuxedStream):
"""
Reset the stream
"""
self.handle_reset(0)
await self.handle_reset(0)
return
def get_remote_address(self) -> tuple[str, int] | None:
return self._connection._remote_addr

View File

@ -15,9 +15,9 @@ from aioquic.quic.connection import (
)
import multiaddr
import trio
from typing_extensions import Unpack
from libp2p.abc import (
IListener,
IRawConnection,
ITransport,
)
@ -28,6 +28,7 @@ from libp2p.custom_types import THandler, TProtocol
from libp2p.peer.id import (
ID,
)
from libp2p.transport.quic.config import QUICTransportKwargs
from libp2p.transport.quic.utils import (
is_quic_multiaddr,
multiaddr_to_quic_version,
@ -131,7 +132,10 @@ class QUICTransport(ITransport):
# # This follows the libp2p TLS spec for peer identity verification
# tls_config = generate_libp2p_tls_config(self._private_key, self._peer_id)
# config.load_cert_chain(certfile=tls_config.cert_file, keyfile=tls_config.key_file)
# config.load_cert_chain(
# certfile=tls_config.cert_file,
# keyfile=tls_config.key_file
# )
# if tls_config.ca_file:
# config.load_verify_locations(tls_config.ca_file)
@ -210,7 +214,7 @@ class QUICTransport(ITransport):
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
raise QUICDialError(f"Dial failed: {e}") from e
def create_listener(self, handler_function: THandler) -> IListener:
def create_listener(self, handler_function: THandler) -> QUICListener:
"""
Create a QUIC listener.
@ -298,12 +302,18 @@ class QUICTransport(ITransport):
logger.info("QUIC transport closed")
def get_stats(self) -> dict:
def get_stats(self) -> dict[str, int | list[str] | object]:
"""Get transport statistics."""
stats = {
protocols = self.protocols()
str_protocols = []
for proto in protocols:
str_protocols.append(str(proto))
stats: dict[str, int | list[str] | object] = {
"active_connections": len(self._connections),
"active_listeners": len(self._listeners),
"supported_protocols": self.protocols(),
"supported_protocols": str_protocols,
}
# Aggregate listener stats
@ -324,7 +334,9 @@ class QUICTransport(ITransport):
def new_transport(
private_key: PrivateKey, config: QUICTransportConfig | None = None, **kwargs
private_key: PrivateKey,
config: QUICTransportConfig | None = None,
**kwargs: Unpack[QUICTransportKwargs],
) -> QUICTransport:
"""
Factory function to create a new QUIC transport.

View File

@ -3,8 +3,6 @@ Multiaddr utilities for QUIC transport.
Handles QUIC-specific multiaddr parsing and validation.
"""
from typing import Tuple
import multiaddr
from libp2p.custom_types import TProtocol
@ -54,7 +52,7 @@ def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
return False
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> Tuple[str, int]:
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
"""
Extract host and port from a QUIC multiaddr.
@ -78,20 +76,21 @@ def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> Tuple[str, int]:
# Try to get IPv4 address
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except ValueError:
pass
# Try to get IPv6 address if IPv4 not found
if host is None:
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except ValueError:
pass
# Get UDP port
try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP)
# The the package is exposed by types not availble
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
port = int(port_str)
except ValueError:
pass