mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 23:51:07 +00:00
fix: impl quic listener
This commit is contained in:
@ -6,6 +6,7 @@ Uses aioquic's sans-IO core with trio for async operations.
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from aioquic.quic import (
|
||||
events,
|
||||
@ -21,9 +22,7 @@ from libp2p.abc import (
|
||||
IMuxedStream,
|
||||
IRawConnection,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
StreamHandlerFn,
|
||||
)
|
||||
from libp2p.custom_types import TQUICStreamHandlerFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -35,9 +34,11 @@ from .exceptions import (
|
||||
from .stream import (
|
||||
QUICStream,
|
||||
)
|
||||
from .transport import (
|
||||
QUICTransport,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .transport import (
|
||||
QUICTransport,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -49,76 +50,177 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
Uses aioquic's sans-IO core with trio for native async support.
|
||||
QUIC natively provides stream multiplexing, so this connection acts as both
|
||||
a raw connection (for transport layer) and muxed connection (for upper layers).
|
||||
|
||||
Updated to work properly with the QUIC listener for server-side connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quic_connection: QuicConnection,
|
||||
remote_addr: tuple[str, int],
|
||||
peer_id: ID,
|
||||
peer_id: ID | None,
|
||||
local_peer_id: ID,
|
||||
initiator: bool,
|
||||
is_initiator: bool,
|
||||
maddr: multiaddr.Multiaddr,
|
||||
transport: QUICTransport,
|
||||
transport: "QUICTransport",
|
||||
):
|
||||
self._quic = quic_connection
|
||||
self._remote_addr = remote_addr
|
||||
self._peer_id = peer_id
|
||||
self._local_peer_id = local_peer_id
|
||||
self.__is_initiator = initiator
|
||||
self.__is_initiator = is_initiator
|
||||
self._maddr = maddr
|
||||
self._transport = transport
|
||||
|
||||
# Trio networking
|
||||
# Trio networking - socket may be provided by listener
|
||||
self._socket: trio.socket.SocketType | None = None
|
||||
self._connected_event = trio.Event()
|
||||
self._closed_event = trio.Event()
|
||||
|
||||
# Stream management
|
||||
self._streams: dict[int, QUICStream] = {}
|
||||
self._next_stream_id: int = (
|
||||
0 if initiator else 1
|
||||
) # Even for initiator, odd for responder
|
||||
self._stream_handler: StreamHandlerFn | None = None
|
||||
self._next_stream_id: int = self._calculate_initial_stream_id()
|
||||
self._stream_handler: TQUICStreamHandlerFn | None = None
|
||||
self._stream_id_lock = trio.Lock()
|
||||
|
||||
# Connection state
|
||||
self._closed = False
|
||||
self._timer_task = None
|
||||
self._established = False
|
||||
self._started = False
|
||||
|
||||
logger.debug(f"Created QUIC connection to {peer_id}")
|
||||
# Background task management
|
||||
self._background_tasks_started = False
|
||||
self._nursery: trio.Nursery | None = None
|
||||
|
||||
logger.debug(f"Created QUIC connection to {peer_id} (initiator: {is_initiator})")
|
||||
|
||||
def _calculate_initial_stream_id(self) -> int:
|
||||
"""
|
||||
Calculate the initial stream ID based on QUIC specification.
|
||||
|
||||
QUIC stream IDs:
|
||||
- Client-initiated bidirectional: 0, 4, 8, 12, ...
|
||||
- Server-initiated bidirectional: 1, 5, 9, 13, ...
|
||||
- Client-initiated unidirectional: 2, 6, 10, 14, ...
|
||||
- Server-initiated unidirectional: 3, 7, 11, 15, ...
|
||||
|
||||
For libp2p, we primarily use bidirectional streams.
|
||||
"""
|
||||
if self.__is_initiator:
|
||||
return 0 # Client starts with 0, then 4, 8, 12...
|
||||
else:
|
||||
return 1 # Server starts with 1, then 5, 9, 13...
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool: # type: ignore
|
||||
return self.__is_initiator
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish the QUIC connection using trio."""
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the connection and its background tasks.
|
||||
|
||||
This method implements the IMuxedConn.start() interface.
|
||||
It should be called to begin processing connection events.
|
||||
"""
|
||||
if self._started:
|
||||
logger.warning("Connection already started")
|
||||
return
|
||||
|
||||
if self._closed:
|
||||
raise QUICConnectionError("Cannot start a closed connection")
|
||||
|
||||
self._started = True
|
||||
logger.debug(f"Starting QUIC connection to {self._peer_id}")
|
||||
|
||||
# If this is a client connection, we need to establish the connection
|
||||
if self.__is_initiator:
|
||||
await self._initiate_connection()
|
||||
else:
|
||||
# For server connections, we're already connected via the listener
|
||||
self._established = True
|
||||
self._connected_event.set()
|
||||
|
||||
logger.debug(f"QUIC connection to {self._peer_id} started")
|
||||
|
||||
async def _initiate_connection(self) -> None:
|
||||
"""Initiate client-side connection establishment."""
|
||||
try:
|
||||
# Create UDP socket using trio
|
||||
self._socket = trio.socket.socket(
|
||||
family=socket.AF_INET, type=socket.SOCK_DGRAM
|
||||
)
|
||||
|
||||
# Connect the socket to the remote address
|
||||
await self._socket.connect(self._remote_addr)
|
||||
|
||||
# Start the connection establishment
|
||||
self._quic.connect(self._remote_addr, now=time.time())
|
||||
|
||||
# Send initial packet(s)
|
||||
await self._transmit()
|
||||
|
||||
# Start background tasks using trio nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
self._handle_incoming_data, None, "QUIC INCOMING DATA"
|
||||
)
|
||||
nursery.start_soon(self._handle_timer, None, "QUIC TIMER HANDLER")
|
||||
# For client connections, we need to manage our own background tasks
|
||||
# In a real implementation, this would be managed by the transport
|
||||
# For now, we'll start them here
|
||||
if not self._background_tasks_started:
|
||||
# We would need a nursery to start background tasks
|
||||
# This is a limitation of the current design
|
||||
logger.warning("Background tasks need nursery - connection may not work properly")
|
||||
|
||||
# Wait for connection to be established
|
||||
await self._connected_event.wait()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate connection: {e}")
|
||||
raise QUICConnectionError(f"Connection initiation failed: {e}") from e
|
||||
|
||||
async def connect(self, nursery: trio.Nursery) -> None:
|
||||
"""
|
||||
Establish the QUIC connection using trio.
|
||||
|
||||
Args:
|
||||
nursery: Trio nursery for background tasks
|
||||
|
||||
"""
|
||||
if not self.__is_initiator:
|
||||
raise QUICConnectionError("connect() should only be called by client connections")
|
||||
|
||||
try:
|
||||
# Store nursery for background tasks
|
||||
self._nursery = nursery
|
||||
|
||||
# Create UDP socket using trio
|
||||
self._socket = trio.socket.socket(
|
||||
family=socket.AF_INET, type=socket.SOCK_DGRAM
|
||||
)
|
||||
|
||||
# Connect the socket to the remote address
|
||||
await self._socket.connect(self._remote_addr)
|
||||
|
||||
# Start the connection establishment
|
||||
self._quic.connect(self._remote_addr, now=time.time())
|
||||
|
||||
# Send initial packet(s)
|
||||
await self._transmit()
|
||||
|
||||
# Start background tasks
|
||||
await self._start_background_tasks(nursery)
|
||||
|
||||
# Wait for connection to be established
|
||||
await self._connected_event.wait()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect: {e}")
|
||||
raise QUICConnectionError(f"Connection failed: {e}") from e
|
||||
|
||||
async def _start_background_tasks(self, nursery: trio.Nursery) -> None:
|
||||
"""Start background tasks for connection management."""
|
||||
if self._background_tasks_started:
|
||||
return
|
||||
|
||||
self._background_tasks_started = True
|
||||
|
||||
# Start background tasks
|
||||
nursery.start_soon(self._handle_incoming_data)
|
||||
nursery.start_soon(self._handle_timer)
|
||||
|
||||
async def _handle_incoming_data(self) -> None:
|
||||
"""Handle incoming UDP datagrams in trio."""
|
||||
while not self._closed:
|
||||
@ -128,6 +230,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._quic.receive_datagram(data, addr, now=time.time())
|
||||
await self._process_events()
|
||||
await self._transmit()
|
||||
|
||||
# Small delay to prevent busy waiting
|
||||
await trio.sleep(0.001)
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
break
|
||||
except Exception as e:
|
||||
@ -137,18 +243,26 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
async def _handle_timer(self) -> None:
|
||||
"""Handle QUIC timer events in trio."""
|
||||
while not self._closed:
|
||||
timer_at = self._quic.get_timer()
|
||||
if timer_at is None:
|
||||
await trio.sleep(1.0) # No timer set, check again later
|
||||
continue
|
||||
try:
|
||||
timer_at = self._quic.get_timer()
|
||||
if timer_at is None:
|
||||
await trio.sleep(0.1) # No timer set, check again later
|
||||
continue
|
||||
|
||||
now = time.time()
|
||||
if timer_at <= now:
|
||||
self._quic.handle_timer(now=now)
|
||||
await self._process_events()
|
||||
await self._transmit()
|
||||
else:
|
||||
await trio.sleep(timer_at - now)
|
||||
now = time.time()
|
||||
if timer_at <= now:
|
||||
self._quic.handle_timer(now=now)
|
||||
await self._process_events()
|
||||
await self._transmit()
|
||||
await trio.sleep(0.001) # Small delay
|
||||
else:
|
||||
# Sleep until timer fires, but check periodically
|
||||
sleep_time = min(timer_at - now, 0.1)
|
||||
await trio.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in timer handler: {e}")
|
||||
await trio.sleep(0.1)
|
||||
|
||||
async def _process_events(self) -> None:
|
||||
"""Process QUIC events from aioquic core."""
|
||||
@ -165,6 +279,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
logger.debug("QUIC handshake completed")
|
||||
self._established = True
|
||||
self._connected_event.set()
|
||||
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
@ -177,25 +292,47 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Handle incoming stream data."""
|
||||
stream_id = event.stream_id
|
||||
|
||||
# Get or create stream
|
||||
if stream_id not in self._streams:
|
||||
# Create new stream for incoming data
|
||||
# Determine if this is an incoming stream
|
||||
is_incoming = self._is_incoming_stream(stream_id)
|
||||
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
is_initiator=False, # pyrefly: ignore
|
||||
is_initiator=not is_incoming,
|
||||
)
|
||||
self._streams[stream_id] = stream
|
||||
|
||||
# Notify stream handler if available
|
||||
if self._stream_handler:
|
||||
# Use trio nursery to start stream handler
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._stream_handler, stream)
|
||||
# Notify stream handler for incoming streams
|
||||
if is_incoming and self._stream_handler:
|
||||
# Start stream handler in background
|
||||
# In a real implementation, you might want to use the nursery
|
||||
# passed to the connection, but for now we'll handle it directly
|
||||
try:
|
||||
await self._stream_handler(stream)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream handler: {e}")
|
||||
|
||||
# Forward data to stream
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_data_received(event.data, event.end_stream)
|
||||
|
||||
def _is_incoming_stream(self, stream_id: int) -> bool:
|
||||
"""
|
||||
Determine if a stream ID represents an incoming stream.
|
||||
|
||||
For bidirectional streams:
|
||||
- Even IDs are client-initiated
|
||||
- Odd IDs are server-initiated
|
||||
"""
|
||||
if self.__is_initiator:
|
||||
# We're the client, so odd stream IDs are incoming
|
||||
return stream_id % 2 == 1
|
||||
else:
|
||||
# We're the server, so even stream IDs are incoming
|
||||
return stream_id % 2 == 0
|
||||
|
||||
async def _handle_stream_reset(self, event: events.StreamReset) -> None:
|
||||
"""Handle stream reset."""
|
||||
stream_id = event.stream_id
|
||||
@ -210,15 +347,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if socket is None:
|
||||
return
|
||||
|
||||
for data, addr in self._quic.datagrams_to_send(now=time.time()):
|
||||
try:
|
||||
try:
|
||||
for data, addr in self._quic.datagrams_to_send(now=time.time()):
|
||||
await socket.sendto(data, addr)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send datagram: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send datagram: {e}")
|
||||
|
||||
# IRawConnection interface
|
||||
|
||||
async def write(self, data: bytes):
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the connection.
|
||||
For QUIC, this creates a new stream for each write operation.
|
||||
@ -230,7 +367,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await stream.write(data)
|
||||
await stream.close()
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
async def read(self, n: int | None = -1) -> bytes:
|
||||
"""
|
||||
Read data from the connection.
|
||||
For QUIC, this reads from the next available stream.
|
||||
@ -252,14 +389,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._closed = True
|
||||
logger.debug(f"Closing QUIC connection to {self._peer_id}")
|
||||
|
||||
# Close all streams using trio nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
for stream in self._streams.values():
|
||||
nursery.start_soon(stream.close)
|
||||
# Close all streams
|
||||
stream_close_tasks = []
|
||||
for stream in list(self._streams.values()):
|
||||
stream_close_tasks.append(stream.close())
|
||||
|
||||
if stream_close_tasks:
|
||||
# Close streams concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
for task in stream_close_tasks:
|
||||
nursery.start_soon(lambda t=task: t)
|
||||
|
||||
# Close QUIC connection
|
||||
self._quic.close()
|
||||
await self._transmit() # Send close frames
|
||||
if self._socket:
|
||||
await self._transmit() # Send close frames
|
||||
|
||||
# Close socket
|
||||
if self._socket:
|
||||
@ -275,6 +419,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Check if connection is closed."""
|
||||
return self._closed
|
||||
|
||||
@property
|
||||
def is_established(self) -> bool:
|
||||
"""Check if connection is established (handshake completed)."""
|
||||
return self._established
|
||||
|
||||
@property
|
||||
def is_started(self) -> bool:
|
||||
"""Check if connection has been started."""
|
||||
return self._started
|
||||
|
||||
def multiaddr(self) -> multiaddr.Multiaddr:
|
||||
"""Get the multiaddr for this connection."""
|
||||
return self._maddr
|
||||
@ -283,6 +437,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Get the local peer ID."""
|
||||
return self._local_peer_id
|
||||
|
||||
def remote_peer_id(self) -> ID | None:
|
||||
"""Get the remote peer ID."""
|
||||
return self._peer_id
|
||||
|
||||
# IMuxedConn interface
|
||||
|
||||
async def open_stream(self) -> IMuxedStream:
|
||||
@ -296,23 +454,27 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if self._closed:
|
||||
raise QUICStreamError("Connection is closed")
|
||||
|
||||
# Generate next stream ID
|
||||
stream_id = self._next_stream_id
|
||||
self._next_stream_id += (
|
||||
2 # Increment by 2 to maintain initiator/responder distinction
|
||||
)
|
||||
if not self._started:
|
||||
raise QUICStreamError("Connection not started")
|
||||
|
||||
# Create stream
|
||||
stream = QUICStream(
|
||||
connection=self, stream_id=stream_id, is_initiator=True
|
||||
) # pyrefly: ignore
|
||||
async with self._stream_id_lock:
|
||||
# Generate next stream ID
|
||||
stream_id = self._next_stream_id
|
||||
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
|
||||
|
||||
self._streams[stream_id] = stream
|
||||
# Create stream
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
is_initiator=True
|
||||
)
|
||||
|
||||
self._streams[stream_id] = stream
|
||||
|
||||
logger.debug(f"Opened QUIC stream {stream_id}")
|
||||
return stream
|
||||
|
||||
def set_stream_handler(self, handler_function: StreamHandlerFn) -> None:
|
||||
def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None:
|
||||
"""
|
||||
Set handler for incoming streams.
|
||||
|
||||
@ -341,17 +503,22 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""
|
||||
# Extract peer ID from TLS certificate
|
||||
# This should match the expected peer ID
|
||||
cert_peer_id = self._extract_peer_id_from_cert()
|
||||
try:
|
||||
cert_peer_id = self._extract_peer_id_from_cert()
|
||||
|
||||
if self._peer_id and cert_peer_id != self._peer_id:
|
||||
raise QUICConnectionError(
|
||||
f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}"
|
||||
)
|
||||
if self._peer_id and cert_peer_id != self._peer_id:
|
||||
raise QUICConnectionError(
|
||||
f"Peer ID mismatch: expected {self._peer_id}, got {cert_peer_id}"
|
||||
)
|
||||
|
||||
if not self._peer_id:
|
||||
self._peer_id = cert_peer_id
|
||||
if not self._peer_id:
|
||||
self._peer_id = cert_peer_id
|
||||
|
||||
logger.debug(f"Verified peer identity: {self._peer_id}")
|
||||
logger.debug(f"Verified peer identity: {self._peer_id}")
|
||||
|
||||
except NotImplementedError:
|
||||
logger.warning("Peer identity verification not implemented - skipping")
|
||||
# For now, we'll skip verification during development
|
||||
|
||||
def _extract_peer_id_from_cert(self) -> ID:
|
||||
"""Extract peer ID from TLS certificate."""
|
||||
@ -363,6 +530,22 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# The certificate should contain the peer ID in a specific extension
|
||||
raise NotImplementedError("Certificate peer ID extraction not implemented")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get connection statistics."""
|
||||
return {
|
||||
"peer_id": str(self._peer_id),
|
||||
"remote_addr": self._remote_addr,
|
||||
"is_initiator": self.__is_initiator,
|
||||
"is_established": self._established,
|
||||
"is_closed": self._closed,
|
||||
"is_started": self._started,
|
||||
"active_streams": len(self._streams),
|
||||
"next_stream_id": self._next_stream_id,
|
||||
}
|
||||
|
||||
def get_remote_address(self):
|
||||
return self._remote_addr
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the connection."""
|
||||
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)})"
|
||||
return f"QUICConnection(peer={self._peer_id}, streams={len(self._streams)}, established={self._established}, started={self._started})"
|
||||
|
||||
Reference in New Issue
Block a user