Merge branch 'main' into keyerror-fix

This commit is contained in:
Manu Sheel Gupta
2025-09-22 01:56:52 +05:30
committed by GitHub
48 changed files with 9477 additions and 100 deletions

View File

@ -36,10 +36,48 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python }}
- name: Install Nim for interop testing
if: matrix.toxenv == 'interop'
run: |
echo "Installing Nim for nim-libp2p interop testing..."
curl -sSf https://nim-lang.org/choosenim/init.sh | sh -s -- -y --firstInstall
echo "$HOME/.nimble/bin" >> $GITHUB_PATH
echo "$HOME/.choosenim/toolchains/nim-stable/bin" >> $GITHUB_PATH
- name: Cache nimble packages
if: matrix.toxenv == 'interop'
uses: actions/cache@v4
with:
path: |
~/.nimble
~/.choosenim/toolchains/*/lib
key: ${{ runner.os }}-nimble-${{ hashFiles('**/nim_echo_server.nim') }}
restore-keys: |
${{ runner.os }}-nimble-
- name: Build nim interop binaries
if: matrix.toxenv == 'interop'
run: |
export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH"
cd tests/interop/nim_libp2p
./scripts/setup_nim_echo.sh
- run: |
python -m pip install --upgrade pip
python -m pip install tox
- run: |
- name: Run Tests or Generate Docs
run: |
if [[ "${{ matrix.toxenv }}" == 'docs' ]]; then
export TOXENV=docs
else
export TOXENV=py${{ matrix.python }}-${{ matrix.toxenv }}
fi
# Set PATH for nim commands during tox
if [[ "${{ matrix.toxenv }}" == 'interop' ]]; then
export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH"
fi
python -m tox run -r
windows:

View File

@ -61,12 +61,12 @@ ______________________________________________________________________
### Discovery
| **Discovery** | **Status** | **Source** |
| -------------------- | :--------: | :--------------------------------------------------------------------------------: |
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
| **`random-walk`** | 🌱 | |
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
| **`rendezvous`** | 🌱 | |
| **Discovery** | **Status** | **Source** |
| -------------------- | :--------: | :----------------------------------------------------------------------------------: |
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
| **`random-walk`** | | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/random_walk) |
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
| **`rendezvous`** | 🌱 | |
______________________________________________________________________

View File

@ -0,0 +1,43 @@
QUIC Echo Demo
==============
This example demonstrates a simple ``echo`` protocol using **QUIC transport**.
QUIC provides built-in TLS security and stream multiplexing over UDP, making it an excellent transport choice for libp2p applications.
.. code-block:: console
$ python -m pip install libp2p
Collecting libp2p
...
Successfully installed libp2p-x.x.x
$ echo-quic-demo
Run this from the same folder in another console:
echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmAsbxRR1HiGJRNVPQLNMeNsBCsXT3rDjoYBQzgzNpM5mJ
Waiting for incoming connection...
Copy the line that starts with ``echo-quic-demo -p 8001``, open a new terminal in the same
folder and paste it in:
.. code-block:: console
$ echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu
I am 16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu
STARTING CLIENT CONNECTION PROCESS
CLIENT CONNECTED TO SERVER
Sent: hi, there!
Got: ECHO: hi, there!
**Key differences from TCP Echo:**
- Uses UDP instead of TCP: ``/udp/8000`` instead of ``/tcp/8000``
- Includes QUIC protocol identifier: ``/quic-v1`` in the multiaddr
- Built-in TLS security (no separate security transport needed)
- Native stream multiplexing over a single QUIC connection
.. literalinclude:: ../examples/echo/echo_quic.py
:language: python
:linenos:

View File

@ -9,6 +9,7 @@ Examples
examples.identify_push
examples.chat
examples.echo
examples.echo_quic
examples.ping
examples.pubsub
examples.circuit_relay

View File

@ -28,6 +28,11 @@ For Python, the most common transport is TCP. Here's how to set up a basic TCP t
.. literalinclude:: ../examples/doc-examples/example_transport.py
:language: python
Also, QUIC is a modern transport protocol that provides built-in TLS security and stream multiplexing over UDP:
.. literalinclude:: ../examples/doc-examples/example_quic_transport.py
:language: python
Connection Encryption
^^^^^^^^^^^^^^^^^^^^^

View File

@ -0,0 +1,77 @@
libp2p.transport.quic package
=============================
Submodules
----------
libp2p.transport.quic.config module
-----------------------------------
.. automodule:: libp2p.transport.quic.config
:members:
:undoc-members:
:show-inheritance:
libp2p.transport.quic.connection module
---------------------------------------
.. automodule:: libp2p.transport.quic.connection
:members:
:undoc-members:
:show-inheritance:
libp2p.transport.quic.exceptions module
---------------------------------------
.. automodule:: libp2p.transport.quic.exceptions
:members:
:undoc-members:
:show-inheritance:
libp2p.transport.quic.listener module
-------------------------------------
.. automodule:: libp2p.transport.quic.listener
:members:
:undoc-members:
:show-inheritance:
libp2p.transport.quic.security module
-------------------------------------
.. automodule:: libp2p.transport.quic.security
:members:
:undoc-members:
:show-inheritance:
libp2p.transport.quic.stream module
-----------------------------------
.. automodule:: libp2p.transport.quic.stream
:members:
:undoc-members:
:show-inheritance:
libp2p.transport.quic.transport module
--------------------------------------
.. automodule:: libp2p.transport.quic.transport
:members:
:undoc-members:
:show-inheritance:
libp2p.transport.quic.utils module
----------------------------------
.. automodule:: libp2p.transport.quic.utils
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: libp2p.transport.quic
:members:
:undoc-members:
:show-inheritance:

View File

@ -9,6 +9,11 @@ Subpackages
libp2p.transport.tcp
.. toctree::
:maxdepth: 4
libp2p.transport.quic
Submodules
----------

View File

@ -0,0 +1,35 @@
import secrets
import multiaddr
import trio
from libp2p import (
new_host,
)
from libp2p.crypto.secp256k1 import (
create_new_key_pair,
)
async def main():
# Create a key pair for the host
secret = secrets.token_bytes(32)
key_pair = create_new_key_pair(secret)
# Create a host with the key pair
host = new_host(key_pair=key_pair, enable_quic=True)
# Configure the listening address
port = 8000
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic-v1")
# Start the host
async with host.run(listen_addrs=[listen_addr]):
print("libp2p has started with QUIC transport")
print("libp2p is listening on:", host.get_addrs())
# Keep the host running
await trio.sleep_forever()
# Run the async function
trio.run(main)

178
examples/echo/echo_quic.py Normal file
View File

@ -0,0 +1,178 @@
#!/usr/bin/env python3
"""
QUIC Echo Example - Fixed version with proper client/server separation
This program demonstrates a simple echo protocol using QUIC transport where a peer
listens for connections and copies back any input received on a stream.
Fixed to properly separate client and server modes - clients don't start listeners.
"""
import argparse
import logging
from multiaddr import Multiaddr
import trio
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.network.stream.net_stream import INetStream
from libp2p.peer.peerinfo import info_from_p2p_addr
PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def _echo_stream_handler(stream: INetStream) -> None:
try:
msg = await stream.read()
await stream.write(msg)
await stream.close()
except Exception as e:
print(f"Echo handler error: {e}")
try:
await stream.close()
except: # noqa: E722
pass
async def run_server(port: int, seed: int | None = None) -> None:
"""Run echo server with QUIC transport."""
listen_addr = Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic")
if seed:
import random
random.seed(seed)
secret_number = random.getrandbits(32 * 8)
secret = secret_number.to_bytes(length=32, byteorder="big")
else:
import secrets
secret = secrets.token_bytes(32)
# Create host with QUIC transport
host = new_host(
enable_quic=True,
key_pair=create_new_key_pair(secret),
)
# Server mode: start listener
async with host.run(listen_addrs=[listen_addr]):
try:
print(f"I am {host.get_id().to_string()}")
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
print(
"Run this from the same folder in another console:\n\n"
f"python3 ./examples/echo/echo_quic.py "
f"-d {host.get_addrs()[0]}\n"
)
print("Waiting for incoming QUIC connections...")
await trio.sleep_forever()
except KeyboardInterrupt:
print("Closing server gracefully...")
await host.close()
return
async def run_client(destination: str, seed: int | None = None) -> None:
"""Run echo client with QUIC transport."""
if seed:
import random
random.seed(seed)
secret_number = random.getrandbits(32 * 8)
secret = secret_number.to_bytes(length=32, byteorder="big")
else:
import secrets
secret = secrets.token_bytes(32)
# Create host with QUIC transport
host = new_host(
enable_quic=True,
key_pair=create_new_key_pair(secret),
)
# Client mode: NO listener, just connect
async with host.run(listen_addrs=[]): # Empty listen_addrs for client
print(f"I am {host.get_id().to_string()}")
maddr = Multiaddr(destination)
info = info_from_p2p_addr(maddr)
# Connect to server
print("STARTING CLIENT CONNECTION PROCESS")
await host.connect(info)
print("CLIENT CONNECTED TO SERVER")
# Start a stream with the destination
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
msg = b"hi, there!\n"
await stream.write(msg)
response = await stream.read()
print(f"Sent: {msg.decode('utf-8')}")
print(f"Got: {response.decode('utf-8')}")
await stream.close()
await host.disconnect(info.peer_id)
async def run(port: int, destination: str, seed: int | None = None) -> None:
"""
Run echo server or client with QUIC transport.
Fixed version that properly separates client and server modes.
"""
if not destination: # Server mode
await run_server(port, seed)
else: # Client mode
await run_client(destination, seed)
def main() -> None:
"""Main function - help text updated for QUIC."""
description = """
This program demonstrates a simple echo protocol using QUIC
transport where a peer listens for connections and copies back
any input received on a stream.
QUIC provides built-in TLS security and stream multiplexing over UDP.
To use it, first run 'echo-quic-demo -p <PORT>', where <PORT> is
the UDP port number. Then, run another host with ,
'echo-quic-demo -d <DESTINATION>'
where <DESTINATION> is the QUIC multiaddress of the previous listener host.
"""
example_maddr = "/ip4/127.0.0.1/udp/8000/quic/p2p/QmQn4SwGkDZKkUEpBRBv"
parser = argparse.ArgumentParser(description=description)
parser.add_argument("-p", "--port", default=0, type=int, help="UDP port number")
parser.add_argument(
"-d",
"--destination",
type=str,
help=f"destination multiaddr string, e.g. {example_maddr}",
)
parser.add_argument(
"-s",
"--seed",
type=int,
help="provide a seed to the random number generator",
)
args = parser.parse_args()
try:
trio.run(run, args.port, args.destination, args.seed)
except KeyboardInterrupt:
pass
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logging.getLogger("aioquic").setLevel(logging.DEBUG)
main()

View File

