mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: add quic utils test and improve connection performance
This commit is contained in:
@ -3,14 +3,16 @@ QUIC Connection implementation.
|
||||
Manages bidirectional QUIC connections with integrated stream multiplexing.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from aioquic.quic import events
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.events import QuicEvent
|
||||
from cryptography import x509
|
||||
import multiaddr
|
||||
import trio
|
||||
@ -104,12 +106,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._connected_event = trio.Event()
|
||||
self._closed_event = trio.Event()
|
||||
|
||||
# Stream management
|
||||
self._streams: dict[int, QUICStream] = {}
|
||||
self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups
|
||||
self._next_stream_id: int = self._calculate_initial_stream_id()
|
||||
self._stream_handler: TQUICStreamHandlerFn | None = None
|
||||
self._stream_id_lock = trio.Lock()
|
||||
self._stream_count_lock = trio.Lock()
|
||||
|
||||
# Single lock for all stream operations
|
||||
self._stream_lock = trio.Lock()
|
||||
|
||||
# Stream counting and limits
|
||||
self._outbound_stream_count = 0
|
||||
@ -118,7 +121,6 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Stream acceptance for incoming streams
|
||||
self._stream_accept_queue: list[QUICStream] = []
|
||||
self._stream_accept_event = trio.Event()
|
||||
self._accept_queue_lock = trio.Lock()
|
||||
|
||||
# Connection state
|
||||
self._closed: bool = False
|
||||
@ -143,9 +145,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._retired_connection_ids: set[bytes] = set()
|
||||
self._connection_id_sequence_numbers: set[int] = set()
|
||||
|
||||
# Event processing control
|
||||
# Event processing control with batching
|
||||
self._event_processing_active = False
|
||||
self._pending_events: list[events.QuicEvent] = []
|
||||
self._event_batch: list[events.QuicEvent] = []
|
||||
self._event_batch_size = 10
|
||||
self._last_event_time = 0.0
|
||||
|
||||
# Set quic connection configuration
|
||||
self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT
|
||||
@ -250,6 +254,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Get the current connection ID."""
|
||||
return self._current_connection_id
|
||||
|
||||
# Fast stream lookup with caching
|
||||
def _get_stream_fast(self, stream_id: int) -> QUICStream | None:
|
||||
"""Get stream with caching for performance."""
|
||||
# Try cache first
|
||||
stream = self._stream_cache.get(stream_id)
|
||||
if stream is not None:
|
||||
return stream
|
||||
|
||||
# Fallback to main dict
|
||||
stream = self._streams.get(stream_id)
|
||||
if stream is not None:
|
||||
self._stream_cache[stream_id] = stream
|
||||
|
||||
return stream
|
||||
|
||||
# Connection lifecycle methods
|
||||
|
||||
async def start(self) -> None:
|
||||
@ -389,8 +408,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
try:
|
||||
while not self._closed:
|
||||
# Process QUIC events
|
||||
await self._process_quic_events()
|
||||
# Batch process events
|
||||
await self._process_quic_events_batched()
|
||||
|
||||
# Handle timer events
|
||||
await self._handle_timer_events()
|
||||
@ -421,12 +440,25 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
cid_stats = self.get_connection_id_stats()
|
||||
logger.debug(f"Connection ID stats: {cid_stats}")
|
||||
|
||||
# Clean cache periodically
|
||||
await self._cleanup_cache()
|
||||
|
||||
# Sleep for maintenance interval
|
||||
await trio.sleep(30.0) # 30 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic maintenance: {e}")
|
||||
|
||||
async def _cleanup_cache(self) -> None:
|
||||
"""Clean up stream cache periodically to prevent memory leaks."""
|
||||
if len(self._stream_cache) > 100: # Arbitrary threshold
|
||||
# Remove closed streams from cache
|
||||
closed_stream_ids = [
|
||||
sid for sid, stream in self._stream_cache.items() if stream.is_closed()
|
||||
]
|
||||
for sid in closed_stream_ids:
|
||||
self._stream_cache.pop(sid, None)
|
||||
|
||||
async def _client_packet_receiver(self) -> None:
|
||||
"""Receive packets for client connections."""
|
||||
logger.debug("Starting client packet receiver")
|
||||
@ -442,8 +474,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Feed packet to QUIC connection
|
||||
self._quic.receive_datagram(data, addr, now=time.time())
|
||||
|
||||
# Process any events that result from the packet
|
||||
await self._process_quic_events()
|
||||
# Batch process events
|
||||
await self._process_quic_events_batched()
|
||||
|
||||
# Send any response packets
|
||||
await self._transmit()
|
||||
@ -675,15 +707,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if not self._started:
|
||||
raise QUICConnectionError("Connection not started")
|
||||
|
||||
# Check stream limits
|
||||
async with self._stream_count_lock:
|
||||
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
|
||||
raise QUICStreamLimitError(
|
||||
f"Maximum outbound streams ({self.MAX_OUTGOING_STREAMS}) reached"
|
||||
)
|
||||
|
||||
# Use single lock for all stream operations
|
||||
with trio.move_on_after(timeout):
|
||||
async with self._stream_id_lock:
|
||||
async with self._stream_lock:
|
||||
# Check stream limits inside lock
|
||||
if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS:
|
||||
raise QUICStreamLimitError(
|
||||
"Maximum outbound streams "
|
||||
f"({self.MAX_OUTGOING_STREAMS}) reached"
|
||||
)
|
||||
|
||||
# Generate next stream ID
|
||||
stream_id = self._next_stream_id
|
||||
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
|
||||
@ -697,10 +730,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
)
|
||||
|
||||
self._streams[stream_id] = stream
|
||||
self._stream_cache[stream_id] = stream # Add to cache
|
||||
|
||||
async with self._stream_count_lock:
|
||||
self._outbound_stream_count += 1
|
||||
self._stats["streams_opened"] += 1
|
||||
self._outbound_stream_count += 1
|
||||
self._stats["streams_opened"] += 1
|
||||
|
||||
logger.debug(f"Opened outbound QUIC stream {stream_id}")
|
||||
return stream
|
||||
@ -737,7 +770,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if self._closed:
|
||||
raise MuxedConnUnavailable("QUIC connection is closed")
|
||||
|
||||
async with self._accept_queue_lock:
|
||||
# Use single lock for stream acceptance
|
||||
async with self._stream_lock:
|
||||
if self._stream_accept_queue:
|
||||
stream = self._stream_accept_queue.pop(0)
|
||||
logger.debug(f"Accepted inbound stream {stream.stream_id}")
|
||||
@ -769,10 +803,12 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""
|
||||
if stream_id in self._streams:
|
||||
stream = self._streams.pop(stream_id)
|
||||
# Remove from cache too
|
||||
self._stream_cache.pop(stream_id, None)
|
||||
|
||||
# Update stream counts asynchronously
|
||||
async def update_counts() -> None:
|
||||
async with self._stream_count_lock:
|
||||
async with self._stream_lock:
|
||||
if stream.direction == StreamDirection.OUTBOUND:
|
||||
self._outbound_stream_count = max(
|
||||
0, self._outbound_stream_count - 1
|
||||
@ -789,29 +825,140 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
logger.debug(f"Removed stream {stream_id} from connection")
|
||||
|
||||
async def _process_quic_events(self) -> None:
|
||||
"""Process all pending QUIC events."""
|
||||
# Batched event processing to reduce overhead
|
||||
async def _process_quic_events_batched(self) -> None:
|
||||
"""Process QUIC events in batches for better performance."""
|
||||
if self._event_processing_active:
|
||||
return # Prevent recursion
|
||||
|
||||
self._event_processing_active = True
|
||||
|
||||
try:
|
||||
current_time = time.time()
|
||||
events_processed = 0
|
||||
while True:
|
||||
|
||||
# Collect events into batch
|
||||
while events_processed < self._event_batch_size:
|
||||
event = self._quic.next_event()
|
||||
if event is None:
|
||||
break
|
||||
|
||||
self._event_batch.append(event)
|
||||
events_processed += 1
|
||||
await self._handle_quic_event(event)
|
||||
|
||||
if events_processed > 0:
|
||||
logger.debug(f"Processed {events_processed} QUIC events")
|
||||
# Process batch if we have events or timeout
|
||||
if self._event_batch and (
|
||||
len(self._event_batch) >= self._event_batch_size
|
||||
or current_time - self._last_event_time > 0.01 # 10ms timeout
|
||||
):
|
||||
await self._process_event_batch()
|
||||
self._event_batch.clear()
|
||||
self._last_event_time = current_time
|
||||
|
||||
finally:
|
||||
self._event_processing_active = False
|
||||
|
||||
async def _process_event_batch(self) -> None:
|
||||
"""Process a batch of events efficiently."""
|
||||
if not self._event_batch:
|
||||
return
|
||||
|
||||
# Group events by type for batch processing where possible
|
||||
events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list)
|
||||
for event in self._event_batch:
|
||||
events_by_type[type(event).__name__].append(event)
|
||||
|
||||
# Process events by type
|
||||
for event_type, event_list in events_by_type.items():
|
||||
if event_type == type(events.StreamDataReceived).__name__:
|
||||
await self._handle_stream_data_batch(
|
||||
cast(list[events.StreamDataReceived], event_list)
|
||||
)
|
||||
else:
|
||||
# Process other events individually
|
||||
for event in event_list:
|
||||
await self._handle_quic_event(event)
|
||||
|
||||
logger.debug(f"Processed batch of {len(self._event_batch)} events")
|
||||
|
||||
async def _handle_stream_data_batch(
|
||||
self, events_list: list[events.StreamDataReceived]
|
||||
) -> None:
|
||||
"""Handle stream data events in batch for better performance."""
|
||||
# Group by stream ID
|
||||
events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list)
|
||||
for event in events_list:
|
||||
events_by_stream[event.stream_id].append(event)
|
||||
|
||||
# Process each stream's events
|
||||
for stream_id, stream_events in events_by_stream.items():
|
||||
stream = self._get_stream_fast(stream_id) # Use fast lookup
|
||||
|
||||
if not stream:
|
||||
if self._is_incoming_stream(stream_id):
|
||||
try:
|
||||
stream = await self._create_inbound_stream(stream_id)
|
||||
except QUICStreamLimitError:
|
||||
# Reset stream if we can't handle it
|
||||
self._quic.reset_stream(stream_id, error_code=0x04)
|
||||
await self._transmit()
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected outbound stream {stream_id} in data event"
|
||||
)
|
||||
continue
|
||||
|
||||
# Process all events for this stream
|
||||
for received_event in stream_events:
|
||||
if hasattr(received_event, "data"):
|
||||
self._stats["bytes_received"] += len(received_event.data) # type: ignore
|
||||
|
||||
if hasattr(received_event, "end_stream"):
|
||||
await stream.handle_data_received(
|
||||
received_event.data, # type: ignore
|
||||
received_event.end_stream, # type: ignore
|
||||
)
|
||||
|
||||
async def _create_inbound_stream(self, stream_id: int) -> QUICStream:
|
||||
"""Create inbound stream with proper limit checking."""
|
||||
async with self._stream_lock:
|
||||
# Double-check stream doesn't exist
|
||||
existing_stream = self._streams.get(stream_id)
|
||||
if existing_stream:
|
||||
return existing_stream
|
||||
|
||||
# Check limits
|
||||
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
|
||||
logger.warning(f"Rejecting inbound stream {stream_id}: limit reached")
|
||||
raise QUICStreamLimitError("Too many inbound streams")
|
||||
|
||||
# Create stream
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
direction=StreamDirection.INBOUND,
|
||||
resource_scope=self._resource_scope,
|
||||
remote_addr=self._remote_addr,
|
||||
)
|
||||
|
||||
self._streams[stream_id] = stream
|
||||
self._stream_cache[stream_id] = stream # Add to cache
|
||||
self._inbound_stream_count += 1
|
||||
self._stats["streams_accepted"] += 1
|
||||
|
||||
# Add to accept queue
|
||||
self._stream_accept_queue.append(stream)
|
||||
self._stream_accept_event.set()
|
||||
|
||||
logger.debug(f"Created inbound stream {stream_id}")
|
||||
return stream
|
||||
|
||||
async def _process_quic_events(self) -> None:
|
||||
"""Process all pending QUIC events."""
|
||||
# Delegate to batched processing for better performance
|
||||
await self._process_quic_events_batched()
|
||||
|
||||
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
|
||||
"""Handle a single QUIC event with COMPLETE event type coverage."""
|
||||
logger.debug(f"Handling QUIC event: {type(event).__name__}")
|
||||
@ -929,8 +1076,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
f"stream_id={event.stream_id}, error_code={event.error_code}"
|
||||
)
|
||||
|
||||
if event.stream_id in self._streams:
|
||||
stream: QUICStream = self._streams[event.stream_id]
|
||||
# Use fast lookup
|
||||
stream = self._get_stream_fast(event.stream_id)
|
||||
if stream:
|
||||
# Handle stop sending on the stream if method exists
|
||||
await stream.handle_stop_sending(event.error_code)
|
||||
|
||||
@ -964,6 +1112,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await stream.close()
|
||||
|
||||
self._streams.clear()
|
||||
self._stream_cache.clear() # Clear cache too
|
||||
self._closed = True
|
||||
self._closed_event.set()
|
||||
|
||||
@ -978,39 +1127,19 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._stats["bytes_received"] += len(event.data)
|
||||
|
||||
try:
|
||||
if stream_id not in self._streams:
|
||||
# Use fast lookup
|
||||
stream = self._get_stream_fast(stream_id)
|
||||
|
||||
if not stream:
|
||||
if self._is_incoming_stream(stream_id):
|
||||
logger.debug(f"Creating new incoming stream {stream_id}")
|
||||
|
||||
from .stream import QUICStream, StreamDirection
|
||||
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
direction=StreamDirection.INBOUND,
|
||||
resource_scope=self._resource_scope,
|
||||
remote_addr=self._remote_addr,
|
||||
)
|
||||
|
||||
# Store the stream
|
||||
self._streams[stream_id] = stream
|
||||
|
||||
async with self._accept_queue_lock:
|
||||
self._stream_accept_queue.append(stream)
|
||||
self._stream_accept_event.set()
|
||||
logger.debug(f"Added stream {stream_id} to accept queue")
|
||||
|
||||
async with self._stream_count_lock:
|
||||
self._inbound_stream_count += 1
|
||||
self._stats["streams_opened"] += 1
|
||||
|
||||
stream = await self._create_inbound_stream(stream_id)
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected outbound stream {stream_id} in data event"
|
||||
)
|
||||
return
|
||||
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_data_received(event.data, event.end_stream)
|
||||
|
||||
except Exception as e:
|
||||
@ -1019,8 +1148,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
|
||||
"""Get existing stream or create new inbound stream."""
|
||||
if stream_id in self._streams:
|
||||
return self._streams[stream_id]
|
||||
# Use fast lookup
|
||||
stream = self._get_stream_fast(stream_id)
|
||||
if stream:
|
||||
return stream
|
||||
|
||||
# Check if this is an incoming stream
|
||||
is_incoming = self._is_incoming_stream(stream_id)
|
||||
@ -1031,49 +1162,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
f"Received data for unknown outbound stream {stream_id}"
|
||||
)
|
||||
|
||||
# Check stream limits for incoming streams
|
||||
async with self._stream_count_lock:
|
||||
if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS:
|
||||
logger.warning(f"Rejecting incoming stream {stream_id}: limit reached")
|
||||
# Send reset to reject the stream
|
||||
self._quic.reset_stream(
|
||||
stream_id, error_code=0x04
|
||||
) # STREAM_LIMIT_ERROR
|
||||
await self._transmit()
|
||||
raise QUICStreamLimitError("Too many inbound streams")
|
||||
|
||||
# Create new inbound stream
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
direction=StreamDirection.INBOUND,
|
||||
resource_scope=self._resource_scope,
|
||||
remote_addr=self._remote_addr,
|
||||
)
|
||||
|
||||
self._streams[stream_id] = stream
|
||||
|
||||
async with self._stream_count_lock:
|
||||
self._inbound_stream_count += 1
|
||||
self._stats["streams_accepted"] += 1
|
||||
|
||||
# Add to accept queue and notify handler
|
||||
async with self._accept_queue_lock:
|
||||
self._stream_accept_queue.append(stream)
|
||||
self._stream_accept_event.set()
|
||||
|
||||
# Handle directly with stream handler if available
|
||||
if self._stream_handler:
|
||||
try:
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(self._stream_handler, stream)
|
||||
else:
|
||||
await self._stream_handler(stream)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream handler for stream {stream_id}: {e}")
|
||||
|
||||
logger.debug(f"Created inbound stream {stream_id}")
|
||||
return stream
|
||||
return await self._create_inbound_stream(stream_id)
|
||||
|
||||
def _is_incoming_stream(self, stream_id: int) -> bool:
|
||||
"""
|
||||
@ -1095,9 +1185,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
stream_id = event.stream_id
|
||||
self._stats["streams_reset"] += 1
|
||||
|
||||
if stream_id in self._streams:
|
||||
# Use fast lookup
|
||||
stream = self._get_stream_fast(stream_id)
|
||||
if stream:
|
||||
try:
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_reset(event.error_code)
|
||||
logger.debug(
|
||||
f"Handled reset for stream {stream_id}"
|
||||
@ -1137,12 +1228,20 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
current_time = time.time()
|
||||
datagrams = self._quic.datagrams_to_send(now=current_time)
|
||||
|
||||
# Batch stats updates
|
||||
packet_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
for data, addr in datagrams:
|
||||
await sock.sendto(data, addr)
|
||||
# Update stats if available
|
||||
if hasattr(self, "_stats"):
|
||||
self._stats["packets_sent"] += 1
|
||||
self._stats["bytes_sent"] += len(data)
|
||||
packet_count += 1
|
||||
total_bytes += len(data)
|
||||
|
||||
# Update stats in batch
|
||||
if packet_count > 0:
|
||||
self._stats["packets_sent"] += packet_count
|
||||
self._stats["bytes_sent"] += total_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transmission error: {e}")
|
||||
@ -1217,6 +1316,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._socket = None
|
||||
|
||||
self._streams.clear()
|
||||
self._stream_cache.clear() # Clear cache
|
||||
self._closed_event.set()
|
||||
|
||||
logger.debug(f"QUIC connection to {self._remote_peer_id} closed")
|
||||
@ -1328,6 +1428,9 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"max_streams": self.MAX_CONCURRENT_STREAMS,
|
||||
"stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS,
|
||||
"stats": self._stats.copy(),
|
||||
"cache_size": len(
|
||||
self._stream_cache
|
||||
), # Include cache metrics for monitoring
|
||||
}
|
||||
|
||||
def get_active_streams(self) -> list[QUICStream]:
|
||||
|
||||
@ -267,56 +267,37 @@ class QUICListener(IListener):
|
||||
return value, 8
|
||||
|
||||
async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""Process incoming QUIC packet with fine-grained locking."""
|
||||
"""Process incoming QUIC packet with optimized routing."""
|
||||
try:
|
||||
self._stats["packets_processed"] += 1
|
||||
self._stats["bytes_received"] += len(data)
|
||||
|
||||
logger.debug(f"Processing packet of {len(data)} bytes from {addr}")
|
||||
|
||||
# Parse packet header OUTSIDE the lock
|
||||
packet_info = self.parse_quic_packet(data)
|
||||
if packet_info is None:
|
||||
logger.error(f"Failed to parse packet header quic packet from {addr}")
|
||||
self._stats["invalid_packets"] += 1
|
||||
return
|
||||
|
||||
dest_cid = packet_info.destination_cid
|
||||
connection_obj = None
|
||||
pending_quic_conn = None
|
||||
|
||||
# Single lock acquisition with all lookups
|
||||
async with self._connection_lock:
|
||||
if dest_cid in self._connections:
|
||||
connection_obj = self._connections[dest_cid]
|
||||
logger.debug(f"Routing to established connection {dest_cid.hex()}")
|
||||
connection_obj = self._connections.get(dest_cid)
|
||||
pending_quic_conn = self._pending_connections.get(dest_cid)
|
||||
|
||||
elif dest_cid in self._pending_connections:
|
||||
pending_quic_conn = self._pending_connections[dest_cid]
|
||||
logger.debug(f"Routing to pending connection {dest_cid.hex()}")
|
||||
|
||||
else:
|
||||
# Check if this is a new connection
|
||||
if packet_info.packet_type.name == "INITIAL":
|
||||
logger.debug(
|
||||
f"Received INITIAL Packet Creating new conn for {addr}"
|
||||
)
|
||||
|
||||
# Create new connection INSIDE the lock for safety
|
||||
if not connection_obj and not pending_quic_conn:
|
||||
if packet_info.packet_type == QuicPacketType.INITIAL:
|
||||
pending_quic_conn = await self._handle_new_connection(
|
||||
data, addr, packet_info
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
# CRITICAL: Process packets OUTSIDE the lock to prevent deadlock
|
||||
# Process outside the lock
|
||||
if connection_obj:
|
||||
# Handle established connection
|
||||
await self._handle_established_connection_packet(
|
||||
connection_obj, data, addr, dest_cid
|
||||
)
|
||||
|
||||
elif pending_quic_conn:
|
||||
# Handle pending connection
|
||||
await self._handle_pending_connection_packet(
|
||||
pending_quic_conn, data, addr, dest_cid
|
||||
)
|
||||
@ -431,6 +412,7 @@ class QUICListener(IListener):
|
||||
f"No configuration found for version 0x{packet_info.version:08x}"
|
||||
)
|
||||
await self._send_version_negotiation(addr, packet_info.source_cid)
|
||||
return None
|
||||
|
||||
if not quic_config:
|
||||
raise QUICListenError("Cannot determine QUIC configuration")
|
||||
|
||||
@ -108,21 +108,21 @@ def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
|
||||
# Try to get IPv4 address
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
|
||||
except ValueError:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to get IPv6 address if IPv4 not found
|
||||
if host is None:
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
|
||||
except ValueError:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get UDP port
|
||||
try:
|
||||
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if host is None or port is None:
|
||||
@ -203,8 +203,7 @@ def create_quic_multiaddr(
|
||||
if version == "quic-v1" or version == "/quic-v1":
|
||||
quic_proto = QUIC_V1_PROTOCOL
|
||||
elif version == "quic" or version == "/quic":
|
||||
# This is DRAFT Protocol
|
||||
quic_proto = QUIC_V1_PROTOCOL
|
||||
quic_proto = QUIC_DRAFT29_PROTOCOL
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
|
||||
|
||||
|
||||
@ -192,7 +192,7 @@ class TestQUICConnection:
|
||||
await trio.sleep(10) # Longer than timeout
|
||||
|
||||
with patch.object(
|
||||
quic_connection._stream_id_lock, "acquire", side_effect=slow_acquire
|
||||
quic_connection._stream_lock, "acquire", side_effect=slow_acquire
|
||||
):
|
||||
with pytest.raises(
|
||||
QUICStreamTimeoutError, match="Stream creation timed out"
|
||||
|
||||
@ -3,333 +3,319 @@ Test suite for QUIC multiaddr utilities.
|
||||
Focused tests covering essential functionality required for QUIC transport.
|
||||
"""
|
||||
|
||||
# TODO: Enable this test after multiaddr repo supports protocol quic-v1
|
||||
|
||||
# import pytest
|
||||
# from multiaddr import Multiaddr
|
||||
|
||||
# from libp2p.custom_types import TProtocol
|
||||
# from libp2p.transport.quic.exceptions import (
|
||||
# QUICInvalidMultiaddrError,
|
||||
# QUICUnsupportedVersionError,
|
||||
# )
|
||||
# from libp2p.transport.quic.utils import (
|
||||
# create_quic_multiaddr,
|
||||
# get_alpn_protocols,
|
||||
# is_quic_multiaddr,
|
||||
# multiaddr_to_quic_version,
|
||||
# normalize_quic_multiaddr,
|
||||
# quic_multiaddr_to_endpoint,
|
||||
# quic_version_to_wire_format,
|
||||
# )
|
||||
|
||||
|
||||
# class TestIsQuicMultiaddr:
|
||||
# """Test QUIC multiaddr detection."""
|
||||
|
||||
# def test_valid_quic_v1_multiaddrs(self):
|
||||
# """Test valid QUIC v1 multiaddrs are detected."""
|
||||
# valid_addrs = [
|
||||
# "/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
# "/ip4/192.168.1.1/udp/8080/quic-v1",
|
||||
# "/ip6/::1/udp/4001/quic-v1",
|
||||
# "/ip6/2001:db8::1/udp/5000/quic-v1",
|
||||
# ]
|
||||
|
||||
# for addr_str in valid_addrs:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
# def test_valid_quic_draft29_multiaddrs(self):
|
||||
# """Test valid QUIC draft-29 multiaddrs are detected."""
|
||||
# valid_addrs = [
|
||||
# "/ip4/127.0.0.1/udp/4001/quic",
|
||||
# "/ip4/10.0.0.1/udp/9000/quic",
|
||||
# "/ip6/::1/udp/4001/quic",
|
||||
# "/ip6/fe80::1/udp/6000/quic",
|
||||
# ]
|
||||
|
||||
# for addr_str in valid_addrs:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
# def test_invalid_multiaddrs(self):
|
||||
# """Test non-QUIC multiaddrs are not detected."""
|
||||
# invalid_addrs = [
|
||||
# "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
|
||||
# "/ip4/127.0.0.1/udp/4001", # UDP without QUIC
|
||||
# "/ip4/127.0.0.1/udp/4001/ws", # WebSocket
|
||||
# "/ip4/127.0.0.1/quic-v1", # Missing UDP
|
||||
# "/udp/4001/quic-v1", # Missing IP
|
||||
# "/dns4/example.com/tcp/443/tls", # Completely different
|
||||
# ]
|
||||
|
||||
# for addr_str in invalid_addrs:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# assert not is_quic_multiaddr(maddr),
|
||||
# f"Should not detect {addr_str} as QUIC"
|
||||
|
||||
# def test_malformed_multiaddrs(self):
|
||||
# """Test malformed multiaddrs don't crash."""
|
||||
# # These should not raise exceptions, just return False
|
||||
# malformed = [
|
||||
# Multiaddr("/ip4/127.0.0.1"),
|
||||
# Multiaddr("/invalid"),
|
||||
# ]
|
||||
|
||||
# for maddr in malformed:
|
||||
# assert not is_quic_multiaddr(maddr)
|
||||
|
||||
|
||||
# class TestQuicMultiaddrToEndpoint:
|
||||
# """Test endpoint extraction from QUIC multiaddrs."""
|
||||
|
||||
# def test_ipv4_extraction(self):
|
||||
# """Test IPv4 host/port extraction."""
|
||||
# test_cases = [
|
||||
# ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
|
||||
# ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
|
||||
# ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
|
||||
# ]
|
||||
|
||||
# for addr_str, expected in test_cases:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# result = quic_multiaddr_to_endpoint(maddr)
|
||||
# assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
# def test_ipv6_extraction(self):
|
||||
# """Test IPv6 host/port extraction."""
|
||||
# test_cases = [
|
||||
# ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
|
||||
# ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
|
||||
# ]
|
||||
|
||||
# for addr_str, expected in test_cases:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# result = quic_multiaddr_to_endpoint(maddr)
|
||||
# assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
# def test_invalid_multiaddr_raises_error(self):
|
||||
# """Test invalid multiaddrs raise appropriate errors."""
|
||||
# invalid_addrs = [
|
||||
# "/ip4/127.0.0.1/tcp/4001", # Not QUIC
|
||||
# "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
|
||||
# ]
|
||||
|
||||
# for addr_str in invalid_addrs:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# with pytest.raises(QUICInvalidMultiaddrError):
|
||||
# quic_multiaddr_to_endpoint(maddr)
|
||||
|
||||
|
||||
# class TestMultiaddrToQuicVersion:
|
||||
# """Test QUIC version extraction."""
|
||||
|
||||
# def test_quic_v1_detection(self):
|
||||
# """Test QUIC v1 version detection."""
|
||||
# addrs = [
|
||||
# "/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
# "/ip6/::1/udp/5000/quic-v1",
|
||||
# ]
|
||||
|
||||
# for addr_str in addrs:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# version = multiaddr_to_quic_version(maddr)
|
||||
# assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
|
||||
|
||||
# def test_quic_draft29_detection(self):
|
||||
# """Test QUIC draft-29 version detection."""
|
||||
# addrs = [
|
||||
# "/ip4/127.0.0.1/udp/4001/quic",
|
||||
# "/ip6/::1/udp/5000/quic",
|
||||
# ]
|
||||
|
||||
# for addr_str in addrs:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# version = multiaddr_to_quic_version(maddr)
|
||||
# assert version == "quic", f"Should detect quic for {addr_str}"
|
||||
|
||||
# def test_non_quic_raises_error(self):
|
||||
# """Test non-QUIC multiaddrs raise error."""
|
||||
# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
# with pytest.raises(QUICInvalidMultiaddrError):
|
||||
# multiaddr_to_quic_version(maddr)
|
||||
|
||||
|
||||
# class TestCreateQuicMultiaddr:
|
||||
# """Test QUIC multiaddr creation."""
|
||||
|
||||
# def test_ipv4_creation(self):
|
||||
# """Test IPv4 QUIC multiaddr creation."""
|
||||
# test_cases = [
|
||||
# ("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
|
||||
# ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
|
||||
# ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
|
||||
# ]
|
||||
|
||||
# for host, port, version, expected in test_cases:
|
||||
# result = create_quic_multiaddr(host, port, version)
|
||||
# assert str(result) == expected
|
||||
|
||||
# def test_ipv6_creation(self):
|
||||
# """Test IPv6 QUIC multiaddr creation."""
|
||||
# test_cases = [
|
||||
# ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
|
||||
# ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
|
||||
# ]
|
||||
|
||||
# for host, port, version, expected in test_cases:
|
||||
# result = create_quic_multiaddr(host, port, version)
|
||||
# assert str(result) == expected
|
||||
|
||||
# def test_default_version(self):
|
||||
# """Test default version is quic-v1."""
|
||||
# result = create_quic_multiaddr("127.0.0.1", 4001)
|
||||
# expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
# assert str(result) == expected
|
||||
|
||||
# def test_invalid_inputs_raise_errors(self):
|
||||
# """Test invalid inputs raise appropriate errors."""
|
||||
# # Invalid IP
|
||||
# with pytest.raises(QUICInvalidMultiaddrError):
|
||||
# create_quic_multiaddr("invalid-ip", 4001)
|
||||
|
||||
# # Invalid port
|
||||
# with pytest.raises(QUICInvalidMultiaddrError):
|
||||
# create_quic_multiaddr("127.0.0.1", 70000)
|
||||
|
||||
# with pytest.raises(QUICInvalidMultiaddrError):
|
||||
# create_quic_multiaddr("127.0.0.1", -1)
|
||||
|
||||
# # Invalid version
|
||||
# with pytest.raises(QUICInvalidMultiaddrError):
|
||||
# create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
|
||||
|
||||
|
||||
# class TestQuicVersionToWireFormat:
|
||||
# """Test QUIC version to wire format conversion."""
|
||||
|
||||
# def test_supported_versions(self):
|
||||
# """Test supported version conversions."""
|
||||
# test_cases = [
|
||||
# ("quic-v1", 0x00000001), # RFC 9000
|
||||
# ("quic", 0xFF00001D), # draft-29
|
||||
# ]
|
||||
|
||||
# for version, expected_wire in test_cases:
|
||||
# result = quic_version_to_wire_format(TProtocol(version))
|
||||
# assert result == expected_wire, f"Failed for version {version}"
|
||||
|
||||
# def test_unsupported_version_raises_error(self):
|
||||
# """Test unsupported versions raise error."""
|
||||
# with pytest.raises(QUICUnsupportedVersionError):
|
||||
# quic_version_to_wire_format(TProtocol("unsupported-version"))
|
||||
|
||||
|
||||
# class TestGetAlpnProtocols:
|
||||
# """Test ALPN protocol retrieval."""
|
||||
|
||||
# def test_returns_libp2p_protocols(self):
|
||||
# """Test returns expected libp2p ALPN protocols."""
|
||||
# protocols = get_alpn_protocols()
|
||||
# assert protocols == ["libp2p"]
|
||||
# assert isinstance(protocols, list)
|
||||
|
||||
# def test_returns_copy(self):
|
||||
# """Test returns a copy, not the original list."""
|
||||
# protocols1 = get_alpn_protocols()
|
||||
# protocols2 = get_alpn_protocols()
|
||||
|
||||
# # Modify one list
|
||||
# protocols1.append("test")
|
||||
|
||||
# # Other list should be unchanged
|
||||
# assert protocols2 == ["libp2p"]
|
||||
|
||||
|
||||
# class TestNormalizeQuicMultiaddr:
|
||||
# """Test QUIC multiaddr normalization."""
|
||||
|
||||
# def test_already_normalized(self):
|
||||
# """Test already normalized multiaddrs pass through."""
|
||||
# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
# maddr = Multiaddr(addr_str)
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICInvalidMultiaddrError,
|
||||
QUICUnsupportedVersionError,
|
||||
)
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_quic_multiaddr,
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
normalize_quic_multiaddr,
|
||||
quic_multiaddr_to_endpoint,
|
||||
quic_version_to_wire_format,
|
||||
)
|
||||
|
||||
|
||||
class TestIsQuicMultiaddr:
|
||||
"""Test QUIC multiaddr detection."""
|
||||
|
||||
def test_valid_quic_v1_multiaddrs(self):
|
||||
"""Test valid QUIC v1 multiaddrs are detected."""
|
||||
valid_addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip4/192.168.1.1/udp/8080/quic-v1",
|
||||
"/ip6/::1/udp/4001/quic-v1",
|
||||
"/ip6/2001:db8::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in valid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
def test_valid_quic_draft29_multiaddrs(self):
|
||||
"""Test valid QUIC draft-29 multiaddrs are detected."""
|
||||
valid_addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip4/10.0.0.1/udp/9000/quic",
|
||||
"/ip6/::1/udp/4001/quic",
|
||||
"/ip6/fe80::1/udp/6000/quic",
|
||||
]
|
||||
|
||||
for addr_str in valid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
def test_invalid_multiaddrs(self):
|
||||
"""Test non-QUIC multiaddrs are not detected."""
|
||||
invalid_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
|
||||
"/ip4/127.0.0.1/udp/4001", # UDP without QUIC
|
||||
"/ip4/127.0.0.1/udp/4001/ws", # WebSocket
|
||||
"/ip4/127.0.0.1/quic-v1", # Missing UDP
|
||||
"/udp/4001/quic-v1", # Missing IP
|
||||
"/dns4/example.com/tcp/443/tls", # Completely different
|
||||
]
|
||||
|
||||
for addr_str in invalid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
|
||||
|
||||
|
||||
class TestQuicMultiaddrToEndpoint:
|
||||
"""Test endpoint extraction from QUIC multiaddrs."""
|
||||
|
||||
def test_ipv4_extraction(self):
|
||||
"""Test IPv4 host/port extraction."""
|
||||
test_cases = [
|
||||
("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
|
||||
("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
|
||||
("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
|
||||
]
|
||||
|
||||
for addr_str, expected in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = quic_multiaddr_to_endpoint(maddr)
|
||||
assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
def test_ipv6_extraction(self):
|
||||
"""Test IPv6 host/port extraction."""
|
||||
test_cases = [
|
||||
("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
|
||||
("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
|
||||
]
|
||||
|
||||
for addr_str, expected in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = quic_multiaddr_to_endpoint(maddr)
|
||||
assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
def test_invalid_multiaddr_raises_error(self):
|
||||
"""Test invalid multiaddrs raise appropriate errors."""
|
||||
invalid_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/4001", # Not QUIC
|
||||
"/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
|
||||
]
|
||||
|
||||
for addr_str in invalid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
quic_multiaddr_to_endpoint(maddr)
|
||||
|
||||
|
||||
class TestMultiaddrToQuicVersion:
|
||||
"""Test QUIC version extraction."""
|
||||
|
||||
def test_quic_v1_detection(self):
|
||||
"""Test QUIC v1 version detection."""
|
||||
addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip6/::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
|
||||
|
||||
def test_quic_draft29_detection(self):
|
||||
"""Test QUIC draft-29 version detection."""
|
||||
addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip6/::1/udp/5000/quic",
|
||||
]
|
||||
|
||||
for addr_str in addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
assert version == "quic", f"Should detect quic for {addr_str}"
|
||||
|
||||
def test_non_quic_raises_error(self):
|
||||
"""Test non-QUIC multiaddrs raise error."""
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
multiaddr_to_quic_version(maddr)
|
||||
|
||||
|
||||
class TestCreateQuicMultiaddr:
|
||||
"""Test QUIC multiaddr creation."""
|
||||
|
||||
def test_ipv4_creation(self):
|
||||
"""Test IPv4 QUIC multiaddr creation."""
|
||||
test_cases = [
|
||||
("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
|
||||
("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
|
||||
("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
|
||||
]
|
||||
|
||||
for host, port, version, expected in test_cases:
|
||||
result = create_quic_multiaddr(host, port, version)
|
||||
assert str(result) == expected
|
||||
|
||||
def test_ipv6_creation(self):
|
||||
"""Test IPv6 QUIC multiaddr creation."""
|
||||
test_cases = [
|
||||
("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
|
||||
("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
|
||||
]
|
||||
|
||||
for host, port, version, expected in test_cases:
|
||||
result = create_quic_multiaddr(host, port, version)
|
||||
assert str(result) == expected
|
||||
|
||||
def test_default_version(self):
|
||||
"""Test default version is quic-v1."""
|
||||
result = create_quic_multiaddr("127.0.0.1", 4001)
|
||||
expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
assert str(result) == expected
|
||||
|
||||
def test_invalid_inputs_raise_errors(self):
|
||||
"""Test invalid inputs raise appropriate errors."""
|
||||
# Invalid IP
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("invalid-ip", 4001)
|
||||
|
||||
# Invalid port
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", 70000)
|
||||
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", -1)
|
||||
|
||||
# Invalid version
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
|
||||
|
||||
|
||||
class TestQuicVersionToWireFormat:
|
||||
"""Test QUIC version to wire format conversion."""
|
||||
|
||||
def test_supported_versions(self):
|
||||
"""Test supported version conversions."""
|
||||
test_cases = [
|
||||
("quic-v1", 0x00000001), # RFC 9000
|
||||
("quic", 0xFF00001D), # draft-29
|
||||
]
|
||||
|
||||
for version, expected_wire in test_cases:
|
||||
result = quic_version_to_wire_format(TProtocol(version))
|
||||
assert result == expected_wire, f"Failed for version {version}"
|
||||
|
||||
def test_unsupported_version_raises_error(self):
|
||||
"""Test unsupported versions raise error."""
|
||||
with pytest.raises(QUICUnsupportedVersionError):
|
||||
quic_version_to_wire_format(TProtocol("unsupported-version"))
|
||||
|
||||
|
||||
class TestGetAlpnProtocols:
|
||||
"""Test ALPN protocol retrieval."""
|
||||
|
||||
def test_returns_libp2p_protocols(self):
|
||||
"""Test returns expected libp2p ALPN protocols."""
|
||||
protocols = get_alpn_protocols()
|
||||
assert protocols == ["libp2p"]
|
||||
assert isinstance(protocols, list)
|
||||
|
||||
def test_returns_copy(self):
|
||||
"""Test returns a copy, not the original list."""
|
||||
protocols1 = get_alpn_protocols()
|
||||
protocols2 = get_alpn_protocols()
|
||||
|
||||
# Modify one list
|
||||
protocols1.append("test")
|
||||
|
||||
# Other list should be unchanged
|
||||
assert protocols2 == ["libp2p"]
|
||||
|
||||
|
||||
class TestNormalizeQuicMultiaddr:
|
||||
"""Test QUIC multiaddr normalization."""
|
||||
|
||||
def test_already_normalized(self):
|
||||
"""Test already normalized multiaddrs pass through."""
|
||||
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
maddr = Multiaddr(addr_str)
|
||||
|
||||
# result = normalize_quic_multiaddr(maddr)
|
||||
# assert str(result) == addr_str
|
||||
|
||||
# def test_normalize_different_versions(self):
|
||||
# """Test normalization works for different QUIC versions."""
|
||||
# test_cases = [
|
||||
# "/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
# "/ip4/127.0.0.1/udp/4001/quic",
|
||||
# "/ip6/::1/udp/5000/quic-v1",
|
||||
# ]
|
||||
|
||||
# for addr_str in test_cases:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# result = normalize_quic_multiaddr(maddr)
|
||||
|
||||
# # Should be valid QUIC multiaddr
|
||||
# assert is_quic_multiaddr(result)
|
||||
|
||||
# # Should be parseable
|
||||
# host, port = quic_multiaddr_to_endpoint(result)
|
||||
# version = multiaddr_to_quic_version(result)
|
||||
result = normalize_quic_multiaddr(maddr)
|
||||
assert str(result) == addr_str
|
||||
|
||||
def test_normalize_different_versions(self):
|
||||
"""Test normalization works for different QUIC versions."""
|
||||
test_cases = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip6/::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = normalize_quic_multiaddr(maddr)
|
||||
|
||||
# Should be valid QUIC multiaddr
|
||||
assert is_quic_multiaddr(result)
|
||||
|
||||
# Should be parseable
|
||||
host, port = quic_multiaddr_to_endpoint(result)
|
||||
version = multiaddr_to_quic_version(result)
|
||||
|
||||
# # Should match original
|
||||
# orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
|
||||
# orig_version = multiaddr_to_quic_version(maddr)
|
||||
# Should match original
|
||||
orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
|
||||
orig_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
# assert host == orig_host
|
||||
# assert port == orig_port
|
||||
# assert version == orig_version
|
||||
assert host == orig_host
|
||||
assert port == orig_port
|
||||
assert version == orig_version
|
||||
|
||||
# def test_non_quic_raises_error(self):
|
||||
# """Test non-QUIC multiaddrs raise error."""
|
||||
# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
# with pytest.raises(QUICInvalidMultiaddrError):
|
||||
# normalize_quic_multiaddr(maddr)
|
||||
def test_non_quic_raises_error(self):
|
||||
"""Test non-QUIC multiaddrs raise error."""
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
normalize_quic_multiaddr(maddr)
|
||||
|
||||
|
||||
# class TestIntegration:
|
||||
# """Integration tests for utility functions working together."""
|
||||
class TestIntegration:
|
||||
"""Integration tests for utility functions working together."""
|
||||
|
||||
# def test_round_trip_conversion(self):
|
||||
# """Test creating and parsing multiaddrs works correctly."""
|
||||
# test_cases = [
|
||||
# ("127.0.0.1", 4001, "quic-v1"),
|
||||
# ("::1", 5000, "quic"),
|
||||
# ("192.168.1.100", 8080, "quic-v1"),
|
||||
# ]
|
||||
def test_round_trip_conversion(self):
|
||||
"""Test creating and parsing multiaddrs works correctly."""
|
||||
test_cases = [
|
||||
("127.0.0.1", 4001, "quic-v1"),
|
||||
("::1", 5000, "quic"),
|
||||
("192.168.1.100", 8080, "quic-v1"),
|
||||
]
|
||||
|
||||
# for host, port, version in test_cases:
|
||||
# # Create multiaddr
|
||||
# maddr = create_quic_multiaddr(host, port, version)
|
||||
for host, port, version in test_cases:
|
||||
# Create multiaddr
|
||||
maddr = create_quic_multiaddr(host, port, version)
|
||||
|
||||
# # Should be detected as QUIC
|
||||
# assert is_quic_multiaddr(maddr)
|
||||
|
||||
# # Should extract original values
|
||||
# extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
|
||||
# extracted_version = multiaddr_to_quic_version(maddr)
|
||||
# Should be detected as QUIC
|
||||
assert is_quic_multiaddr(maddr)
|
||||
|
||||
# Should extract original values
|
||||
extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
|
||||
extracted_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
# assert extracted_host == host
|
||||
# assert extracted_port == port
|
||||
# assert extracted_version == version
|
||||
assert extracted_host == host
|
||||
assert extracted_port == port
|
||||
assert extracted_version == version
|
||||
|
||||
# # Should normalize to same value
|
||||
# normalized = normalize_quic_multiaddr(maddr)
|
||||
# assert str(normalized) == str(maddr)
|
||||
# Should normalize to same value
|
||||
normalized = normalize_quic_multiaddr(maddr)
|
||||
assert str(normalized) == str(maddr)
|
||||
|
||||
# def test_wire_format_integration(self):
|
||||
# """Test wire format conversion works with version detection."""
|
||||
# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
# maddr = Multiaddr(addr_str)
|
||||
def test_wire_format_integration(self):
|
||||
"""Test wire format conversion works with version detection."""
|
||||
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
maddr = Multiaddr(addr_str)
|
||||
|
||||
# # Extract version and convert to wire format
|
||||
# version = multiaddr_to_quic_version(maddr)
|
||||
# wire_format = quic_version_to_wire_format(version)
|
||||
# Extract version and convert to wire format
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
wire_format = quic_version_to_wire_format(version)
|
||||
|
||||
# # Should be QUIC v1 wire format
|
||||
# assert wire_format == 0x00000001
|
||||
# Should be QUIC v1 wire format
|
||||
assert wire_format == 0x00000001
|
||||
|
||||
Reference in New Issue
Block a user