diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py index 1610bde9..428acd83 100644 --- a/libp2p/transport/quic/connection.py +++ b/libp2p/transport/quic/connection.py @@ -3,14 +3,16 @@ QUIC Connection implementation. Manages bidirectional QUIC connections with integrated stream multiplexing. """ +from collections import defaultdict from collections.abc import Awaitable, Callable import logging import socket import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, cast from aioquic.quic import events from aioquic.quic.connection import QuicConnection +from aioquic.quic.events import QuicEvent from cryptography import x509 import multiaddr import trio @@ -104,12 +106,13 @@ class QUICConnection(IRawConnection, IMuxedConn): self._connected_event = trio.Event() self._closed_event = trio.Event() - # Stream management self._streams: dict[int, QUICStream] = {} + self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups self._next_stream_id: int = self._calculate_initial_stream_id() self._stream_handler: TQUICStreamHandlerFn | None = None - self._stream_id_lock = trio.Lock() - self._stream_count_lock = trio.Lock() + + # Single lock for all stream operations + self._stream_lock = trio.Lock() # Stream counting and limits self._outbound_stream_count = 0 @@ -118,7 +121,6 @@ class QUICConnection(IRawConnection, IMuxedConn): # Stream acceptance for incoming streams self._stream_accept_queue: list[QUICStream] = [] self._stream_accept_event = trio.Event() - self._accept_queue_lock = trio.Lock() # Connection state self._closed: bool = False @@ -143,9 +145,11 @@ class QUICConnection(IRawConnection, IMuxedConn): self._retired_connection_ids: set[bytes] = set() self._connection_id_sequence_numbers: set[int] = set() - # Event processing control + # Event processing control with batching self._event_processing_active = False - self._pending_events: list[events.QuicEvent] = [] + self._event_batch: list[events.QuicEvent] = [] + self._event_batch_size = 10 + self._last_event_time = 0.0 # Set quic connection configuration self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT @@ -250,6 +254,21 @@ class QUICConnection(IRawConnection, IMuxedConn): """Get the current connection ID.""" return self._current_connection_id + # Fast stream lookup with caching + def _get_stream_fast(self, stream_id: int) -> QUICStream | None: + """Get stream with caching for performance.""" + # Try cache first + stream = self._stream_cache.get(stream_id) + if stream is not None: + return stream + + # Fallback to main dict + stream = self._streams.get(stream_id) + if stream is not None: + self._stream_cache[stream_id] = stream + + return stream + # Connection lifecycle methods async def start(self) -> None: @@ -389,8 +408,8 @@ class QUICConnection(IRawConnection, IMuxedConn): try: while not self._closed: - # Process QUIC events - await self._process_quic_events() + # Batch process events + await self._process_quic_events_batched() # Handle timer events await self._handle_timer_events() @@ -421,12 +440,25 @@ class QUICConnection(IRawConnection, IMuxedConn): cid_stats = self.get_connection_id_stats() logger.debug(f"Connection ID stats: {cid_stats}") + # Clean cache periodically + await self._cleanup_cache() + # Sleep for maintenance interval await trio.sleep(30.0) # 30 seconds except Exception as e: logger.error(f"Error in periodic maintenance: {e}") + async def _cleanup_cache(self) -> None: + """Clean up stream cache periodically to prevent memory leaks.""" + if len(self._stream_cache) > 100: # Arbitrary threshold + # Remove closed streams from cache + closed_stream_ids = [ + sid for sid, stream in self._stream_cache.items() if stream.is_closed() + ] + for sid in closed_stream_ids: + self._stream_cache.pop(sid, None) + async def _client_packet_receiver(self) -> None: """Receive packets for client connections.""" logger.debug("Starting client packet receiver") @@ -442,8 +474,8 @@ class QUICConnection(IRawConnection, IMuxedConn): # Feed packet to QUIC connection self._quic.receive_datagram(data, addr, now=time.time()) - # Process any events that result from the packet - await self._process_quic_events() + # Batch process events + await self._process_quic_events_batched() # Send any response packets await self._transmit() @@ -675,15 +707,16 @@ class QUICConnection(IRawConnection, IMuxedConn): if not self._started: raise QUICConnectionError("Connection not started") - # Check stream limits - async with self._stream_count_lock: - if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS: - raise QUICStreamLimitError( - f"Maximum outbound streams ({self.MAX_OUTGOING_STREAMS}) reached" - ) - + # Use single lock for all stream operations with trio.move_on_after(timeout): - async with self._stream_id_lock: + async with self._stream_lock: + # Check stream limits inside lock + if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS: + raise QUICStreamLimitError( + "Maximum outbound streams " + f"({self.MAX_OUTGOING_STREAMS}) reached" + ) + # Generate next stream ID stream_id = self._next_stream_id self._next_stream_id += 4 # Increment by 4 for bidirectional streams @@ -697,10 +730,10 @@ class QUICConnection(IRawConnection, IMuxedConn): ) self._streams[stream_id] = stream + self._stream_cache[stream_id] = stream # Add to cache - async with self._stream_count_lock: - self._outbound_stream_count += 1 - self._stats["streams_opened"] += 1 + self._outbound_stream_count += 1 + self._stats["streams_opened"] += 1 logger.debug(f"Opened outbound QUIC stream {stream_id}") return stream @@ -737,7 +770,8 @@ class QUICConnection(IRawConnection, IMuxedConn): if self._closed: raise MuxedConnUnavailable("QUIC connection is closed") - async with self._accept_queue_lock: + # Use single lock for stream acceptance + async with self._stream_lock: if self._stream_accept_queue: stream = self._stream_accept_queue.pop(0) logger.debug(f"Accepted inbound stream {stream.stream_id}") @@ -769,10 +803,12 @@ class QUICConnection(IRawConnection, IMuxedConn): """ if stream_id in self._streams: stream = self._streams.pop(stream_id) + # Remove from cache too + self._stream_cache.pop(stream_id, None) # Update stream counts asynchronously async def update_counts() -> None: - async with self._stream_count_lock: + async with self._stream_lock: if stream.direction == StreamDirection.OUTBOUND: self._outbound_stream_count = max( 0, self._outbound_stream_count - 1 @@ -789,29 +825,140 @@ class QUICConnection(IRawConnection, IMuxedConn): logger.debug(f"Removed stream {stream_id} from connection") - async def _process_quic_events(self) -> None: - """Process all pending QUIC events.""" + # Batched event processing to reduce overhead + async def _process_quic_events_batched(self) -> None: + """Process QUIC events in batches for better performance.""" if self._event_processing_active: return # Prevent recursion self._event_processing_active = True try: + current_time = time.time() events_processed = 0 - while True: + + # Collect events into batch + while events_processed < self._event_batch_size: event = self._quic.next_event() if event is None: break + self._event_batch.append(event) events_processed += 1 - await self._handle_quic_event(event) - if events_processed > 0: - logger.debug(f"Processed {events_processed} QUIC events") + # Process batch if we have events or timeout + if self._event_batch and ( + len(self._event_batch) >= self._event_batch_size + or current_time - self._last_event_time > 0.01 # 10ms timeout + ): + await self._process_event_batch() + self._event_batch.clear() + self._last_event_time = current_time finally: self._event_processing_active = False + async def _process_event_batch(self) -> None: + """Process a batch of events efficiently.""" + if not self._event_batch: + return + + # Group events by type for batch processing where possible + events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list) + for event in self._event_batch: + events_by_type[type(event).__name__].append(event) + + # Process events by type + for event_type, event_list in events_by_type.items(): + if event_type == type(events.StreamDataReceived).__name__: + await self._handle_stream_data_batch( + cast(list[events.StreamDataReceived], event_list) + ) + else: + # Process other events individually + for event in event_list: + await self._handle_quic_event(event) + + logger.debug(f"Processed batch of {len(self._event_batch)} events") + + async def _handle_stream_data_batch( + self, events_list: list[events.StreamDataReceived] + ) -> None: + """Handle stream data events in batch for better performance.""" + # Group by stream ID + events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list) + for event in events_list: + events_by_stream[event.stream_id].append(event) + + # Process each stream's events + for stream_id, stream_events in events_by_stream.items(): + stream = self._get_stream_fast(stream_id) # Use fast lookup + + if not stream: + if self._is_incoming_stream(stream_id): + try: + stream = await self._create_inbound_stream(stream_id) + except QUICStreamLimitError: + # Reset stream if we can't handle it + self._quic.reset_stream(stream_id, error_code=0x04) + await self._transmit() + continue + else: + logger.error( + f"Unexpected outbound stream {stream_id} in data event" + ) + continue + + # Process all events for this stream + for received_event in stream_events: + if hasattr(received_event, "data"): + self._stats["bytes_received"] += len(received_event.data) # type: ignore + + if hasattr(received_event, "end_stream"): + await stream.handle_data_received( + received_event.data, # type: ignore + received_event.end_stream, # type: ignore + ) + + async def _create_inbound_stream(self, stream_id: int) -> QUICStream: + """Create inbound stream with proper limit checking.""" + async with self._stream_lock: + # Double-check stream doesn't exist + existing_stream = self._streams.get(stream_id) + if existing_stream: + return existing_stream + + # Check limits + if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS: + logger.warning(f"Rejecting inbound stream {stream_id}: limit reached") + raise QUICStreamLimitError("Too many inbound streams") + + # Create stream + stream = QUICStream( + connection=self, + stream_id=stream_id, + direction=StreamDirection.INBOUND, + resource_scope=self._resource_scope, + remote_addr=self._remote_addr, + ) + + self._streams[stream_id] = stream + self._stream_cache[stream_id] = stream # Add to cache + self._inbound_stream_count += 1 + self._stats["streams_accepted"] += 1 + + # Add to accept queue + self._stream_accept_queue.append(stream) + self._stream_accept_event.set() + + logger.debug(f"Created inbound stream {stream_id}") + return stream + + async def _process_quic_events(self) -> None: + """Process all pending QUIC events.""" + # Delegate to batched processing for better performance + await self._process_quic_events_batched() + async def _handle_quic_event(self, event: events.QuicEvent) -> None: """Handle a single QUIC event with COMPLETE event type coverage.""" logger.debug(f"Handling QUIC event: {type(event).__name__}") @@ -929,8 +1076,9 @@ class QUICConnection(IRawConnection, IMuxedConn): f"stream_id={event.stream_id}, error_code={event.error_code}" ) - if event.stream_id in self._streams: - stream: QUICStream = self._streams[event.stream_id] + # Use fast lookup + stream = self._get_stream_fast(event.stream_id) + if stream: # Handle stop sending on the stream if method exists await stream.handle_stop_sending(event.error_code) @@ -964,6 +1112,7 @@ class QUICConnection(IRawConnection, IMuxedConn): await stream.close() self._streams.clear() + self._stream_cache.clear() # Clear cache too self._closed = True self._closed_event.set() @@ -978,39 +1127,19 @@ class QUICConnection(IRawConnection, IMuxedConn): self._stats["bytes_received"] += len(event.data) try: - if stream_id not in self._streams: + # Use fast lookup + stream = self._get_stream_fast(stream_id) + + if not stream: if self._is_incoming_stream(stream_id): logger.debug(f"Creating new incoming stream {stream_id}") - - from .stream import QUICStream, StreamDirection - - stream = QUICStream( - connection=self, - stream_id=stream_id, - direction=StreamDirection.INBOUND, - resource_scope=self._resource_scope, - remote_addr=self._remote_addr, - ) - - # Store the stream - self._streams[stream_id] = stream - - async with self._accept_queue_lock: - self._stream_accept_queue.append(stream) - self._stream_accept_event.set() - logger.debug(f"Added stream {stream_id} to accept queue") - - async with self._stream_count_lock: - self._inbound_stream_count += 1 - self._stats["streams_opened"] += 1 - + stream = await self._create_inbound_stream(stream_id) else: logger.error( f"Unexpected outbound stream {stream_id} in data event" ) return - stream = self._streams[stream_id] await stream.handle_data_received(event.data, event.end_stream) except Exception as e: @@ -1019,8 +1148,10 @@ class QUICConnection(IRawConnection, IMuxedConn): async def _get_or_create_stream(self, stream_id: int) -> QUICStream: """Get existing stream or create new inbound stream.""" - if stream_id in self._streams: - return self._streams[stream_id] + # Use fast lookup + stream = self._get_stream_fast(stream_id) + if stream: + return stream # Check if this is an incoming stream is_incoming = self._is_incoming_stream(stream_id) @@ -1031,49 +1162,8 @@ class QUICConnection(IRawConnection, IMuxedConn): f"Received data for unknown outbound stream {stream_id}" ) - # Check stream limits for incoming streams - async with self._stream_count_lock: - if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS: - logger.warning(f"Rejecting incoming stream {stream_id}: limit reached") - # Send reset to reject the stream - self._quic.reset_stream( - stream_id, error_code=0x04 - ) # STREAM_LIMIT_ERROR - await self._transmit() - raise QUICStreamLimitError("Too many inbound streams") - # Create new inbound stream - stream = QUICStream( - connection=self, - stream_id=stream_id, - direction=StreamDirection.INBOUND, - resource_scope=self._resource_scope, - remote_addr=self._remote_addr, - ) - - self._streams[stream_id] = stream - - async with self._stream_count_lock: - self._inbound_stream_count += 1 - self._stats["streams_accepted"] += 1 - - # Add to accept queue and notify handler - async with self._accept_queue_lock: - self._stream_accept_queue.append(stream) - self._stream_accept_event.set() - - # Handle directly with stream handler if available - if self._stream_handler: - try: - if self._nursery: - self._nursery.start_soon(self._stream_handler, stream) - else: - await self._stream_handler(stream) - except Exception as e: - logger.error(f"Error in stream handler for stream {stream_id}: {e}") - - logger.debug(f"Created inbound stream {stream_id}") - return stream + return await self._create_inbound_stream(stream_id) def _is_incoming_stream(self, stream_id: int) -> bool: """ @@ -1095,9 +1185,10 @@ class QUICConnection(IRawConnection, IMuxedConn): stream_id = event.stream_id self._stats["streams_reset"] += 1 - if stream_id in self._streams: + # Use fast lookup + stream = self._get_stream_fast(stream_id) + if stream: try: - stream = self._streams[stream_id] await stream.handle_reset(event.error_code) logger.debug( f"Handled reset for stream {stream_id}" @@ -1137,12 +1228,20 @@ class QUICConnection(IRawConnection, IMuxedConn): try: current_time = time.time() datagrams = self._quic.datagrams_to_send(now=current_time) + + # Batch stats updates + packet_count = 0 + total_bytes = 0 + for data, addr in datagrams: await sock.sendto(data, addr) - # Update stats if available - if hasattr(self, "_stats"): - self._stats["packets_sent"] += 1 - self._stats["bytes_sent"] += len(data) + packet_count += 1 + total_bytes += len(data) + + # Update stats in batch + if packet_count > 0: + self._stats["packets_sent"] += packet_count + self._stats["bytes_sent"] += total_bytes except Exception as e: logger.error(f"Transmission error: {e}") @@ -1217,6 +1316,7 @@ class QUICConnection(IRawConnection, IMuxedConn): self._socket = None self._streams.clear() + self._stream_cache.clear() # Clear cache self._closed_event.set() logger.debug(f"QUIC connection to {self._remote_peer_id} closed") @@ -1328,6 +1428,9 @@ class QUICConnection(IRawConnection, IMuxedConn): "max_streams": self.MAX_CONCURRENT_STREAMS, "stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS, "stats": self._stats.copy(), + "cache_size": len( + self._stream_cache + ), # Include cache metrics for monitoring } def get_active_streams(self) -> list[QUICStream]: diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py index fd7cc0f1..0e8e66ad 100644 --- a/libp2p/transport/quic/listener.py +++ b/libp2p/transport/quic/listener.py @@ -267,56 +267,37 @@ class QUICListener(IListener): return value, 8 async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: - """Process incoming QUIC packet with fine-grained locking.""" + """Process incoming QUIC packet with optimized routing.""" try: self._stats["packets_processed"] += 1 self._stats["bytes_received"] += len(data) - logger.debug(f"Processing packet of {len(data)} bytes from {addr}") - - # Parse packet header OUTSIDE the lock packet_info = self.parse_quic_packet(data) if packet_info is None: - logger.error(f"Failed to parse packet header quic packet from {addr}") self._stats["invalid_packets"] += 1 return dest_cid = packet_info.destination_cid - connection_obj = None - pending_quic_conn = None + # Single lock acquisition with all lookups async with self._connection_lock: - if dest_cid in self._connections: - connection_obj = self._connections[dest_cid] - logger.debug(f"Routing to established connection {dest_cid.hex()}") + connection_obj = self._connections.get(dest_cid) + pending_quic_conn = self._pending_connections.get(dest_cid) - elif dest_cid in self._pending_connections: - pending_quic_conn = self._pending_connections[dest_cid] - logger.debug(f"Routing to pending connection {dest_cid.hex()}") - - else: - # Check if this is a new connection - if packet_info.packet_type.name == "INITIAL": - logger.debug( - f"Received INITIAL Packet Creating new conn for {addr}" - ) - - # Create new connection INSIDE the lock for safety + if not connection_obj and not pending_quic_conn: + if packet_info.packet_type == QuicPacketType.INITIAL: pending_quic_conn = await self._handle_new_connection( data, addr, packet_info ) else: return - # CRITICAL: Process packets OUTSIDE the lock to prevent deadlock + # Process outside the lock if connection_obj: - # Handle established connection await self._handle_established_connection_packet( connection_obj, data, addr, dest_cid ) - elif pending_quic_conn: - # Handle pending connection await self._handle_pending_connection_packet( pending_quic_conn, data, addr, dest_cid ) @@ -431,6 +412,7 @@ class QUICListener(IListener): f"No configuration found for version 0x{packet_info.version:08x}" ) await self._send_version_negotiation(addr, packet_info.source_cid) + return None if not quic_config: raise QUICListenError("Cannot determine QUIC configuration") diff --git a/libp2p/transport/quic/utils.py b/libp2p/transport/quic/utils.py index f57f92a7..37b7880b 100644 --- a/libp2p/transport/quic/utils.py +++ b/libp2p/transport/quic/utils.py @@ -108,21 +108,21 @@ def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]: # Try to get IPv4 address try: host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore - except ValueError: + except Exception: pass # Try to get IPv6 address if IPv4 not found if host is None: try: host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore - except ValueError: + except Exception: pass # Get UDP port try: port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore port = int(port_str) - except ValueError: + except Exception: pass if host is None or port is None: @@ -203,8 +203,7 @@ def create_quic_multiaddr( if version == "quic-v1" or version == "/quic-v1": quic_proto = QUIC_V1_PROTOCOL elif version == "quic" or version == "/quic": - # This is DRAFT Protocol - quic_proto = QUIC_V1_PROTOCOL + quic_proto = QUIC_DRAFT29_PROTOCOL else: raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}") diff --git a/tests/core/transport/quic/test_connection.py b/tests/core/transport/quic/test_connection.py index 40bfc96f..9b3ad3a9 100644 --- a/tests/core/transport/quic/test_connection.py +++ b/tests/core/transport/quic/test_connection.py @@ -192,7 +192,7 @@ class TestQUICConnection: await trio.sleep(10) # Longer than timeout with patch.object( - quic_connection._stream_id_lock, "acquire", side_effect=slow_acquire + quic_connection._stream_lock, "acquire", side_effect=slow_acquire ): with pytest.raises( QUICStreamTimeoutError, match="Stream creation timed out" diff --git a/tests/core/transport/quic/test_utils.py b/tests/core/transport/quic/test_utils.py index acc96ade..900c5c7e 100644 --- a/tests/core/transport/quic/test_utils.py +++ b/tests/core/transport/quic/test_utils.py @@ -3,333 +3,319 @@ Test suite for QUIC multiaddr utilities. Focused tests covering essential functionality required for QUIC transport. """ -# TODO: Enable this test after multiaddr repo supports protocol quic-v1 - -# import pytest -# from multiaddr import Multiaddr - -# from libp2p.custom_types import TProtocol -# from libp2p.transport.quic.exceptions import ( -# QUICInvalidMultiaddrError, -# QUICUnsupportedVersionError, -# ) -# from libp2p.transport.quic.utils import ( -# create_quic_multiaddr, -# get_alpn_protocols, -# is_quic_multiaddr, -# multiaddr_to_quic_version, -# normalize_quic_multiaddr, -# quic_multiaddr_to_endpoint, -# quic_version_to_wire_format, -# ) - - -# class TestIsQuicMultiaddr: -# """Test QUIC multiaddr detection.""" - -# def test_valid_quic_v1_multiaddrs(self): -# """Test valid QUIC v1 multiaddrs are detected.""" -# valid_addrs = [ -# "/ip4/127.0.0.1/udp/4001/quic-v1", -# "/ip4/192.168.1.1/udp/8080/quic-v1", -# "/ip6/::1/udp/4001/quic-v1", -# "/ip6/2001:db8::1/udp/5000/quic-v1", -# ] - -# for addr_str in valid_addrs: -# maddr = Multiaddr(addr_str) -# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" - -# def test_valid_quic_draft29_multiaddrs(self): -# """Test valid QUIC draft-29 multiaddrs are detected.""" -# valid_addrs = [ -# "/ip4/127.0.0.1/udp/4001/quic", -# "/ip4/10.0.0.1/udp/9000/quic", -# "/ip6/::1/udp/4001/quic", -# "/ip6/fe80::1/udp/6000/quic", -# ] - -# for addr_str in valid_addrs: -# maddr = Multiaddr(addr_str) -# assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" - -# def test_invalid_multiaddrs(self): -# """Test non-QUIC multiaddrs are not detected.""" -# invalid_addrs = [ -# "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC -# "/ip4/127.0.0.1/udp/4001", # UDP without QUIC -# "/ip4/127.0.0.1/udp/4001/ws", # WebSocket -# "/ip4/127.0.0.1/quic-v1", # Missing UDP -# "/udp/4001/quic-v1", # Missing IP -# "/dns4/example.com/tcp/443/tls", # Completely different -# ] - -# for addr_str in invalid_addrs: -# maddr = Multiaddr(addr_str) -# assert not is_quic_multiaddr(maddr), -# f"Should not detect {addr_str} as QUIC" - -# def test_malformed_multiaddrs(self): -# """Test malformed multiaddrs don't crash.""" -# # These should not raise exceptions, just return False -# malformed = [ -# Multiaddr("/ip4/127.0.0.1"), -# Multiaddr("/invalid"), -# ] - -# for maddr in malformed: -# assert not is_quic_multiaddr(maddr) - - -# class TestQuicMultiaddrToEndpoint: -# """Test endpoint extraction from QUIC multiaddrs.""" - -# def test_ipv4_extraction(self): -# """Test IPv4 host/port extraction.""" -# test_cases = [ -# ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)), -# ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)), -# ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)), -# ] - -# for addr_str, expected in test_cases: -# maddr = Multiaddr(addr_str) -# result = quic_multiaddr_to_endpoint(maddr) -# assert result == expected, f"Failed for {addr_str}" - -# def test_ipv6_extraction(self): -# """Test IPv6 host/port extraction.""" -# test_cases = [ -# ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)), -# ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)), -# ] - -# for addr_str, expected in test_cases: -# maddr = Multiaddr(addr_str) -# result = quic_multiaddr_to_endpoint(maddr) -# assert result == expected, f"Failed for {addr_str}" - -# def test_invalid_multiaddr_raises_error(self): -# """Test invalid multiaddrs raise appropriate errors.""" -# invalid_addrs = [ -# "/ip4/127.0.0.1/tcp/4001", # Not QUIC -# "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol -# ] - -# for addr_str in invalid_addrs: -# maddr = Multiaddr(addr_str) -# with pytest.raises(QUICInvalidMultiaddrError): -# quic_multiaddr_to_endpoint(maddr) - - -# class TestMultiaddrToQuicVersion: -# """Test QUIC version extraction.""" - -# def test_quic_v1_detection(self): -# """Test QUIC v1 version detection.""" -# addrs = [ -# "/ip4/127.0.0.1/udp/4001/quic-v1", -# "/ip6/::1/udp/5000/quic-v1", -# ] - -# for addr_str in addrs: -# maddr = Multiaddr(addr_str) -# version = multiaddr_to_quic_version(maddr) -# assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}" - -# def test_quic_draft29_detection(self): -# """Test QUIC draft-29 version detection.""" -# addrs = [ -# "/ip4/127.0.0.1/udp/4001/quic", -# "/ip6/::1/udp/5000/quic", -# ] - -# for addr_str in addrs: -# maddr = Multiaddr(addr_str) -# version = multiaddr_to_quic_version(maddr) -# assert version == "quic", f"Should detect quic for {addr_str}" - -# def test_non_quic_raises_error(self): -# """Test non-QUIC multiaddrs raise error.""" -# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") -# with pytest.raises(QUICInvalidMultiaddrError): -# multiaddr_to_quic_version(maddr) - - -# class TestCreateQuicMultiaddr: -# """Test QUIC multiaddr creation.""" - -# def test_ipv4_creation(self): -# """Test IPv4 QUIC multiaddr creation.""" -# test_cases = [ -# ("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"), -# ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"), -# ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"), -# ] - -# for host, port, version, expected in test_cases: -# result = create_quic_multiaddr(host, port, version) -# assert str(result) == expected - -# def test_ipv6_creation(self): -# """Test IPv6 QUIC multiaddr creation.""" -# test_cases = [ -# ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"), -# ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"), -# ] - -# for host, port, version, expected in test_cases: -# result = create_quic_multiaddr(host, port, version) -# assert str(result) == expected - -# def test_default_version(self): -# """Test default version is quic-v1.""" -# result = create_quic_multiaddr("127.0.0.1", 4001) -# expected = "/ip4/127.0.0.1/udp/4001/quic-v1" -# assert str(result) == expected - -# def test_invalid_inputs_raise_errors(self): -# """Test invalid inputs raise appropriate errors.""" -# # Invalid IP -# with pytest.raises(QUICInvalidMultiaddrError): -# create_quic_multiaddr("invalid-ip", 4001) - -# # Invalid port -# with pytest.raises(QUICInvalidMultiaddrError): -# create_quic_multiaddr("127.0.0.1", 70000) - -# with pytest.raises(QUICInvalidMultiaddrError): -# create_quic_multiaddr("127.0.0.1", -1) - -# # Invalid version -# with pytest.raises(QUICInvalidMultiaddrError): -# create_quic_multiaddr("127.0.0.1", 4001, "invalid-version") - - -# class TestQuicVersionToWireFormat: -# """Test QUIC version to wire format conversion.""" - -# def test_supported_versions(self): -# """Test supported version conversions.""" -# test_cases = [ -# ("quic-v1", 0x00000001), # RFC 9000 -# ("quic", 0xFF00001D), # draft-29 -# ] - -# for version, expected_wire in test_cases: -# result = quic_version_to_wire_format(TProtocol(version)) -# assert result == expected_wire, f"Failed for version {version}" - -# def test_unsupported_version_raises_error(self): -# """Test unsupported versions raise error.""" -# with pytest.raises(QUICUnsupportedVersionError): -# quic_version_to_wire_format(TProtocol("unsupported-version")) - - -# class TestGetAlpnProtocols: -# """Test ALPN protocol retrieval.""" - -# def test_returns_libp2p_protocols(self): -# """Test returns expected libp2p ALPN protocols.""" -# protocols = get_alpn_protocols() -# assert protocols == ["libp2p"] -# assert isinstance(protocols, list) - -# def test_returns_copy(self): -# """Test returns a copy, not the original list.""" -# protocols1 = get_alpn_protocols() -# protocols2 = get_alpn_protocols() - -# # Modify one list -# protocols1.append("test") - -# # Other list should be unchanged -# assert protocols2 == ["libp2p"] - - -# class TestNormalizeQuicMultiaddr: -# """Test QUIC multiaddr normalization.""" - -# def test_already_normalized(self): -# """Test already normalized multiaddrs pass through.""" -# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" -# maddr = Multiaddr(addr_str) +import pytest +from multiaddr import Multiaddr + +from libp2p.custom_types import TProtocol +from libp2p.transport.quic.exceptions import ( + QUICInvalidMultiaddrError, + QUICUnsupportedVersionError, +) +from libp2p.transport.quic.utils import ( + create_quic_multiaddr, + get_alpn_protocols, + is_quic_multiaddr, + multiaddr_to_quic_version, + normalize_quic_multiaddr, + quic_multiaddr_to_endpoint, + quic_version_to_wire_format, +) + + +class TestIsQuicMultiaddr: + """Test QUIC multiaddr detection.""" + + def test_valid_quic_v1_multiaddrs(self): + """Test valid QUIC v1 multiaddrs are detected.""" + valid_addrs = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip4/192.168.1.1/udp/8080/quic-v1", + "/ip6/::1/udp/4001/quic-v1", + "/ip6/2001:db8::1/udp/5000/quic-v1", + ] + + for addr_str in valid_addrs: + maddr = Multiaddr(addr_str) + assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" + + def test_valid_quic_draft29_multiaddrs(self): + """Test valid QUIC draft-29 multiaddrs are detected.""" + valid_addrs = [ + "/ip4/127.0.0.1/udp/4001/quic", + "/ip4/10.0.0.1/udp/9000/quic", + "/ip6/::1/udp/4001/quic", + "/ip6/fe80::1/udp/6000/quic", + ] + + for addr_str in valid_addrs: + maddr = Multiaddr(addr_str) + assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" + + def test_invalid_multiaddrs(self): + """Test non-QUIC multiaddrs are not detected.""" + invalid_addrs = [ + "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC + "/ip4/127.0.0.1/udp/4001", # UDP without QUIC + "/ip4/127.0.0.1/udp/4001/ws", # WebSocket + "/ip4/127.0.0.1/quic-v1", # Missing UDP + "/udp/4001/quic-v1", # Missing IP + "/dns4/example.com/tcp/443/tls", # Completely different + ] + + for addr_str in invalid_addrs: + maddr = Multiaddr(addr_str) + assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC" + + +class TestQuicMultiaddrToEndpoint: + """Test endpoint extraction from QUIC multiaddrs.""" + + def test_ipv4_extraction(self): + """Test IPv4 host/port extraction.""" + test_cases = [ + ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)), + ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)), + ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)), + ] + + for addr_str, expected in test_cases: + maddr = Multiaddr(addr_str) + result = quic_multiaddr_to_endpoint(maddr) + assert result == expected, f"Failed for {addr_str}" + + def test_ipv6_extraction(self): + """Test IPv6 host/port extraction.""" + test_cases = [ + ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)), + ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)), + ] + + for addr_str, expected in test_cases: + maddr = Multiaddr(addr_str) + result = quic_multiaddr_to_endpoint(maddr) + assert result == expected, f"Failed for {addr_str}" + + def test_invalid_multiaddr_raises_error(self): + """Test invalid multiaddrs raise appropriate errors.""" + invalid_addrs = [ + "/ip4/127.0.0.1/tcp/4001", # Not QUIC + "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol + ] + + for addr_str in invalid_addrs: + maddr = Multiaddr(addr_str) + with pytest.raises(QUICInvalidMultiaddrError): + quic_multiaddr_to_endpoint(maddr) + + +class TestMultiaddrToQuicVersion: + """Test QUIC version extraction.""" + + def test_quic_v1_detection(self): + """Test QUIC v1 version detection.""" + addrs = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip6/::1/udp/5000/quic-v1", + ] + + for addr_str in addrs: + maddr = Multiaddr(addr_str) + version = multiaddr_to_quic_version(maddr) + assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}" + + def test_quic_draft29_detection(self): + """Test QUIC draft-29 version detection.""" + addrs = [ + "/ip4/127.0.0.1/udp/4001/quic", + "/ip6/::1/udp/5000/quic", + ] + + for addr_str in addrs: + maddr = Multiaddr(addr_str) + version = multiaddr_to_quic_version(maddr) + assert version == "quic", f"Should detect quic for {addr_str}" + + def test_non_quic_raises_error(self): + """Test non-QUIC multiaddrs raise error.""" + maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + with pytest.raises(QUICInvalidMultiaddrError): + multiaddr_to_quic_version(maddr) + + +class TestCreateQuicMultiaddr: + """Test QUIC multiaddr creation.""" + + def test_ipv4_creation(self): + """Test IPv4 QUIC multiaddr creation.""" + test_cases = [ + ("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"), + ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"), + ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"), + ] + + for host, port, version, expected in test_cases: + result = create_quic_multiaddr(host, port, version) + assert str(result) == expected + + def test_ipv6_creation(self): + """Test IPv6 QUIC multiaddr creation.""" + test_cases = [ + ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"), + ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"), + ] + + for host, port, version, expected in test_cases: + result = create_quic_multiaddr(host, port, version) + assert str(result) == expected + + def test_default_version(self): + """Test default version is quic-v1.""" + result = create_quic_multiaddr("127.0.0.1", 4001) + expected = "/ip4/127.0.0.1/udp/4001/quic-v1" + assert str(result) == expected + + def test_invalid_inputs_raise_errors(self): + """Test invalid inputs raise appropriate errors.""" + # Invalid IP + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("invalid-ip", 4001) + + # Invalid port + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", 70000) + + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", -1) + + # Invalid version + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", 4001, "invalid-version") + + +class TestQuicVersionToWireFormat: + """Test QUIC version to wire format conversion.""" + + def test_supported_versions(self): + """Test supported version conversions.""" + test_cases = [ + ("quic-v1", 0x00000001), # RFC 9000 + ("quic", 0xFF00001D), # draft-29 + ] + + for version, expected_wire in test_cases: + result = quic_version_to_wire_format(TProtocol(version)) + assert result == expected_wire, f"Failed for version {version}" + + def test_unsupported_version_raises_error(self): + """Test unsupported versions raise error.""" + with pytest.raises(QUICUnsupportedVersionError): + quic_version_to_wire_format(TProtocol("unsupported-version")) + + +class TestGetAlpnProtocols: + """Test ALPN protocol retrieval.""" + + def test_returns_libp2p_protocols(self): + """Test returns expected libp2p ALPN protocols.""" + protocols = get_alpn_protocols() + assert protocols == ["libp2p"] + assert isinstance(protocols, list) + + def test_returns_copy(self): + """Test returns a copy, not the original list.""" + protocols1 = get_alpn_protocols() + protocols2 = get_alpn_protocols() + + # Modify one list + protocols1.append("test") + + # Other list should be unchanged + assert protocols2 == ["libp2p"] + + +class TestNormalizeQuicMultiaddr: + """Test QUIC multiaddr normalization.""" + + def test_already_normalized(self): + """Test already normalized multiaddrs pass through.""" + addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" + maddr = Multiaddr(addr_str) -# result = normalize_quic_multiaddr(maddr) -# assert str(result) == addr_str - -# def test_normalize_different_versions(self): -# """Test normalization works for different QUIC versions.""" -# test_cases = [ -# "/ip4/127.0.0.1/udp/4001/quic-v1", -# "/ip4/127.0.0.1/udp/4001/quic", -# "/ip6/::1/udp/5000/quic-v1", -# ] - -# for addr_str in test_cases: -# maddr = Multiaddr(addr_str) -# result = normalize_quic_multiaddr(maddr) - -# # Should be valid QUIC multiaddr -# assert is_quic_multiaddr(result) - -# # Should be parseable -# host, port = quic_multiaddr_to_endpoint(result) -# version = multiaddr_to_quic_version(result) + result = normalize_quic_multiaddr(maddr) + assert str(result) == addr_str + + def test_normalize_different_versions(self): + """Test normalization works for different QUIC versions.""" + test_cases = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip4/127.0.0.1/udp/4001/quic", + "/ip6/::1/udp/5000/quic-v1", + ] + + for addr_str in test_cases: + maddr = Multiaddr(addr_str) + result = normalize_quic_multiaddr(maddr) + + # Should be valid QUIC multiaddr + assert is_quic_multiaddr(result) + + # Should be parseable + host, port = quic_multiaddr_to_endpoint(result) + version = multiaddr_to_quic_version(result) -# # Should match original -# orig_host, orig_port = quic_multiaddr_to_endpoint(maddr) -# orig_version = multiaddr_to_quic_version(maddr) + # Should match original + orig_host, orig_port = quic_multiaddr_to_endpoint(maddr) + orig_version = multiaddr_to_quic_version(maddr) -# assert host == orig_host -# assert port == orig_port -# assert version == orig_version + assert host == orig_host + assert port == orig_port + assert version == orig_version -# def test_non_quic_raises_error(self): -# """Test non-QUIC multiaddrs raise error.""" -# maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") -# with pytest.raises(QUICInvalidMultiaddrError): -# normalize_quic_multiaddr(maddr) + def test_non_quic_raises_error(self): + """Test non-QUIC multiaddrs raise error.""" + maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + with pytest.raises(QUICInvalidMultiaddrError): + normalize_quic_multiaddr(maddr) -# class TestIntegration: -# """Integration tests for utility functions working together.""" +class TestIntegration: + """Integration tests for utility functions working together.""" -# def test_round_trip_conversion(self): -# """Test creating and parsing multiaddrs works correctly.""" -# test_cases = [ -# ("127.0.0.1", 4001, "quic-v1"), -# ("::1", 5000, "quic"), -# ("192.168.1.100", 8080, "quic-v1"), -# ] + def test_round_trip_conversion(self): + """Test creating and parsing multiaddrs works correctly.""" + test_cases = [ + ("127.0.0.1", 4001, "quic-v1"), + ("::1", 5000, "quic"), + ("192.168.1.100", 8080, "quic-v1"), + ] -# for host, port, version in test_cases: -# # Create multiaddr -# maddr = create_quic_multiaddr(host, port, version) + for host, port, version in test_cases: + # Create multiaddr + maddr = create_quic_multiaddr(host, port, version) -# # Should be detected as QUIC -# assert is_quic_multiaddr(maddr) - -# # Should extract original values -# extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr) -# extracted_version = multiaddr_to_quic_version(maddr) + # Should be detected as QUIC + assert is_quic_multiaddr(maddr) + + # Should extract original values + extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr) + extracted_version = multiaddr_to_quic_version(maddr) -# assert extracted_host == host -# assert extracted_port == port -# assert extracted_version == version + assert extracted_host == host + assert extracted_port == port + assert extracted_version == version -# # Should normalize to same value -# normalized = normalize_quic_multiaddr(maddr) -# assert str(normalized) == str(maddr) + # Should normalize to same value + normalized = normalize_quic_multiaddr(maddr) + assert str(normalized) == str(maddr) -# def test_wire_format_integration(self): -# """Test wire format conversion works with version detection.""" -# addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" -# maddr = Multiaddr(addr_str) + def test_wire_format_integration(self): + """Test wire format conversion works with version detection.""" + addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" + maddr = Multiaddr(addr_str) -# # Extract version and convert to wire format -# version = multiaddr_to_quic_version(maddr) -# wire_format = quic_version_to_wire_format(version) + # Extract version and convert to wire format + version = multiaddr_to_quic_version(maddr) + wire_format = quic_version_to_wire_format(version) -# # Should be QUIC v1 wire format -# assert wire_format == 0x00000001 + # Should be QUIC v1 wire format + assert wire_format == 0x00000001