@ -1,5 +1,11 @@
"""Libp2p Python implementation."""
import logging
from libp2p.transport.quic.utils import is_quic_multiaddr
from typing import Any
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.quic.config import QUICTransportConfig
from collections.abc import (
Mapping,
Sequence,
@ -38,10 +44,12 @@ from libp2p.host.routed_host import (
RoutedHost,
)
from libp2p.network.swarm import (
ConnectionConfig,
RetryConfig,
Swarm,
)
from libp2p.network.config import (
ConnectionConfig,
RetryConfig
)
from libp2p.peer.id import (
ID,
)
@ -87,6 +95,7 @@ MUXER_YAMUX = "YAMUX"
MUXER_MPLEX = "MPLEX"
DEFAULT_NEGOTIATE_TIMEOUT = 5
logger = logging.getLogger(__name__)
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
"""
@ -162,8 +171,9 @@ def new_swarm(
peerstore_opt: IPeerStore | None = None,
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
enable_quic: bool = False,
retry_config: Optional["RetryConfig"] = None,
connection_config: Optional["ConnectionConfig"] = None,
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
) -> INetworkService:
"""
Create a swarm instance based on the parameters.
@ -174,6 +184,8 @@ def new_swarm(
:param peerstore_opt: optional peerstore
:param muxer_preference: optional explicit muxer preference
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_quic: enable quic for transport
:param quic_transport_opt: options for transport
:return: return a default swarm instance
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
@ -186,14 +198,21 @@ def new_swarm(
id_opt = generate_peer_id_from(key_pair)
transport: TCP | QUICTransport
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
if listen_addrs is None:
transport = TCP()
if enable_quic:
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
else:
transport = TCP()
else:
addr = listen_addrs[0]
is_quic = is_quic_multiaddr(addr)
if addr.__contains__("tcp"):
transport = TCP()
elif addr.__contains__("quic"):
raise ValueError("QUIC not yet supported")
elif is_quic:
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
else:
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
@ -261,6 +280,8 @@ def new_host(
enable_mDNS: bool = False,
bootstrap: list[str] | None = None,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
enable_quic: bool = False,
quic_transport_opt: QUICTransportConfig | None = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.
@ -274,15 +295,23 @@ def new_host(
:param listen_addrs: optional list of multiaddrs to listen on
:param enable_mDNS: whether to enable mDNS discovery
:param bootstrap: optional list of bootstrap peer addresses as strings
:param enable_quic: optinal choice to use QUIC for transport
:param transport_opt: optional configuration for quic transport
:return: return a host instance
"""
if not enable_quic and quic_transport_opt is not None:
logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config")
swarm = new_swarm(
enable_quic=enable_quic,
key_pair=key_pair,
muxer_opt=muxer_opt,
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
muxer_preference=muxer_preference,
listen_addrs=listen_addrs,
connection_config=quic_transport_opt if enable_quic else None
)
if disc_opt is not None:

View File

@ -5,17 +5,17 @@ from collections.abc import (
)
from typing import TYPE_CHECKING, NewType, Union, cast
from libp2p.transport.quic.stream import QUICStream
if TYPE_CHECKING:
from libp2p.abc import (
IMuxedConn,
INetStream,
ISecureTransport,
)
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
from libp2p.transport.quic.connection import QUICConnection
else:
IMuxedConn = cast(type, object)
INetStream = cast(type, object)
ISecureTransport = cast(type, object)
IMuxedStream = cast(type, object)
QUICConnection = cast(type, object)
from libp2p.io.abc import (
ReadWriteCloser,
@ -37,4 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
UnsubscribeFn = Callable[[], Awaitable[None]]
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]]
MessageID = NewType("MessageID", str)

View File

@ -213,7 +213,6 @@ class BasicHost(IHost):
self,
peer_id: ID,
protocol_ids: Sequence[TProtocol],
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> INetStream:
"""
:param peer_id: peer_id that host is connecting
@ -227,7 +226,7 @@ class BasicHost(IHost):
selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids),
MultiselectCommunicator(net_stream),
negotitate_timeout,
self.negotiate_timeout,
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)

70
libp2p/network/config.py Normal file
View File

@ -0,0 +1,70 @@
from dataclasses import dataclass
@dataclass
class RetryConfig:
"""
Configuration for retry logic with exponential backoff.
This configuration controls how connection attempts are retried when they fail.
The retry mechanism uses exponential backoff with jitter to prevent thundering
herd problems in distributed systems.
Attributes:
max_retries: Maximum number of retry attempts before giving up.
Default: 3 attempts
initial_delay: Initial delay in seconds before the first retry.
Default: 0.1 seconds (100ms)
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
Default: 30.0 seconds
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
the delay by this factor). Default: 2.0 (doubles each time)
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
and prevent synchronized retries. Default: 0.1 (10% jitter)
"""
max_retries: int = 3
initial_delay: float = 0.1
max_delay: float = 30.0
backoff_multiplier: float = 2.0
jitter_factor: float = 0.1
@dataclass
class ConnectionConfig:
"""
Configuration for multi-connection support.
This configuration controls how multiple connections per peer are managed,
including connection limits, timeouts, and load balancing strategies.
Attributes:
max_connections_per_peer: Maximum number of connections allowed to a single
peer. Default: 3 connections
connection_timeout: Timeout in seconds for establishing new connections.
Default: 30.0 seconds
load_balancing_strategy: Strategy for distributing streams across connections.
Options: "round_robin" (default) or "least_loaded"
"""
max_connections_per_peer: int = 3
connection_timeout: float = 30.0
load_balancing_strategy: str = "round_robin" # or "least_loaded"
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not (
self.load_balancing_strategy == "round_robin"
or self.load_balancing_strategy == "least_loaded"
):
raise ValueError(
"Load balancing strategy can only be 'round_robin' or 'least_loaded'"
)
if self.max_connections_per_peer < 1:
raise ValueError("Max connection per peer should be atleast 1")
if self.connection_timeout < 0:
raise ValueError("Connection timeout should be positive")

View File

@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamError,
MuxedStreamReset,
)
from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError
from .exceptions import (
StreamClosed,
@ -170,7 +171,7 @@ class NetStream(INetStream):
elif self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_READ
raise StreamEOF() from error
except MuxedStreamReset as error:
except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error:
async with self._state_lock:
if self.__stream_state in [
StreamState.OPEN,
@ -199,7 +200,12 @@ class NetStream(INetStream):
try:
await self.muxed_stream.write(data)
except (MuxedStreamClosed, MuxedStreamError) as error:
except (
MuxedStreamClosed,
MuxedStreamError,
QUICStreamClosedError,
QUICStreamResetError,
) as error:
async with self._state_lock:
if self.__stream_state == StreamState.OPEN:
self.__stream_state = StreamState.CLOSE_WRITE

View File

@ -2,9 +2,9 @@ from collections.abc import (
Awaitable,
Callable,
)
from dataclasses import dataclass
import logging
import random
from typing import cast
from multiaddr import (
Multiaddr,
@ -27,6 +27,7 @@ from libp2p.custom_types import (
from libp2p.io.abc import (
ReadWriteCloser,
)
from libp2p.network.config import ConnectionConfig, RetryConfig
from libp2p.peer.id import (
ID,
)
@ -41,6 +42,9 @@ from libp2p.transport.exceptions import (
OpenConnectionError,
SecurityUpgradeFailure,
)
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.upgrader import (
TransportUpgrader,
)
@ -61,59 +65,6 @@ from .exceptions import (
logger = logging.getLogger("libp2p.network.swarm")
@dataclass
class RetryConfig:
"""
Configuration for retry logic with exponential backoff.
This configuration controls how connection attempts are retried when they fail.
The retry mechanism uses exponential backoff with jitter to prevent thundering
herd problems in distributed systems.
Attributes:
max_retries: Maximum number of retry attempts before giving up.
Default: 3 attempts
initial_delay: Initial delay in seconds before the first retry.
Default: 0.1 seconds (100ms)
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
Default: 30.0 seconds
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
the delay by this factor). Default: 2.0 (doubles each time)
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
and prevent synchronized retries. Default: 0.1 (10% jitter)
"""
max_retries: int = 3
initial_delay: float = 0.1
max_delay: float = 30.0
backoff_multiplier: float = 2.0
jitter_factor: float = 0.1
@dataclass
class ConnectionConfig:
"""
Configuration for multi-connection support.
This configuration controls how multiple connections per peer are managed,
including connection limits, timeouts, and load balancing strategies.
Attributes:
max_connections_per_peer: Maximum number of connections allowed to a single
peer. Default: 3 connections
connection_timeout: Timeout in seconds for establishing new connections.
Default: 30.0 seconds
load_balancing_strategy: Strategy for distributing streams across connections.
Options: "round_robin" (default) or "least_loaded"
"""
max_connections_per_peer: int = 3
connection_timeout: float = 30.0
load_balancing_strategy: str = "round_robin" # or "least_loaded"
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
async def stream_handler(stream: INetStream) -> None:
await network.get_manager().wait_finished()
@ -126,8 +77,7 @@ class Swarm(Service, INetworkService):
peerstore: IPeerStore
upgrader: TransportUpgrader
transport: ITransport
# Enhanced: Support for multiple connections per peer
connections: dict[ID, list[INetConn]] # Multiple connections per peer
connections: dict[ID, list[INetConn]]
listeners: dict[str, IListener]
common_stream_handler: StreamHandlerFn
listener_nursery: trio.Nursery | None
@ -137,7 +87,7 @@ class Swarm(Service, INetworkService):
# Enhanced: New configuration
retry_config: RetryConfig
connection_config: ConnectionConfig
connection_config: ConnectionConfig | QUICTransportConfig
_round_robin_index: dict[ID, int]
def __init__(
@ -147,7 +97,7 @@ class Swarm(Service, INetworkService):
upgrader: TransportUpgrader,
transport: ITransport,
retry_config: RetryConfig | None = None,
connection_config: ConnectionConfig | None = None,
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
):
self.self_id = peer_id
self.peerstore = peerstore
@ -178,6 +128,11 @@ class Swarm(Service, INetworkService):
# Create a nursery for listener tasks.
self.listener_nursery = nursery
self.event_listener_nursery_created.set()
if isinstance(self.transport, QUICTransport):
self.transport.set_background_nursery(nursery)
self.transport.set_swarm(self)
try:
await self.manager.wait_finished()
finally:
@ -370,6 +325,7 @@ class Swarm(Service, INetworkService):
# Dial peer (connection to peer does not yet exist)
# Transport dials peer (gets back a raw conn)
try:
addr = Multiaddr(f"{addr}/p2p/{peer_id}")
raw_conn = await self.transport.dial(addr)
except OpenConnectionError as error:
logger.debug("fail to dial peer %s over base transport", peer_id)
@ -377,6 +333,15 @@ class Swarm(Service, INetworkService):
f"fail to open connection to peer {peer_id}"
) from error
if isinstance(self.transport, QUICTransport) and isinstance(
raw_conn, IMuxedConn
):
logger.info(
"Skipping upgrade for QUIC, QUIC connections are already multiplexed"
)
swarm_conn = await self.add_conn(raw_conn)
return swarm_conn
logger.debug("dialed peer %s over base transport", peer_id)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
@ -402,9 +367,7 @@ class Swarm(Service, INetworkService):
logger.debug("upgraded mux for peer %s", peer_id)
swarm_conn = await self.add_conn(muxed_conn)
logger.debug("successfully dialed peer %s", peer_id)
return swarm_conn
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
@ -427,7 +390,6 @@ class Swarm(Service, INetworkService):
:return: net stream instance
"""
logger.debug("attempting to open a stream to peer %s", peer_id)
# Get existing connections or dial new ones
connections = self.get_connections(peer_id)
if not connections:
@ -436,6 +398,10 @@ class Swarm(Service, INetworkService):
# Load balancing strategy at interface level
connection = self._select_connection(connections, peer_id)
if isinstance(self.transport, QUICTransport) and connection is not None:
conn = cast(SwarmConn, connection)
return await conn.new_stream()
try:
net_stream = await connection.new_stream()
logger.debug("successfully opened a stream to peer %s", peer_id)
@ -516,6 +482,7 @@ class Swarm(Service, INetworkService):
- Map multiaddr to listener
"""
# We need to wait until `self.listener_nursery` is created.
logger.debug("Starting to listen")
await self.event_listener_nursery_created.wait()
success_count = 0
@ -527,6 +494,22 @@ class Swarm(Service, INetworkService):
async def conn_handler(
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
) -> None:
# No need to upgrade QUIC Connection
if isinstance(self.transport, QUICTransport):
try:
quic_conn = cast(QUICConnection, read_write_closer)
await self.add_conn(quic_conn)
peer_id = quic_conn.peer_id
logger.debug(
f"successfully opened quic connection to peer {peer_id}"
)
# NOTE: This is a intentional barrier to prevent from the
# handler exiting and closing the connection.
await self.manager.wait_finished()
except Exception:
await read_write_closer.close()
return
raw_conn = RawConnection(read_write_closer, False)
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
@ -660,9 +643,10 @@ class Swarm(Service, INetworkService):
muxed_conn,
self,
)
logger.debug("Swarm::add_conn | starting muxed connection")
self.manager.run_task(muxed_conn.start)
await muxed_conn.event_started.wait()
logger.debug("Swarm::add_conn | starting swarm connection")
self.manager.run_task(swarm_conn.start)
await swarm_conn.event_started.wait()

View File

@ -1,3 +1,5 @@
from builtins import AssertionError
from libp2p.abc import (
IMultiselectCommunicator,
)
@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
msg_bytes = encode_delim(msg_str.encode())
try:
await self.read_writer.write(msg_bytes)
except IOException as error:
# Handle for connection close during ongoing negotiation in QUIC
except (IOException, AssertionError, ValueError) as error:
raise MultiselectCommunicatorError(
"fail to write to multiselect communicator"
) from error

View File

@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
DEFAULT_NEGOTIATE_TIMEOUT,
Multiselect,
)
from libp2p.protocol_muxer.multiselect_client import (
@ -46,11 +47,17 @@ class MuxerMultistream:
transports: "OrderedDict[TProtocol, TMuxerClass]"
multiselect: Multiselect
multiselect_client: MultiselectClient
negotiate_timeout: int
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
def __init__(
self,
muxer_transports_by_protocol: TMuxerOptions,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> None:
self.transports = OrderedDict()
self.multiselect = Multiselect()
self.multistream_client = MultiselectClient()
self.negotiate_timeout = negotiate_timeout
for protocol, transport in muxer_transports_by_protocol.items():
self.add_transport(protocol, transport)
@ -80,10 +87,12 @@ class MuxerMultistream:
communicator = MultiselectCommunicator(conn)
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
tuple(self.transports.keys()), communicator
tuple(self.transports.keys()), communicator, self.negotiate_timeout
)
else:
protocol, _ = await self.multiselect.negotiate(communicator)
protocol, _ = await self.multiselect.negotiate(
communicator, self.negotiate_timeout
)
if protocol is None:
raise MultiselectError(
"Fail to negotiate a stream muxer protocol: no protocol selected"
@ -93,7 +102,7 @@ class MuxerMultistream:
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
communicator = MultiselectCommunicator(conn)
protocol = await self.multistream_client.select_one_of(
tuple(self.transports.keys()), communicator
tuple(self.transports.keys()), communicator, self.negotiate_timeout
)
transport_class = self.transports[protocol]
if protocol == PROTOCOL_ID:

View File

View File

@ -0,0 +1,345 @@
"""
Configuration classes for QUIC transport.
"""
from dataclasses import (
dataclass,
field,
)
import ssl
from typing import Any, Literal, TypedDict
from libp2p.custom_types import TProtocol
from libp2p.network.config import ConnectionConfig
class QUICTransportKwargs(TypedDict, total=False):
"""Type definition for kwargs accepted by new_transport function."""
# Connection settings
idle_timeout: float
max_datagram_size: int
local_port: int | None
# Protocol version support
enable_draft29: bool
enable_v1: bool
# TLS settings
verify_mode: ssl.VerifyMode
alpn_protocols: list[str]
# Performance settings
max_concurrent_streams: int
connection_window: int
stream_window: int
# Logging and debugging
enable_qlog: bool
qlog_dir: str | None
# Connection management
max_connections: int
connection_timeout: float
# Protocol identifiers
PROTOCOL_QUIC_V1: TProtocol
PROTOCOL_QUIC_DRAFT29: TProtocol
@dataclass
class QUICTransportConfig(ConnectionConfig):
"""Configuration for QUIC transport."""
# Connection settings
idle_timeout: float = 30.0 # Seconds before an idle connection is closed.
max_datagram_size: int = (
1200 # Maximum size of UDP datagrams to avoid IP fragmentation.
)
local_port: int | None = (
None # Local port to bind to. If None, a random port is chosen.
)
# Protocol version support
enable_draft29: bool = True # Enable QUIC draft-29 for compatibility
enable_v1: bool = True # Enable QUIC v1 (RFC 9000)
# TLS settings
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
# Performance settings
max_concurrent_streams: int = 100 # Maximum concurrent streams per connection
connection_window: int = 1024 * 1024 # Connection flow control window
stream_window: int = 64 * 1024 # Stream flow control window
# Logging and debugging
enable_qlog: bool = False # Enable QUIC logging
qlog_dir: str | None = None # Directory for QUIC logs
# Connection management
max_connections: int = 1000 # Maximum number of connections
connection_timeout: float = 10.0 # Connection establishment timeout
MAX_CONCURRENT_STREAMS: int = 1000
"""Maximum number of concurrent streams per connection."""
MAX_INCOMING_STREAMS: int = 1000
"""Maximum number of incoming streams per connection."""
CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0
"""Timeout for connection handshake (seconds)."""
MAX_OUTGOING_STREAMS: int = 1000
"""Maximum number of outgoing streams per connection."""
CONNECTION_CLOSE_TIMEOUT: int = 10
"""Timeout for opening new connection (seconds)."""
# Stream timeouts
STREAM_OPEN_TIMEOUT: float = 5.0
"""Timeout for opening new streams (seconds)."""
STREAM_ACCEPT_TIMEOUT: float = 30.0
"""Timeout for accepting incoming streams (seconds)."""
STREAM_READ_TIMEOUT: float = 30.0
"""Default timeout for stream read operations (seconds)."""
STREAM_WRITE_TIMEOUT: float = 30.0
"""Default timeout for stream write operations (seconds)."""
STREAM_CLOSE_TIMEOUT: float = 10.0
"""Timeout for graceful stream close (seconds)."""
# Flow control configuration
STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB
"""Per-stream flow control window size."""
CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB
"""Connection-wide flow control window size."""
# Buffer management
MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB
"""Maximum receive buffer size per stream."""
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
"""Low watermark for stream receive buffer."""
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
"""High watermark for stream receive buffer."""
# Stream lifecycle configuration
ENABLE_STREAM_RESET_ON_ERROR: bool = True
"""Whether to automatically reset streams on errors."""
STREAM_RESET_ERROR_CODE: int = 1
"""Default error code for stream resets."""
ENABLE_STREAM_KEEP_ALIVE: bool = False
"""Whether to enable stream keep-alive mechanisms."""
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
"""Interval for stream keep-alive pings (seconds)."""
# Resource management
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
"""Whether to track stream resource usage."""
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
"""Memory limit per individual stream."""
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
"""Total memory limit for all streams per connection."""
# Concurrency and performance
ENABLE_STREAM_BATCHING: bool = True
"""Whether to batch multiple stream operations."""
STREAM_BATCH_SIZE: int = 10
"""Number of streams to process in a batch."""
STREAM_PROCESSING_CONCURRENCY: int = 100
"""Maximum concurrent stream processing tasks."""
# Debugging and monitoring
ENABLE_STREAM_METRICS: bool = True
"""Whether to collect stream metrics."""
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
"""Whether to track stream lifecycle timelines."""
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
"""Interval for collecting stream metrics (seconds)."""
# Error handling configuration
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
"""Number of retry attempts for recoverable stream errors."""
STREAM_ERROR_RETRY_DELAY: float = 1.0
"""Initial delay between stream error retries (seconds)."""
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
"""Backoff factor for stream error retries."""
# Protocol identifiers matching go-libp2p
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
def __post_init__(self) -> None:
"""Validate configuration after initialization."""
if not (self.enable_draft29 or self.enable_v1):
raise ValueError("At least one QUIC version must be enabled")
if self.idle_timeout <= 0:
raise ValueError("Idle timeout must be positive")
if self.max_datagram_size < 1200:
raise ValueError("Max datagram size must be at least 1200 bytes")
# Validate timeouts
timeout_fields = [
"STREAM_OPEN_TIMEOUT",
"STREAM_ACCEPT_TIMEOUT",
"STREAM_READ_TIMEOUT",
"STREAM_WRITE_TIMEOUT",
"STREAM_CLOSE_TIMEOUT",
]
for timeout_field in timeout_fields:
if getattr(self, timeout_field) <= 0:
raise ValueError(f"{timeout_field} must be positive")
# Validate flow control windows
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
raise ValueError(
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
)
# Validate buffer sizes
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
raise ValueError(
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
"exceed MAX_STREAM_RECEIVE_BUFFER"
)
)
if (
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
):
raise ValueError(
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
)
# Validate memory limits
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
expected_stream_memory = (
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
)
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
# Allow some headroom, but warn if configuration seems inconsistent
import logging
logger = logging.getLogger(__name__)
logger.warning(
"Stream memory configuration may be inconsistent: "
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
"could exceed connection limit of"
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
)
def get_stream_config_dict(self) -> dict[str, Any]:
"""Get stream-specific configuration as dictionary."""
stream_config = {}
for attr_name in dir(self):
if attr_name.startswith(
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
):
stream_config[attr_name.lower()] = getattr(self, attr_name)
return stream_config
# Additional configuration classes for specific stream features
class QUICStreamFlowControlConfig:
"""Configuration for QUIC stream flow control."""
def __init__(
self,
initial_window_size: int = 512 * 1024,
max_window_size: int = 2 * 1024 * 1024,
window_update_threshold: float = 0.5,
enable_auto_tuning: bool = True,
):
self.initial_window_size = initial_window_size
self.max_window_size = max_window_size
self.window_update_threshold = window_update_threshold
self.enable_auto_tuning = enable_auto_tuning
def create_stream_config_for_use_case(
use_case: Literal[
"high_throughput", "low_latency", "many_streams", "memory_constrained"
],
) -> QUICTransportConfig:
"""
Create optimized stream configuration for specific use cases.
Args:
use_case: One of "high_throughput", "low_latency", "many_streams","
"memory_constrained"
Returns:
Optimized QUICTransportConfig
"""
base_config = QUICTransportConfig()
if use_case == "high_throughput":
# Optimize for high throughput
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
base_config.STREAM_PROCESSING_CONCURRENCY = 200
elif use_case == "low_latency":
# Optimize for low latency
base_config.STREAM_OPEN_TIMEOUT = 1.0
base_config.STREAM_READ_TIMEOUT = 5.0
base_config.STREAM_WRITE_TIMEOUT = 5.0
base_config.ENABLE_STREAM_BATCHING = False
base_config.STREAM_BATCH_SIZE = 1
elif use_case == "many_streams":
# Optimize for many concurrent streams
base_config.MAX_CONCURRENT_STREAMS = 5000
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
base_config.STREAM_PROCESSING_CONCURRENCY = 500
elif use_case == "memory_constrained":
# Optimize for low memory usage
base_config.MAX_CONCURRENT_STREAMS = 100
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
base_config.STREAM_PROCESSING_CONCURRENCY = 50
else:
raise ValueError(f"Unknown use case: {use_case}")
return base_config

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,391 @@
"""
QUIC Transport exceptions
"""
from typing import Any, Literal
class QUICError(Exception):
"""Base exception for all QUIC transport errors."""
def __init__(self, message: str, error_code: int | None = None):
super().__init__(message)
self.error_code = error_code
# Transport-level exceptions
class QUICTransportError(QUICError):
"""Base exception for QUIC transport operations."""
pass
class QUICDialError(QUICTransportError):
"""Error occurred during QUIC connection establishment."""
pass
class QUICListenError(QUICTransportError):
"""Error occurred during QUIC listener operations."""
pass
class QUICSecurityError(QUICTransportError):
"""Error related to QUIC security/TLS operations."""
pass
# Connection-level exceptions
class QUICConnectionError(QUICError):
"""Base exception for QUIC connection operations."""
pass
class QUICConnectionClosedError(QUICConnectionError):
"""QUIC connection has been closed."""
pass
class QUICConnectionTimeoutError(QUICConnectionError):
"""QUIC connection operation timed out."""
pass
class QUICHandshakeError(QUICConnectionError):
"""Error during QUIC handshake process."""
pass
class QUICPeerVerificationError(QUICConnectionError):
"""Error verifying peer identity during handshake."""
pass
# Stream-level exceptions
class QUICStreamError(QUICError):
"""Base exception for QUIC stream operations."""
def __init__(
self,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
):
super().__init__(message, error_code)
self.stream_id = stream_id
class QUICStreamClosedError(QUICStreamError):
"""Stream is closed and cannot be used for I/O operations."""
pass
class QUICStreamResetError(QUICStreamError):
"""Stream was reset by local or remote peer."""
def __init__(
self,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
reset_by_peer: bool = False,
):
super().__init__(message, stream_id, error_code)
self.reset_by_peer = reset_by_peer
class QUICStreamTimeoutError(QUICStreamError):
"""Stream operation timed out."""
pass
class QUICStreamBackpressureError(QUICStreamError):
"""Stream write blocked due to flow control."""
pass
class QUICStreamLimitError(QUICStreamError):
"""Stream limit reached (too many concurrent streams)."""
pass
class QUICStreamStateError(QUICStreamError):
"""Invalid operation for current stream state."""
def __init__(
self,
message: str,
stream_id: str | None = None,
current_state: str | None = None,
attempted_operation: str | None = None,
):
super().__init__(message, stream_id)
self.current_state = current_state
self.attempted_operation = attempted_operation
# Flow control exceptions
class QUICFlowControlError(QUICError):
"""Base exception for flow control related errors."""
pass
class QUICFlowControlViolationError(QUICFlowControlError):
"""Flow control limits were violated."""
pass
class QUICFlowControlDeadlockError(QUICFlowControlError):
"""Flow control deadlock detected."""
pass
# Resource management exceptions
class QUICResourceError(QUICError):
"""Base exception for resource management errors."""
pass
class QUICMemoryLimitError(QUICResourceError):
"""Memory limit exceeded."""
pass
class QUICConnectionLimitError(QUICResourceError):
"""Connection limit exceeded."""
pass
# Multiaddr and addressing exceptions
class QUICAddressError(QUICError):
"""Base exception for QUIC addressing errors."""
pass
class QUICInvalidMultiaddrError(QUICAddressError):
"""Invalid multiaddr format for QUIC transport."""
pass
class QUICAddressResolutionError(QUICAddressError):
"""Failed to resolve QUIC address."""
pass
class QUICProtocolError(QUICError):
"""Base exception for QUIC protocol errors."""
pass
class QUICVersionNegotiationError(QUICProtocolError):
"""QUIC version negotiation failed."""
pass
class QUICUnsupportedVersionError(QUICProtocolError):
"""Unsupported QUIC version."""
pass
# Configuration exceptions
class QUICConfigurationError(QUICError):
"""Base exception for QUIC configuration errors."""
pass
class QUICInvalidConfigError(QUICConfigurationError):
"""Invalid QUIC configuration parameters."""
pass
class QUICCertificateError(QUICConfigurationError):
"""Error with TLS certificate configuration."""
pass
def map_quic_error_code(error_code: int) -> str:
"""
Map QUIC error codes to human-readable descriptions.
Based on RFC 9000 Transport Error Codes.
"""
error_codes = {
0x00: "NO_ERROR",
0x01: "INTERNAL_ERROR",
0x02: "CONNECTION_REFUSED",
0x03: "FLOW_CONTROL_ERROR",
0x04: "STREAM_LIMIT_ERROR",
0x05: "STREAM_STATE_ERROR",
0x06: "FINAL_SIZE_ERROR",
0x07: "FRAME_ENCODING_ERROR",
0x08: "TRANSPORT_PARAMETER_ERROR",
0x09: "CONNECTION_ID_LIMIT_ERROR",
0x0A: "PROTOCOL_VIOLATION",
0x0B: "INVALID_TOKEN",
0x0C: "APPLICATION_ERROR",
0x0D: "CRYPTO_BUFFER_EXCEEDED",
0x0E: "KEY_UPDATE_ERROR",
0x0F: "AEAD_LIMIT_REACHED",
0x10: "NO_VIABLE_PATH",
}
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
def create_stream_error(
error_type: str,
message: str,
stream_id: str | None = None,
error_code: int | None = None,
) -> QUICStreamError:
"""
Factory function to create appropriate stream error based on type.
Args:
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
message: Error message
stream_id: Stream identifier
error_code: QUIC error code
Returns:
Appropriate QUICStreamError subclass
"""
error_type = error_type.lower()
if error_type in ("closed", "close"):
return QUICStreamClosedError(message, stream_id, error_code)
elif error_type == "reset":
return QUICStreamResetError(message, stream_id, error_code)
elif error_type == "timeout":
return QUICStreamTimeoutError(message, stream_id, error_code)
elif error_type in ("backpressure", "flow_control"):
return QUICStreamBackpressureError(message, stream_id, error_code)
elif error_type in ("limit", "stream_limit"):
return QUICStreamLimitError(message, stream_id, error_code)
elif error_type == "state":
return QUICStreamStateError(message, stream_id)
else:
return QUICStreamError(message, stream_id, error_code)
def create_connection_error(
error_type: str, message: str, error_code: int | None = None
) -> QUICConnectionError:
"""
Factory function to create appropriate connection error based on type.
Args:
error_type: Type of error ("closed", "timeout", "handshake", etc.)
message: Error message
error_code: QUIC error code
Returns:
Appropriate QUICConnectionError subclass
"""
error_type = error_type.lower()
if error_type in ("closed", "close"):
return QUICConnectionClosedError(message, error_code)
elif error_type == "timeout":
return QUICConnectionTimeoutError(message, error_code)
elif error_type == "handshake":
return QUICHandshakeError(message, error_code)
elif error_type in ("peer_verification", "verification"):
return QUICPeerVerificationError(message, error_code)
else:
return QUICConnectionError(message, error_code)
class QUICErrorContext:
"""
Context manager for handling QUIC errors with automatic error mapping.
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
"""
def __init__(self, operation: str, component: str = "quic") -> None:
self.operation = operation
self.component = component
def __enter__(self) -> "QUICErrorContext":
return self
# TODO: Fix types for exc_type
def __exit__(
self,
exc_type: type[BaseException] | None | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> Literal[False]:
if exc_type is None:
return False
if exc_val is None:
return False
# Map common aioquic exceptions to our exceptions
if "ConnectionClosed" in str(exc_type):
raise QUICConnectionClosedError(
f"Connection closed during {self.operation}: {exc_val}"
) from exc_val
elif "StreamReset" in str(exc_type):
raise QUICStreamResetError(
f"Stream reset during {self.operation}: {exc_val}"
) from exc_val
elif "timeout" in str(exc_val).lower():
if "stream" in self.component.lower():
raise QUICStreamTimeoutError(
f"Timeout during {self.operation}: {exc_val}"
) from exc_val
else:
raise QUICConnectionTimeoutError(
f"Timeout during {self.operation}: {exc_val}"
) from exc_val
elif "flow control" in str(exc_val).lower():
raise QUICStreamBackpressureError(
f"Flow control error during {self.operation}: {exc_val}"
) from exc_val
# Let other exceptions propagate
return False

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,656 @@
"""
QUIC Stream implementation
Provides stream interface over QUIC's native multiplexing.
"""
from enum import Enum
import logging
import time
from types import TracebackType
from typing import TYPE_CHECKING, Any, cast
import trio
from .exceptions import (
QUICStreamBackpressureError,
QUICStreamClosedError,
QUICStreamResetError,
QUICStreamTimeoutError,
)
if TYPE_CHECKING:
from libp2p.abc import IMuxedStream
from libp2p.custom_types import TProtocol
from .connection import QUICConnection
else:
IMuxedStream = cast(type, object)
TProtocol = cast(type, object)
logger = logging.getLogger(__name__)
class StreamState(Enum):
"""Stream lifecycle states following libp2p patterns."""
OPEN = "open"
WRITE_CLOSED = "write_closed"
READ_CLOSED = "read_closed"
CLOSED = "closed"
RESET = "reset"
class StreamDirection(Enum):
"""Stream direction for tracking initiator."""
INBOUND = "inbound"
OUTBOUND = "outbound"
class StreamTimeline:
"""Track stream lifecycle events for debugging and monitoring."""
def __init__(self) -> None:
self.created_at = time.time()
self.opened_at: float | None = None
self.first_data_at: float | None = None
self.closed_at: float | None = None
self.reset_at: float | None = None
self.error_code: int | None = None
def record_open(self) -> None:
self.opened_at = time.time()
def record_first_data(self) -> None:
if self.first_data_at is None:
self.first_data_at = time.time()
def record_close(self) -> None:
self.closed_at = time.time()
def record_reset(self, error_code: int) -> None:
self.reset_at = time.time()
self.error_code = error_code
class QUICStream(IMuxedStream):
"""
QUIC Stream implementation following libp2p IMuxedStream interface.
Based on patterns from go-libp2p and js-libp2p, this implementation:
- Leverages QUIC's native multiplexing and flow control
- Integrates with libp2p resource management
- Provides comprehensive error handling with QUIC-specific codes
- Supports bidirectional communication with independent close semantics
- Implements proper stream lifecycle management
"""
def __init__(
self,
connection: "QUICConnection",
stream_id: int,
direction: StreamDirection,
remote_addr: tuple[str, int],
resource_scope: Any | None = None,
):
"""
Initialize QUIC stream.
Args:
connection: Parent QUIC connection
stream_id: QUIC stream identifier
direction: Stream direction (inbound/outbound)
resource_scope: Resource manager scope for memory accounting
remote_addr: Remote addr stream is connected to
"""
self._connection = connection
self._stream_id = stream_id
self._direction = direction
self._resource_scope = resource_scope
# libp2p interface compliance
self._protocol: TProtocol | None = None
self._metadata: dict[str, Any] = {}
self._remote_addr = remote_addr
# Stream state management
self._state = StreamState.OPEN
self._state_lock = trio.Lock()
# Flow control and buffering
self._receive_buffer = bytearray()
self._receive_buffer_lock = trio.Lock()
self._receive_event = trio.Event()
self._backpressure_event = trio.Event()
self._backpressure_event.set() # Initially no backpressure
# Close/reset state
self._write_closed = False
self._read_closed = False
self._close_event = trio.Event()
self._reset_error_code: int | None = None
# Lifecycle tracking
self._timeline = StreamTimeline()
self._timeline.record_open()
# Resource accounting
self._memory_reserved = 0
# Stream constant configurations
self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT
self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT
self.FLOW_CONTROL_WINDOW_SIZE = (
connection._transport._config.STREAM_FLOW_CONTROL_WINDOW
)
self.MAX_RECEIVE_BUFFER_SIZE = (
connection._transport._config.MAX_STREAM_RECEIVE_BUFFER
)
if self._resource_scope:
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
logger.debug(
f"Created QUIC stream {stream_id} "
f"({direction.value}, connection: {connection.remote_peer_id()})"
)
# Properties for libp2p interface compliance
@property
def protocol(self) -> TProtocol | None:
"""Get the protocol identifier for this stream."""
return self._protocol
@protocol.setter
def protocol(self, protocol_id: TProtocol) -> None:
"""Set the protocol identifier for this stream."""
self._protocol = protocol_id
self._metadata["protocol"] = protocol_id
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
@property
def stream_id(self) -> str:
"""Get stream ID as string for libp2p compatibility."""
return str(self._stream_id)
@property
def muxed_conn(self) -> "QUICConnection": # type: ignore
"""Get the parent muxed connection."""
return self._connection
@property
def state(self) -> StreamState:
"""Get current stream state."""
return self._state
@property
def direction(self) -> StreamDirection:
"""Get stream direction."""
return self._direction
@property
def is_initiator(self) -> bool:
"""Check if this stream was locally initiated."""
return self._direction == StreamDirection.OUTBOUND
# Core stream operations
async def read(self, n: int | None = None) -> bytes:
"""
Read data from the stream with QUIC flow control.
Args:
n: Maximum number of bytes to read. If None or -1, read all available.
Returns:
Data read from stream
Raises:
QUICStreamClosedError: Stream is closed
QUICStreamResetError: Stream was reset
QUICStreamTimeoutError: Read timeout exceeded
"""
if n is None:
n = -1
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
if self._read_closed:
# Return any remaining buffered data, then EOF
async with self._receive_buffer_lock:
if self._receive_buffer:
data = self._extract_data_from_buffer(n)
self._timeline.record_first_data()
return data
return b""
# Wait for data with timeout
timeout = self.READ_TIMEOUT
try:
with trio.move_on_after(timeout) as cancel_scope:
while True:
async with self._receive_buffer_lock:
if self._receive_buffer:
data = self._extract_data_from_buffer(n)
self._timeline.record_first_data()
return data
# Check if stream was closed while waiting
if self._read_closed:
return b""
# Wait for more data
await self._receive_event.wait()
self._receive_event = trio.Event() # Reset for next wait
if cancel_scope.cancelled_caught:
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
return b""
except QUICStreamResetError:
# Stream was reset while reading
raise
except Exception as e:
logger.error(f"Error reading from stream {self.stream_id}: {e}")
await self._handle_stream_error(e)
raise
async def write(self, data: bytes) -> None:
"""
Write data to the stream with QUIC flow control.
Args:
data: Data to write
Raises:
QUICStreamClosedError: Stream is closed for writing
QUICStreamBackpressureError: Flow control window exhausted
QUICStreamResetError: Stream was reset
"""
if not data:
return
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
if self._write_closed:
raise QUICStreamClosedError(
f"Stream {self.stream_id} write side is closed"
)
try:
# Handle flow control backpressure
await self._backpressure_event.wait()
# Send data through QUIC connection
self._connection._quic.send_stream_data(self._stream_id, data)
await self._connection._transmit()
self._timeline.record_first_data()
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
except Exception as e:
logger.error(f"Error writing to stream {self.stream_id}: {e}")
# Convert QUIC-specific errors
if "flow control" in str(e).lower():
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
await self._handle_stream_error(e)
raise
async def close(self) -> None:
"""
Close the stream gracefully (both read and write sides).
This implements proper close semantics where both sides
are closed and resources are cleaned up.
"""
async with self._state_lock:
if self._state in (StreamState.CLOSED, StreamState.RESET):
return
logger.debug(f"Closing stream {self.stream_id}")
# Close both sides
if not self._write_closed:
await self.close_write()
if not self._read_closed:
await self.close_read()
# Update state and cleanup
async with self._state_lock:
self._state = StreamState.CLOSED
await self._cleanup_resources()
self._timeline.record_close()
self._close_event.set()
logger.debug(f"Stream {self.stream_id} closed")
async def close_write(self) -> None:
"""Close the write side of the stream."""
if self._write_closed:
return
try:
# Send FIN to close write side
self._connection._quic.send_stream_data(
self._stream_id, b"", end_stream=True
)
await self._connection._transmit()
self._write_closed = True
async with self._state_lock:
if self._read_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.WRITE_CLOSED
logger.debug(f"Stream {self.stream_id} write side closed")
except Exception as e:
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
async def close_read(self) -> None:
"""Close the read side of the stream."""
if self._read_closed:
return
try:
self._read_closed = True
async with self._state_lock:
if self._write_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.READ_CLOSED
# Wake up any pending reads
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} read side closed")
except Exception as e:
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
async def reset(self, error_code: int = 0) -> None:
"""
Reset the stream with the given error code.
Args:
error_code: QUIC error code for the reset
"""
async with self._state_lock:
if self._state == StreamState.RESET:
return
logger.debug(
f"Resetting stream {self.stream_id} with error code {error_code}"
)
self._state = StreamState.RESET
self._reset_error_code = error_code
try:
# Send QUIC reset frame
self._connection._quic.reset_stream(self._stream_id, error_code)
await self._connection._transmit()
except Exception as e:
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
finally:
# Always cleanup resources
await self._cleanup_resources()
self._timeline.record_reset(error_code)
self._close_event.set()
def is_closed(self) -> bool:
"""Check if stream is completely closed."""
return self._state in (StreamState.CLOSED, StreamState.RESET)
def is_reset(self) -> bool:
"""Check if stream was reset."""
return self._state == StreamState.RESET
def can_read(self) -> bool:
"""Check if stream can be read from."""
return not self._read_closed and self._state not in (
StreamState.CLOSED,
StreamState.RESET,
)
def can_write(self) -> bool:
"""Check if stream can be written to."""
return not self._write_closed and self._state not in (
StreamState.CLOSED,
StreamState.RESET,
)
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
"""
Handle data received from the QUIC connection.
Args:
data: Received data
end_stream: Whether this is the last data (FIN received)
"""
if self._state == StreamState.RESET:
return
if data:
async with self._receive_buffer_lock:
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
logger.warning(
f"Stream {self.stream_id} receive buffer overflow, "
f"dropping {len(data)} bytes"
)
return
self._receive_buffer.extend(data)
self._timeline.record_first_data()
# Notify waiting readers
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
if end_stream:
self._read_closed = True
async with self._state_lock:
if self._write_closed:
self._state = StreamState.CLOSED
else:
self._state = StreamState.READ_CLOSED
# Wake up readers to process remaining data and EOF
self._receive_event.set()
logger.debug(f"Stream {self.stream_id} received FIN")
async def handle_stop_sending(self, error_code: int) -> None:
"""
Handle STOP_SENDING frame from remote peer.
When a STOP_SENDING frame is received, the peer is requesting that we
stop sending data on this stream. We respond by resetting the stream.
Args:
error_code: Error code from the STOP_SENDING frame
"""
logger.debug(
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
)
self._write_closed = True
# Wake up any pending write operations
self._backpressure_event.set()
async with self._state_lock:
if self.direction == StreamDirection.OUTBOUND:
self._state = StreamState.CLOSED
elif self._read_closed:
self._state = StreamState.CLOSED
else:
# Only write side closed - add WRITE_CLOSED state if needed
self._state = StreamState.WRITE_CLOSED
# Send RESET_STREAM in response (QUIC protocol requirement)
try:
self._connection._quic.reset_stream(int(self.stream_id), error_code)
await self._connection._transmit()
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
except Exception as e:
logger.warning(
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
)
async def handle_reset(self, error_code: int) -> None:
"""
Handle stream reset from remote peer.
Args:
error_code: QUIC error code from reset frame
"""
logger.debug(
f"Stream {self.stream_id} reset by peer with error code {error_code}"
)
async with self._state_lock:
self._state = StreamState.RESET
self._reset_error_code = error_code
await self._cleanup_resources()
self._timeline.record_reset(error_code)
self._close_event.set()
# Wake up any pending operations
self._receive_event.set()
self._backpressure_event.set()
async def handle_flow_control_update(self, available_window: int) -> None:
"""
Handle flow control window updates.
Args:
available_window: Available flow control window size
"""
if available_window > 0:
self._backpressure_event.set()
logger.debug(
f"Stream {self.stream_id} flow control".__add__(
f"window updated: {available_window}"
)
)
else:
self._backpressure_event = trio.Event() # Reset to blocking state
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
def _extract_data_from_buffer(self, n: int) -> bytes:
"""Extract data from receive buffer with specified limit."""
if n == -1:
# Read all available data
data = bytes(self._receive_buffer)
self._receive_buffer.clear()
else:
# Read up to n bytes
data = bytes(self._receive_buffer[:n])
self._receive_buffer = self._receive_buffer[n:]
return data
async def _handle_stream_error(self, error: Exception) -> None:
"""Handle errors by resetting the stream."""
logger.error(f"Stream {self.stream_id} error: {error}")
await self.reset(error_code=1) # Generic error code
def _reserve_memory(self, size: int) -> None:
"""Reserve memory with resource manager."""
if self._resource_scope:
try:
self._resource_scope.reserve_memory(size)
self._memory_reserved += size
except Exception as e:
logger.warning(
f"Failed to reserve memory for stream {self.stream_id}: {e}"
)
def _release_memory(self, size: int) -> None:
"""Release memory with resource manager."""
if self._resource_scope and size > 0:
try:
self._resource_scope.release_memory(size)
self._memory_reserved = max(0, self._memory_reserved - size)
except Exception as e:
logger.warning(
f"Failed to release memory for stream {self.stream_id}: {e}"
)
async def _cleanup_resources(self) -> None:
"""Clean up stream resources."""
# Release all reserved memory
if self._memory_reserved > 0:
self._release_memory(self._memory_reserved)
# Clear receive buffer
async with self._receive_buffer_lock:
self._receive_buffer.clear()
# Remove from connection's stream registry
self._connection._remove_stream(self._stream_id)
logger.debug(f"Stream {self.stream_id} resources cleaned up")
# Abstact implementations
def get_remote_address(self) -> tuple[str, int]:
return self._remote_addr
async def __aenter__(self) -> "QUICStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager and close the stream."""
logger.debug("Exiting the context and closing the stream")
await self.close()
def set_deadline(self, ttl: int) -> bool:
"""
Set a deadline for the stream. QUIC does not support deadlines natively,
so this method always returns False to indicate the operation is unsupported.
:param ttl: Time-to-live in seconds (ignored).
:return: False, as deadlines are not supported.
"""
raise NotImplementedError("QUIC does not support setting read deadlines")
# String representation for debugging
def __repr__(self) -> str:
return (
f"QUICStream(id={self.stream_id}, "
f"state={self._state.value}, "
f"direction={self._direction.value}, "
f"protocol={self._protocol})"
)
def __str__(self) -> str:
return f"QUICStream({self.stream_id})"

View File

@ -0,0 +1,491 @@
"""
QUIC Transport implementation
"""
import copy
import logging
import ssl
from typing import TYPE_CHECKING, cast
from aioquic.quic.configuration import (
QuicConfiguration,
)
from aioquic.quic.connection import (
QuicConnection as NativeQUICConnection,
)
from aioquic.quic.logger import QuicLogger
import multiaddr
import trio
from libp2p.abc import (
ITransport,
)
from libp2p.crypto.keys import (
PrivateKey,
)
from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
from libp2p.peer.id import (
ID,
)
from libp2p.transport.quic.security import QUICTLSSecurityConfig
from libp2p.transport.quic.utils import (
create_client_config_from_base,
create_server_config_from_base,
get_alpn_protocols,
is_quic_multiaddr,
multiaddr_to_quic_version,
quic_multiaddr_to_endpoint,
quic_version_to_wire_format,
)
if TYPE_CHECKING:
from libp2p.network.swarm import Swarm
else:
Swarm = cast(type, object)
from .config import (
QUICTransportConfig,
)
from .connection import (
QUICConnection,
)
from .exceptions import (
QUICDialError,
QUICListenError,
QUICSecurityError,
)
from .listener import (
QUICListener,
)
from .security import (
QUICTLSConfigManager,
create_quic_security_transport,
)
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
logger = logging.getLogger(__name__)
class QUICTransport(ITransport):
"""
QUIC Stream implementation following libp2p IMuxedStream interface.
"""
def __init__(
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
) -> None:
"""
Initialize QUIC transport with security integration.
Args:
private_key: libp2p private key for identity and TLS cert generation
config: QUIC transport configuration options
"""
self._private_key = private_key
self._peer_id = ID.from_pubkey(private_key.get_public_key())
self._config = config or QUICTransportConfig()
# Connection management
self._connections: dict[str, QUICConnection] = {}
self._listeners: list[QUICListener] = []
# Security manager for TLS integration
self._security_manager = create_quic_security_transport(
self._private_key, self._peer_id
)
# QUIC configurations for different versions
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
self._setup_quic_configurations()
# Resource management
self._closed = False
self._nursery_manager = trio.CapacityLimiter(1)
self._background_nursery: trio.Nursery | None = None
self._swarm: Swarm | None = None
logger.debug(
f"Initialized QUIC transport with security for peer {self._peer_id}"
)
def set_background_nursery(self, nursery: trio.Nursery) -> None:
"""Set the nursery to use for background tasks (called by swarm)."""
self._background_nursery = nursery
logger.debug("Transport background nursery set")
def set_swarm(self, swarm: Swarm) -> None:
"""Set the swarm for adding incoming connections."""
self._swarm = swarm
def _setup_quic_configurations(self) -> None:
"""Setup QUIC configurations."""
try:
# Get TLS configuration from security manager
server_tls_config = self._security_manager.create_server_config()
client_tls_config = self._security_manager.create_client_config()
# Base server configuration
base_server_config = QuicConfiguration(
is_client=False,
alpn_protocols=get_alpn_protocols(),
verify_mode=self._config.verify_mode,
max_datagram_frame_size=self._config.max_datagram_size,
idle_timeout=self._config.idle_timeout,
)
# Base client configuration
base_client_config = QuicConfiguration(
is_client=True,
alpn_protocols=get_alpn_protocols(),
verify_mode=self._config.verify_mode,
max_datagram_frame_size=self._config.max_datagram_size,
idle_timeout=self._config.idle_timeout,
)
# Apply TLS configuration
self._apply_tls_configuration(base_server_config, server_tls_config)
self._apply_tls_configuration(base_client_config, client_tls_config)
# QUIC v1 (RFC 9000) configurations
if self._config.enable_v1:
quic_v1_server_config = create_server_config_from_base(
base_server_config, self._security_manager, self._config
)
quic_v1_server_config.supported_versions = [
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
]
quic_v1_client_config = create_client_config_from_base(
base_client_config, self._security_manager, self._config
)
quic_v1_client_config.supported_versions = [
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
]
# Store both server and client configs for v1
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = (
quic_v1_server_config
)
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = (
quic_v1_client_config
)
# QUIC draft-29 configurations for compatibility
if self._config.enable_draft29:
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
draft29_server_config.supported_versions = [
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
]
draft29_client_config = copy.copy(base_client_config)
draft29_client_config.supported_versions = [
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
]
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = (
draft29_server_config
)
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = (
draft29_client_config
)
logger.debug("QUIC configurations initialized with libp2p TLS security")
except Exception as e:
raise QUICSecurityError(
f"Failed to setup QUIC TLS configurations: {e}"
) from e
def _apply_tls_configuration(
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
) -> None:
"""
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
Args:
config: QuicConfiguration to update
tls_config: TLS configuration dictionary from security manager
"""
try:
config.certificate = tls_config.certificate
config.private_key = tls_config.private_key
config.certificate_chain = tls_config.certificate_chain
config.alpn_protocols = tls_config.alpn_protocols
config.verify_mode = ssl.CERT_NONE
logger.debug("Successfully applied TLS configuration to QUIC config")
except Exception as e:
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
async def dial(
self,
maddr: multiaddr.Multiaddr,
) -> QUICConnection:
"""
Dial a remote peer using QUIC transport with security verification.
Args:
maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1)
peer_id: Expected peer ID for verification
nursery: Nursery to execute the background tasks
Returns:
Raw connection interface to the remote peer
Raises:
QUICDialError: If dialing fails
QUICSecurityError: If security verification fails
"""
if self._closed:
raise QUICDialError("Transport is closed")
if not is_quic_multiaddr(maddr):
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
try:
# Extract connection details from multiaddr
host, port = quic_multiaddr_to_endpoint(maddr)
remote_peer_id = maddr.get_peer_id()
if remote_peer_id is not None:
remote_peer_id = ID.from_base58(remote_peer_id)
if remote_peer_id is None:
logger.error("Unable to derive peer id from multiaddr")
raise QUICDialError("Unable to derive peer id from multiaddr")
quic_version = multiaddr_to_quic_version(maddr)
# Get appropriate QUIC client configuration
config_key = TProtocol(f"{quic_version}_client")
logger.debug("config_key", config_key, self._quic_configs.keys())
config = self._quic_configs.get(config_key)
if not config:
raise QUICDialError(f"Unsupported QUIC version: {quic_version}")
config.is_client = True
config.quic_logger = QuicLogger()
# Ensure client certificate is properly set for mutual authentication
if not config.certificate or not config.private_key:
logger.warning(
"Client config missing certificate - applying TLS config"
)
client_tls_config = self._security_manager.create_client_config()
self._apply_tls_configuration(config, client_tls_config)
# Debug log to verify certificate is present
logger.info(
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
)
logger.debug("Starting QUIC Connection")
# Create QUIC connection using aioquic's sans-IO core
native_quic_connection = NativeQUICConnection(configuration=config)
# Create trio-based QUIC connection wrapper with security
connection = QUICConnection(
quic_connection=native_quic_connection,
remote_addr=(host, port),
remote_peer_id=remote_peer_id,
local_peer_id=self._peer_id,
is_initiator=True,
maddr=maddr,
transport=self,
security_manager=self._security_manager,
)
logger.debug("QUIC Connection Created")
if self._background_nursery is None:
logger.error("No nursery set to execute background tasks")
raise QUICDialError("No nursery found to execute tasks")
await connection.connect(self._background_nursery)
# Store connection for management
conn_id = f"{host}:{port}"
self._connections[conn_id] = connection
return connection
except Exception as e:
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
raise QUICDialError(f"Dial failed: {e}") from e
async def _verify_peer_identity(
self, connection: QUICConnection, expected_peer_id: ID
) -> None:
"""
Verify remote peer identity after TLS handshake.
Args:
connection: The established QUIC connection
expected_peer_id: Expected peer ID
Raises:
QUICSecurityError: If peer verification fails
"""
try:
# Get peer certificate from the connection
peer_certificate = await connection.get_peer_certificate()
if not peer_certificate:
raise QUICSecurityError("No peer certificate available")
# Verify peer identity using security manager
verified_peer_id = self._security_manager.verify_peer_identity(
peer_certificate, expected_peer_id
)
if verified_peer_id != expected_peer_id:
raise QUICSecurityError(
"Peer ID verification failed: expected "
f"{expected_peer_id}, got {verified_peer_id}"
)
logger.debug(f"Peer identity verified: {verified_peer_id}")
logger.debug(f"Peer identity verified: {verified_peer_id}")
except Exception as e:
raise QUICSecurityError(f"Peer identity verification failed: {e}") from e
def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener:
"""
Create a QUIC listener with integrated security.
Args:
handler_function: Function to handle new connections
Returns:
QUIC listener instance
Raises:
QUICListenError: If transport is closed
"""
if self._closed:
raise QUICListenError("Transport is closed")
# Get server configurations for the listener
server_configs = {
version: config
for version, config in self._quic_configs.items()
if version.endswith("_server")
}
listener = QUICListener(
transport=self,
handler_function=handler_function,
quic_configs=server_configs,
config=self._config,
security_manager=self._security_manager,
)
self._listeners.append(listener)
logger.debug("Created QUIC listener with security")
return listener
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
"""
Check if this transport can dial the given multiaddr.
Args:
maddr: Multiaddr to check
Returns:
True if this transport can dial the address
"""
return is_quic_multiaddr(maddr)
def protocols(self) -> list[TProtocol]:
"""
Get supported protocol identifiers.
Returns:
List of supported protocol strings
"""
protocols = [QUIC_V1_PROTOCOL]
if self._config.enable_draft29:
protocols.append(QUIC_DRAFT29_PROTOCOL)
return protocols
def listen_order(self) -> int:
"""
Get the listen order priority for this transport.
Matches go-libp2p's ListenOrder = 1 for QUIC.
Returns:
Priority order for listening (lower = higher priority)
"""
return 1
async def close(self) -> None:
"""Close the transport and cleanup resources."""
if self._closed:
return
self._closed = True
logger.debug("Closing QUIC transport")
# Close all active connections and listeners concurrently using trio nursery
async with trio.open_nursery() as nursery:
# Close all connections
for connection in self._connections.values():
nursery.start_soon(connection.close)
# Close all listeners
for listener in self._listeners:
nursery.start_soon(listener.close)
self._connections.clear()
self._listeners.clear()
logger.debug("QUIC transport closed")
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
"""Clean up a terminated connection from all listeners."""
try:
for listener in self._listeners:
await listener._remove_connection_by_object(connection)
logger.debug(
"✅ TRANSPORT: Cleaned up terminated connection from all listeners"
)
except Exception as e:
logger.error(f"❌ TRANSPORT: Error cleaning up terminated connection: {e}")
def get_stats(self) -> dict[str, int | list[str] | object]:
"""Get transport statistics including security info."""
return {
"active_connections": len(self._connections),
"active_listeners": len(self._listeners),
"supported_protocols": self.protocols(),
"local_peer_id": str(self._peer_id),
"security_enabled": True,
"tls_configured": True,
}
def get_security_manager(self) -> QUICTLSConfigManager:
"""
Get the security manager for this transport.
Returns:
The QUIC TLS configuration manager
"""
return self._security_manager
def get_listener_socket(self) -> trio.socket.SocketType | None:
"""Get the socket from the first active listener."""
for listener in self._listeners:
if listener.is_listening() and listener._socket:
return listener._socket
return None

View File

@ -0,0 +1,466 @@
"""
Multiaddr utilities for QUIC transport - Module 4.
Essential utilities required for QUIC transport implementation.
Based on go-libp2p and js-libp2p QUIC implementations.
"""
import ipaddress
import logging
import ssl
from aioquic.quic.configuration import QuicConfiguration
import multiaddr
from libp2p.custom_types import TProtocol
from libp2p.transport.quic.security import QUICTLSConfigManager
from .config import QUICTransportConfig
from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError
logger = logging.getLogger(__name__)
# Protocol constants
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
UDP_PROTOCOL = "udp"
IP4_PROTOCOL = "ip4"
IP6_PROTOCOL = "ip6"
SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server"
CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client"
SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server"
CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client"
CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = {
SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000
CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000
SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
}
# QUIC version to wire format mappings (required for aioquic)
QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = {
QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000
QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29
}
# ALPN protocols for libp2p over QUIC
LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"]
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
"""
Check if a multiaddr represents a QUIC address.
Valid QUIC multiaddrs:
- /ip4/127.0.0.1/udp/4001/quic-v1
- /ip4/127.0.0.1/udp/4001/quic
- /ip6/::1/udp/4001/quic-v1
- /ip6/::1/udp/4001/quic
Args:
maddr: Multiaddr to check
Returns:
True if the multiaddr represents a QUIC address
"""
try:
addr_str = str(maddr)
# Check for required components
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
has_quic = (
f"/{QUIC_V1_PROTOCOL}" in addr_str
or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str
or "/quic" in addr_str
)
return has_ip and has_udp and has_quic
except Exception:
return False
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
"""
Extract host and port from a QUIC multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
Tuple of (host, port)
Raises:
QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address
"""
if not is_quic_multiaddr(maddr):
raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}")
try:
host = None
port = None
# Try to get IPv4 address
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
except Exception:
pass
# Try to get IPv6 address if IPv4 not found
if host is None:
try:
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
except Exception:
pass
# Get UDP port
try:
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
port = int(port_str)
except Exception:
pass
if host is None or port is None:
raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}")
return host, port
except Exception as e:
raise QUICInvalidMultiaddrError(
f"Failed to parse QUIC multiaddr {maddr}: {e}"
) from e
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
"""
Determine QUIC version from multiaddr.
Args:
maddr: QUIC multiaddr
Returns:
QUIC version identifier ("quic-v1" or "quic")
Raises:
QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol
"""
try:
addr_str = str(maddr)
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
return QUIC_V1_PROTOCOL # RFC 9000
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
return QUIC_DRAFT29_PROTOCOL # draft-29
else:
raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}")
except Exception as e:
raise QUICInvalidMultiaddrError(
f"Failed to determine QUIC version from {maddr}: {e}"
) from e
def create_quic_multiaddr(
host: str, port: int, version: str = "quic-v1"
) -> multiaddr.Multiaddr:
"""
Create a QUIC multiaddr from host, port, and version.
Args:
host: IP address (IPv4 or IPv6)
port: UDP port number
version: QUIC version ("quic-v1" or "quic")
Returns:
QUIC multiaddr
Raises:
QUICInvalidMultiaddrError: If invalid parameters provided
"""
try:
# Determine IP version
try:
ip = ipaddress.ip_address(host)
if isinstance(ip, ipaddress.IPv4Address):
ip_proto = IP4_PROTOCOL
else:
ip_proto = IP6_PROTOCOL
except ValueError:
raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}")
# Validate port
if not (0 <= port <= 65535):
raise QUICInvalidMultiaddrError(f"Invalid port: {port}")
# Validate and normalize QUIC version
if version == "quic-v1" or version == "/quic-v1":
quic_proto = QUIC_V1_PROTOCOL
elif version == "quic" or version == "/quic":
quic_proto = QUIC_DRAFT29_PROTOCOL
else:
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
# Construct multiaddr
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
return multiaddr.Multiaddr(addr_str)
except Exception as e:
raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e
def quic_version_to_wire_format(version: TProtocol) -> int:
"""
Convert QUIC version string to wire format integer for aioquic.
Args:
version: QUIC version string ("quic-v1" or "quic")
Returns:
Wire format version number
Raises:
QUICUnsupportedVersionError: If version is not supported
"""
wire_version = QUIC_VERSION_MAPPINGS.get(version)
if wire_version is None:
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
return wire_version
def custom_quic_version_to_wire_format(version: TProtocol) -> int:
"""
Convert QUIC version string to wire format integer for aioquic.
Args:
version: QUIC version string ("quic-v1" or "quic")
Returns:
Wire format version number
Raises:
QUICUnsupportedVersionError: If version is not supported
"""
wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version)
if wire_version is None:
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
return wire_version
def get_alpn_protocols() -> list[str]:
"""
Get ALPN protocols for libp2p over QUIC.
Returns:
List of ALPN protocol identifiers
"""
return LIBP2P_ALPN_PROTOCOLS.copy()
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
"""
Normalize a QUIC multiaddr to canonical form.
Args:
maddr: Input QUIC multiaddr
Returns:
Normalized multiaddr
Raises:
QUICInvalidMultiaddrError: If not a valid QUIC multiaddr
"""
if not is_quic_multiaddr(maddr):
raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}")
host, port = quic_multiaddr_to_endpoint(maddr)
version = multiaddr_to_quic_version(maddr)
return create_quic_multiaddr(host, port, version)
def create_server_config_from_base(
base_config: QuicConfiguration,
security_manager: QUICTLSConfigManager | None = None,
transport_config: QUICTransportConfig | None = None,
) -> QuicConfiguration:
"""
Create a server configuration without using deepcopy.
Manually copies attributes while handling cryptography objects properly.
"""
try:
# Create new server configuration from scratch
server_config = QuicConfiguration(is_client=False)
server_config.verify_mode = ssl.CERT_NONE
# Copy basic configuration attributes (these are safe to copy)
copyable_attrs = [
"alpn_protocols",
"verify_mode",
"max_datagram_frame_size",
"idle_timeout",
"max_concurrent_streams",
"supported_versions",
"max_data",
"max_stream_data",
"stateless_retry",
"quantum_readiness_test",
]
for attr in copyable_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(server_config, attr, value)
# Handle cryptography objects - these need direct reference, not copying
crypto_attrs = [
"certificate",
"private_key",
"certificate_chain",
"ca_certs",
]
for attr in crypto_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(server_config, attr, value)
# Apply security manager configuration if available
if security_manager:
try:
server_tls_config = security_manager.create_server_config()
# Override with security manager's TLS configuration
if server_tls_config.certificate:
server_config.certificate = server_tls_config.certificate
if server_tls_config.private_key:
server_config.private_key = server_tls_config.private_key
if server_tls_config.certificate_chain:
server_config.certificate_chain = (
server_tls_config.certificate_chain
)
if server_tls_config.alpn_protocols:
server_config.alpn_protocols = server_tls_config.alpn_protocols
server_tls_config.request_client_certificate = True
if getattr(server_tls_config, "request_client_certificate", False):
server_config._libp2p_request_client_cert = True # type: ignore
else:
logger.error(
"🔧 Failed to set request_client_certificate in server config"
)
except Exception as e:
logger.warning(f"Failed to apply security manager config: {e}")
# Set transport-specific defaults if provided
if transport_config:
if server_config.idle_timeout == 0:
server_config.idle_timeout = getattr(
transport_config, "idle_timeout", 30.0
)
if server_config.max_datagram_frame_size is None:
server_config.max_datagram_frame_size = getattr(
transport_config, "max_datagram_size", 1200
)
# Ensure we have ALPN protocols
if not server_config.alpn_protocols:
server_config.alpn_protocols = ["libp2p"]
logger.debug("Successfully created server config without deepcopy")
return server_config
except Exception as e:
logger.error(f"Failed to create server config: {e}")
raise
def create_client_config_from_base(
base_config: QuicConfiguration,
security_manager: QUICTLSConfigManager | None = None,
transport_config: QUICTransportConfig | None = None,
) -> QuicConfiguration:
"""
Create a client configuration without using deepcopy.
"""
try:
# Create new client configuration from scratch
client_config = QuicConfiguration(is_client=True)
client_config.verify_mode = ssl.CERT_NONE
# Copy basic configuration attributes
copyable_attrs = [
"alpn_protocols",
"verify_mode",
"max_datagram_frame_size",
"idle_timeout",
"max_concurrent_streams",
"supported_versions",
"max_data",
"max_stream_data",
"quantum_readiness_test",
]
for attr in copyable_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(client_config, attr, value)
# Handle cryptography objects - these need direct reference, not copying
crypto_attrs = [
"certificate",
"private_key",
"certificate_chain",
"ca_certs",
]
for attr in crypto_attrs:
if hasattr(base_config, attr):
value = getattr(base_config, attr)
if value is not None:
setattr(client_config, attr, value)
# Apply security manager configuration if available
if security_manager:
try:
client_tls_config = security_manager.create_client_config()
# Override with security manager's TLS configuration
if client_tls_config.certificate:
client_config.certificate = client_tls_config.certificate
if client_tls_config.private_key:
client_config.private_key = client_tls_config.private_key
if client_tls_config.certificate_chain:
client_config.certificate_chain = (
client_tls_config.certificate_chain
)
if client_tls_config.alpn_protocols:
client_config.alpn_protocols = client_tls_config.alpn_protocols
except Exception as e:
logger.warning(f"Failed to apply security manager config: {e}")
# Ensure we have ALPN protocols
if not client_config.alpn_protocols:
client_config.alpn_protocols = ["libp2p"]
logger.debug("Successfully created client config without deepcopy")
return client_config
except Exception as e:
logger.error(f"Failed to create client config: {e}")
raise

