fix: handle short quic headers and compelete connection establishment

This commit is contained in:
Akash Mondal
2025-06-29 06:27:54 +00:00
committed by lla-dane
parent 8263052f88
commit 2689040d48
3 changed files with 150 additions and 47 deletions

View File

@ -25,15 +25,16 @@ PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def _echo_stream_handler(stream: INetStream) -> None:
"""
Echo stream handler - unchanged from TCP version.
Demonstrates transport abstraction: same handler works for both TCP and QUIC.
"""
# Wait until EOF
msg = await stream.read()
await stream.write(msg)
await stream.close()
try:
msg = await stream.read()
await stream.write(msg)
await stream.close()
except Exception as e:
print(f"Echo handler error: {e}")
try:
await stream.close()
except:
pass
async def run_server(port: int, seed: int | None = None) -> None:

View File

@ -82,6 +82,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
transport: "QUICTransport",
security_manager: Optional["QUICTLSConfigManager"] = None,
resource_scope: Any | None = None,
listener_socket: trio.socket.SocketType | None = None,
):
"""
Initialize QUIC connection with security integration.
@ -96,6 +97,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
transport: Parent QUIC transport
security_manager: Security manager for TLS/certificate handling
resource_scope: Resource manager scope for tracking
listener_socket: Socket of listener to transmit data
"""
self._quic = quic_connection
@ -109,7 +111,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._resource_scope = resource_scope
# Trio networking - socket may be provided by listener
self._socket: trio.socket.SocketType | None = None
self._socket = listener_socket if listener_socket else None
self._owns_socket = listener_socket is None
self._connected_event = trio.Event()
self._closed_event = trio.Event()
@ -974,23 +977,56 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._closed_event.set()
async def _handle_stream_data(self, event: events.StreamDataReceived) -> None:
"""Stream data handling with proper error management."""
"""Handle stream data events - create streams and add to accept queue."""
stream_id = event.stream_id
self._stats["bytes_received"] += len(event.data)
try:
with QUICErrorContext("stream_data_handling", "stream"):
# Get or create stream
stream = await self._get_or_create_stream(stream_id)
print(f"🔧 STREAM_DATA: Handling data for stream {stream_id}")
# Forward data to stream
await stream.handle_data_received(event.data, event.end_stream)
if stream_id not in self._streams:
if self._is_incoming_stream(stream_id):
print(f"🔧 STREAM_DATA: Creating new incoming stream {stream_id}")
from .stream import QUICStream, StreamDirection
stream = QUICStream(
connection=self,
stream_id=stream_id,
direction=StreamDirection.INBOUND,
resource_scope=self._resource_scope,
remote_addr=self._remote_addr,
)
# Store the stream
self._streams[stream_id] = stream
async with self._accept_queue_lock:
self._stream_accept_queue.append(stream)
self._stream_accept_event.set()
print(
f"✅ STREAM_DATA: Added stream {stream_id} to accept queue"
)
async with self._stream_count_lock:
self._inbound_stream_count += 1
self._stats["streams_opened"] += 1
else:
print(
f"❌ STREAM_DATA: Unexpected outbound stream {stream_id} in data event"
)
return
stream = self._streams[stream_id]
await stream.handle_data_received(event.data, event.end_stream)
print(
f"✅ STREAM_DATA: Forwarded {len(event.data)} bytes to stream {stream_id}"
)
except Exception as e:
logger.error(f"Error handling stream data for stream {stream_id}: {e}")
# Reset the stream on error
if stream_id in self._streams:
await self._streams[stream_id].reset(error_code=1)
print(f"❌ STREAM_DATA: Error: {e}")
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
"""Get existing stream or create new inbound stream."""
@ -1103,20 +1139,24 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Network transmission
async def _transmit(self) -> None:
"""Send pending datagrams using trio."""
"""Transmit pending QUIC packets using available socket."""
sock = self._socket
if not sock:
print("No socket to transmit")
return
try:
datagrams = self._quic.datagrams_to_send(now=time.time())
current_time = time.time()
datagrams = self._quic.datagrams_to_send(now=current_time)
for data, addr in datagrams:
await sock.sendto(data, addr)
self._stats["packets_sent"] += 1
self._stats["bytes_sent"] += len(data)
# Update stats if available
if hasattr(self, "_stats"):
self._stats["packets_sent"] += 1
self._stats["bytes_sent"] += len(data)
except Exception as e:
logger.error(f"Failed to send datagram: {e}")
logger.error(f"Transmission error: {e}")
await self._handle_connection_error(e)
# Additional methods for stream data processing
@ -1179,8 +1219,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
await self._transmit() # Send close frames
# Close socket
if self._socket:
if self._socket and self._owns_socket:
self._socket.close()
self._socket = None
self._streams.clear()
self._closed_event.set()

View File

