mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 08:00:54 +00:00
fix: handle short quic headers and compelete connection establishment
This commit is contained in:
@ -25,15 +25,16 @@ PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def _echo_stream_handler(stream: INetStream) -> None:
|
||||
"""
|
||||
Echo stream handler - unchanged from TCP version.
|
||||
|
||||
Demonstrates transport abstraction: same handler works for both TCP and QUIC.
|
||||
"""
|
||||
# Wait until EOF
|
||||
msg = await stream.read()
|
||||
await stream.write(msg)
|
||||
await stream.close()
|
||||
try:
|
||||
msg = await stream.read()
|
||||
await stream.write(msg)
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
print(f"Echo handler error: {e}")
|
||||
try:
|
||||
await stream.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
async def run_server(port: int, seed: int | None = None) -> None:
|
||||
|
||||
@ -82,6 +82,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
transport: "QUICTransport",
|
||||
security_manager: Optional["QUICTLSConfigManager"] = None,
|
||||
resource_scope: Any | None = None,
|
||||
listener_socket: trio.socket.SocketType | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize QUIC connection with security integration.
|
||||
@ -96,6 +97,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
transport: Parent QUIC transport
|
||||
security_manager: Security manager for TLS/certificate handling
|
||||
resource_scope: Resource manager scope for tracking
|
||||
listener_socket: Socket of listener to transmit data
|
||||
|
||||
"""
|
||||
self._quic = quic_connection
|
||||
@ -109,7 +111,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._resource_scope = resource_scope
|
||||
|
||||
# Trio networking - socket may be provided by listener
|
||||
self._socket: trio.socket.SocketType | None = None
|
||||
self._socket = listener_socket if listener_socket else None
|
||||
self._owns_socket = listener_socket is None
|
||||
self._connected_event = trio.Event()
|
||||
self._closed_event = trio.Event()
|
||||
|
||||
@ -974,23 +977,56 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._closed_event.set()
|
||||
|
||||
async def _handle_stream_data(self, event: events.StreamDataReceived) -> None:
|
||||
"""Stream data handling with proper error management."""
|
||||
"""Handle stream data events - create streams and add to accept queue."""
|
||||
stream_id = event.stream_id
|
||||
self._stats["bytes_received"] += len(event.data)
|
||||
|
||||
try:
|
||||
with QUICErrorContext("stream_data_handling", "stream"):
|
||||
# Get or create stream
|
||||
stream = await self._get_or_create_stream(stream_id)
|
||||
print(f"🔧 STREAM_DATA: Handling data for stream {stream_id}")
|
||||
|
||||
# Forward data to stream
|
||||
await stream.handle_data_received(event.data, event.end_stream)
|
||||
if stream_id not in self._streams:
|
||||
if self._is_incoming_stream(stream_id):
|
||||
print(f"🔧 STREAM_DATA: Creating new incoming stream {stream_id}")
|
||||
|
||||
from .stream import QUICStream, StreamDirection
|
||||
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
direction=StreamDirection.INBOUND,
|
||||
resource_scope=self._resource_scope,
|
||||
remote_addr=self._remote_addr,
|
||||
)
|
||||
|
||||
# Store the stream
|
||||
self._streams[stream_id] = stream
|
||||
|
||||
async with self._accept_queue_lock:
|
||||
self._stream_accept_queue.append(stream)
|
||||
self._stream_accept_event.set()
|
||||
print(
|
||||
f"✅ STREAM_DATA: Added stream {stream_id} to accept queue"
|
||||
)
|
||||
|
||||
async with self._stream_count_lock:
|
||||
self._inbound_stream_count += 1
|
||||
self._stats["streams_opened"] += 1
|
||||
|
||||
else:
|
||||
print(
|
||||
f"❌ STREAM_DATA: Unexpected outbound stream {stream_id} in data event"
|
||||
)
|
||||
return
|
||||
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_data_received(event.data, event.end_stream)
|
||||
print(
|
||||
f"✅ STREAM_DATA: Forwarded {len(event.data)} bytes to stream {stream_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling stream data for stream {stream_id}: {e}")
|
||||
# Reset the stream on error
|
||||
if stream_id in self._streams:
|
||||
await self._streams[stream_id].reset(error_code=1)
|
||||
print(f"❌ STREAM_DATA: Error: {e}")
|
||||
|
||||
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
|
||||
"""Get existing stream or create new inbound stream."""
|
||||
@ -1103,20 +1139,24 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Network transmission
|
||||
|
||||
async def _transmit(self) -> None:
|
||||
"""Send pending datagrams using trio."""
|
||||
"""Transmit pending QUIC packets using available socket."""
|
||||
sock = self._socket
|
||||
if not sock:
|
||||
print("No socket to transmit")
|
||||
return
|
||||
|
||||
try:
|
||||
datagrams = self._quic.datagrams_to_send(now=time.time())
|
||||
current_time = time.time()
|
||||
datagrams = self._quic.datagrams_to_send(now=current_time)
|
||||
for data, addr in datagrams:
|
||||
await sock.sendto(data, addr)
|
||||
self._stats["packets_sent"] += 1
|
||||
self._stats["bytes_sent"] += len(data)
|
||||
# Update stats if available
|
||||
if hasattr(self, "_stats"):
|
||||
self._stats["packets_sent"] += 1
|
||||
self._stats["bytes_sent"] += len(data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send datagram: {e}")
|
||||
logger.error(f"Transmission error: {e}")
|
||||
await self._handle_connection_error(e)
|
||||
|
||||
# Additional methods for stream data processing
|
||||
@ -1179,8 +1219,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await self._transmit() # Send close frames
|
||||
|
||||
# Close socket
|
||||
if self._socket:
|
||||
if self._socket and self._owns_socket:
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
|
||||
self._streams.clear()
|
||||
self._closed_event.set()
|
||||
|
||||
@ -160,11 +160,20 @@ class QUICListener(IListener):
|
||||
is_long_header = (first_byte & 0x80) != 0
|
||||
|
||||
if not is_long_header:
|
||||
# Short header packet - extract destination connection ID
|
||||
# For short headers, we need to know the connection ID length
|
||||
# This is typically managed by the connection state
|
||||
# For now, we'll handle this in the connection routing logic
|
||||
return None
|
||||
cid_length = 8 # We are using standard CID length everywhere
|
||||
|
||||
if len(data) < 1 + cid_length:
|
||||
return None
|
||||
|
||||
dest_cid = data[1 : 1 + cid_length]
|
||||
|
||||
return QUICPacketInfo(
|
||||
version=1, # Assume QUIC v1 for established connections
|
||||
destination_cid=dest_cid,
|
||||
source_cid=b"", # Not available in short header
|
||||
packet_type=QuicPacketType.ONE_RTT,
|
||||
token=b"",
|
||||
)
|
||||
|
||||
# Long header packet parsing
|
||||
offset = 1
|
||||
@ -276,6 +285,13 @@ class QUICListener(IListener):
|
||||
# Parse packet to extract connection information
|
||||
packet_info = self.parse_quic_packet(data)
|
||||
|
||||
print(f"🔧 DEBUG: Packet info: {packet_info is not None}")
|
||||
if packet_info:
|
||||
print(f"🔧 DEBUG: Packet type: {packet_info.packet_type}")
|
||||
print(
|
||||
f"🔧 DEBUG: Is short header: {packet_info.packet_type == QuicPacketType.ONE_RTT}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}"
|
||||
)
|
||||
@ -606,23 +622,36 @@ class QUICListener(IListener):
|
||||
async def _handle_short_header_packet(
|
||||
self, data: bytes, addr: tuple[str, int]
|
||||
) -> None:
|
||||
"""Handle short header packets using address-based fallback routing."""
|
||||
"""Handle short header packets for established connections."""
|
||||
try:
|
||||
# Check if we have a connection for this address
|
||||
print(f"🔧 SHORT_HDR: Handling short header packet from {addr}")
|
||||
|
||||
# First, try address-based lookup
|
||||
dest_cid = self._addr_to_cid.get(addr)
|
||||
if dest_cid:
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
elif dest_cid in self._pending_connections:
|
||||
quic_conn = self._pending_connections[dest_cid]
|
||||
await self._handle_pending_connection(
|
||||
quic_conn, data, addr, dest_cid
|
||||
if dest_cid and dest_cid in self._connections:
|
||||
print(f"✅ SHORT_HDR: Routing via address mapping to {dest_cid.hex()}")
|
||||
connection = self._connections[dest_cid]
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
return
|
||||
|
||||
# Fallback: try to extract CID from packet
|
||||
if len(data) >= 9: # 1 byte header + 8 byte CID
|
||||
potential_cid = data[1:9]
|
||||
|
||||
if potential_cid in self._connections:
|
||||
print(
|
||||
f"✅ SHORT_HDR: Routing via extracted CID {potential_cid.hex()}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Received short header packet from unknown address {addr}"
|
||||
)
|
||||
connection = self._connections[potential_cid]
|
||||
|
||||
# Update mappings for future packets
|
||||
self._addr_to_cid[addr] = potential_cid
|
||||
self._cid_to_addr[potential_cid] = addr
|
||||
|
||||
await self._route_to_connection(connection, data, addr)
|
||||
return
|
||||
|
||||
print(f"❌ SHORT_HDR: No matching connection found for {addr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling short header packet from {addr}: {e}")
|
||||
@ -858,7 +887,7 @@ class QUICListener(IListener):
|
||||
|
||||
# Create multiaddr for this connection
|
||||
host, port = addr
|
||||
quic_version = next(iter(self._quic_configs.keys()))
|
||||
quic_version = "quic"
|
||||
remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}")
|
||||
|
||||
from .connection import QUICConnection
|
||||
@ -872,9 +901,19 @@ class QUICListener(IListener):
|
||||
maddr=remote_maddr,
|
||||
transport=self._transport,
|
||||
security_manager=self._security_manager,
|
||||
listener_socket=self._socket,
|
||||
)
|
||||
|
||||
print(
|
||||
f"🔧 PROMOTION: Created connection with socket: {self._socket is not None}"
|
||||
)
|
||||
print(
|
||||
f"🔧 PROMOTION: Socket type: {type(self._socket) if self._socket else 'None'}"
|
||||
)
|
||||
|
||||
self._connections[dest_cid] = connection
|
||||
self._addr_to_cid[addr] = dest_cid
|
||||
self._cid_to_addr[dest_cid] = addr
|
||||
|
||||
if self._nursery:
|
||||
await connection.connect(self._nursery)
|
||||
@ -1178,9 +1217,31 @@ class QUICListener(IListener):
|
||||
async def _handle_new_established_connection(
|
||||
self, connection: QUICConnection
|
||||
) -> None:
|
||||
"""Handle a newly established connection."""
|
||||
"""Handle newly established connection with proper stream management."""
|
||||
try:
|
||||
await self._handler(connection)
|
||||
logger.debug(
|
||||
f"Handling new established connection from {connection._remote_addr}"
|
||||
)
|
||||
|
||||
# Accept incoming streams and pass them to the handler
|
||||
while not connection.is_closed:
|
||||
try:
|
||||
print(f"🔧 CONN_HANDLER: Waiting for stream...")
|
||||
stream = await connection.accept_stream(timeout=1.0)
|
||||
print(f"✅ CONN_HANDLER: Accepted stream {stream.stream_id}")
|
||||
|
||||
if self._nursery:
|
||||
# Pass STREAM to handler, not connection
|
||||
self._nursery.start_soon(self._handler, stream)
|
||||
print(
|
||||
f"✅ CONN_HANDLER: Started handler for stream {stream.stream_id}"
|
||||
)
|
||||
except trio.TooSlowError:
|
||||
continue # Timeout is normal
|
||||
except Exception as e:
|
||||
logger.error(f"Error accepting stream: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in connection handler: {e}")
|
||||
await connection.close()
|
||||
|
||||
Reference in New Issue
Block a user