View File

@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
DEFAULT_NEGOTIATE_TIMEOUT,
)
from libp2p.security.exceptions import (
HandshakeFailure,
)
@ -37,9 +40,12 @@ class TransportUpgrader:
self,
secure_transports_by_protocol: TSecurityOptions,
muxer_transports_by_protocol: TMuxerOptions,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
):
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
self.muxer_multistream = MuxerMultistream(
muxer_transports_by_protocol, negotiate_timeout
)
async def upgrade_security(
self,

View File

@ -0,0 +1 @@
Add QUIC transport support for faster, more efficient peer-to-peer connections with native stream multiplexing.

View File

@ -0,0 +1 @@
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly

View File

@ -0,0 +1 @@
Fix multiaddr dependency to use the last py-multiaddr commit hash to resolve installation issues

View File

@ -0,0 +1 @@
Updated multiaddr dependency from git repository to pip package version 0.0.11.

View File

@ -16,13 +16,14 @@ maintainers = [
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
]
dependencies = [
"aioquic>=1.2.0",
"base58>=1.0.3",
"coincurve>=10.0.0",
"coincurve==21.0.0",
"exceptiongroup>=1.2.0; python_version < '3.11'",
"fastecdsa==2.3.2; sys_platform != 'win32'",
"grpcio>=1.41.0",
"lru-dict>=1.1.6",
# "multiaddr>=0.0.9",
"multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6",
"multiaddr>=0.0.11",
"mypy-protobuf>=3.0.0",
"noiseprotocol>=0.3.0",
"protobuf>=4.25.0,<5.0.0",
@ -32,7 +33,6 @@ dependencies = [
"rpcudp>=3.0.0",
"trio-typing>=0.0.4",
"trio>=0.26.0",
"fastecdsa==2.3.2; sys_platform != 'win32'",
"zeroconf (>=0.147.0,<0.148.0)",
]
classifiers = [
@ -52,6 +52,7 @@ Homepage = "https://github.com/libp2p/py-libp2p"
[project.scripts]
chat-demo = "examples.chat.chat:main"
echo-demo = "examples.echo.echo:main"
echo-quic-demo="examples.echo.echo_quic:main"
ping-demo = "examples.ping.ping:main"
identify-demo = "examples.identify.identify:main"
identify-push-demo = "examples.identify_push.identify_push_demo:run_main"
@ -77,6 +78,7 @@ dev = [
"pytest>=7.0.0",
"pytest-xdist>=2.4.0",
"pytest-trio>=0.5.2",
"pytest-timeout>=2.4.0",
"factory-boy>=2.12.0,<3.0.0",
"ruff>=0.11.10",
"pyrefly (>=0.17.1,<0.18.0)",
@ -88,11 +90,12 @@ docs = [
"tomli; python_version < '3.11'",
]
test = [
"factory-boy>=2.12.0,<3.0.0",
"p2pclient==0.2.0",
"pytest>=7.0.0",
"pytest-xdist>=2.4.0",
"pytest-timeout>=2.4.0",
"pytest-trio>=0.5.2",
"factory-boy>=2.12.0,<3.0.0",
"pytest-xdist>=2.4.0",
]
[tool.setuptools]
@ -282,4 +285,5 @@ project_excludes = [
"**/*pb2.py",
"**/*.pyi",
".venv/**",
"./tests/interop/nim_libp2p",
]

View File

@ -250,10 +250,13 @@ def test_new_swarm_tcp_multiaddr_supported():
assert isinstance(swarm.transport, TCP)
def test_new_swarm_quic_multiaddr_raises():
def test_new_swarm_quic_multiaddr_supported():
from libp2p.transport.quic.transport import QUICTransport
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
with pytest.raises(ValueError, match="QUIC not yet supported"):
new_swarm(listen_addrs=[addr])
swarm = new_swarm(listen_addrs=[addr])
assert isinstance(swarm, Swarm)
assert isinstance(swarm.transport, QUICTransport)
@pytest.mark.trio

View File

@ -0,0 +1,108 @@
from unittest.mock import (
AsyncMock,
MagicMock,
)
import pytest
from libp2p.custom_types import (
TMuxerClass,
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.stream_muxer.muxer_multistream import (
MuxerMultistream,
)
@pytest.mark.trio
async def test_muxer_timeout_configuration():
"""Test that muxer respects timeout configuration."""
muxer = MuxerMultistream({}, negotiate_timeout=1)
assert muxer.negotiate_timeout == 1
@pytest.mark.trio
async def test_select_transport_passes_timeout_to_multiselect():
"""Test that timeout is passed to multiselect client in select_transport."""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = False
# Mock MultiselectClient
muxer = MuxerMultistream({}, negotiate_timeout=10)
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
# Call select_transport
await muxer.select_transport(mock_conn)
# Verify that select_one_of was called with the correct timeout
args, _ = muxer.multiselect.negotiate.call_args
assert args[1] == 10
@pytest.mark.trio
async def test_new_conn_passes_timeout_to_multistream_client():
"""Test that timeout is passed to multistream client in new_conn."""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = True
mock_peer_id = ID(b"test_peer")
mock_communicator = MagicMock()
# Mock MultistreamClient and transports
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
# Call new_conn
await muxer.new_conn(mock_conn, mock_peer_id)
# Verify that select_one_of was called with the correct timeout
muxer.multistream_client.select_one_of(
tuple(muxer.transports.keys()), mock_communicator, 30
)
@pytest.mark.trio
async def test_select_transport_no_protocol_selected():
"""
Test that select_transport raises MultiselectError when no protocol is selected.
"""
# Mock dependencies
mock_conn = MagicMock()
mock_conn.is_initiator = False
# Mock Multiselect to return None
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
# Expect MultiselectError to be raised
with pytest.raises(MultiselectError, match="no protocol selected"):
await muxer.select_transport(mock_conn)
@pytest.mark.trio
async def test_add_transport_updates_precedence():
"""Test that adding a transport updates protocol precedence."""
# Mock transport classes
mock_transport1 = MagicMock(spec=TMuxerClass)
mock_transport2 = MagicMock(spec=TMuxerClass)
# Initialize muxer and add transports
muxer = MuxerMultistream({}, negotiate_timeout=30)
muxer.add_transport(TProtocol("proto1"), mock_transport1)
muxer.add_transport(TProtocol("proto2"), mock_transport2)
# Verify transport order
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
# Re-add proto1 to check if it moves to the end
muxer.add_transport(TProtocol("proto1"), mock_transport1)
assert list(muxer.transports.keys()) == ["proto2", "proto1"]

View File

@ -0,0 +1,553 @@
"""
Enhanced tests for QUIC connection functionality - Module 3.
Tests all new features including advanced stream management, resource management,
error handling, and concurrent operations.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
from multiaddr.multiaddr import Multiaddr
import trio
from libp2p.crypto.ed25519 import create_new_key_pair
from libp2p.peer.id import ID
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.exceptions import (
QUICConnectionClosedError,
QUICConnectionError,
QUICConnectionTimeoutError,
QUICPeerVerificationError,
QUICStreamLimitError,
QUICStreamTimeoutError,
)
from libp2p.transport.quic.security import QUICTLSConfigManager
from libp2p.transport.quic.stream import QUICStream, StreamDirection
class MockResourceScope:
"""Mock resource scope for testing."""
def __init__(self):
self.memory_reserved = 0
def reserve_memory(self, size):
self.memory_reserved += size
def release_memory(self, size):
self.memory_reserved = max(0, self.memory_reserved - size)
class TestQUICConnection:
"""Test suite for QUIC connection functionality."""
@pytest.fixture
def mock_quic_connection(self):
"""Create mock aioquic QuicConnection."""
mock = Mock()
mock.next_event.return_value = None
mock.datagrams_to_send.return_value = []
mock.get_timer.return_value = None
mock.connect = Mock()
mock.close = Mock()
mock.send_stream_data = Mock()
mock.reset_stream = Mock()
return mock
@pytest.fixture
def mock_quic_transport(self):
mock = Mock()
mock._config = QUICTransportConfig()
return mock
@pytest.fixture
def mock_resource_scope(self):
"""Create mock resource scope."""
return MockResourceScope()
@pytest.fixture
def quic_connection(
self,
mock_quic_connection: Mock,
mock_quic_transport: Mock,
mock_resource_scope: MockResourceScope,
):
"""Create test QUIC connection with enhanced features."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
mock_security_manager = Mock()
return QUICConnection(
quic_connection=mock_quic_connection,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=None,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=mock_quic_transport,
resource_scope=mock_resource_scope,
security_manager=mock_security_manager,
)
@pytest.fixture
def server_connection(self, mock_quic_connection, mock_resource_scope):
"""Create server-side QUIC connection."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
return QUICConnection(
quic_connection=mock_quic_connection,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=False,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
resource_scope=mock_resource_scope,
)
# Basic functionality tests
def test_connection_initialization_enhanced(
self, quic_connection, mock_resource_scope
):
"""Test enhanced connection initialization."""
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
assert quic_connection.is_initiator is True
assert not quic_connection.is_closed
assert not quic_connection.is_established
assert len(quic_connection._streams) == 0
assert quic_connection._resource_scope == mock_resource_scope
assert quic_connection._outbound_stream_count == 0
assert quic_connection._inbound_stream_count == 0
assert len(quic_connection._stream_accept_queue) == 0
def test_stream_id_calculation_enhanced(self):
"""Test enhanced stream ID calculation for client/server."""
# Client connection (initiator)
client_conn = QUICConnection(
quic_connection=Mock(),
remote_addr=("127.0.0.1", 4001),
remote_peer_id=None,
local_peer_id=Mock(),
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
assert client_conn._next_stream_id == 0 # Client starts with 0
# Server connection (not initiator)
server_conn = QUICConnection(
quic_connection=Mock(),
remote_addr=("127.0.0.1", 4001),
remote_peer_id=None,
local_peer_id=Mock(),
is_initiator=False,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
assert server_conn._next_stream_id == 1 # Server starts with 1
def test_incoming_stream_detection_enhanced(self, quic_connection):
"""Test enhanced incoming stream detection logic."""
# For client (initiator), odd stream IDs are incoming
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
# Stream management tests
@pytest.mark.trio
async def test_open_stream_basic(self, quic_connection):
"""Test basic stream opening."""
quic_connection._started = True
stream = await quic_connection.open_stream()
assert isinstance(stream, QUICStream)
assert stream.stream_id == "0"
assert stream.direction == StreamDirection.OUTBOUND
assert 0 in quic_connection._streams
assert quic_connection._outbound_stream_count == 1
@pytest.mark.trio
async def test_open_stream_limit_reached(self, quic_connection):
"""Test stream limit enforcement."""
quic_connection._started = True
quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS
with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"):
await quic_connection.open_stream()
@pytest.mark.trio
async def test_open_stream_timeout(self, quic_connection: QUICConnection):
"""Test stream opening timeout."""
quic_connection._started = True
return
# Mock the stream ID lock to simulate slow operation
async def slow_acquire():
await trio.sleep(10) # Longer than timeout
with patch.object(
quic_connection._stream_lock, "acquire", side_effect=slow_acquire
):
with pytest.raises(
QUICStreamTimeoutError, match="Stream creation timed out"
):
await quic_connection.open_stream(timeout=0.1)
@pytest.mark.trio
async def test_accept_stream_basic(self, quic_connection):
"""Test basic stream acceptance."""
# Create a mock inbound stream
mock_stream = Mock(spec=QUICStream)
mock_stream.stream_id = "1"
# Add to accept queue
quic_connection._stream_accept_queue.append(mock_stream)
quic_connection._stream_accept_event.set()
accepted_stream = await quic_connection.accept_stream(timeout=0.1)
assert accepted_stream == mock_stream
assert len(quic_connection._stream_accept_queue) == 0
@pytest.mark.trio
async def test_accept_stream_timeout(self, quic_connection):
"""Test stream acceptance timeout."""
with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"):
await quic_connection.accept_stream(timeout=0.1)
@pytest.mark.trio
async def test_accept_stream_on_closed_connection(self, quic_connection):
"""Test stream acceptance on closed connection."""
await quic_connection.close()
with pytest.raises(QUICConnectionClosedError, match="Connection is closed"):
await quic_connection.accept_stream()
# Stream handler tests
@pytest.mark.trio
async def test_stream_handler_setting(self, quic_connection):
"""Test setting stream handler."""
async def mock_handler(stream):
pass
quic_connection.set_stream_handler(mock_handler)
assert quic_connection._stream_handler == mock_handler
# Connection lifecycle tests
@pytest.mark.trio
async def test_connection_start_client(self, quic_connection):
"""Test client connection start."""
with patch.object(
quic_connection, "_initiate_connection", new_callable=AsyncMock
) as mock_initiate:
await quic_connection.start()
assert quic_connection._started
mock_initiate.assert_called_once()
@pytest.mark.trio
async def test_connection_start_server(self, server_connection):
"""Test server connection start."""
await server_connection.start()
assert server_connection._started
assert server_connection._established
assert server_connection._connected_event.is_set()
@pytest.mark.trio
async def test_connection_start_already_started(self, quic_connection):
"""Test starting already started connection."""
quic_connection._started = True
# Should not raise error, just log warning
await quic_connection.start()
assert quic_connection._started
@pytest.mark.trio
async def test_connection_start_closed(self, quic_connection):
"""Test starting closed connection."""
quic_connection._closed = True
with pytest.raises(
QUICConnectionError, match="Cannot start a closed connection"
):
await quic_connection.start()
@pytest.mark.trio
async def test_connection_connect_with_nursery(
self, quic_connection: QUICConnection
):
"""Test connection establishment with nursery."""
quic_connection._started = True
quic_connection._established = True
quic_connection._connected_event.set()
with patch.object(
quic_connection, "_start_background_tasks", new_callable=AsyncMock
) as mock_start_tasks:
with patch.object(
quic_connection,
"_verify_peer_identity_with_security",
new_callable=AsyncMock,
) as mock_verify:
async with trio.open_nursery() as nursery:
await quic_connection.connect(nursery)
assert quic_connection._nursery == nursery
mock_start_tasks.assert_called_once()
mock_verify.assert_called_once()
@pytest.mark.trio
@pytest.mark.slow
async def test_connection_connect_timeout(
self, quic_connection: QUICConnection
) -> None:
"""Test connection establishment timeout."""
quic_connection._started = True
# Don't set connected event to simulate timeout
with patch.object(
quic_connection, "_start_background_tasks", new_callable=AsyncMock
):
async with trio.open_nursery() as nursery:
with pytest.raises(
QUICConnectionTimeoutError, match="Connection handshake timed out"
):
await quic_connection.connect(nursery)
# Resource management tests
@pytest.mark.trio
async def test_stream_removal_resource_cleanup(
self, quic_connection: QUICConnection, mock_resource_scope
):
"""Test stream removal and resource cleanup."""
quic_connection._started = True
# Create a stream
stream = await quic_connection.open_stream()
# Remove the stream
quic_connection._remove_stream(int(stream.stream_id))
assert int(stream.stream_id) not in quic_connection._streams
# Note: Count updates is async, so we can't test it directly here
# Error handling tests
@pytest.mark.trio
async def test_connection_error_handling(self, quic_connection) -> None:
"""Test connection error handling."""
error = Exception("Test error")
with patch.object(
quic_connection, "close", new_callable=AsyncMock
) as mock_close:
await quic_connection._handle_connection_error(error)
mock_close.assert_called_once()
# Statistics and monitoring tests
@pytest.mark.trio
async def test_connection_stats_enhanced(self, quic_connection) -> None:
"""Test enhanced connection statistics."""
quic_connection._started = True
# Create some streams
_stream1 = await quic_connection.open_stream()
_stream2 = await quic_connection.open_stream()
stats = quic_connection.get_stream_stats()
expected_keys = [
"total_streams",
"outbound_streams",
"inbound_streams",
"max_streams",
"stream_utilization",
"stats",
]
for key in expected_keys:
assert key in stats
assert stats["total_streams"] == 2
assert stats["outbound_streams"] == 2
assert stats["inbound_streams"] == 0
@pytest.mark.trio
async def test_get_active_streams(self, quic_connection) -> None:
"""Test getting active streams."""
quic_connection._started = True
# Create streams
stream1 = await quic_connection.open_stream()
stream2 = await quic_connection.open_stream()
active_streams = quic_connection.get_active_streams()
assert len(active_streams) == 2
assert stream1 in active_streams
assert stream2 in active_streams
@pytest.mark.trio
async def test_get_streams_by_protocol(self, quic_connection) -> None:
"""Test getting streams by protocol."""
quic_connection._started = True
# Create streams with different protocols
stream1 = await quic_connection.open_stream()
stream1.protocol = "/test/1.0.0"
stream2 = await quic_connection.open_stream()
stream2.protocol = "/other/1.0.0"
test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0")
other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0")
assert len(test_streams) == 1
assert len(other_streams) == 1
assert stream1 in test_streams
assert stream2 in other_streams
# Enhanced close tests
@pytest.mark.trio
async def test_connection_close_enhanced(
self, quic_connection: QUICConnection
) -> None:
"""Test enhanced connection close with stream cleanup."""
quic_connection._started = True
# Create some streams
_stream1 = await quic_connection.open_stream()
_stream2 = await quic_connection.open_stream()
await quic_connection.close()
assert quic_connection.is_closed
assert len(quic_connection._streams) == 0
# Concurrent operations tests
@pytest.mark.trio
async def test_concurrent_stream_operations(
self, quic_connection: QUICConnection
) -> None:
"""Test concurrent stream operations."""
quic_connection._started = True
async def create_stream():
return await quic_connection.open_stream()
# Create multiple streams concurrently
async with trio.open_nursery() as nursery:
for i in range(10):
nursery.start_soon(create_stream)
# Wait a bit for all to start
await trio.sleep(0.1)
# Should have created streams without conflicts
assert quic_connection._outbound_stream_count == 10
assert len(quic_connection._streams) == 10
# Connection properties tests
def test_connection_properties(self, quic_connection: QUICConnection) -> None:
"""Test connection property accessors."""
assert quic_connection.multiaddr() == quic_connection._maddr
assert quic_connection.local_peer_id() == quic_connection._local_peer_id
assert quic_connection.remote_peer_id() == quic_connection._remote_peer_id
# IRawConnection interface tests
@pytest.mark.trio
async def test_raw_connection_write(self, quic_connection: QUICConnection) -> None:
"""Test raw connection write interface."""
quic_connection._started = True
with patch.object(quic_connection, "open_stream") as mock_open:
mock_stream = AsyncMock()
mock_open.return_value = mock_stream
await quic_connection.write(b"test data")
mock_open.assert_called_once()
mock_stream.write.assert_called_once_with(b"test data")
mock_stream.close_write.assert_called_once()
@pytest.mark.trio
async def test_raw_connection_read_not_implemented(
self, quic_connection: QUICConnection
) -> None:
"""Test raw connection read raises NotImplementedError."""
with pytest.raises(NotImplementedError):
await quic_connection.read()
# Mock verification helpers
def test_mock_resource_scope_functionality(self, mock_resource_scope) -> None:
"""Test mock resource scope works correctly."""
assert mock_resource_scope.memory_reserved == 0
mock_resource_scope.reserve_memory(1000)
assert mock_resource_scope.memory_reserved == 1000
mock_resource_scope.reserve_memory(500)
assert mock_resource_scope.memory_reserved == 1500
mock_resource_scope.release_memory(600)
assert mock_resource_scope.memory_reserved == 900
mock_resource_scope.release_memory(2000) # Should not go negative
assert mock_resource_scope.memory_reserved == 0
@pytest.mark.trio
async def test_invalid_certificate_verification():
key_pair1 = create_new_key_pair()
key_pair2 = create_new_key_pair()
peer_id1 = ID.from_pubkey(key_pair1.public_key)
peer_id2 = ID.from_pubkey(key_pair2.public_key)
manager = QUICTLSConfigManager(
libp2p_private_key=key_pair1.private_key, peer_id=peer_id1
)
# Match the certificate against a different peer_id
with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"):
manager.verify_peer_identity(manager.tls_config.certificate, peer_id2)
from cryptography.hazmat.primitives.serialization import Encoding
# --- Corrupt the certificate by tampering the DER bytes ---
cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER)
corrupted_bytes = bytearray(cert_bytes)
# Flip some random bytes in the middle of the certificate
corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF
from cryptography import x509
from cryptography.hazmat.backends import default_backend
# This will still parse (structurally valid), but the signature
# or fingerprint will break
corrupted_cert = x509.load_der_x509_certificate(
bytes(corrupted_bytes), backend=default_backend()
)
with pytest.raises(
QUICPeerVerificationError, match="Certificate verification failed"
):
manager.verify_peer_identity(corrupted_cert, peer_id1)

View File

@ -0,0 +1,624 @@
"""
QUIC Connection ID Management Tests
This test module covers comprehensive testing of QUIC connection ID functionality
including generation, rotation, retirement, and validation according to RFC 9000.
Tests are organized into:
1. Basic Connection ID Management
2. Connection ID Rotation and Updates
3. Connection ID Retirement
4. Error Conditions and Edge Cases
5. Integration Tests with Real Connections
"""
import secrets
import time
from typing import Any
from unittest.mock import Mock
import pytest
from aioquic.buffer import Buffer
# Import aioquic components for low-level testing
from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection, QuicConnectionId
from multiaddr import Multiaddr
from libp2p.crypto.ed25519 import create_new_key_pair
from libp2p.peer.id import ID
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.transport import QUICTransport
class ConnectionIdTestHelper:
"""Helper class for connection ID testing utilities."""
@staticmethod
def generate_connection_id(length: int = 8) -> bytes:
"""Generate a random connection ID of specified length."""
return secrets.token_bytes(length)
@staticmethod
def create_quic_connection_id(cid: bytes, sequence: int = 0) -> QuicConnectionId:
"""Create a QuicConnectionId object."""
return QuicConnectionId(
cid=cid,
sequence_number=sequence,
stateless_reset_token=secrets.token_bytes(16),
)
@staticmethod
def extract_connection_ids_from_connection(conn: QUICConnection) -> dict[str, Any]:
"""Extract connection ID information from a QUIC connection."""
quic = conn._quic
return {
"host_cids": [cid.cid.hex() for cid in getattr(quic, "_host_cids", [])],
"peer_cid": getattr(quic, "_peer_cid", None),
"peer_cid_available": [
cid.cid.hex() for cid in getattr(quic, "_peer_cid_available", [])
],
"retire_connection_ids": getattr(quic, "_retire_connection_ids", []),
"host_cid_seq": getattr(quic, "_host_cid_seq", 0),
}
class TestBasicConnectionIdManagement:
"""Test basic connection ID management functionality."""
@pytest.fixture
def mock_quic_connection(self):
"""Create a mock QUIC connection with connection ID support."""
mock_quic = Mock(spec=QuicConnection)
mock_quic._host_cids = []
mock_quic._host_cid_seq = 0
mock_quic._peer_cid = None
mock_quic._peer_cid_available = []
mock_quic._retire_connection_ids = []
mock_quic._configuration = Mock()
mock_quic._configuration.connection_id_length = 8
mock_quic._remote_active_connection_id_limit = 8
return mock_quic
@pytest.fixture
def quic_connection(self, mock_quic_connection):
"""Create a QUICConnection instance for testing."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
return QUICConnection(
quic_connection=mock_quic_connection,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
def test_connection_id_initialization(self, quic_connection):
"""Test that connection ID tracking is properly initialized."""
# Check that connection ID tracking structures are initialized
assert hasattr(quic_connection, "_available_connection_ids")
assert hasattr(quic_connection, "_current_connection_id")
assert hasattr(quic_connection, "_retired_connection_ids")
assert hasattr(quic_connection, "_connection_id_sequence_numbers")
# Initial state should be empty
assert len(quic_connection._available_connection_ids) == 0
assert quic_connection._current_connection_id is None
assert len(quic_connection._retired_connection_ids) == 0
assert len(quic_connection._connection_id_sequence_numbers) == 0
def test_connection_id_stats_tracking(self, quic_connection):
"""Test connection ID statistics are properly tracked."""
stats = quic_connection.get_connection_id_stats()
# Check that all expected stats are present
expected_keys = [
"available_connection_ids",
"current_connection_id",
"retired_connection_ids",
"connection_ids_issued",
"connection_ids_retired",
"connection_id_changes",
"available_cid_list",
]
for key in expected_keys:
assert key in stats
# Initial values should be zero/empty
assert stats["available_connection_ids"] == 0
assert stats["current_connection_id"] is None
assert stats["retired_connection_ids"] == 0
assert stats["connection_ids_issued"] == 0
assert stats["connection_ids_retired"] == 0
assert stats["connection_id_changes"] == 0
assert stats["available_cid_list"] == []
def test_current_connection_id_getter(self, quic_connection):
"""Test getting current connection ID."""
# Initially no connection ID
assert quic_connection.get_current_connection_id() is None
# Set a connection ID
test_cid = ConnectionIdTestHelper.generate_connection_id()
quic_connection._current_connection_id = test_cid
assert quic_connection.get_current_connection_id() == test_cid
def test_connection_id_generation(self):
"""Test connection ID generation utilities."""
# Test default length
cid1 = ConnectionIdTestHelper.generate_connection_id()
assert len(cid1) == 8
assert isinstance(cid1, bytes)
# Test custom length
cid2 = ConnectionIdTestHelper.generate_connection_id(16)
assert len(cid2) == 16
# Test uniqueness
cid3 = ConnectionIdTestHelper.generate_connection_id()
assert cid1 != cid3
class TestConnectionIdRotationAndUpdates:
"""Test connection ID rotation and update mechanisms."""
@pytest.fixture
def transport_config(self):
"""Create transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=100,
)
@pytest.fixture
def server_key(self):
"""Generate server private key."""
return create_new_key_pair().private_key
@pytest.fixture
def client_key(self):
"""Generate client private key."""
return create_new_key_pair().private_key
def test_connection_id_replenishment(self):
"""Test connection ID replenishment mechanism."""
# Create a real QuicConnection to test replenishment
config = QuicConfiguration(is_client=True)
config.connection_id_length = 8
quic_conn = QuicConnection(configuration=config)
# Initial state - should have some host connection IDs
initial_count = len(quic_conn._host_cids)
assert initial_count > 0
# Remove some connection IDs to trigger replenishment
while len(quic_conn._host_cids) > 2:
quic_conn._host_cids.pop()
# Trigger replenishment
quic_conn._replenish_connection_ids()
# Should have replenished up to the limit
assert len(quic_conn._host_cids) >= initial_count
# All connection IDs should have unique sequence numbers
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
assert len(sequences) == len(set(sequences))
def test_connection_id_sequence_numbers(self):
"""Test connection ID sequence number management."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Get initial sequence number
initial_seq = quic_conn._host_cid_seq
# Trigger replenishment to generate new connection IDs
quic_conn._replenish_connection_ids()
# Sequence numbers should increment
assert quic_conn._host_cid_seq > initial_seq
# All host connection IDs should have sequential numbers
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
sequences.sort()
# Check for proper sequence
for i in range(len(sequences) - 1):
assert sequences[i + 1] > sequences[i]
def test_connection_id_limits(self):
"""Test connection ID limit enforcement."""
config = QuicConfiguration(is_client=True)
config.connection_id_length = 8
quic_conn = QuicConnection(configuration=config)
# Set a reasonable limit
quic_conn._remote_active_connection_id_limit = 4
# Replenish connection IDs
quic_conn._replenish_connection_ids()
# Should not exceed the limit
assert len(quic_conn._host_cids) <= quic_conn._remote_active_connection_id_limit
class TestConnectionIdRetirement:
"""Test connection ID retirement functionality."""
def test_connection_id_retirement_basic(self):
"""Test basic connection ID retirement."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Create a test connection ID to retire
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=1
)
# Add it to peer connection IDs
quic_conn._peer_cid_available.append(test_cid)
quic_conn._peer_cid_sequence_numbers.add(1)
# Retire the connection ID
quic_conn._retire_peer_cid(test_cid)
# Should be added to retirement list
assert 1 in quic_conn._retire_connection_ids
def test_connection_id_retirement_limits(self):
"""Test connection ID retirement limits."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Fill up retirement list near the limit
max_retirements = 32 # Based on aioquic's default limit
for i in range(max_retirements):
quic_conn._retire_connection_ids.append(i)
# Should be at limit
assert len(quic_conn._retire_connection_ids) == max_retirements
def test_connection_id_retirement_events(self):
"""Test that retirement generates proper events."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Create and add a host connection ID
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=5
)
quic_conn._host_cids.append(test_cid)
# Create a retirement frame buffer
from aioquic.buffer import Buffer
buf = Buffer(capacity=16)
buf.push_uint_var(5) # sequence number to retire
buf.seek(0)
# Process retirement (this should generate an event)
try:
quic_conn._handle_retire_connection_id_frame(
Mock(), # context
0x19, # RETIRE_CONNECTION_ID frame type
buf,
)
# Check that connection ID was removed
remaining_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
assert 5 not in remaining_sequences
except Exception:
# May fail due to missing context, but that's okay for this test
pass
class TestConnectionIdErrorConditions:
"""Test error conditions and edge cases in connection ID handling."""
def test_invalid_connection_id_length(self):
"""Test handling of invalid connection ID lengths."""
# Connection IDs must be 1-20 bytes according to RFC 9000
# Test too short (0 bytes) - this should be handled gracefully
empty_cid = b""
assert len(empty_cid) == 0
# Test too long (>20 bytes)
long_cid = secrets.token_bytes(21)
assert len(long_cid) == 21
# Test valid lengths
for length in range(1, 21):
valid_cid = secrets.token_bytes(length)
assert len(valid_cid) == length
def test_duplicate_sequence_numbers(self):
"""Test handling of duplicate sequence numbers."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Create two connection IDs with same sequence number
cid1 = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=10
)
cid2 = ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), sequence=10
)
# Add first connection ID
quic_conn._peer_cid_available.append(cid1)
quic_conn._peer_cid_sequence_numbers.add(10)
# Adding second with same sequence should be handled appropriately
# (The implementation should prevent duplicates)
if 10 not in quic_conn._peer_cid_sequence_numbers:
quic_conn._peer_cid_available.append(cid2)
quic_conn._peer_cid_sequence_numbers.add(10)
# Should only have one entry for sequence 10
sequences = [cid.sequence_number for cid in quic_conn._peer_cid_available]
assert sequences.count(10) <= 1
def test_retire_unknown_connection_id(self):
"""Test retiring an unknown connection ID."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Try to create a buffer to retire unknown sequence number
buf = Buffer(capacity=16)
buf.push_uint_var(999) # Unknown sequence number
buf.seek(0)
# This should raise an error when processed
# (Testing the error condition, not the full processing)
unknown_sequence = 999
known_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
assert unknown_sequence not in known_sequences
def test_retire_current_connection_id(self):
"""Test that retiring current connection ID is prevented."""
config = QuicConfiguration(is_client=True)
quic_conn = QuicConnection(configuration=config)
# Get current connection ID if available
if quic_conn._host_cids:
current_cid = quic_conn._host_cids[0]
current_sequence = current_cid.sequence_number
# Trying to retire current connection ID should be prevented
# This is tested by checking the sequence number logic
assert current_sequence >= 0
class TestConnectionIdIntegration:
"""Integration tests for connection ID functionality with real connections."""
@pytest.fixture
def server_config(self):
"""Server transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=100,
)
@pytest.fixture
def client_config(self):
"""Client transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
)
@pytest.fixture
def server_key(self):
"""Generate server private key."""
return create_new_key_pair().private_key
@pytest.fixture
def client_key(self):
"""Generate client private key."""
return create_new_key_pair().private_key
@pytest.mark.trio
async def test_connection_id_exchange_during_handshake(
self, server_key, client_key, server_config, client_config
):
"""Test connection ID exchange during connection handshake."""
# This test would require a full connection setup
# For now, we test the setup components
server_transport = QUICTransport(server_key, server_config)
client_transport = QUICTransport(client_key, client_config)
# Verify transports are created with proper configuration
assert server_transport._config == server_config
assert client_transport._config == client_config
# Test that connection ID tracking is available
# (Integration with actual networking would require more setup)
def test_connection_id_extraction_utilities(self):
"""Test connection ID extraction utilities."""
# Create a mock connection with some connection IDs
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
mock_quic = Mock()
mock_quic._host_cids = [
ConnectionIdTestHelper.create_quic_connection_id(
ConnectionIdTestHelper.generate_connection_id(), i
)
for i in range(3)
]
mock_quic._peer_cid = None
mock_quic._peer_cid_available = []
mock_quic._retire_connection_ids = []
mock_quic._host_cid_seq = 3
quic_conn = QUICConnection(
quic_connection=mock_quic,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
# Extract connection ID information
cid_info = ConnectionIdTestHelper.extract_connection_ids_from_connection(
quic_conn
)
# Verify extraction works
assert "host_cids" in cid_info
assert "peer_cid" in cid_info
assert "peer_cid_available" in cid_info
assert "retire_connection_ids" in cid_info
assert "host_cid_seq" in cid_info
# Check values
assert len(cid_info["host_cids"]) == 3
assert cid_info["host_cid_seq"] == 3
assert cid_info["peer_cid"] is None
assert len(cid_info["peer_cid_available"]) == 0
assert len(cid_info["retire_connection_ids"]) == 0
class TestConnectionIdStatistics:
"""Test connection ID statistics and monitoring."""
@pytest.fixture
def connection_with_stats(self):
"""Create a connection with connection ID statistics."""
private_key = create_new_key_pair().private_key
peer_id = ID.from_pubkey(private_key.get_public_key())
mock_quic = Mock()
mock_quic._host_cids = []
mock_quic._peer_cid = None
mock_quic._peer_cid_available = []
mock_quic._retire_connection_ids = []
return QUICConnection(
quic_connection=mock_quic,
remote_addr=("127.0.0.1", 4001),
remote_peer_id=peer_id,
local_peer_id=peer_id,
is_initiator=True,
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
transport=Mock(),
)
def test_connection_id_stats_initialization(self, connection_with_stats):
"""Test that connection ID statistics are properly initialized."""
stats = connection_with_stats._stats
# Check that connection ID stats are present
assert "connection_ids_issued" in stats
assert "connection_ids_retired" in stats
assert "connection_id_changes" in stats
# Initial values should be zero
assert stats["connection_ids_issued"] == 0
assert stats["connection_ids_retired"] == 0
assert stats["connection_id_changes"] == 0
def test_connection_id_stats_update(self, connection_with_stats):
"""Test updating connection ID statistics."""
conn = connection_with_stats
# Add some connection IDs to tracking
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(3)]
for cid in test_cids:
conn._available_connection_ids.add(cid)
# Update stats (this would normally be done by the implementation)
conn._stats["connection_ids_issued"] = len(test_cids)
# Verify stats
stats = conn.get_connection_id_stats()
assert stats["connection_ids_issued"] == 3
assert stats["available_connection_ids"] == 3
def test_connection_id_list_representation(self, connection_with_stats):
"""Test connection ID list representation in stats."""
conn = connection_with_stats
# Add some connection IDs
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(2)]
for cid in test_cids:
conn._available_connection_ids.add(cid)
# Get stats
stats = conn.get_connection_id_stats()
# Check that CID list is properly formatted
assert "available_cid_list" in stats
assert len(stats["available_cid_list"]) == 2
# All entries should be hex strings
for cid_hex in stats["available_cid_list"]:
assert isinstance(cid_hex, str)
assert len(cid_hex) == 16 # 8 bytes = 16 hex chars
# Performance and stress tests
class TestConnectionIdPerformance:
"""Test connection ID performance and stress scenarios."""
def test_connection_id_generation_performance(self):
"""Test connection ID generation performance."""
start_time = time.time()
# Generate many connection IDs
cids = []
for _ in range(1000):
cid = ConnectionIdTestHelper.generate_connection_id()
cids.append(cid)
end_time = time.time()
generation_time = end_time - start_time
# Should be reasonably fast (less than 1 second for 1000 IDs)
assert generation_time < 1.0
# All should be unique
assert len(set(cids)) == len(cids)
def test_connection_id_tracking_memory(self):
"""Test memory usage of connection ID tracking."""
conn_ids = set()
# Add many connection IDs
for _ in range(1000):
cid = ConnectionIdTestHelper.generate_connection_id()
conn_ids.add(cid)
# Verify they're all stored
assert len(conn_ids) == 1000
# Clean up
conn_ids.clear()
assert len(conn_ids) == 0
if __name__ == "__main__":
# Run tests if executed directly
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,418 @@
"""
Basic QUIC Echo Test
Simple test to verify the basic QUIC flow:
1. Client connects to server
2. Client sends data
3. Server receives data and echoes back
4. Client receives the echo
This test focuses on identifying where the accept_stream issue occurs.
"""
import logging
import pytest
import multiaddr
import trio
from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID
from libp2p import new_host
from libp2p.abc import INetStream
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.transport.quic.config import QUICTransportConfig
from libp2p.transport.quic.connection import QUICConnection
from libp2p.transport.quic.transport import QUICTransport
from libp2p.transport.quic.utils import create_quic_multiaddr
# Set up logging to see what's happening
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestBasicQUICFlow:
"""Test basic QUIC client-server communication flow."""
@pytest.fixture
def server_key(self):
"""Generate server key pair."""
return create_new_key_pair()
@pytest.fixture
def client_key(self):
"""Generate client key pair."""
return create_new_key_pair()
@pytest.fixture
def server_config(self):
"""Simple server configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=10,
max_connections=5,
)
@pytest.fixture
def client_config(self):
"""Simple client configuration."""
return QUICTransportConfig(
idle_timeout=10.0,
connection_timeout=5.0,
max_concurrent_streams=5,
)
@pytest.mark.trio
async def test_basic_echo_flow(
self, server_key, client_key, server_config, client_config
):
"""Test basic client-server echo flow with detailed logging."""
print("\n=== BASIC QUIC ECHO TEST ===")
# Create server components
server_transport = QUICTransport(server_key.private_key, server_config)
# Track test state
server_received_data = None
server_connection_established = False
echo_sent = False
async def echo_server_handler(connection: QUICConnection) -> None:
"""Simple echo server handler with detailed logging."""
nonlocal server_received_data, server_connection_established, echo_sent
print("🔗 SERVER: Connection handler called")
server_connection_established = True
try:
print("📡 SERVER: Waiting for incoming stream...")
# Accept stream with timeout and detailed logging
print("📡 SERVER: Calling accept_stream...")
stream = await connection.accept_stream(timeout=5.0)
if stream is None:
print("❌ SERVER: accept_stream returned None")
return
print(f"✅ SERVER: Stream accepted! Stream ID: {stream.stream_id}")
# Read data from the stream
print("📖 SERVER: Reading data from stream...")
server_data = await stream.read(1024)
if not server_data:
print("❌ SERVER: No data received from stream")
return
server_received_data = server_data.decode("utf-8", errors="ignore")
print(f"📨 SERVER: Received data: '{server_received_data}'")
# Echo the data back
echo_message = f"ECHO: {server_received_data}"
print(f"📤 SERVER: Sending echo: '{echo_message}'")
await stream.write(echo_message.encode())
echo_sent = True
print("✅ SERVER: Echo sent successfully")
# Close the stream
await stream.close()
print("🔒 SERVER: Stream closed")
except Exception as e:
print(f"❌ SERVER: Error in handler: {e}")
import traceback
traceback.print_exc()
# Create listener
listener = server_transport.create_listener(echo_server_handler)
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
# Variables to track client state
client_connected = False
client_sent_data = False
client_received_echo = None
try:
print("🚀 Starting server...")
async with trio.open_nursery() as nursery:
# Start server listener
success = await listener.listen(listen_addr, nursery)
assert success, "Failed to start server listener"
# Get server address
server_addrs = listener.get_addrs()
server_addr = multiaddr.Multiaddr(
f"{server_addrs[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
)
print(f"🔧 SERVER: Listening on {server_addr}")
# Give server a moment to be ready
await trio.sleep(0.1)
print("🚀 Starting client...")
# Create client transport
client_transport = QUICTransport(client_key.private_key, client_config)
client_transport.set_background_nursery(nursery)
try:
# Connect to server
print(f"📞 CLIENT: Connecting to {server_addr}")
connection = await client_transport.dial(server_addr)
client_connected = True
print("✅ CLIENT: Connected to server")
# Open a stream
print("📤 CLIENT: Opening stream...")
stream = await connection.open_stream()
print(f"✅ CLIENT: Stream opened with ID: {stream.stream_id}")
# Send test data
test_message = "Hello QUIC Server!"
print(f"📨 CLIENT: Sending message: '{test_message}'")
await stream.write(test_message.encode())
client_sent_data = True
print("✅ CLIENT: Message sent")
# Read echo response
print("📖 CLIENT: Waiting for echo response...")
response_data = await stream.read(1024)
if response_data:
client_received_echo = response_data.decode(
"utf-8", errors="ignore"
)
print(f"📬 CLIENT: Received echo: '{client_received_echo}'")
else:
print("❌ CLIENT: No echo response received")
print("🔒 CLIENT: Closing connection")
await connection.close()
print("🔒 CLIENT: Connection closed")
print("🔒 CLIENT: Closing transport")
await client_transport.close()
print("🔒 CLIENT: Transport closed")
except Exception as e:
print(f"❌ CLIENT: Error: {e}")
import traceback
traceback.print_exc()
finally:
await client_transport.close()
print("🔒 CLIENT: Transport closed")
# Give everything time to complete
await trio.sleep(0.5)
# Cancel nursery to stop server
nursery.cancel_scope.cancel()
finally:
# Cleanup
if not listener._closed:
await listener.close()
await server_transport.close()
# Verify the flow worked
print("\n📊 TEST RESULTS:")
print(f" Server connection established: {server_connection_established}")
print(f" Client connected: {client_connected}")
print(f" Client sent data: {client_sent_data}")
print(f" Server received data: '{server_received_data}'")
print(f" Echo sent by server: {echo_sent}")
print(f" Client received echo: '{client_received_echo}'")
# Test assertions
assert server_connection_established, "Server connection handler was not called"
assert client_connected, "Client failed to connect"
assert client_sent_data, "Client failed to send data"
assert server_received_data == "Hello QUIC Server!", (
f"Server received wrong data: '{server_received_data}'"
)
assert echo_sent, "Server failed to send echo"
assert client_received_echo == "ECHO: Hello QUIC Server!", (
f"Client received wrong echo: '{client_received_echo}'"
)
print("✅ BASIC ECHO TEST PASSED!")
@pytest.mark.trio
async def test_server_accept_stream_timeout(
self, server_key, client_key, server_config, client_config
):
"""Test what happens when server accept_stream times out."""
print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===")
server_transport = QUICTransport(server_key.private_key, server_config)
accept_stream_called = False
accept_stream_timeout = False
async def timeout_test_handler(connection: QUICConnection) -> None:
"""Handler that tests accept_stream timeout."""
nonlocal accept_stream_called, accept_stream_timeout
print("🔗 SERVER: Connection established, testing accept_stream timeout")
accept_stream_called = True
try:
print("📡 SERVER: Calling accept_stream with 2 second timeout...")
stream = await connection.accept_stream(timeout=2.0)
print(f"✅ SERVER: accept_stream returned: {stream}")
except Exception as e:
print(f"⏰ SERVER: accept_stream timed out or failed: {e}")
accept_stream_timeout = True
listener = server_transport.create_listener(timeout_test_handler)
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
client_connected = False
try:
async with trio.open_nursery() as nursery:
# Start server
server_transport.set_background_nursery(nursery)
success = await listener.listen(listen_addr, nursery)
assert success
server_addr = multiaddr.Multiaddr(
f"{listener.get_addrs()[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
)
print(f"🔧 SERVER: Listening on {server_addr}")
# Create client but DON'T open a stream
async with trio.open_nursery() as client_nursery:
client_transport = QUICTransport(
client_key.private_key, client_config
)
client_transport.set_background_nursery(client_nursery)
try:
print("📞 CLIENT: Connecting (but NOT opening stream)...")
connection = await client_transport.dial(server_addr)
client_connected = True
print("✅ CLIENT: Connected (no stream opened)")
# Wait for server timeout
await trio.sleep(3.0)
await connection.close()
print("🔒 CLIENT: Connection closed")
finally:
await client_transport.close()
nursery.cancel_scope.cancel()
finally:
await listener.close()
await server_transport.close()
print("\n📊 TIMEOUT TEST RESULTS:")
print(f" Client connected: {client_connected}")
print(f" accept_stream called: {accept_stream_called}")
print(f" accept_stream timeout: {accept_stream_timeout}")
assert client_connected, "Client should have connected"
assert accept_stream_called, "accept_stream should have been called"
assert accept_stream_timeout, (
"accept_stream should have timed out when no stream was opened"
)
print("✅ TIMEOUT TEST PASSED!")
@pytest.mark.trio
async def test_yamux_stress_ping():
STREAM_COUNT = 100
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
latencies = []
failures = []
# === Server Setup ===
server_host = new_host(listen_addrs=[listen_addr])
async def handle_ping(stream: INetStream) -> None:
try:
while True:
payload = await stream.read(PING_LENGTH)
if not payload:
break
await stream.write(payload)
except Exception:
await stream.reset()
server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
async with server_host.run(listen_addrs=[listen_addr]):
# Give server time to start
await trio.sleep(0.1)
# === Client Setup ===
destination = str(server_host.get_addrs()[0])
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
client_host = new_host(listen_addrs=[client_listen_addr])
async with client_host.run(listen_addrs=[client_listen_addr]):
await client_host.connect(info)
async def ping_stream(i: int):
stream = None
try:
start = trio.current_time()
stream = await client_host.new_stream(
info.peer_id, [PING_PROTOCOL_ID]
)
await stream.write(b"\x01" * PING_LENGTH)
with trio.fail_after(5):
response = await stream.read(PING_LENGTH)
if response == b"\x01" * PING_LENGTH:
latency_ms = int((trio.current_time() - start) * 1000)
latencies.append(latency_ms)
print(f"[Ping #{i}] Latency: {latency_ms} ms")
await stream.close()
except Exception as e:
print(f"[Ping #{i}] Failed: {e}")
failures.append(i)
if stream:
await stream.reset()
async with trio.open_nursery() as nursery:
for i in range(STREAM_COUNT):
nursery.start_soon(ping_stream, i)
# === Result Summary ===
print("\n📊 Ping Stress Test Summary")
print(f"Total Streams Launched: {STREAM_COUNT}")
print(f"Successful Pings: {len(latencies)}")
print(f"Failed Pings: {len(failures)}")
if failures:
print(f"❌ Failed stream indices: {failures}")
# === Assertions ===
assert len(latencies) == STREAM_COUNT, (
f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}"
)
assert all(isinstance(x, int) and x >= 0 for x in latencies), (
"Invalid latencies"
)
avg_latency = sum(latencies) / len(latencies)
print(f"✅ Average Latency: {avg_latency:.2f} ms")
assert avg_latency < 1000

