fix: initial connection succesfull

This commit is contained in:
Akash Mondal
2025-06-30 11:16:08 +00:00
committed by lla-dane
parent 2689040d48
commit bbe632bd85
6 changed files with 120 additions and 79 deletions

View File

@ -115,7 +115,9 @@ async def run_client(destination: str, seed: int | None = None) -> None:
info = info_from_p2p_addr(maddr) info = info_from_p2p_addr(maddr)
# Connect to server # Connect to server
print("STARTING CLIENT CONNECTION PROCESS")
await host.connect(info) await host.connect(info)
print("CLIENT CONNECTED TO SERVER")
# Start a stream with the destination # Start a stream with the destination
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])

View File

@ -40,6 +40,7 @@ from libp2p.transport.exceptions import (
OpenConnectionError, OpenConnectionError,
SecurityUpgradeFailure, SecurityUpgradeFailure,
) )
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.upgrader import ( from libp2p.transport.upgrader import (
TransportUpgrader, TransportUpgrader,
) )
@ -114,6 +115,11 @@ class Swarm(Service, INetworkService):
# Create a nursery for listener tasks. # Create a nursery for listener tasks.
self.listener_nursery = nursery self.listener_nursery = nursery
self.event_listener_nursery_created.set() self.event_listener_nursery_created.set()
if isinstance(self.transport, QUICTransport):
self.transport.set_background_nursery(nursery)
self.transport.set_swarm(self)
try: try:
await self.manager.wait_finished() await self.manager.wait_finished()
finally: finally:
@ -177,6 +183,14 @@ class Swarm(Service, INetworkService):
""" """
Try to create a connection to peer_id with addr. Try to create a connection to peer_id with addr.
""" """
# QUIC Transport
if isinstance(self.transport, QUICTransport):
raw_conn = await self.transport.dial(addr, peer_id)
print("detected QUIC connection, skipping upgrade steps")
swarm_conn = await self.add_conn(raw_conn)
print("successfully dialed peer %s via QUIC", peer_id)
return swarm_conn
try: try:
raw_conn = await self.transport.dial(addr) raw_conn = await self.transport.dial(addr)
except OpenConnectionError as error: except OpenConnectionError as error:
@ -187,14 +201,6 @@ class Swarm(Service, INetworkService):
logger.debug("dialed peer %s over base transport", peer_id) logger.debug("dialed peer %s over base transport", peer_id)
# NEW: Check if this is a QUIC connection (already secure and muxed)
if isinstance(raw_conn, IMuxedConn):
# QUIC connections are already secure and muxed, skip upgrade steps
logger.debug("detected QUIC connection, skipping upgrade steps")
swarm_conn = await self.add_conn(raw_conn)
logger.debug("successfully dialed peer %s via QUIC", peer_id)
return swarm_conn
# Standard TCP flow - security then mux upgrade # Standard TCP flow - security then mux upgrade
try: try:
secured_conn = await self.upgrader.upgrade_security(raw_conn, True, peer_id) secured_conn = await self.upgrader.upgrade_security(raw_conn, True, peer_id)

View File

@ -147,7 +147,8 @@ class MultiselectClient(IMultiselectClient):
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error raise MultiselectClientError() from error
if response == protocol_str: print("Response: ", response)
if response == protocol:
return protocol return protocol
if response == PROTOCOL_NOT_FOUND_MSG: if response == PROTOCOL_NOT_FOUND_MSG:
raise MultiselectClientError("protocol not supported") raise MultiselectClientError("protocol not supported")

View File

