From 2689040d483a8e525afc89488a9f48156124006f Mon Sep 17 00:00:00 2001 From: Akash Mondal Date: Sun, 29 Jun 2025 06:27:54 +0000 Subject: [PATCH] fix: handle short quic headers and compelete connection establishment --- examples/echo/echo_quic.py | 19 ++--- libp2p/transport/quic/connection.py | 73 ++++++++++++++----- libp2p/transport/quic/listener.py | 105 ++++++++++++++++++++++------ 3 files changed, 150 insertions(+), 47 deletions(-) diff --git a/examples/echo/echo_quic.py b/examples/echo/echo_quic.py index 532cfe3d..fbcce8db 100644 --- a/examples/echo/echo_quic.py +++ b/examples/echo/echo_quic.py @@ -25,15 +25,16 @@ PROTOCOL_ID = TProtocol("/echo/1.0.0") async def _echo_stream_handler(stream: INetStream) -> None: - """ - Echo stream handler - unchanged from TCP version. - - Demonstrates transport abstraction: same handler works for both TCP and QUIC. - """ - # Wait until EOF - msg = await stream.read() - await stream.write(msg) - await stream.close() + try: + msg = await stream.read() + await stream.write(msg) + await stream.close() + except Exception as e: + print(f"Echo handler error: {e}") + try: + await stream.close() + except: + pass async def run_server(port: int, seed: int | None = None) -> None: diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py index 11a30a54..c0861ea1 100644 --- a/libp2p/transport/quic/connection.py +++ b/libp2p/transport/quic/connection.py @@ -82,6 +82,7 @@ class QUICConnection(IRawConnection, IMuxedConn): transport: "QUICTransport", security_manager: Optional["QUICTLSConfigManager"] = None, resource_scope: Any | None = None, + listener_socket: trio.socket.SocketType | None = None, ): """ Initialize QUIC connection with security integration. @@ -96,6 +97,7 @@ class QUICConnection(IRawConnection, IMuxedConn): transport: Parent QUIC transport security_manager: Security manager for TLS/certificate handling resource_scope: Resource manager scope for tracking + listener_socket: Socket of listener to transmit data """ self._quic = quic_connection @@ -109,7 +111,8 @@ class QUICConnection(IRawConnection, IMuxedConn): self._resource_scope = resource_scope # Trio networking - socket may be provided by listener - self._socket: trio.socket.SocketType | None = None + self._socket = listener_socket if listener_socket else None + self._owns_socket = listener_socket is None self._connected_event = trio.Event() self._closed_event = trio.Event() @@ -974,23 +977,56 @@ class QUICConnection(IRawConnection, IMuxedConn): self._closed_event.set() async def _handle_stream_data(self, event: events.StreamDataReceived) -> None: - """Stream data handling with proper error management.""" + """Handle stream data events - create streams and add to accept queue.""" stream_id = event.stream_id self._stats["bytes_received"] += len(event.data) try: - with QUICErrorContext("stream_data_handling", "stream"): - # Get or create stream - stream = await self._get_or_create_stream(stream_id) + print(f"🔧 STREAM_DATA: Handling data for stream {stream_id}") - # Forward data to stream - await stream.handle_data_received(event.data, event.end_stream) + if stream_id not in self._streams: + if self._is_incoming_stream(stream_id): + print(f"🔧 STREAM_DATA: Creating new incoming stream {stream_id}") + + from .stream import QUICStream, StreamDirection + + stream = QUICStream( + connection=self, + stream_id=stream_id, + direction=StreamDirection.INBOUND, + resource_scope=self._resource_scope, + remote_addr=self._remote_addr, + ) + + # Store the stream + self._streams[stream_id] = stream + + async with self._accept_queue_lock: + self._stream_accept_queue.append(stream) + self._stream_accept_event.set() + print( + f"✅ STREAM_DATA: Added stream {stream_id} to accept queue" + ) + + async with self._stream_count_lock: + self._inbound_stream_count += 1 + self._stats["streams_opened"] += 1 + + else: + print( + f"❌ STREAM_DATA: Unexpected outbound stream {stream_id} in data event" + ) + return + + stream = self._streams[stream_id] + await stream.handle_data_received(event.data, event.end_stream) + print( + f"✅ STREAM_DATA: Forwarded {len(event.data)} bytes to stream {stream_id}" + ) except Exception as e: logger.error(f"Error handling stream data for stream {stream_id}: {e}") - # Reset the stream on error - if stream_id in self._streams: - await self._streams[stream_id].reset(error_code=1) + print(f"❌ STREAM_DATA: Error: {e}") async def _get_or_create_stream(self, stream_id: int) -> QUICStream: """Get existing stream or create new inbound stream.""" @@ -1103,20 +1139,24 @@ class QUICConnection(IRawConnection, IMuxedConn): # Network transmission async def _transmit(self) -> None: - """Send pending datagrams using trio.""" + """Transmit pending QUIC packets using available socket.""" sock = self._socket if not sock: print("No socket to transmit") return try: - datagrams = self._quic.datagrams_to_send(now=time.time()) + current_time = time.time() + datagrams = self._quic.datagrams_to_send(now=current_time) for data, addr in datagrams: await sock.sendto(data, addr) - self._stats["packets_sent"] += 1 - self._stats["bytes_sent"] += len(data) + # Update stats if available + if hasattr(self, "_stats"): + self._stats["packets_sent"] += 1 + self._stats["bytes_sent"] += len(data) + except Exception as e: - logger.error(f"Failed to send datagram: {e}") + logger.error(f"Transmission error: {e}") await self._handle_connection_error(e) # Additional methods for stream data processing @@ -1179,8 +1219,9 @@ class QUICConnection(IRawConnection, IMuxedConn): await self._transmit() # Send close frames # Close socket - if self._socket: + if self._socket and self._owns_socket: self._socket.close() + self._socket = None self._streams.clear() self._closed_event.set() diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py index 0f499817..5171d21c 100644 --- a/libp2p/transport/quic/listener.py +++ b/libp2p/transport/quic/listener.py @@ -160,11 +160,20 @@ class QUICListener(IListener): is_long_header = (first_byte & 0x80) != 0 if not is_long_header: - # Short header packet - extract destination connection ID - # For short headers, we need to know the connection ID length - # This is typically managed by the connection state - # For now, we'll handle this in the connection routing logic - return None + cid_length = 8 # We are using standard CID length everywhere + + if len(data) < 1 + cid_length: + return None + + dest_cid = data[1 : 1 + cid_length] + + return QUICPacketInfo( + version=1, # Assume QUIC v1 for established connections + destination_cid=dest_cid, + source_cid=b"", # Not available in short header + packet_type=QuicPacketType.ONE_RTT, + token=b"", + ) # Long header packet parsing offset = 1 @@ -276,6 +285,13 @@ class QUICListener(IListener): # Parse packet to extract connection information packet_info = self.parse_quic_packet(data) + print(f"🔧 DEBUG: Packet info: {packet_info is not None}") + if packet_info: + print(f"🔧 DEBUG: Packet type: {packet_info.packet_type}") + print( + f"🔧 DEBUG: Is short header: {packet_info.packet_type == QuicPacketType.ONE_RTT}" + ) + print( f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}" ) @@ -606,23 +622,36 @@ class QUICListener(IListener): async def _handle_short_header_packet( self, data: bytes, addr: tuple[str, int] ) -> None: - """Handle short header packets using address-based fallback routing.""" + """Handle short header packets for established connections.""" try: - # Check if we have a connection for this address + print(f"🔧 SHORT_HDR: Handling short header packet from {addr}") + + # First, try address-based lookup dest_cid = self._addr_to_cid.get(addr) - if dest_cid: - if dest_cid in self._connections: - connection = self._connections[dest_cid] - await self._route_to_connection(connection, data, addr) - elif dest_cid in self._pending_connections: - quic_conn = self._pending_connections[dest_cid] - await self._handle_pending_connection( - quic_conn, data, addr, dest_cid + if dest_cid and dest_cid in self._connections: + print(f"✅ SHORT_HDR: Routing via address mapping to {dest_cid.hex()}") + connection = self._connections[dest_cid] + await self._route_to_connection(connection, data, addr) + return + + # Fallback: try to extract CID from packet + if len(data) >= 9: # 1 byte header + 8 byte CID + potential_cid = data[1:9] + + if potential_cid in self._connections: + print( + f"✅ SHORT_HDR: Routing via extracted CID {potential_cid.hex()}" ) - else: - logger.debug( - f"Received short header packet from unknown address {addr}" - ) + connection = self._connections[potential_cid] + + # Update mappings for future packets + self._addr_to_cid[addr] = potential_cid + self._cid_to_addr[potential_cid] = addr + + await self._route_to_connection(connection, data, addr) + return + + print(f"❌ SHORT_HDR: No matching connection found for {addr}") except Exception as e: logger.error(f"Error handling short header packet from {addr}: {e}") @@ -858,7 +887,7 @@ class QUICListener(IListener): # Create multiaddr for this connection host, port = addr - quic_version = next(iter(self._quic_configs.keys())) + quic_version = "quic" remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}") from .connection import QUICConnection @@ -872,9 +901,19 @@ class QUICListener(IListener): maddr=remote_maddr, transport=self._transport, security_manager=self._security_manager, + listener_socket=self._socket, + ) + + print( + f"🔧 PROMOTION: Created connection with socket: {self._socket is not None}" + ) + print( + f"🔧 PROMOTION: Socket type: {type(self._socket) if self._socket else 'None'}" ) self._connections[dest_cid] = connection + self._addr_to_cid[addr] = dest_cid + self._cid_to_addr[dest_cid] = addr if self._nursery: await connection.connect(self._nursery) @@ -1178,9 +1217,31 @@ class QUICListener(IListener): async def _handle_new_established_connection( self, connection: QUICConnection ) -> None: - """Handle a newly established connection.""" + """Handle newly established connection with proper stream management.""" try: - await self._handler(connection) + logger.debug( + f"Handling new established connection from {connection._remote_addr}" + ) + + # Accept incoming streams and pass them to the handler + while not connection.is_closed: + try: + print(f"🔧 CONN_HANDLER: Waiting for stream...") + stream = await connection.accept_stream(timeout=1.0) + print(f"✅ CONN_HANDLER: Accepted stream {stream.stream_id}") + + if self._nursery: + # Pass STREAM to handler, not connection + self._nursery.start_soon(self._handler, stream) + print( + f"✅ CONN_HANDLER: Started handler for stream {stream.stream_id}" + ) + except trio.TooSlowError: + continue # Timeout is normal + except Exception as e: + logger.error(f"Error accepting stream: {e}") + break + except Exception as e: logger.error(f"Error in connection handler: {e}") await connection.close()