View File

@ -0,0 +1,150 @@
from unittest.mock import AsyncMock
import pytest
from multiaddr.multiaddr import Multiaddr
import trio
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.transport.quic.exceptions import (
QUICListenError,
)
from libp2p.transport.quic.listener import QUICListener
from libp2p.transport.quic.transport import (
QUICTransport,
QUICTransportConfig,
)
from libp2p.transport.quic.utils import (
create_quic_multiaddr,
)
class TestQUICListener:
"""Test suite for QUIC listener functionality."""
@pytest.fixture
def private_key(self):
"""Generate test private key."""
return create_new_key_pair().private_key
@pytest.fixture
def transport_config(self):
"""Generate test transport configuration."""
return QUICTransportConfig(idle_timeout=10.0)
@pytest.fixture
def transport(self, private_key, transport_config):
"""Create test transport instance."""
return QUICTransport(private_key, transport_config)
@pytest.fixture
def connection_handler(self):
"""Mock connection handler."""
return AsyncMock()
@pytest.fixture
def listener(self, transport, connection_handler):
"""Create test listener."""
return transport.create_listener(connection_handler)
def test_listener_creation(self, transport, connection_handler):
"""Test listener creation."""
listener = transport.create_listener(connection_handler)
assert isinstance(listener, QUICListener)
assert listener._transport == transport
assert listener._handler == connection_handler
assert not listener._listening
assert not listener._closed
@pytest.mark.trio
async def test_listener_invalid_multiaddr(self, listener: QUICListener):
"""Test listener with invalid multiaddr."""
async with trio.open_nursery() as nursery:
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"):
await listener.listen(invalid_addr, nursery)
@pytest.mark.trio
async def test_listener_basic_lifecycle(self, listener: QUICListener):
"""Test basic listener lifecycle."""
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random
async with trio.open_nursery() as nursery:
# Start listening
success = await listener.listen(listen_addr, nursery)
assert success
assert listener.is_listening()
# Check bound addresses
addrs = listener.get_addrs()
assert len(addrs) == 1
# Check stats
stats = listener.get_stats()
assert stats["is_listening"] is True
assert stats["active_connections"] == 0
assert stats["pending_connections"] == 0
# Sender Cancel Signal
nursery.cancel_scope.cancel()
await listener.close()
assert not listener.is_listening()
@pytest.mark.trio
async def test_listener_double_listen(self, listener: QUICListener):
"""Test that double listen raises error."""
listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic")
try:
async with trio.open_nursery() as nursery:
success = await listener.listen(listen_addr, nursery)
assert success
await trio.sleep(0.01)
addrs = listener.get_addrs()
assert len(addrs) > 0
async with trio.open_nursery() as nursery2:
with pytest.raises(QUICListenError, match="Already listening"):
await listener.listen(listen_addr, nursery2)
nursery2.cancel_scope.cancel()
nursery.cancel_scope.cancel()
finally:
await listener.close()
@pytest.mark.trio
async def test_listener_port_binding(self, listener: QUICListener):
"""Test listener port binding and cleanup."""
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
try:
async with trio.open_nursery() as nursery:
success = await listener.listen(listen_addr, nursery)
assert success
await trio.sleep(0.5)
addrs = listener.get_addrs()
assert len(addrs) > 0
nursery.cancel_scope.cancel()
finally:
await listener.close()
# By the time we get here, the listener and its tasks have been fully
# shut down, allowing the nursery to exit without hanging.
print("TEST COMPLETED SUCCESSFULLY.")
@pytest.mark.trio
async def test_listener_stats_tracking(self, listener):
"""Test listener statistics tracking."""
initial_stats = listener.get_stats()
# All counters should start at 0
assert initial_stats["connections_accepted"] == 0
assert initial_stats["connections_rejected"] == 0
assert initial_stats["bytes_received"] == 0
assert initial_stats["packets_processed"] == 0