@ -3,11 +3,12 @@ QUIC Connection implementation.
Uses aioquic's sans-IO core with trio for async operations. Uses aioquic's sans-IO core with trio for async operations.
""" """
from collections.abc import Awaitable, Callable
import logging import logging
import socket import socket
from sys import stdout from sys import stdout
import time import time
from typing import TYPE_CHECKING, Any, Optional, Set from typing import TYPE_CHECKING, Any, Optional
from aioquic.quic import events from aioquic.quic import events
from aioquic.quic.connection import QuicConnection from aioquic.quic.connection import QuicConnection
@ -75,7 +76,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
self, self,
quic_connection: QuicConnection, quic_connection: QuicConnection,
remote_addr: tuple[str, int], remote_addr: tuple[str, int],
peer_id: ID | None, peer_id: ID,
local_peer_id: ID, local_peer_id: ID,
is_initiator: bool, is_initiator: bool,
maddr: multiaddr.Multiaddr, maddr: multiaddr.Multiaddr,
@ -102,7 +103,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
""" """
self._quic = quic_connection self._quic = quic_connection
self._remote_addr = remote_addr self._remote_addr = remote_addr
self._peer_id = peer_id self.peer_id = peer_id
self._local_peer_id = local_peer_id self._local_peer_id = local_peer_id
self.__is_initiator = is_initiator self.__is_initiator = is_initiator
self._maddr = maddr self._maddr = maddr
@ -147,12 +148,14 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._background_tasks_started = False self._background_tasks_started = False
self._nursery: trio.Nursery | None = None self._nursery: trio.Nursery | None = None
self._event_processing_task: Any | None = None self._event_processing_task: Any | None = None
self.on_close: Callable[[], Awaitable[None]] | None = None
self.event_started = trio.Event()
# *** NEW: Connection ID tracking - CRITICAL for fixing the original issue *** # *** NEW: Connection ID tracking - CRITICAL for fixing the original issue ***
self._available_connection_ids: Set[bytes] = set() self._available_connection_ids: set[bytes] = set()
self._current_connection_id: Optional[bytes] = None self._current_connection_id: bytes | None = None
self._retired_connection_ids: Set[bytes] = set() self._retired_connection_ids: set[bytes] = set()
self._connection_id_sequence_numbers: Set[int] = set() self._connection_id_sequence_numbers: set[int] = set()
# Event processing control # Event processing control
self._event_processing_active = False self._event_processing_active = False
@ -235,7 +238,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
def remote_peer_id(self) -> ID | None: def remote_peer_id(self) -> ID | None:
"""Get the remote peer ID.""" """Get the remote peer ID."""
return self._peer_id return self.peer_id
# *** NEW: Connection ID management methods *** # *** NEW: Connection ID management methods ***
def get_connection_id_stats(self) -> dict[str, Any]: def get_connection_id_stats(self) -> dict[str, Any]:
@ -252,7 +255,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
"available_cid_list": [cid.hex() for cid in self._available_connection_ids], "available_cid_list": [cid.hex() for cid in self._available_connection_ids],
} }
def get_current_connection_id(self) -> Optional[bytes]: def get_current_connection_id(self) -> bytes | None:
"""Get the current connection ID.""" """Get the current connection ID."""
return self._current_connection_id return self._current_connection_id
@ -273,7 +276,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
raise QUICConnectionError("Cannot start a closed connection") raise QUICConnectionError("Cannot start a closed connection")
self._started = True self._started = True
logger.debug(f"Starting QUIC connection to {self._peer_id}") self.event_started.set()
logger.debug(f"Starting QUIC connection to {self.peer_id}")
try: try:
# If this is a client connection, we need to establish the connection # If this is a client connection, we need to establish the connection
@ -284,7 +288,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._established = True self._established = True
self._connected_event.set() self._connected_event.set()
logger.debug(f"QUIC connection to {self._peer_id} started") logger.debug(f"QUIC connection to {self.peer_id} started")
except Exception as e: except Exception as e:
logger.error(f"Failed to start connection: {e}") logger.error(f"Failed to start connection: {e}")
@ -356,7 +360,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
await self._verify_peer_identity_with_security() await self._verify_peer_identity_with_security()
self._established = True self._established = True
logger.info(f"QUIC connection established with {self._peer_id}") logger.info(f"QUIC connection established with {self.peer_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to establish connection: {e}") logger.error(f"Failed to establish connection: {e}")
@ -491,17 +495,16 @@ class QUICConnection(IRawConnection, IMuxedConn):
# Verify peer identity using security manager # Verify peer identity using security manager
verified_peer_id = self._security_manager.verify_peer_identity( verified_peer_id = self._security_manager.verify_peer_identity(
self._peer_certificate, self._peer_certificate,
self._peer_id, # Expected peer ID for outbound connections self.peer_id, # Expected peer ID for outbound connections
) )
# Update peer ID if it wasn't known (inbound connections) # Update peer ID if it wasn't known (inbound connections)
if not self._peer_id: if not self.peer_id:
self._peer_id = verified_peer_id self.peer_id = verified_peer_id
logger.info(f"Discovered peer ID from certificate: {verified_peer_id}") logger.info(f"Discovered peer ID from certificate: {verified_peer_id}")
elif self._peer_id != verified_peer_id: elif self.peer_id != verified_peer_id:
raise QUICPeerVerificationError( raise QUICPeerVerificationError(
f"Peer ID mismatch: expected {self._peer_id}, " f"Peer ID mismatch: expected {self.peer_id}, got {verified_peer_id}"
f"got {verified_peer_id}"
) )
self._peer_verified = True self._peer_verified = True
@ -605,7 +608,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
info: dict[str, bool | Any | None] = { info: dict[str, bool | Any | None] = {
"peer_verified": self._peer_verified, "peer_verified": self._peer_verified,
"handshake_complete": self._handshake_completed, "handshake_complete": self._handshake_completed,
"peer_id": str(self._peer_id) if self._peer_id else None, "peer_id": str(self.peer_id) if self.peer_id else None,
"local_peer_id": str(self._local_peer_id), "local_peer_id": str(self._local_peer_id),
"is_initiator": self.__is_initiator, "is_initiator": self.__is_initiator,
"has_certificate": self._peer_certificate is not None, "has_certificate": self._peer_certificate is not None,
@ -1188,7 +1191,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
return return
self._closed = True self._closed = True
logger.debug(f"Closing QUIC connection to {self._peer_id}") logger.debug(f"Closing QUIC connection to {self.peer_id}")
try: try:
# Close all streams gracefully # Close all streams gracefully
@ -1213,8 +1216,12 @@ class QUICConnection(IRawConnection, IMuxedConn):
except Exception: except Exception:
pass pass
if self.on_close:
await self.on_close()
# Close QUIC connection # Close QUIC connection
self._quic.close() self._quic.close()
if self._socket: if self._socket:
await self._transmit() # Send close frames await self._transmit() # Send close frames
@ -1226,7 +1233,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
self._streams.clear() self._streams.clear()
self._closed_event.set() self._closed_event.set()
logger.debug(f"QUIC connection to {self._peer_id} closed") logger.debug(f"QUIC connection to {self.peer_id} closed")
except Exception as e: except Exception as e:
logger.error(f"Error during connection close: {e}") logger.error(f"Error during connection close: {e}")
@ -1266,6 +1273,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
QUICStreamClosedError: If stream is closed for reading. QUICStreamClosedError: If stream is closed for reading.
QUICStreamResetError: If stream was reset. QUICStreamResetError: If stream was reset.
QUICStreamTimeoutError: If read timeout occurs. QUICStreamTimeoutError: If read timeout occurs.
""" """
# This method doesn't make sense for a muxed connection # This method doesn't make sense for a muxed connection
# It's here for interface compatibility but should not be used # It's here for interface compatibility but should not be used
@ -1325,7 +1333,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"QUICConnection(peer={self._peer_id}, " f"QUICConnection(peer={self.peer_id}, "
f"addr={self._remote_addr}, " f"addr={self._remote_addr}, "
f"initiator={self.__is_initiator}, " f"initiator={self.__is_initiator}, "
f"verified={self._peer_verified}, " f"verified={self._peer_verified}, "
@ -1335,4 +1343,4 @@ class QUICConnection(IRawConnection, IMuxedConn):
) )
def __str__(self) -> str: def __str__(self) -> str:
return f"QUICConnection({self._peer_id})" return f"QUICConnection({self.peer_id})"

