mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: client certificate verification done
This commit is contained in:
@ -6,6 +6,7 @@ from libp2p.transport.quic.connection import QUICConnection
|
||||
from typing import cast
|
||||
import logging
|
||||
import sys
|
||||
from typing import cast
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -42,6 +43,7 @@ from libp2p.transport.exceptions import (
|
||||
OpenConnectionError,
|
||||
SecurityUpgradeFailure,
|
||||
)
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
@ -285,7 +287,6 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
# No need to upgrade QUIC Connection
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
print("Connecting QUIC Connection")
|
||||
quic_conn = cast(QUICConnection, raw_conn)
|
||||
await self.add_conn(quic_conn)
|
||||
# NOTE: This is a intentional barrier to prevent from the handler
|
||||
@ -410,7 +411,6 @@ class Swarm(Service, INetworkService):
|
||||
self,
|
||||
)
|
||||
print("add_conn called")
|
||||
|
||||
self.manager.run_task(muxed_conn.start)
|
||||
await muxed_conn.event_started.wait()
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
|
||||
@ -180,7 +180,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"connection_id_changes": 0,
|
||||
}
|
||||
|
||||
print(
|
||||
logger.debug(
|
||||
f"Created QUIC connection to {remote_peer_id} "
|
||||
f"(initiator: {is_initiator}, addr: {remote_addr}, "
|
||||
"security: {security_manager is not None})"
|
||||
@ -279,7 +279,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
self._started = True
|
||||
self.event_started.set()
|
||||
print(f"Starting QUIC connection to {self._remote_peer_id}")
|
||||
logger.debug(f"Starting QUIC connection to {self._remote_peer_id}")
|
||||
|
||||
try:
|
||||
# If this is a client connection, we need to establish the connection
|
||||
@ -290,7 +290,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._established = True
|
||||
self._connected_event.set()
|
||||
|
||||
print(f"QUIC connection to {self._remote_peer_id} started")
|
||||
logger.debug(f"QUIC connection to {self._remote_peer_id} started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start connection: {e}")
|
||||
@ -301,7 +301,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
with QUICErrorContext("connection_initiation", "connection"):
|
||||
if not self._socket:
|
||||
print("Creating new socket for outbound connection")
|
||||
logger.debug("Creating new socket for outbound connection")
|
||||
self._socket = trio.socket.socket(
|
||||
family=socket.AF_INET, type=socket.SOCK_DGRAM
|
||||
)
|
||||
@ -313,7 +313,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Send initial packet(s)
|
||||
await self._transmit()
|
||||
|
||||
print(f"Initiated QUIC connection to {self._remote_addr}")
|
||||
logger.debug(f"Initiated QUIC connection to {self._remote_addr}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate connection: {e}")
|
||||
@ -335,16 +335,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
with QUICErrorContext("connection_establishment", "connection"):
|
||||
# Start the connection if not already started
|
||||
print("STARTING TO CONNECT")
|
||||
logger.debug("STARTING TO CONNECT")
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
# Start background event processing
|
||||
if not self._background_tasks_started:
|
||||
print("STARTING BACKGROUND TASK")
|
||||
logger.debug("STARTING BACKGROUND TASK")
|
||||
await self._start_background_tasks()
|
||||
else:
|
||||
print("BACKGROUND TASK ALREADY STARTED")
|
||||
logger.debug("BACKGROUND TASK ALREADY STARTED")
|
||||
|
||||
# Wait for handshake completion with timeout
|
||||
with trio.move_on_after(
|
||||
@ -358,13 +358,18 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
f"{self.CONNECTION_HANDSHAKE_TIMEOUT}s"
|
||||
)
|
||||
|
||||
print("QUICConnection: Verifying peer identity with security manager")
|
||||
logger.debug(
|
||||
"QUICConnection: Verifying peer identity with security manager"
|
||||
)
|
||||
# Verify peer identity using security manager
|
||||
self.peer_id = await self._verify_peer_identity_with_security()
|
||||
peer_id = await self._verify_peer_identity_with_security()
|
||||
|
||||
print("QUICConnection: Peer identity verified")
|
||||
if peer_id:
|
||||
self.peer_id = peer_id
|
||||
|
||||
logger.debug(f"QUICConnection {id(self)}: Peer identity verified")
|
||||
self._established = True
|
||||
print(f"QUIC connection established with {self._remote_peer_id}")
|
||||
logger.debug(f"QUIC connection established with {self._remote_peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to establish connection: {e}")
|
||||
@ -384,11 +389,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._nursery.start_soon(async_fn=self._event_processing_loop)
|
||||
self._nursery.start_soon(async_fn=self._periodic_maintenance)
|
||||
|
||||
print("Started background tasks for QUIC connection")
|
||||
logger.debug("Started background tasks for QUIC connection")
|
||||
|
||||
async def _event_processing_loop(self) -> None:
|
||||
"""Main event processing loop for the connection."""
|
||||
print(
|
||||
logger.debug(
|
||||
f"Started QUIC event processing loop for connection id: {id(self)} "
|
||||
f"and local peer id {str(self.local_peer_id())}"
|
||||
)
|
||||
@ -411,7 +416,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
logger.error(f"Error in event processing loop: {e}")
|
||||
await self._handle_connection_error(e)
|
||||
finally:
|
||||
print("QUIC event processing loop finished")
|
||||
logger.debug("QUIC event processing loop finished")
|
||||
|
||||
async def _periodic_maintenance(self) -> None:
|
||||
"""Perform periodic connection maintenance."""
|
||||
@ -426,7 +431,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# *** NEW: Log connection ID status periodically ***
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
cid_stats = self.get_connection_id_stats()
|
||||
print(f"Connection ID stats: {cid_stats}")
|
||||
logger.debug(f"Connection ID stats: {cid_stats}")
|
||||
|
||||
# Sleep for maintenance interval
|
||||
await trio.sleep(30.0) # 30 seconds
|
||||
@ -436,15 +441,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
async def _client_packet_receiver(self) -> None:
|
||||
"""Receive packets for client connections."""
|
||||
print("Starting client packet receiver")
|
||||
print("Started QUIC client packet receiver")
|
||||
logger.debug("Starting client packet receiver")
|
||||
logger.debug("Started QUIC client packet receiver")
|
||||
|
||||
try:
|
||||
while not self._closed and self._socket:
|
||||
try:
|
||||
# Receive UDP packets
|
||||
data, addr = await self._socket.recvfrom(65536)
|
||||
print(f"Client received {len(data)} bytes from {addr}")
|
||||
logger.debug(f"Client received {len(data)} bytes from {addr}")
|
||||
|
||||
# Feed packet to QUIC connection
|
||||
self._quic.receive_datagram(data, addr, now=time.time())
|
||||
@ -456,21 +461,21 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await self._transmit()
|
||||
|
||||
except trio.ClosedResourceError:
|
||||
print("Client socket closed")
|
||||
logger.debug("Client socket closed")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving client packet: {e}")
|
||||
await trio.sleep(0.01)
|
||||
|
||||
except trio.Cancelled:
|
||||
print("Client packet receiver cancelled")
|
||||
logger.debug("Client packet receiver cancelled")
|
||||
raise
|
||||
finally:
|
||||
print("Client packet receiver terminated")
|
||||
logger.debug("Client packet receiver terminated")
|
||||
|
||||
# Security and identity methods
|
||||
|
||||
async def _verify_peer_identity_with_security(self) -> ID:
|
||||
async def _verify_peer_identity_with_security(self) -> ID | None:
|
||||
"""
|
||||
Verify peer identity using integrated security manager.
|
||||
|
||||
@ -478,22 +483,22 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
QUICPeerVerificationError: If peer verification fails
|
||||
|
||||
"""
|
||||
print("VERIFYING PEER IDENTITY")
|
||||
logger.debug("VERIFYING PEER IDENTITY")
|
||||
if not self._security_manager:
|
||||
print("No security manager available for peer verification")
|
||||
return
|
||||
logger.debug("No security manager available for peer verification")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract peer certificate from TLS handshake
|
||||
await self._extract_peer_certificate()
|
||||
|
||||
if not self._peer_certificate:
|
||||
print("No peer certificate available for verification")
|
||||
return
|
||||
logger.debug("No peer certificate available for verification")
|
||||
return None
|
||||
|
||||
# Validate certificate format and accessibility
|
||||
if not self._validate_peer_certificate():
|
||||
print("Validation Failed for peer cerificate")
|
||||
logger.debug("Validation Failed for peer cerificate")
|
||||
raise QUICPeerVerificationError("Peer certificate validation failed")
|
||||
|
||||
# Verify peer identity using security manager
|
||||
@ -505,7 +510,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Update peer ID if it wasn't known (inbound connections)
|
||||
if not self._remote_peer_id:
|
||||
self._remote_peer_id = verified_peer_id
|
||||
print(f"Discovered peer ID from certificate: {verified_peer_id}")
|
||||
logger.debug(f"Discovered peer ID from certificate: {verified_peer_id}")
|
||||
elif self._remote_peer_id != verified_peer_id:
|
||||
raise QUICPeerVerificationError(
|
||||
f"Peer ID mismatch: expected {self._remote_peer_id}, "
|
||||
@ -513,7 +518,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
)
|
||||
|
||||
self._peer_verified = True
|
||||
print(f"Peer identity verified successfully: {verified_peer_id}")
|
||||
logger.debug(f"Peer identity verified successfully: {verified_peer_id}")
|
||||
return verified_peer_id
|
||||
|
||||
except QUICPeerVerificationError:
|
||||
@ -534,14 +539,14 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# aioquic stores the peer certificate as cryptography
|
||||
# x509.Certificate
|
||||
self._peer_certificate = tls_context._peer_certificate
|
||||
print(
|
||||
logger.debug(
|
||||
f"Extracted peer certificate: {self._peer_certificate.subject}"
|
||||
)
|
||||
else:
|
||||
print("No peer certificate found in TLS context")
|
||||
logger.debug("No peer certificate found in TLS context")
|
||||
|
||||
else:
|
||||
print("No TLS context available for certificate extraction")
|
||||
logger.debug("No TLS context available for certificate extraction")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract peer certificate: {e}")
|
||||
@ -590,7 +595,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
subject = self._peer_certificate.subject
|
||||
serial_number = self._peer_certificate.serial_number
|
||||
|
||||
print(
|
||||
logger.debug(
|
||||
f"Certificate validation - Subject: {subject}, Serial: {serial_number}"
|
||||
)
|
||||
return True
|
||||
@ -715,7 +720,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._outbound_stream_count += 1
|
||||
self._stats["streams_opened"] += 1
|
||||
|
||||
print(f"Opened outbound QUIC stream {stream_id}")
|
||||
logger.debug(f"Opened outbound QUIC stream {stream_id}")
|
||||
return stream
|
||||
|
||||
raise QUICStreamTimeoutError(f"Stream creation timed out after {timeout}s")
|
||||
@ -777,7 +782,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
"""
|
||||
self._stream_handler = handler_function
|
||||
print("Set stream handler for incoming streams")
|
||||
logger.debug("Set stream handler for incoming streams")
|
||||
|
||||
def _remove_stream(self, stream_id: int) -> None:
|
||||
"""
|
||||
@ -804,7 +809,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(update_counts)
|
||||
|
||||
print(f"Removed stream {stream_id} from connection")
|
||||
logger.debug(f"Removed stream {stream_id} from connection")
|
||||
|
||||
# *** UPDATED: Complete QUIC event handling - FIXES THE ORIGINAL ISSUE ***
|
||||
|
||||
@ -826,15 +831,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
await self._handle_quic_event(event)
|
||||
|
||||
if events_processed > 0:
|
||||
print(f"Processed {events_processed} QUIC events")
|
||||
logger.debug(f"Processed {events_processed} QUIC events")
|
||||
|
||||
finally:
|
||||
self._event_processing_active = False
|
||||
|
||||
async def _handle_quic_event(self, event: events.QuicEvent) -> None:
|
||||
"""Handle a single QUIC event with COMPLETE event type coverage."""
|
||||
print(f"Handling QUIC event: {type(event).__name__}")
|
||||
print(f"QUIC event: {type(event).__name__}")
|
||||
logger.debug(f"Handling QUIC event: {type(event).__name__}")
|
||||
logger.debug(f"QUIC event: {type(event).__name__}")
|
||||
|
||||
try:
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
@ -860,8 +865,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
elif isinstance(event, events.StopSendingReceived):
|
||||
await self._handle_stop_sending_received(event)
|
||||
else:
|
||||
print(f"Unhandled QUIC event type: {type(event).__name__}")
|
||||
print(f"Unhandled QUIC event: {type(event).__name__}")
|
||||
logger.debug(f"Unhandled QUIC event type: {type(event).__name__}")
|
||||
logger.debug(f"Unhandled QUIC event: {type(event).__name__}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling QUIC event {type(event).__name__}: {e}")
|
||||
@ -876,8 +881,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
This is the CRITICAL missing functionality that was causing your issue!
|
||||
"""
|
||||
print(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
|
||||
print(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
|
||||
logger.debug(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
|
||||
logger.debug(f"🆔 NEW CONNECTION ID ISSUED: {event.connection_id.hex()}")
|
||||
|
||||
# Add to available connection IDs
|
||||
self._available_connection_ids.add(event.connection_id)
|
||||
@ -885,14 +890,18 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# If we don't have a current connection ID, use this one
|
||||
if self._current_connection_id is None:
|
||||
self._current_connection_id = event.connection_id
|
||||
print(f"🆔 Set current connection ID to: {event.connection_id.hex()}")
|
||||
print(f"🆔 Set current connection ID to: {event.connection_id.hex()}")
|
||||
logger.debug(
|
||||
f"🆔 Set current connection ID to: {event.connection_id.hex()}"
|
||||
)
|
||||
logger.debug(
|
||||
f"🆔 Set current connection ID to: {event.connection_id.hex()}"
|
||||
)
|
||||
|
||||
# Update statistics
|
||||
self._stats["connection_ids_issued"] += 1
|
||||
|
||||
print(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
print(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}")
|
||||
|
||||
async def _handle_connection_id_retired(
|
||||
self, event: events.ConnectionIdRetired
|
||||
@ -902,8 +911,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
This handles when the peer tells us to stop using a connection ID.
|
||||
"""
|
||||
print(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
|
||||
print(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
|
||||
logger.debug(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
|
||||
logger.debug(f"🗑️ CONNECTION ID RETIRED: {event.connection_id.hex()}")
|
||||
|
||||
# Remove from available IDs and add to retired set
|
||||
self._available_connection_ids.discard(event.connection_id)
|
||||
@ -920,7 +929,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
else:
|
||||
self._current_connection_id = None
|
||||
logger.warning("⚠️ No available connection IDs after retirement!")
|
||||
print("⚠️ No available connection IDs after retirement!")
|
||||
logger.debug("⚠️ No available connection IDs after retirement!")
|
||||
|
||||
# Update statistics
|
||||
self._stats["connection_ids_retired"] += 1
|
||||
@ -929,13 +938,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
async def _handle_ping_acknowledged(self, event: events.PingAcknowledged) -> None:
|
||||
"""Handle ping acknowledgment."""
|
||||
print(f"Ping acknowledged: uid={event.uid}")
|
||||
logger.debug(f"Ping acknowledged: uid={event.uid}")
|
||||
|
||||
async def _handle_protocol_negotiated(
|
||||
self, event: events.ProtocolNegotiated
|
||||
) -> None:
|
||||
"""Handle protocol negotiation completion."""
|
||||
print(f"Protocol negotiated: {event.alpn_protocol}")
|
||||
logger.debug(f"Protocol negotiated: {event.alpn_protocol}")
|
||||
|
||||
async def _handle_stop_sending_received(
|
||||
self, event: events.StopSendingReceived
|
||||
@ -957,7 +966,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self, event: events.HandshakeCompleted
|
||||
) -> None:
|
||||
"""Handle handshake completion with security integration."""
|
||||
print("QUIC handshake completed")
|
||||
logger.debug("QUIC handshake completed")
|
||||
self._handshake_completed = True
|
||||
|
||||
# Store handshake event for security verification
|
||||
@ -966,14 +975,14 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Try to extract certificate information after handshake
|
||||
await self._extract_peer_certificate()
|
||||
|
||||
print("✅ Setting connected event")
|
||||
logger.debug("✅ Setting connected event")
|
||||
self._connected_event.set()
|
||||
|
||||
async def _handle_connection_terminated(
|
||||
self, event: events.ConnectionTerminated
|
||||
) -> None:
|
||||
"""Handle connection termination."""
|
||||
print(f"QUIC connection terminated: {event.reason_phrase}")
|
||||
logger.debug(f"QUIC connection terminated: {event.reason_phrase}")
|
||||
|
||||
# Close all streams
|
||||
for stream in list(self._streams.values()):
|
||||
@ -999,7 +1008,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
if stream_id not in self._streams:
|
||||
if self._is_incoming_stream(stream_id):
|
||||
print(f"Creating new incoming stream {stream_id}")
|
||||
logger.debug(f"Creating new incoming stream {stream_id}")
|
||||
|
||||
from .stream import QUICStream, StreamDirection
|
||||
|
||||
@ -1034,7 +1043,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling stream data for stream {stream_id}: {e}")
|
||||
print(f"❌ STREAM_DATA: Error: {e}")
|
||||
logger.debug(f"❌ STREAM_DATA: Error: {e}")
|
||||
|
||||
async def _get_or_create_stream(self, stream_id: int) -> QUICStream:
|
||||
"""Get existing stream or create new inbound stream."""
|
||||
@ -1091,7 +1100,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream handler for stream {stream_id}: {e}")
|
||||
|
||||
print(f"Created inbound stream {stream_id}")
|
||||
logger.debug(f"Created inbound stream {stream_id}")
|
||||
return stream
|
||||
|
||||
def _is_incoming_stream(self, stream_id: int) -> bool:
|
||||
@ -1118,7 +1127,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
stream = self._streams[stream_id]
|
||||
await stream.handle_reset(event.error_code)
|
||||
print(
|
||||
logger.debug(
|
||||
f"Handled reset for stream {stream_id}"
|
||||
f"with error code {event.error_code}"
|
||||
)
|
||||
@ -1127,13 +1136,13 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Force remove the stream
|
||||
self._remove_stream(stream_id)
|
||||
else:
|
||||
print(f"Received reset for unknown stream {stream_id}")
|
||||
logger.debug(f"Received reset for unknown stream {stream_id}")
|
||||
|
||||
async def _handle_datagram_received(
|
||||
self, event: events.DatagramFrameReceived
|
||||
) -> None:
|
||||
"""Handle datagram frame (if using QUIC datagrams)."""
|
||||
print(f"Datagram frame received: size={len(event.data)}")
|
||||
logger.debug(f"Datagram frame received: size={len(event.data)}")
|
||||
# For now, just log. Could be extended for custom datagram handling
|
||||
|
||||
async def _handle_timer_events(self) -> None:
|
||||
@ -1150,7 +1159,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""Transmit pending QUIC packets using available socket."""
|
||||
sock = self._socket
|
||||
if not sock:
|
||||
print("No socket to transmit")
|
||||
logger.debug("No socket to transmit")
|
||||
return
|
||||
|
||||
try:
|
||||
@ -1196,7 +1205,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
print(f"Closing QUIC connection to {self._remote_peer_id}")
|
||||
logger.debug(f"Closing QUIC connection to {self._remote_peer_id}")
|
||||
|
||||
try:
|
||||
# Close all streams gracefully
|
||||
@ -1238,7 +1247,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._streams.clear()
|
||||
self._closed_event.set()
|
||||
|
||||
print(f"QUIC connection to {self._remote_peer_id} closed")
|
||||
logger.debug(f"QUIC connection to {self._remote_peer_id} closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during connection close: {e}")
|
||||
@ -1253,13 +1262,15 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
try:
|
||||
if self._transport:
|
||||
await self._transport._cleanup_terminated_connection(self)
|
||||
print("Notified transport of connection termination")
|
||||
logger.debug("Notified transport of connection termination")
|
||||
return
|
||||
|
||||
for listener in self._transport._listeners:
|
||||
try:
|
||||
await listener._remove_connection_by_object(self)
|
||||
print("Found and notified listener of connection termination")
|
||||
logger.debug(
|
||||
"Found and notified listener of connection termination"
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
continue
|
||||
@ -1284,10 +1295,10 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
for tracked_cid, tracked_conn in list(listener._connections.items()):
|
||||
if tracked_conn is self:
|
||||
await listener._remove_connection(tracked_cid)
|
||||
print(f"Removed connection {tracked_cid.hex()}")
|
||||
logger.debug(f"Removed connection {tracked_cid.hex()}")
|
||||
return
|
||||
|
||||
print("Fallback cleanup by connection ID completed")
|
||||
logger.debug("Fallback cleanup by connection ID completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback cleanup: {e}")
|
||||
|
||||
@ -1330,9 +1341,6 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
"""
|
||||
# This method doesn't make sense for a muxed connection
|
||||
# It's here for interface compatibility but should not be used
|
||||
import traceback
|
||||
|
||||
traceback.print_stack()
|
||||
raise NotImplementedError(
|
||||
"Use streams for reading data from QUIC connections. "
|
||||
"Call accept_stream() or open_stream() instead."
|
||||
|
||||
@ -47,6 +47,7 @@ logging.basicConfig(
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class QUICPacketInfo:
|
||||
@ -368,10 +369,7 @@ class QUICListener(IListener):
|
||||
await self._transmit_for_connection(quic_conn, addr)
|
||||
|
||||
# Check if handshake completed (with minimal locking)
|
||||
if (
|
||||
hasattr(quic_conn, "_handshake_complete")
|
||||
and quic_conn._handshake_complete
|
||||
):
|
||||
if quic_conn._handshake_complete:
|
||||
logger.debug("PENDING: Handshake completed, promoting connection")
|
||||
await self._promote_pending_connection(quic_conn, addr, dest_cid)
|
||||
else:
|
||||
@ -497,6 +495,15 @@ class QUICListener(IListener):
|
||||
|
||||
# Process initial packet
|
||||
quic_conn.receive_datagram(data, addr, now=time.time())
|
||||
if quic_conn.tls:
|
||||
if self._security_manager:
|
||||
try:
|
||||
quic_conn.tls._request_client_certificate = True
|
||||
logger.debug(
|
||||
"request_client_certificate set to True in server TLS context"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"FAILED to apply request_client_certificate: {e}")
|
||||
|
||||
# Process events and send response
|
||||
await self._process_quic_events(quic_conn, addr, destination_cid)
|
||||
@ -686,12 +693,10 @@ class QUICListener(IListener):
|
||||
self._pending_connections.pop(dest_cid, None)
|
||||
|
||||
if dest_cid in self._connections:
|
||||
connection = self._connections[dest_cid]
|
||||
logger.debug(
|
||||
f"Using existing QUICConnection {id(connection)} "
|
||||
f"for {dest_cid.hex()}"
|
||||
f"⚠️ PROMOTE: Connection {dest_cid.hex()} already exists in _connections!"
|
||||
)
|
||||
|
||||
connection = self._connections[dest_cid]
|
||||
else:
|
||||
from .connection import QUICConnection
|
||||
|
||||
@ -726,7 +731,8 @@ class QUICListener(IListener):
|
||||
if self._security_manager:
|
||||
try:
|
||||
peer_id = await connection._verify_peer_identity_with_security()
|
||||
connection.peer_id = peer_id
|
||||
if peer_id:
|
||||
connection.peer_id = peer_id
|
||||
logger.info(
|
||||
f"Security verification successful for {dest_cid.hex()}"
|
||||
)
|
||||
|
||||
@ -136,21 +136,23 @@ class LibP2PExtensionHandler:
|
||||
Parse the libp2p Public Key Extension with enhanced debugging.
|
||||
"""
|
||||
try:
|
||||
print(f"🔍 Extension type: {type(extension)}")
|
||||
print(f"🔍 Extension.value type: {type(extension.value)}")
|
||||
logger.debug(f"🔍 Extension type: {type(extension)}")
|
||||
logger.debug(f"🔍 Extension.value type: {type(extension.value)}")
|
||||
|
||||
# Extract the raw bytes from the extension
|
||||
if isinstance(extension.value, UnrecognizedExtension):
|
||||
# Use the .value property to get the bytes
|
||||
raw_bytes = extension.value.value
|
||||
print("🔍 Extension is UnrecognizedExtension, using .value property")
|
||||
logger.debug(
|
||||
"🔍 Extension is UnrecognizedExtension, using .value property"
|
||||
)
|
||||
else:
|
||||
# Fallback if it's already bytes somehow
|
||||
raw_bytes = extension.value
|
||||
print("🔍 Extension.value is already bytes")
|
||||
logger.debug("🔍 Extension.value is already bytes")
|
||||
|
||||
print(f"🔍 Total extension length: {len(raw_bytes)} bytes")
|
||||
print(f"🔍 Extension hex (first 50 bytes): {raw_bytes[:50].hex()}")
|
||||
logger.debug(f"🔍 Total extension length: {len(raw_bytes)} bytes")
|
||||
logger.debug(f"🔍 Extension hex (first 50 bytes): {raw_bytes[:50].hex()}")
|
||||
|
||||
if not isinstance(raw_bytes, bytes):
|
||||
raise QUICCertificateError(f"Expected bytes, got {type(raw_bytes)}")
|
||||
@ -164,16 +166,16 @@ class LibP2PExtensionHandler:
|
||||
public_key_length = int.from_bytes(
|
||||
raw_bytes[offset : offset + 4], byteorder="big"
|
||||
)
|
||||
print(f"🔍 Public key length: {public_key_length} bytes")
|
||||
logger.debug(f"🔍 Public key length: {public_key_length} bytes")
|
||||
offset += 4
|
||||
|
||||
if len(raw_bytes) < offset + public_key_length:
|
||||
raise QUICCertificateError("Extension too short for public key data")
|
||||
|
||||
public_key_bytes = raw_bytes[offset : offset + public_key_length]
|
||||
print(f"🔍 Public key data: {public_key_bytes.hex()}")
|
||||
logger.debug(f"🔍 Public key data: {public_key_bytes.hex()}")
|
||||
offset += public_key_length
|
||||
print(f"🔍 Offset after public key: {offset}")
|
||||
logger.debug(f"🔍 Offset after public key: {offset}")
|
||||
|
||||
# Parse signature length and data
|
||||
if len(raw_bytes) < offset + 4:
|
||||
@ -182,17 +184,17 @@ class LibP2PExtensionHandler:
|
||||
signature_length = int.from_bytes(
|
||||
raw_bytes[offset : offset + 4], byteorder="big"
|
||||
)
|
||||
print(f"🔍 Signature length: {signature_length} bytes")
|
||||
logger.debug(f"🔍 Signature length: {signature_length} bytes")
|
||||
offset += 4
|
||||
print(f"🔍 Offset after signature length: {offset}")
|
||||
logger.debug(f"🔍 Offset after signature length: {offset}")
|
||||
|
||||
if len(raw_bytes) < offset + signature_length:
|
||||
raise QUICCertificateError("Extension too short for signature data")
|
||||
|
||||
signature = raw_bytes[offset : offset + signature_length]
|
||||
print(f"🔍 Extracted signature length: {len(signature)} bytes")
|
||||
print(f"🔍 Signature hex (first 20 bytes): {signature[:20].hex()}")
|
||||
print(
|
||||
logger.debug(f"🔍 Extracted signature length: {len(signature)} bytes")
|
||||
logger.debug(f"🔍 Signature hex (first 20 bytes): {signature[:20].hex()}")
|
||||
logger.debug(
|
||||
f"🔍 Signature starts with DER header: {signature[:2].hex() == '3045'}"
|
||||
)
|
||||
|
||||
@ -220,27 +222,27 @@ class LibP2PExtensionHandler:
|
||||
|
||||
# Check if we have extra data
|
||||
expected_total = 4 + public_key_length + 4 + signature_length
|
||||
print(f"🔍 Expected total length: {expected_total}")
|
||||
print(f"🔍 Actual total length: {len(raw_bytes)}")
|
||||
logger.debug(f"🔍 Expected total length: {expected_total}")
|
||||
logger.debug(f"🔍 Actual total length: {len(raw_bytes)}")
|
||||
|
||||
if len(raw_bytes) > expected_total:
|
||||
extra_bytes = len(raw_bytes) - expected_total
|
||||
print(f"⚠️ Extra {extra_bytes} bytes detected!")
|
||||
print(f"🔍 Extra data: {raw_bytes[expected_total:].hex()}")
|
||||
logger.debug(f"⚠️ Extra {extra_bytes} bytes detected!")
|
||||
logger.debug(f"🔍 Extra data: {raw_bytes[expected_total:].hex()}")
|
||||
|
||||
# Deserialize the public key
|
||||
public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes)
|
||||
print(f"🔍 Successfully deserialized public key: {type(public_key)}")
|
||||
logger.debug(f"🔍 Successfully deserialized public key: {type(public_key)}")
|
||||
|
||||
print(f"🔍 Final signature to return: {len(signature)} bytes")
|
||||
logger.debug(f"🔍 Final signature to return: {len(signature)} bytes")
|
||||
|
||||
return public_key, signature
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Extension parsing failed: {e}")
|
||||
logger.debug(f"❌ Extension parsing failed: {e}")
|
||||
import traceback
|
||||
|
||||
print(f"❌ Traceback: {traceback.format_exc()}")
|
||||
logger.debug(f"❌ Traceback: {traceback.format_exc()}")
|
||||
raise QUICCertificateError(
|
||||
f"Failed to parse signed key extension: {e}"
|
||||
) from e
|
||||
@ -424,11 +426,11 @@ class PeerAuthenticator:
|
||||
raise QUICPeerVerificationError("Certificate missing libp2p extension")
|
||||
|
||||
assert libp2p_extension.value is not None
|
||||
print(f"Extension type: {type(libp2p_extension)}")
|
||||
print(f"Extension value type: {type(libp2p_extension.value)}")
|
||||
logger.debug(f"Extension type: {type(libp2p_extension)}")
|
||||
logger.debug(f"Extension value type: {type(libp2p_extension.value)}")
|
||||
if hasattr(libp2p_extension.value, "__len__"):
|
||||
print(f"Extension value length: {len(libp2p_extension.value)}")
|
||||
print(f"Extension value: {libp2p_extension.value}")
|
||||
logger.debug(f"Extension value length: {len(libp2p_extension.value)}")
|
||||
logger.debug(f"Extension value: {libp2p_extension.value}")
|
||||
# Parse the extension to get public key and signature
|
||||
public_key, signature = self.extension_handler.parse_signed_key_extension(
|
||||
libp2p_extension
|
||||
@ -455,8 +457,8 @@ class PeerAuthenticator:
|
||||
|
||||
# Verify against expected peer ID if provided
|
||||
if expected_peer_id and derived_peer_id != expected_peer_id:
|
||||
print(f"Expected Peer id: {expected_peer_id}")
|
||||
print(f"Derived Peer ID: {derived_peer_id}")
|
||||
logger.debug(f"Expected Peer id: {expected_peer_id}")
|
||||
logger.debug(f"Derived Peer ID: {derived_peer_id}")
|
||||
raise QUICPeerVerificationError(
|
||||
f"Peer ID mismatch: expected {expected_peer_id}, "
|
||||
f"got {derived_peer_id}"
|
||||
@ -615,22 +617,24 @@ class QUICTLSSecurityConfig:
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def debug_print(self) -> None:
|
||||
"""Print debugging information about this configuration."""
|
||||
print(f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===")
|
||||
print(f"Is client config: {self.is_client_config}")
|
||||
print(f"ALPN protocols: {self.alpn_protocols}")
|
||||
print(f"Verify mode: {self.verify_mode}")
|
||||
print(f"Check hostname: {self.check_hostname}")
|
||||
print(f"Certificate chain length: {len(self.certificate_chain)}")
|
||||
def debug_config(self) -> None:
|
||||
"""logger.debug debugging information about this configuration."""
|
||||
logger.debug(
|
||||
f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ==="
|
||||
)
|
||||
logger.debug(f"Is client config: {self.is_client_config}")
|
||||
logger.debug(f"ALPN protocols: {self.alpn_protocols}")
|
||||
logger.debug(f"Verify mode: {self.verify_mode}")
|
||||
logger.debug(f"Check hostname: {self.check_hostname}")
|
||||
logger.debug(f"Certificate chain length: {len(self.certificate_chain)}")
|
||||
|
||||
cert_info: dict[Any, Any] = self.get_certificate_info()
|
||||
for key, value in cert_info.items():
|
||||
print(f"Certificate {key}: {value}")
|
||||
logger.debug(f"Certificate {key}: {value}")
|
||||
|
||||
print(f"Private key type: {type(self.private_key).__name__}")
|
||||
logger.debug(f"Private key type: {type(self.private_key).__name__}")
|
||||
if hasattr(self.private_key, "key_size"):
|
||||
print(f"Private key size: {self.private_key.key_size}")
|
||||
logger.debug(f"Private key size: {self.private_key.key_size}")
|
||||
|
||||
|
||||
def create_server_tls_config(
|
||||
@ -727,8 +731,7 @@ class QUICTLSConfigManager:
|
||||
peer_id=self.peer_id,
|
||||
)
|
||||
|
||||
print("🔧 SECURITY: Created server config")
|
||||
config.debug_print()
|
||||
logger.debug("🔧 SECURITY: Created server config")
|
||||
return config
|
||||
|
||||
def create_client_config(self) -> QUICTLSSecurityConfig:
|
||||
@ -745,8 +748,7 @@ class QUICTLSConfigManager:
|
||||
peer_id=self.peer_id,
|
||||
)
|
||||
|
||||
print("🔧 SECURITY: Created client config")
|
||||
config.debug_print()
|
||||
logger.debug("🔧 SECURITY: Created client config")
|
||||
return config
|
||||
|
||||
def verify_peer_identity(
|
||||
|
||||
@ -33,6 +33,8 @@ from libp2p.peer.id import (
|
||||
)
|
||||
from libp2p.transport.quic.security import QUICTLSSecurityConfig
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_client_config_from_base,
|
||||
create_server_config_from_base,
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
@ -162,12 +164,16 @@ class QUICTransport(ITransport):
|
||||
self._apply_tls_configuration(base_client_config, client_tls_config)
|
||||
|
||||
# QUIC v1 (RFC 9000) configurations
|
||||
quic_v1_server_config = copy.copy(base_server_config)
|
||||
quic_v1_server_config = create_server_config_from_base(
|
||||
base_server_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
quic_v1_client_config = copy.copy(base_client_config)
|
||||
quic_v1_client_config = create_client_config_from_base(
|
||||
base_client_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
@ -269,9 +275,21 @@ class QUICTransport(ITransport):
|
||||
|
||||
config.is_client = True
|
||||
config.quic_logger = QuicLogger()
|
||||
print(f"Dialing QUIC connection to {host}:{port} (version: {quic_version})")
|
||||
|
||||
print("Start QUIC Connection")
|
||||
# Ensure client certificate is properly set for mutual authentication
|
||||
if not config.certificate or not config.private_key:
|
||||
logger.warning(
|
||||
"Client config missing certificate - applying TLS config"
|
||||
)
|
||||
client_tls_config = self._security_manager.create_client_config()
|
||||
self._apply_tls_configuration(config, client_tls_config)
|
||||
|
||||
# Debug log to verify certificate is present
|
||||
logger.info(
|
||||
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
|
||||
)
|
||||
|
||||
logger.debug("Starting QUIC Connection")
|
||||
# Create QUIC connection using aioquic's sans-IO core
|
||||
native_quic_connection = NativeQUICConnection(configuration=config)
|
||||
|
||||
|
||||
@ -350,11 +350,18 @@ def create_server_config_from_base(
|
||||
if server_tls_config.private_key:
|
||||
server_config.private_key = server_tls_config.private_key
|
||||
if server_tls_config.certificate_chain:
|
||||
server_config.certificate_chain = server_tls_config.certificate_chain
|
||||
server_config.certificate_chain = (
|
||||
server_tls_config.certificate_chain
|
||||
)
|
||||
if server_tls_config.alpn_protocols:
|
||||
server_config.alpn_protocols = server_tls_config.alpn_protocols
|
||||
print("Setting request client certificate to True")
|
||||
server_tls_config.request_client_certificate = True
|
||||
if getattr(server_tls_config, "request_client_certificate", False):
|
||||
server_config._libp2p_request_client_cert = True # type: ignore
|
||||
else:
|
||||
logger.error(
|
||||
"🔧 Failed to set request_client_certificate in server config"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
@ -379,3 +386,81 @@ def create_server_config_from_base(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create server config: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_client_config_from_base(
|
||||
base_config: QuicConfiguration,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
transport_config: QUICTransportConfig | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""
|
||||
Create a client configuration without using deepcopy.
|
||||
"""
|
||||
try:
|
||||
# Create new client configuration from scratch
|
||||
client_config = QuicConfiguration(is_client=True)
|
||||
client_config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Copy basic configuration attributes
|
||||
copyable_attrs = [
|
||||
"alpn_protocols",
|
||||
"verify_mode",
|
||||
"max_datagram_frame_size",
|
||||
"idle_timeout",
|
||||
"max_concurrent_streams",
|
||||
"supported_versions",
|
||||
"max_data",
|
||||
"max_stream_data",
|
||||
"quantum_readiness_test",
|
||||
]
|
||||
|
||||
for attr in copyable_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Handle cryptography objects - these need direct reference, not copying
|
||||
crypto_attrs = [
|
||||
"certificate",
|
||||
"private_key",
|
||||
"certificate_chain",
|
||||
"ca_certs",
|
||||
]
|
||||
|
||||
for attr in crypto_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Apply security manager configuration if available
|
||||
if security_manager:
|
||||
try:
|
||||
client_tls_config = security_manager.create_client_config()
|
||||
|
||||
# Override with security manager's TLS configuration
|
||||
if client_tls_config.certificate:
|
||||
client_config.certificate = client_tls_config.certificate
|
||||
if client_tls_config.private_key:
|
||||
client_config.private_key = client_tls_config.private_key
|
||||
if client_tls_config.certificate_chain:
|
||||
client_config.certificate_chain = (
|
||||
client_tls_config.certificate_chain
|
||||
)
|
||||
if client_tls_config.alpn_protocols:
|
||||
client_config.alpn_protocols = client_tls_config.alpn_protocols
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
|
||||
# Ensure we have ALPN protocols
|
||||
if not client_config.alpn_protocols:
|
||||
client_config.alpn_protocols = ["libp2p"]
|
||||
|
||||
logger.debug("Successfully created client config without deepcopy")
|
||||
return client_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create client config: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user