@ -160,11 +160,20 @@ class QUICListener(IListener):
is_long_header = (first_byte & 0x80) != 0
if not is_long_header:
# Short header packet - extract destination connection ID
# For short headers, we need to know the connection ID length
# This is typically managed by the connection state
# For now, we'll handle this in the connection routing logic
return None
cid_length = 8 # We are using standard CID length everywhere
if len(data) < 1 + cid_length:
return None
dest_cid = data[1 : 1 + cid_length]
return QUICPacketInfo(
version=1, # Assume QUIC v1 for established connections
destination_cid=dest_cid,
source_cid=b"", # Not available in short header
packet_type=QuicPacketType.ONE_RTT,
token=b"",
)
# Long header packet parsing
offset = 1
@ -276,6 +285,13 @@ class QUICListener(IListener):
# Parse packet to extract connection information
packet_info = self.parse_quic_packet(data)
print(f"🔧 DEBUG: Packet info: {packet_info is not None}")
if packet_info:
print(f"🔧 DEBUG: Packet type: {packet_info.packet_type}")
print(
f"🔧 DEBUG: Is short header: {packet_info.packet_type == QuicPacketType.ONE_RTT}"
)
print(
f"🔧 DEBUG: Pending connections: {[cid.hex() for cid in self._pending_connections.keys()]}"
)
@ -606,23 +622,36 @@ class QUICListener(IListener):
async def _handle_short_header_packet(
self, data: bytes, addr: tuple[str, int]
) -> None:
"""Handle short header packets using address-based fallback routing."""
"""Handle short header packets for established connections."""
try:
# Check if we have a connection for this address
print(f"🔧 SHORT_HDR: Handling short header packet from {addr}")
# First, try address-based lookup
dest_cid = self._addr_to_cid.get(addr)
if dest_cid:
if dest_cid in self._connections:
connection = self._connections[dest_cid]
await self._route_to_connection(connection, data, addr)
elif dest_cid in self._pending_connections:
quic_conn = self._pending_connections[dest_cid]
await self._handle_pending_connection(
quic_conn, data, addr, dest_cid
if dest_cid and dest_cid in self._connections:
print(f"✅ SHORT_HDR: Routing via address mapping to {dest_cid.hex()}")
connection = self._connections[dest_cid]
await self._route_to_connection(connection, data, addr)
return
# Fallback: try to extract CID from packet
if len(data) >= 9: # 1 byte header + 8 byte CID
potential_cid = data[1:9]
if potential_cid in self._connections:
print(
f"✅ SHORT_HDR: Routing via extracted CID {potential_cid.hex()}"
)
else:
logger.debug(
f"Received short header packet from unknown address {addr}"
)
connection = self._connections[potential_cid]
# Update mappings for future packets
self._addr_to_cid[addr] = potential_cid
self._cid_to_addr[potential_cid] = addr
await self._route_to_connection(connection, data, addr)
return
print(f"❌ SHORT_HDR: No matching connection found for {addr}")
except Exception as e:
logger.error(f"Error handling short header packet from {addr}: {e}")
@ -858,7 +887,7 @@ class QUICListener(IListener):
# Create multiaddr for this connection
host, port = addr
quic_version = next(iter(self._quic_configs.keys()))
quic_version = "quic"
remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}")
from .connection import QUICConnection
@ -872,9 +901,19 @@ class QUICListener(IListener):
maddr=remote_maddr,
transport=self._transport,
security_manager=self._security_manager,
listener_socket=self._socket,
)
print(
f"🔧 PROMOTION: Created connection with socket: {self._socket is not None}"
)
print(
f"🔧 PROMOTION: Socket type: {type(self._socket) if self._socket else 'None'}"
)
self._connections[dest_cid] = connection
self._addr_to_cid[addr] = dest_cid
self._cid_to_addr[dest_cid] = addr
if self._nursery:
await connection.connect(self._nursery)
@ -1178,9 +1217,31 @@ class QUICListener(IListener):
async def _handle_new_established_connection(
self, connection: QUICConnection
) -> None:
"""Handle a newly established connection."""
"""Handle newly established connection with proper stream management."""
try:
await self._handler(connection)
logger.debug(
f"Handling new established connection from {connection._remote_addr}"
)
# Accept incoming streams and pass them to the handler
while not connection.is_closed:
try:
print(f"🔧 CONN_HANDLER: Waiting for stream...")
stream = await connection.accept_stream(timeout=1.0)
print(f"✅ CONN_HANDLER: Accepted stream {stream.stream_id}")
if self._nursery:
# Pass STREAM to handler, not connection
self._nursery.start_soon(self._handler, stream)
print(
f"✅ CONN_HANDLER: Started handler for stream {stream.stream_id}"
)
except trio.TooSlowError:
continue # Timeout is normal
except Exception as e:
logger.error(f"Error accepting stream: {e}")
break
except Exception as e:
logger.error(f"Error in connection handler: {e}")
await connection.close()