View File

@ -12,18 +12,19 @@ from typing import TYPE_CHECKING
from aioquic.quic import events from aioquic.quic import events
from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection from aioquic.quic.connection import QuicConnection
from aioquic.quic.packet import QuicPacketType
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio import trio
from libp2p.abc import IListener from libp2p.abc import IListener
from libp2p.custom_types import THandler, TProtocol from libp2p.custom_types import (
TProtocol,
TQUICConnHandlerFn,
)
from libp2p.transport.quic.security import ( from libp2p.transport.quic.security import (
LIBP2P_TLS_EXTENSION_OID, LIBP2P_TLS_EXTENSION_OID,
QUICTLSConfigManager, QUICTLSConfigManager,
) )
from libp2p.custom_types import TQUICConnHandlerFn
from libp2p.custom_types import TQUICStreamHandlerFn
from aioquic.quic.packet import QuicPacketType
from .config import QUICTransportConfig from .config import QUICTransportConfig
from .connection import QUICConnection from .connection import QUICConnection
@ -1099,12 +1100,21 @@ class QUICListener(IListener):
if not is_quic_multiaddr(maddr): if not is_quic_multiaddr(maddr):
raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}") raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}")
if self._transport._background_nursery:
active_nursery = self._transport._background_nursery
logger.debug("Using transport background nursery for listener")
elif nursery:
active_nursery = nursery
logger.debug("Using provided nursery for listener")
else:
raise QUICListenError("No nursery available")
try: try:
host, port = quic_multiaddr_to_endpoint(maddr) host, port = quic_multiaddr_to_endpoint(maddr)
# Create and configure socket # Create and configure socket
self._socket = await self._create_socket(host, port) self._socket = await self._create_socket(host, port)
self._nursery = nursery self._nursery = active_nursery
# Get the actual bound address # Get the actual bound address
bound_host, bound_port = self._socket.getsockname() bound_host, bound_port = self._socket.getsockname()
@ -1115,7 +1125,7 @@ class QUICListener(IListener):
self._listening = True self._listening = True
# Start packet handling loop # Start packet handling loop
nursery.start_soon(self._handle_incoming_packets) active_nursery.start_soon(self._handle_incoming_packets)
logger.info( logger.info(
f"QUIC listener started on {bound_maddr} with connection ID support" f"QUIC listener started on {bound_maddr} with connection ID support"
@ -1217,33 +1227,22 @@ class QUICListener(IListener):
async def _handle_new_established_connection( async def _handle_new_established_connection(
self, connection: QUICConnection self, connection: QUICConnection
) -> None: ) -> None:
"""Handle newly established connection with proper stream management.""" """Handle newly established connection by adding to swarm."""
try: try:
logger.debug( logger.debug(
f"Handling new established connection from {connection._remote_addr}" f"New QUIC connection established from {connection._remote_addr}"
) )
# Accept incoming streams and pass them to the handler if self._transport._swarm:
while not connection.is_closed: logger.debug("Adding QUIC connection directly to swarm")
try: await self._transport._swarm.add_conn(connection)
print(f"🔧 CONN_HANDLER: Waiting for stream...") logger.debug("Successfully added QUIC connection to swarm")
stream = await connection.accept_stream(timeout=1.0) else:
print(f"✅ CONN_HANDLER: Accepted stream {stream.stream_id}") logger.error("No swarm available for QUIC connection")
await connection.close()
if self._nursery:
# Pass STREAM to handler, not connection
self._nursery.start_soon(self._handler, stream)
print(
f"✅ CONN_HANDLER: Started handler for stream {stream.stream_id}"
)
except trio.TooSlowError:
continue # Timeout is normal
except Exception as e:
logger.error(f"Error accepting stream: {e}")
break
except Exception as e: except Exception as e:
logger.error(f"Error in connection handler: {e}") logger.error(f"Error adding QUIC connection to swarm: {e}")
await connection.close() await connection.close()
def get_addrs(self) -> tuple[Multiaddr]: def get_addrs(self) -> tuple[Multiaddr]:

View File

@ -9,6 +9,7 @@ import copy
import logging import logging
import ssl import ssl
import sys import sys
from typing import TYPE_CHECKING, cast
from aioquic.quic.configuration import ( from aioquic.quic.configuration import (
QuicConfiguration, QuicConfiguration,
@ -21,13 +22,12 @@ import multiaddr
import trio import trio
from libp2p.abc import ( from libp2p.abc import (
IRawConnection,
ITransport, ITransport,
) )
from libp2p.crypto.keys import ( from libp2p.crypto.keys import (
PrivateKey, PrivateKey,
) )
from libp2p.custom_types import THandler, TProtocol, TQUICConnHandlerFn from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
from libp2p.peer.id import ( from libp2p.peer.id import (
ID, ID,
) )
@ -40,6 +40,11 @@ from libp2p.transport.quic.utils import (
quic_version_to_wire_format, quic_version_to_wire_format,
) )
if TYPE_CHECKING:
from libp2p.network.swarm import Swarm
else:
Swarm = cast(type, object)
from .config import ( from .config import (
QUICTransportConfig, QUICTransportConfig,
) )
@ -112,10 +117,20 @@ class QUICTransport(ITransport):
# Resource management # Resource management
self._closed = False self._closed = False
self._nursery_manager = trio.CapacityLimiter(1) self._nursery_manager = trio.CapacityLimiter(1)
self._background_nursery: trio.Nursery | None = None
logger.info( self._swarm = None
f"Initialized QUIC transport with security for peer {self._peer_id}"
) print(f"Initialized QUIC transport with security for peer {self._peer_id}")
def set_background_nursery(self, nursery: trio.Nursery) -> None:
"""Set the nursery to use for background tasks (called by swarm)."""
self._background_nursery = nursery
print("Transport background nursery set")
def set_swarm(self, swarm) -> None:
"""Set the swarm for adding incoming connections."""
self._swarm = swarm
def _setup_quic_configurations(self) -> None: def _setup_quic_configurations(self) -> None:
"""Setup QUIC configurations.""" """Setup QUIC configurations."""
@ -184,7 +199,7 @@ class QUICTransport(ITransport):
draft29_client_config draft29_client_config
) )
logger.info("QUIC configurations initialized with libp2p TLS security") print("QUIC configurations initialized with libp2p TLS security")
except Exception as e: except Exception as e:
raise QUICSecurityError( raise QUICSecurityError(
@ -214,14 +229,13 @@ class QUICTransport(ITransport):
config.verify_mode = ssl.CERT_NONE config.verify_mode = ssl.CERT_NONE
logger.debug("Successfully applied TLS configuration to QUIC config") print("Successfully applied TLS configuration to QUIC config")
except Exception as e: except Exception as e:
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
async def dial( # type: ignore
self, maddr: multiaddr.Multiaddr, peer_id: ID | None = None async def dial(self, maddr: multiaddr.Multiaddr, peer_id: ID) -> QUICConnection:
) -> QUICConnection:
""" """
Dial a remote peer using QUIC transport with security verification. Dial a remote peer using QUIC transport with security verification.
@ -243,6 +257,9 @@ class QUICTransport(ITransport):
if not is_quic_multiaddr(maddr): if not is_quic_multiaddr(maddr):
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}") raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
if not peer_id:
raise QUICDialError("Peer id cannot be null")
try: try:
# Extract connection details from multiaddr # Extract connection details from multiaddr
host, port = quic_multiaddr_to_endpoint(maddr) host, port = quic_multiaddr_to_endpoint(maddr)
@ -257,9 +274,7 @@ class QUICTransport(ITransport):
config.is_client = True config.is_client = True
config.quic_logger = QuicLogger() config.quic_logger = QuicLogger()
logger.debug( print(f"Dialing QUIC connection to {host}:{port} (version: {quic_version})")
f"Dialing QUIC connection to {host}:{port} (version: {quic_version})"
)
print("Start QUIC Connection") print("Start QUIC Connection")
# Create QUIC connection using aioquic's sans-IO core # Create QUIC connection using aioquic's sans-IO core
@ -279,8 +294,18 @@ class QUICTransport(ITransport):
) )
# Establish connection using trio # Establish connection using trio
async with trio.open_nursery() as nursery: if self._background_nursery:
await connection.connect(nursery) # Use swarm's long-lived nursery - background tasks persist!
await connection.connect(self._background_nursery)
print("Using background nursery for connection tasks")
else:
# Fallback to temporary nursery (with warning)
print(
"No background nursery available. Connection background tasks "
"may be cancelled when dial completes."
)
async with trio.open_nursery() as temp_nursery:
await connection.connect(temp_nursery)
# Verify peer identity after TLS handshake # Verify peer identity after TLS handshake
if peer_id: if peer_id:
@ -290,7 +315,7 @@ class QUICTransport(ITransport):
conn_id = f"{host}:{port}:{peer_id}" conn_id = f"{host}:{port}:{peer_id}"
self._connections[conn_id] = connection self._connections[conn_id] = connection
logger.info(f"Successfully dialed secure QUIC connection to {peer_id}") print(f"Successfully dialed secure QUIC connection to {peer_id}")
return connection return connection
except Exception as e: except Exception as e:
@ -329,7 +354,7 @@ class QUICTransport(ITransport):
f"{expected_peer_id}, got {verified_peer_id}" f"{expected_peer_id}, got {verified_peer_id}"
) )
logger.info(f"Peer identity verified: {verified_peer_id}") print(f"Peer identity verified: {verified_peer_id}")
print(f"Peer identity verified: {verified_peer_id}") print(f"Peer identity verified: {verified_peer_id}")
except Exception as e: except Exception as e:
@ -368,7 +393,7 @@ class QUICTransport(ITransport):
) )
self._listeners.append(listener) self._listeners.append(listener)
logger.debug("Created QUIC listener with security") print("Created QUIC listener with security")
return listener return listener
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool: def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
@ -414,7 +439,7 @@ class QUICTransport(ITransport):
return return
self._closed = True self._closed = True
logger.info("Closing QUIC transport") print("Closing QUIC transport")
# Close all active connections and listeners concurrently using trio nursery # Close all active connections and listeners concurrently using trio nursery
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
@ -429,7 +454,7 @@ class QUICTransport(ITransport):
self._connections.clear() self._connections.clear()
self._listeners.clear() self._listeners.clear()
logger.info("QUIC transport closed") print("QUIC transport closed")
def get_stats(self) -> dict[str, int | list[str] | object]: def get_stats(self) -> dict[str, int | list[str] | object]:
"""Get transport statistics including security info.""" """Get transport statistics including security info."""