View File

@ -0,0 +1,123 @@
from unittest.mock import (
Mock,
)
import pytest
from libp2p.crypto.ed25519 import (
create_new_key_pair,
)
from libp2p.crypto.keys import PrivateKey
from libp2p.transport.quic.exceptions import (
QUICDialError,
QUICListenError,
)
from libp2p.transport.quic.transport import (
QUICTransport,
QUICTransportConfig,
)
class TestQUICTransport:
"""Test suite for QUIC transport using trio."""
@pytest.fixture
def private_key(self):
"""Generate test private key."""
return create_new_key_pair().private_key
@pytest.fixture
def transport_config(self):
"""Generate test transport configuration."""
return QUICTransportConfig(
idle_timeout=10.0, enable_draft29=True, enable_v1=True
)
@pytest.fixture
def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig):
"""Create test transport instance."""
return QUICTransport(private_key, transport_config)
def test_transport_initialization(self, transport):
"""Test transport initialization."""
assert transport._private_key is not None
assert transport._peer_id is not None
assert not transport._closed
assert len(transport._quic_configs) >= 1
def test_supported_protocols(self, transport):
"""Test supported protocol identifiers."""
protocols = transport.protocols()
# TODO: Update when quic-v1 compatible
# assert "quic-v1" in protocols
assert "quic" in protocols # draft-29
def test_can_dial_quic_addresses(self, transport: QUICTransport):
"""Test multiaddr compatibility checking."""
import multiaddr
# Valid QUIC addresses
valid_addrs = [
# TODO: Update Multiaddr package to accept quic-v1
multiaddr.Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
),
multiaddr.Multiaddr(
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
multiaddr.Multiaddr(
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
multiaddr.Multiaddr(
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
),
]
for addr in valid_addrs:
assert transport.can_dial(addr)
# Invalid addresses
invalid_addrs = [
multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/4001"),
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001"),
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/ws"),
]
for addr in invalid_addrs:
assert not transport.can_dial(addr)
@pytest.mark.trio
async def test_transport_lifecycle(self, transport):
"""Test transport lifecycle management using trio."""
assert not transport._closed
await transport.close()
assert transport._closed
# Should be safe to close multiple times
await transport.close()
@pytest.mark.trio
async def test_dial_closed_transport(self, transport: QUICTransport) -> None:
"""Test dialing with closed transport raises error."""
import multiaddr
await transport.close()
with pytest.raises(QUICDialError, match="Transport is closed"):
await transport.dial(
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
)
def test_create_listener_closed_transport(self, transport: QUICTransport) -> None:
"""Test creating listener with closed transport raises error."""
transport._closed = True
with pytest.raises(QUICListenError, match="Transport is closed"):
transport.create_listener(Mock())

View File

@ -0,0 +1,321 @@
"""
Test suite for QUIC multiaddr utilities.
Focused tests covering essential functionality required for QUIC transport.
"""
import pytest
from multiaddr import Multiaddr
from libp2p.custom_types import TProtocol
from libp2p.transport.quic.exceptions import (
QUICInvalidMultiaddrError,
QUICUnsupportedVersionError,
)
from libp2p.transport.quic.utils import (
create_quic_multiaddr,
get_alpn_protocols,
is_quic_multiaddr,
multiaddr_to_quic_version,
normalize_quic_multiaddr,
quic_multiaddr_to_endpoint,
quic_version_to_wire_format,
)
class TestIsQuicMultiaddr:
"""Test QUIC multiaddr detection."""
def test_valid_quic_v1_multiaddrs(self):
"""Test valid QUIC v1 multiaddrs are detected."""
valid_addrs = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip4/192.168.1.1/udp/8080/quic-v1",
"/ip6/::1/udp/4001/quic-v1",
"/ip6/2001:db8::1/udp/5000/quic-v1",
]
for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
def test_valid_quic_draft29_multiaddrs(self):
"""Test valid QUIC draft-29 multiaddrs are detected."""
valid_addrs = [
"/ip4/127.0.0.1/udp/4001/quic",
"/ip4/10.0.0.1/udp/9000/quic",
"/ip6/::1/udp/4001/quic",
"/ip6/fe80::1/udp/6000/quic",
]
for addr_str in valid_addrs:
maddr = Multiaddr(addr_str)
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
def test_invalid_multiaddrs(self):
"""Test non-QUIC multiaddrs are not detected."""
invalid_addrs = [
"/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
"/ip4/127.0.0.1/udp/4001", # UDP without QUIC
"/ip4/127.0.0.1/udp/4001/ws", # WebSocket
"/ip4/127.0.0.1/quic-v1", # Missing UDP
"/udp/4001/quic-v1", # Missing IP
"/dns4/example.com/tcp/443/tls", # Completely different
]
for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
class TestQuicMultiaddrToEndpoint:
"""Test endpoint extraction from QUIC multiaddrs."""
def test_ipv4_extraction(self):
"""Test IPv4 host/port extraction."""
test_cases = [
("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
]
for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
result = quic_multiaddr_to_endpoint(maddr)
assert result == expected, f"Failed for {addr_str}"
def test_ipv6_extraction(self):
"""Test IPv6 host/port extraction."""
test_cases = [
("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
]
for addr_str, expected in test_cases:
maddr = Multiaddr(addr_str)
result = quic_multiaddr_to_endpoint(maddr)
assert result == expected, f"Failed for {addr_str}"
def test_invalid_multiaddr_raises_error(self):
"""Test invalid multiaddrs raise appropriate errors."""
invalid_addrs = [
"/ip4/127.0.0.1/tcp/4001", # Not QUIC
"/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
]
for addr_str in invalid_addrs:
maddr = Multiaddr(addr_str)
with pytest.raises(QUICInvalidMultiaddrError):
quic_multiaddr_to_endpoint(maddr)
class TestMultiaddrToQuicVersion:
"""Test QUIC version extraction."""
def test_quic_v1_detection(self):
"""Test QUIC v1 version detection."""
addrs = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip6/::1/udp/5000/quic-v1",
]
for addr_str in addrs:
maddr = Multiaddr(addr_str)
version = multiaddr_to_quic_version(maddr)
assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
def test_quic_draft29_detection(self):
"""Test QUIC draft-29 version detection."""
addrs = [
"/ip4/127.0.0.1/udp/4001/quic",
"/ip6/::1/udp/5000/quic",
]
for addr_str in addrs:
maddr = Multiaddr(addr_str)
version = multiaddr_to_quic_version(maddr)
assert version == "quic", f"Should detect quic for {addr_str}"
def test_non_quic_raises_error(self):
"""Test non-QUIC multiaddrs raise error."""
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICInvalidMultiaddrError):
multiaddr_to_quic_version(maddr)
class TestCreateQuicMultiaddr:
"""Test QUIC multiaddr creation."""
def test_ipv4_creation(self):
"""Test IPv4 QUIC multiaddr creation."""
test_cases = [
("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
]
for host, port, version, expected in test_cases:
result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
def test_ipv6_creation(self):
"""Test IPv6 QUIC multiaddr creation."""
test_cases = [
("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
]
for host, port, version, expected in test_cases:
result = create_quic_multiaddr(host, port, version)
assert str(result) == expected
def test_default_version(self):
"""Test default version is quic-v1."""
result = create_quic_multiaddr("127.0.0.1", 4001)
expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
assert str(result) == expected
def test_invalid_inputs_raise_errors(self):
"""Test invalid inputs raise appropriate errors."""
# Invalid IP
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("invalid-ip", 4001)
# Invalid port
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", 70000)
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", -1)
# Invalid version
with pytest.raises(QUICInvalidMultiaddrError):
create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
class TestQuicVersionToWireFormat:
"""Test QUIC version to wire format conversion."""
def test_supported_versions(self):
"""Test supported version conversions."""
test_cases = [
("quic-v1", 0x00000001), # RFC 9000
("quic", 0xFF00001D), # draft-29
]
for version, expected_wire in test_cases:
result = quic_version_to_wire_format(TProtocol(version))
assert result == expected_wire, f"Failed for version {version}"
def test_unsupported_version_raises_error(self):
"""Test unsupported versions raise error."""
with pytest.raises(QUICUnsupportedVersionError):
quic_version_to_wire_format(TProtocol("unsupported-version"))
class TestGetAlpnProtocols:
"""Test ALPN protocol retrieval."""
def test_returns_libp2p_protocols(self):
"""Test returns expected libp2p ALPN protocols."""
protocols = get_alpn_protocols()
assert protocols == ["libp2p"]
assert isinstance(protocols, list)
def test_returns_copy(self):
"""Test returns a copy, not the original list."""
protocols1 = get_alpn_protocols()
protocols2 = get_alpn_protocols()
# Modify one list
protocols1.append("test")
# Other list should be unchanged
assert protocols2 == ["libp2p"]
class TestNormalizeQuicMultiaddr:
"""Test QUIC multiaddr normalization."""
def test_already_normalized(self):
"""Test already normalized multiaddrs pass through."""
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
maddr = Multiaddr(addr_str)
result = normalize_quic_multiaddr(maddr)
assert str(result) == addr_str
def test_normalize_different_versions(self):
"""Test normalization works for different QUIC versions."""
test_cases = [
"/ip4/127.0.0.1/udp/4001/quic-v1",
"/ip4/127.0.0.1/udp/4001/quic",
"/ip6/::1/udp/5000/quic-v1",
]
for addr_str in test_cases:
maddr = Multiaddr(addr_str)
result = normalize_quic_multiaddr(maddr)
# Should be valid QUIC multiaddr
assert is_quic_multiaddr(result)
# Should be parseable
host, port = quic_multiaddr_to_endpoint(result)
version = multiaddr_to_quic_version(result)
# Should match original
orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
orig_version = multiaddr_to_quic_version(maddr)
assert host == orig_host
assert port == orig_port
assert version == orig_version
def test_non_quic_raises_error(self):
"""Test non-QUIC multiaddrs raise error."""
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
with pytest.raises(QUICInvalidMultiaddrError):
normalize_quic_multiaddr(maddr)
class TestIntegration:
"""Integration tests for utility functions working together."""
def test_round_trip_conversion(self):
"""Test creating and parsing multiaddrs works correctly."""
test_cases = [
("127.0.0.1", 4001, "quic-v1"),
("::1", 5000, "quic"),
("192.168.1.100", 8080, "quic-v1"),
]
for host, port, version in test_cases:
# Create multiaddr
maddr = create_quic_multiaddr(host, port, version)
# Should be detected as QUIC
assert is_quic_multiaddr(maddr)
# Should extract original values
extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
extracted_version = multiaddr_to_quic_version(maddr)
assert extracted_host == host
assert extracted_port == port
assert extracted_version == version
# Should normalize to same value
normalized = normalize_quic_multiaddr(maddr)
assert str(normalized) == str(maddr)
def test_wire_format_integration(self):
"""Test wire format conversion works with version detection."""
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
maddr = Multiaddr(addr_str)
# Extract version and convert to wire format
version = multiaddr_to_quic_version(maddr)
wire_format = quic_version_to_wire_format(version)
# Should be QUIC v1 wire format
assert wire_format == 0x00000001

View File

@ -0,0 +1,27 @@
import pytest
from libp2p.custom_types import (
TMuxerOptions,
TSecurityOptions,
)
from libp2p.transport.upgrader import (
TransportUpgrader,
)
@pytest.mark.trio
async def test_transport_upgrader_security_and_muxer_initialization():
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
secure_transports: TSecurityOptions = {}
muxer_transports: TMuxerOptions = {}
negotiate_timeout = 15
upgrader = TransportUpgrader(
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
)
# Verify security multistream initialization
assert upgrader.security_multistream.transports == secure_transports
# Verify muxer multistream initialization and timeout
assert upgrader.muxer_multistream.transports == muxer_transports
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout

View File

@ -0,0 +1,6 @@
def test_echo_quic_example():
"""Test that the QUIC echo example can be imported and has required functions."""
from examples.echo import echo_quic
assert hasattr(echo_quic, "main")
assert hasattr(echo_quic, "run")

8
tests/interop/nim_libp2p/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
nimble.develop
nimble.paths
*.nimble
nim-libp2p/
nim_echo_server
config.nims

View File

@ -0,0 +1,119 @@
import fcntl
import logging
from pathlib import Path
import shutil
import subprocess
import time
import pytest
logger = logging.getLogger(__name__)
def check_nim_available():
"""Check if nim compiler is available."""
return shutil.which("nim") is not None and shutil.which("nimble") is not None
def check_nim_binary_built():
"""Check if nim echo server binary is built."""
current_dir = Path(__file__).parent
binary_path = current_dir / "nim_echo_server"
return binary_path.exists() and binary_path.stat().st_size > 0
def run_nim_setup_with_lock():
"""Run nim setup with file locking to prevent parallel execution."""
current_dir = Path(__file__).parent
lock_file = current_dir / ".setup_lock"
setup_script = current_dir / "scripts" / "setup_nim_echo.sh"
if not setup_script.exists():
raise RuntimeError(f"Setup script not found: {setup_script}")
# Try to acquire lock
try:
with open(lock_file, "w") as f:
# Non-blocking lock attempt
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
# Double-check binary doesn't exist (another worker might have built it)
if check_nim_binary_built():
logger.info("Binary already exists, skipping setup")
return
logger.info("Acquired setup lock, running nim-libp2p setup...")
# Make setup script executable and run it
setup_script.chmod(0o755)
result = subprocess.run(
[str(setup_script)],
cwd=current_dir,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
)
if result.returncode != 0:
raise RuntimeError(
f"Setup failed (exit {result.returncode}):\n"
f"stdout: {result.stdout}\n"
f"stderr: {result.stderr}"
)
# Verify binary was built
if not check_nim_binary_built():
raise RuntimeError("nim_echo_server binary not found after setup")
logger.info("nim-libp2p setup completed successfully")
except BlockingIOError:
# Another worker is running setup, wait for it to complete
logger.info("Another worker is running setup, waiting...")
# Wait for setup to complete (check every 2 seconds, max 5 minutes)
for _ in range(150): # 150 * 2 = 300 seconds = 5 minutes
if check_nim_binary_built():
logger.info("Setup completed by another worker")
return
time.sleep(2)
raise TimeoutError("Timed out waiting for setup to complete")
finally:
# Clean up lock file
try:
lock_file.unlink(missing_ok=True)
except Exception:
pass
@pytest.fixture(scope="function") # Changed to function scope
def nim_echo_binary():
"""Get nim echo server binary path."""
current_dir = Path(__file__).parent
binary_path = current_dir / "nim_echo_server"
if not binary_path.exists():
pytest.skip(
"nim_echo_server binary not found. "
"Run setup script: ./scripts/setup_nim_echo.sh"
)
return binary_path
@pytest.fixture
async def nim_server(nim_echo_binary):
"""Start and stop nim echo server for tests."""
# Import here to avoid circular imports
# pyrefly: ignore
from test_echo_interop import NimEchoServer
server = NimEchoServer(nim_echo_binary)
try:
peer_id, listen_addr = await server.start()
yield server, peer_id, listen_addr
finally:
await server.stop()

View File

@ -0,0 +1,108 @@
{.used.}
import chronos
import stew/byteutils
import libp2p
##
# Simple Echo Protocol Implementation for py-libp2p Interop Testing
##
const EchoCodec = "/echo/1.0.0"
type EchoProto = ref object of LPProtocol
proc new(T: typedesc[EchoProto]): T =
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
try:
echo "Echo server: Received connection from ", conn.peerId
# Read and echo messages in a loop
while not conn.atEof:
try:
# Read length-prefixed message using nim-libp2p's readLp
let message = await conn.readLp(1024 * 1024) # Max 1MB
if message.len == 0:
echo "Echo server: Empty message, closing connection"
break
let messageStr = string.fromBytes(message)
echo "Echo server: Received (", message.len, " bytes): ", messageStr
# Echo back using writeLp
await conn.writeLp(message)
echo "Echo server: Echoed message back"
except CatchableError as e:
echo "Echo server: Error processing message: ", e.msg
break
except CancelledError as e:
echo "Echo server: Connection cancelled"
raise e
except CatchableError as e:
echo "Echo server: Exception in handler: ", e.msg
finally:
echo "Echo server: Connection closed"
await conn.close()
return T.new(codecs = @[EchoCodec], handler = handle)
##
# Create QUIC-enabled switch
##
proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
var switch = SwitchBuilder
.new()
.withRng(rng)
.withAddress(ma)
.withQuicTransport()
.build()
result = switch
##
# Main server
##
proc main() {.async.} =
let
rng = newRng()
localAddr = MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet()
echoProto = EchoProto.new()
echo "=== Nim Echo Server for py-libp2p Interop ==="
# Create switch
let switch = createSwitch(localAddr, rng)
switch.mount(echoProto)
# Start server
await switch.start()
# Print connection info
echo "Peer ID: ", $switch.peerInfo.peerId
echo "Listening on:"
for addr in switch.peerInfo.addrs:
echo " ", $addr, "/p2p/", $switch.peerInfo.peerId
echo "Protocol: ", EchoCodec
echo "Ready for py-libp2p connections!"
echo ""
# Keep running
try:
await sleepAsync(100.hours)
except CancelledError:
echo "Shutting down..."
finally:
await switch.stop()
# Graceful shutdown handler
proc signalHandler() {.noconv.} =
echo "\nShutdown signal received"
quit(0)
when isMainModule:
setControlCHook(signalHandler)
try:
waitFor(main())
except CatchableError as e:
echo "Error: ", e.msg
quit(1)

View File

@ -0,0 +1,74 @@
#!/usr/bin/env bash
# tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
# Cache-aware setup that skips installation if packages exist
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_DIR="${SCRIPT_DIR}/.."
# Colors
GREEN='\033[0;32m'
RED='\033[0;31m'
YELLOW='\033[1;33m'
NC='\033[0m'
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
main() {
log_info "Setting up nim echo server for interop testing..."
# Check if nim is available
if ! command -v nim &> /dev/null || ! command -v nimble &> /dev/null; then
log_error "Nim not found. Please install nim first."
exit 1
fi
cd "${PROJECT_DIR}"
# Create logs directory
mkdir -p logs
# Check if binary already exists
if [[ -f "nim_echo_server" ]]; then
log_info "nim_echo_server already exists, skipping build"
return 0
fi
# Check if libp2p is already installed (cache-aware)
if nimble list -i | grep -q "libp2p"; then
log_info "libp2p already installed, skipping installation"
else
log_info "Installing nim-libp2p globally..."
nimble install -y libp2p
fi
log_info "Building nim echo server..."
# Compile the echo server
nim c \
-d:release \
-d:chronicles_log_level=INFO \
-d:libp2p_quic_support \
-d:chronos_event_loop=iocp \
-d:ssl \
--opt:speed \
--mm:orc \
--verbosity:1 \
-o:nim_echo_server \
nim_echo_server.nim
# Verify binary was created
if [[ -f "nim_echo_server" ]]; then
log_info "✅ nim_echo_server built successfully"
log_info "Binary size: $(ls -lh nim_echo_server | awk '{print $5}')"
else
log_error "❌ Failed to build nim_echo_server"
exit 1
fi
log_info "🎉 Setup complete!"
}
main "$@"

View File

@ -0,0 +1,195 @@
import logging
from pathlib import Path
import subprocess
import time
import pytest
import multiaddr
import trio
from libp2p import new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.utils.varint import encode_varint_prefixed, read_varint_prefixed_bytes
# Configuration
PROTOCOL_ID = TProtocol("/echo/1.0.0")
TEST_TIMEOUT = 30
SERVER_START_TIMEOUT = 10.0
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NimEchoServer:
"""Simple nim echo server manager."""
def __init__(self, binary_path: Path):
self.binary_path = binary_path
self.process: None | subprocess.Popen = None
self.peer_id = None
self.listen_addr = None
async def start(self):
"""Start nim echo server and get connection info."""
logger.info(f"Starting nim echo server: {self.binary_path}")
self.process = subprocess.Popen(
[str(self.binary_path)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
)
# Parse output for connection info
start_time = time.time()
while time.time() - start_time < SERVER_START_TIMEOUT:
if self.process and self.process.poll() and self.process.stdout:
output = self.process.stdout.read()
raise RuntimeError(f"Server exited early: {output}")
reader = self.process.stdout if self.process else None
if reader:
line = reader.readline().strip()
if not line:
continue
logger.info(f"Server: {line}")
if line.startswith("Peer ID:"):
self.peer_id = line.split(":", 1)[1].strip()
elif "/quic-v1/p2p/" in line and self.peer_id:
if line.strip().startswith("/"):
self.listen_addr = line.strip()
logger.info(f"Server ready: {self.listen_addr}")
return self.peer_id, self.listen_addr
await self.stop()
raise TimeoutError(f"Server failed to start within {SERVER_START_TIMEOUT}s")
async def stop(self):
"""Stop the server."""
if self.process:
logger.info("Stopping nim echo server...")
try:
self.process.terminate()
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()
self.process = None
async def run_echo_test(server_addr: str, messages: list[str]):
"""Test echo protocol against nim server with proper timeout handling."""
# Create py-libp2p QUIC client with shorter timeouts
host = new_host(
enable_quic=True,
key_pair=create_new_key_pair(),
)
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/udp/0/quic-v1")
responses = []
try:
async with host.run(listen_addrs=[listen_addr]):
logger.info(f"Connecting to nim server: {server_addr}")
# Connect to nim server
maddr = multiaddr.Multiaddr(server_addr)
info = info_from_p2p_addr(maddr)
await host.connect(info)
# Create stream
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
logger.info("Stream created")
# Test each message
for i, message in enumerate(messages, 1):
logger.info(f"Testing message {i}: {message}")
# Send with varint length prefix
data = message.encode("utf-8")
prefixed_data = encode_varint_prefixed(data)
await stream.write(prefixed_data)
# Read response
response_data = await read_varint_prefixed_bytes(stream)
response = response_data.decode("utf-8")
logger.info(f"Got echo: {response}")
responses.append(response)
# Verify echo
assert message == response, (
f"Echo failed: sent {message!r}, got {response!r}"
)
await stream.close()
logger.info("✅ All messages echoed correctly")
finally:
await host.close()
return responses
@pytest.mark.trio
@pytest.mark.timeout(TEST_TIMEOUT)
async def test_basic_echo_interop(nim_server):
"""Test basic echo functionality between py-libp2p and nim-libp2p."""
server, peer_id, listen_addr = nim_server
test_messages = [
"Hello from py-libp2p!",
"QUIC transport working",
"Echo test successful!",
"Unicode: Ñoël, 测试, Ψυχή",
]
logger.info(f"Testing against nim server: {peer_id}")
# Run test with timeout
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
responses = await run_echo_test(listen_addr, test_messages)
# Verify all messages echoed correctly
assert len(responses) == len(test_messages)
for sent, received in zip(test_messages, responses):
assert sent == received
logger.info("✅ Basic echo interop test passed!")
@pytest.mark.trio
@pytest.mark.timeout(TEST_TIMEOUT)
async def test_large_message_echo(nim_server):
"""Test echo with larger messages."""
server, peer_id, listen_addr = nim_server
large_messages = [
"x" * 1024,
"y" * 5000,
]
logger.info("Testing large message echo...")
# Run test with timeout
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
responses = await run_echo_test(listen_addr, large_messages)
assert len(responses) == len(large_messages)
for sent, received in zip(large_messages, responses):
assert sent == received
logger.info("✅ Large message echo test passed!")
if __name__ == "__main__":
# Run tests directly
pytest.main([__file__, "-v", "--tb=short"])