mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into keyerror-fix
This commit is contained in:
40
.github/workflows/tox.yml
vendored
40
.github/workflows/tox.yml
vendored
@ -36,10 +36,48 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install Nim for interop testing
|
||||
if: matrix.toxenv == 'interop'
|
||||
run: |
|
||||
echo "Installing Nim for nim-libp2p interop testing..."
|
||||
curl -sSf https://nim-lang.org/choosenim/init.sh | sh -s -- -y --firstInstall
|
||||
echo "$HOME/.nimble/bin" >> $GITHUB_PATH
|
||||
echo "$HOME/.choosenim/toolchains/nim-stable/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Cache nimble packages
|
||||
if: matrix.toxenv == 'interop'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.nimble
|
||||
~/.choosenim/toolchains/*/lib
|
||||
key: ${{ runner.os }}-nimble-${{ hashFiles('**/nim_echo_server.nim') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-nimble-
|
||||
|
||||
- name: Build nim interop binaries
|
||||
if: matrix.toxenv == 'interop'
|
||||
run: |
|
||||
export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH"
|
||||
cd tests/interop/nim_libp2p
|
||||
./scripts/setup_nim_echo.sh
|
||||
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install tox
|
||||
- run: |
|
||||
|
||||
- name: Run Tests or Generate Docs
|
||||
run: |
|
||||
if [[ "${{ matrix.toxenv }}" == 'docs' ]]; then
|
||||
export TOXENV=docs
|
||||
else
|
||||
export TOXENV=py${{ matrix.python }}-${{ matrix.toxenv }}
|
||||
fi
|
||||
# Set PATH for nim commands during tox
|
||||
if [[ "${{ matrix.toxenv }}" == 'interop' ]]; then
|
||||
export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH"
|
||||
fi
|
||||
python -m tox run -r
|
||||
|
||||
windows:
|
||||
|
||||
12
README.md
12
README.md
@ -61,12 +61,12 @@ ______________________________________________________________________
|
||||
|
||||
### Discovery
|
||||
|
||||
| **Discovery** | **Status** | **Source** |
|
||||
| -------------------- | :--------: | :--------------------------------------------------------------------------------: |
|
||||
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
|
||||
| **`random-walk`** | 🌱 | |
|
||||
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
|
||||
| **`rendezvous`** | 🌱 | |
|
||||
| **Discovery** | **Status** | **Source** |
|
||||
| -------------------- | :--------: | :----------------------------------------------------------------------------------: |
|
||||
| **`bootstrap`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) |
|
||||
| **`random-walk`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/random_walk) |
|
||||
| **`mdns-discovery`** | ✅ | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) |
|
||||
| **`rendezvous`** | 🌱 | |
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
|
||||
43
docs/examples.echo_quic.rst
Normal file
43
docs/examples.echo_quic.rst
Normal file
@ -0,0 +1,43 @@
|
||||
QUIC Echo Demo
|
||||
==============
|
||||
|
||||
This example demonstrates a simple ``echo`` protocol using **QUIC transport**.
|
||||
|
||||
QUIC provides built-in TLS security and stream multiplexing over UDP, making it an excellent transport choice for libp2p applications.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m pip install libp2p
|
||||
Collecting libp2p
|
||||
...
|
||||
Successfully installed libp2p-x.x.x
|
||||
$ echo-quic-demo
|
||||
Run this from the same folder in another console:
|
||||
|
||||
echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmAsbxRR1HiGJRNVPQLNMeNsBCsXT3rDjoYBQzgzNpM5mJ
|
||||
|
||||
Waiting for incoming connection...
|
||||
|
||||
Copy the line that starts with ``echo-quic-demo -p 8001``, open a new terminal in the same
|
||||
folder and paste it in:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu
|
||||
|
||||
I am 16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu
|
||||
STARTING CLIENT CONNECTION PROCESS
|
||||
CLIENT CONNECTED TO SERVER
|
||||
Sent: hi, there!
|
||||
Got: ECHO: hi, there!
|
||||
|
||||
**Key differences from TCP Echo:**
|
||||
|
||||
- Uses UDP instead of TCP: ``/udp/8000`` instead of ``/tcp/8000``
|
||||
- Includes QUIC protocol identifier: ``/quic-v1`` in the multiaddr
|
||||
- Built-in TLS security (no separate security transport needed)
|
||||
- Native stream multiplexing over a single QUIC connection
|
||||
|
||||
.. literalinclude:: ../examples/echo/echo_quic.py
|
||||
:language: python
|
||||
:linenos:
|
||||
@ -9,6 +9,7 @@ Examples
|
||||
examples.identify_push
|
||||
examples.chat
|
||||
examples.echo
|
||||
examples.echo_quic
|
||||
examples.ping
|
||||
examples.pubsub
|
||||
examples.circuit_relay
|
||||
|
||||
@ -28,6 +28,11 @@ For Python, the most common transport is TCP. Here's how to set up a basic TCP t
|
||||
.. literalinclude:: ../examples/doc-examples/example_transport.py
|
||||
:language: python
|
||||
|
||||
Also, QUIC is a modern transport protocol that provides built-in TLS security and stream multiplexing over UDP:
|
||||
|
||||
.. literalinclude:: ../examples/doc-examples/example_quic_transport.py
|
||||
:language: python
|
||||
|
||||
Connection Encryption
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
77
docs/libp2p.transport.quic.rst
Normal file
77
docs/libp2p.transport.quic.rst
Normal file
@ -0,0 +1,77 @@
|
||||
libp2p.transport.quic package
|
||||
=============================
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
libp2p.transport.quic.config module
|
||||
-----------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.config
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.transport.quic.connection module
|
||||
---------------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.connection
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.transport.quic.exceptions module
|
||||
---------------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.exceptions
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.transport.quic.listener module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.listener
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.transport.quic.security module
|
||||
-------------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.security
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.transport.quic.stream module
|
||||
-----------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.stream
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.transport.quic.transport module
|
||||
--------------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.transport
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
libp2p.transport.quic.utils module
|
||||
----------------------------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic.utils
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
.. automodule:: libp2p.transport.quic
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
@ -9,6 +9,11 @@ Subpackages
|
||||
|
||||
libp2p.transport.tcp
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 4
|
||||
|
||||
libp2p.transport.quic
|
||||
|
||||
Submodules
|
||||
----------
|
||||
|
||||
|
||||
35
examples/doc-examples/example_quic_transport.py
Normal file
35
examples/doc-examples/example_quic_transport.py
Normal file
@ -0,0 +1,35 @@
|
||||
import secrets
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.crypto.secp256k1 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
# Create a key pair for the host
|
||||
secret = secrets.token_bytes(32)
|
||||
key_pair = create_new_key_pair(secret)
|
||||
|
||||
# Create a host with the key pair
|
||||
host = new_host(key_pair=key_pair, enable_quic=True)
|
||||
|
||||
# Configure the listening address
|
||||
port = 8000
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic-v1")
|
||||
|
||||
# Start the host
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
print("libp2p has started with QUIC transport")
|
||||
print("libp2p is listening on:", host.get_addrs())
|
||||
# Keep the host running
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
# Run the async function
|
||||
trio.run(main)
|
||||
178
examples/echo/echo_quic.py
Normal file
178
examples/echo/echo_quic.py
Normal file
@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
QUIC Echo Example - Fixed version with proper client/server separation
|
||||
|
||||
This program demonstrates a simple echo protocol using QUIC transport where a peer
|
||||
listens for connections and copies back any input received on a stream.
|
||||
|
||||
Fixed to properly separate client and server modes - clients don't start listeners.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.network.stream.net_stream import INetStream
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def _echo_stream_handler(stream: INetStream) -> None:
|
||||
try:
|
||||
msg = await stream.read()
|
||||
await stream.write(msg)
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
print(f"Echo handler error: {e}")
|
||||
try:
|
||||
await stream.close()
|
||||
except: # noqa: E722
|
||||
pass
|
||||
|
||||
|
||||
async def run_server(port: int, seed: int | None = None) -> None:
|
||||
"""Run echo server with QUIC transport."""
|
||||
listen_addr = Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic")
|
||||
|
||||
if seed:
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
secret_number = random.getrandbits(32 * 8)
|
||||
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||
else:
|
||||
import secrets
|
||||
|
||||
secret = secrets.token_bytes(32)
|
||||
|
||||
# Create host with QUIC transport
|
||||
host = new_host(
|
||||
enable_quic=True,
|
||||
key_pair=create_new_key_pair(secret),
|
||||
)
|
||||
|
||||
# Server mode: start listener
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
try:
|
||||
print(f"I am {host.get_id().to_string()}")
|
||||
host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler)
|
||||
|
||||
print(
|
||||
"Run this from the same folder in another console:\n\n"
|
||||
f"python3 ./examples/echo/echo_quic.py "
|
||||
f"-d {host.get_addrs()[0]}\n"
|
||||
)
|
||||
print("Waiting for incoming QUIC connections...")
|
||||
await trio.sleep_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("Closing server gracefully...")
|
||||
await host.close()
|
||||
return
|
||||
|
||||
|
||||
async def run_client(destination: str, seed: int | None = None) -> None:
|
||||
"""Run echo client with QUIC transport."""
|
||||
if seed:
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
secret_number = random.getrandbits(32 * 8)
|
||||
secret = secret_number.to_bytes(length=32, byteorder="big")
|
||||
else:
|
||||
import secrets
|
||||
|
||||
secret = secrets.token_bytes(32)
|
||||
|
||||
# Create host with QUIC transport
|
||||
host = new_host(
|
||||
enable_quic=True,
|
||||
key_pair=create_new_key_pair(secret),
|
||||
)
|
||||
|
||||
# Client mode: NO listener, just connect
|
||||
async with host.run(listen_addrs=[]): # Empty listen_addrs for client
|
||||
print(f"I am {host.get_id().to_string()}")
|
||||
|
||||
maddr = Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
|
||||
# Connect to server
|
||||
print("STARTING CLIENT CONNECTION PROCESS")
|
||||
await host.connect(info)
|
||||
print("CLIENT CONNECTED TO SERVER")
|
||||
|
||||
# Start a stream with the destination
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
|
||||
msg = b"hi, there!\n"
|
||||
|
||||
await stream.write(msg)
|
||||
response = await stream.read()
|
||||
|
||||
print(f"Sent: {msg.decode('utf-8')}")
|
||||
print(f"Got: {response.decode('utf-8')}")
|
||||
await stream.close()
|
||||
await host.disconnect(info.peer_id)
|
||||
|
||||
|
||||
async def run(port: int, destination: str, seed: int | None = None) -> None:
|
||||
"""
|
||||
Run echo server or client with QUIC transport.
|
||||
|
||||
Fixed version that properly separates client and server modes.
|
||||
"""
|
||||
if not destination: # Server mode
|
||||
await run_server(port, seed)
|
||||
else: # Client mode
|
||||
await run_client(destination, seed)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main function - help text updated for QUIC."""
|
||||
description = """
|
||||
This program demonstrates a simple echo protocol using QUIC
|
||||
transport where a peer listens for connections and copies back
|
||||
any input received on a stream.
|
||||
|
||||
QUIC provides built-in TLS security and stream multiplexing over UDP.
|
||||
|
||||
To use it, first run 'echo-quic-demo -p <PORT>', where <PORT> is
|
||||
the UDP port number. Then, run another host with ,
|
||||
'echo-quic-demo -d <DESTINATION>'
|
||||
where <DESTINATION> is the QUIC multiaddress of the previous listener host.
|
||||
"""
|
||||
|
||||
example_maddr = "/ip4/127.0.0.1/udp/8000/quic/p2p/QmQn4SwGkDZKkUEpBRBv"
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="UDP port number")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
help="provide a seed to the random number generator",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trio.run(run, args.port, args.destination, args.seed)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.getLogger("aioquic").setLevel(logging.DEBUG)
|
||||
main()
|
||||
@ -1,5 +1,11 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
import logging
|
||||
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
from typing import Any
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
@ -38,10 +44,12 @@ from libp2p.host.routed_host import (
|
||||
RoutedHost,
|
||||
)
|
||||
from libp2p.network.swarm import (
|
||||
ConnectionConfig,
|
||||
RetryConfig,
|
||||
Swarm,
|
||||
)
|
||||
from libp2p.network.config import (
|
||||
ConnectionConfig,
|
||||
RetryConfig
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -87,6 +95,7 @@ MUXER_YAMUX = "YAMUX"
|
||||
MUXER_MPLEX = "MPLEX"
|
||||
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||
"""
|
||||
@ -162,8 +171,9 @@ def new_swarm(
|
||||
peerstore_opt: IPeerStore | None = None,
|
||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||
enable_quic: bool = False,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: Optional["ConnectionConfig"] = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
) -> INetworkService:
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
@ -174,6 +184,8 @@ def new_swarm(
|
||||
:param peerstore_opt: optional peerstore
|
||||
:param muxer_preference: optional explicit muxer preference
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_quic: enable quic for transport
|
||||
:param quic_transport_opt: options for transport
|
||||
:return: return a default swarm instance
|
||||
|
||||
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
||||
@ -186,14 +198,21 @@ def new_swarm(
|
||||
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
transport: TCP | QUICTransport
|
||||
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
|
||||
|
||||
if listen_addrs is None:
|
||||
transport = TCP()
|
||||
if enable_quic:
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
else:
|
||||
transport = TCP()
|
||||
else:
|
||||
addr = listen_addrs[0]
|
||||
is_quic = is_quic_multiaddr(addr)
|
||||
if addr.__contains__("tcp"):
|
||||
transport = TCP()
|
||||
elif addr.__contains__("quic"):
|
||||
raise ValueError("QUIC not yet supported")
|
||||
elif is_quic:
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
else:
|
||||
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
|
||||
|
||||
@ -261,6 +280,8 @@ def new_host(
|
||||
enable_mDNS: bool = False,
|
||||
bootstrap: list[str] | None = None,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
enable_quic: bool = False,
|
||||
quic_transport_opt: QUICTransportConfig | None = None,
|
||||
) -> IHost:
|
||||
"""
|
||||
Create a new libp2p host based on the given parameters.
|
||||
@ -274,15 +295,23 @@ def new_host(
|
||||
:param listen_addrs: optional list of multiaddrs to listen on
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||
:param enable_quic: optinal choice to use QUIC for transport
|
||||
:param transport_opt: optional configuration for quic transport
|
||||
:return: return a host instance
|
||||
"""
|
||||
|
||||
if not enable_quic and quic_transport_opt is not None:
|
||||
logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config")
|
||||
|
||||
swarm = new_swarm(
|
||||
enable_quic=enable_quic,
|
||||
key_pair=key_pair,
|
||||
muxer_opt=muxer_opt,
|
||||
sec_opt=sec_opt,
|
||||
peerstore_opt=peerstore_opt,
|
||||
muxer_preference=muxer_preference,
|
||||
listen_addrs=listen_addrs,
|
||||
connection_config=quic_transport_opt if enable_quic else None
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
|
||||
@ -5,17 +5,17 @@ from collections.abc import (
|
||||
)
|
||||
from typing import TYPE_CHECKING, NewType, Union, cast
|
||||
|
||||
from libp2p.transport.quic.stream import QUICStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import (
|
||||
IMuxedConn,
|
||||
INetStream,
|
||||
ISecureTransport,
|
||||
)
|
||||
from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
else:
|
||||
IMuxedConn = cast(type, object)
|
||||
INetStream = cast(type, object)
|
||||
ISecureTransport = cast(type, object)
|
||||
|
||||
IMuxedStream = cast(type, object)
|
||||
QUICConnection = cast(type, object)
|
||||
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
@ -37,4 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool]
|
||||
AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]]
|
||||
ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn]
|
||||
UnsubscribeFn = Callable[[], Awaitable[None]]
|
||||
TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]]
|
||||
TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]]
|
||||
MessageID = NewType("MessageID", str)
|
||||
|
||||
@ -213,7 +213,6 @@ class BasicHost(IHost):
|
||||
self,
|
||||
peer_id: ID,
|
||||
protocol_ids: Sequence[TProtocol],
|
||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> INetStream:
|
||||
"""
|
||||
:param peer_id: peer_id that host is connecting
|
||||
@ -227,7 +226,7 @@ class BasicHost(IHost):
|
||||
selected_protocol = await self.multiselect_client.select_one_of(
|
||||
list(protocol_ids),
|
||||
MultiselectCommunicator(net_stream),
|
||||
negotitate_timeout,
|
||||
self.negotiate_timeout,
|
||||
)
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
|
||||
70
libp2p/network/config.py
Normal file
70
libp2p/network/config.py
Normal file
@ -0,0 +1,70 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""
|
||||
Configuration for retry logic with exponential backoff.
|
||||
|
||||
This configuration controls how connection attempts are retried when they fail.
|
||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||
herd problems in distributed systems.
|
||||
|
||||
Attributes:
|
||||
max_retries: Maximum number of retry attempts before giving up.
|
||||
Default: 3 attempts
|
||||
initial_delay: Initial delay in seconds before the first retry.
|
||||
Default: 0.1 seconds (100ms)
|
||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||
Default: 30.0 seconds
|
||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||
the delay by this factor). Default: 2.0 (doubles each time)
|
||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""
|
||||
Configuration for multi-connection support.
|
||||
|
||||
This configuration controls how multiple connections per peer are managed,
|
||||
including connection limits, timeouts, and load balancing strategies.
|
||||
|
||||
Attributes:
|
||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||
peer. Default: 3 connections
|
||||
connection_timeout: Timeout in seconds for establishing new connections.
|
||||
Default: 30.0 seconds
|
||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||
Options: "round_robin" (default) or "least_loaded"
|
||||
|
||||
"""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration after initialization."""
|
||||
if not (
|
||||
self.load_balancing_strategy == "round_robin"
|
||||
or self.load_balancing_strategy == "least_loaded"
|
||||
):
|
||||
raise ValueError(
|
||||
"Load balancing strategy can only be 'round_robin' or 'least_loaded'"
|
||||
)
|
||||
|
||||
if self.max_connections_per_peer < 1:
|
||||
raise ValueError("Max connection per peer should be atleast 1")
|
||||
|
||||
if self.connection_timeout < 0:
|
||||
raise ValueError("Connection timeout should be positive")
|
||||
@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamError,
|
||||
MuxedStreamReset,
|
||||
)
|
||||
from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError
|
||||
|
||||
from .exceptions import (
|
||||
StreamClosed,
|
||||
@ -170,7 +171,7 @@ class NetStream(INetStream):
|
||||
elif self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_READ
|
||||
raise StreamEOF() from error
|
||||
except MuxedStreamReset as error:
|
||||
except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state in [
|
||||
StreamState.OPEN,
|
||||
@ -199,7 +200,12 @@ class NetStream(INetStream):
|
||||
|
||||
try:
|
||||
await self.muxed_stream.write(data)
|
||||
except (MuxedStreamClosed, MuxedStreamError) as error:
|
||||
except (
|
||||
MuxedStreamClosed,
|
||||
MuxedStreamError,
|
||||
QUICStreamClosedError,
|
||||
QUICStreamResetError,
|
||||
) as error:
|
||||
async with self._state_lock:
|
||||
if self.__stream_state == StreamState.OPEN:
|
||||
self.__stream_state = StreamState.CLOSE_WRITE
|
||||
|
||||
@ -2,9 +2,9 @@ from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import random
|
||||
from typing import cast
|
||||
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
@ -27,6 +27,7 @@ from libp2p.custom_types import (
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
)
|
||||
from libp2p.network.config import ConnectionConfig, RetryConfig
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
@ -41,6 +42,9 @@ from libp2p.transport.exceptions import (
|
||||
OpenConnectionError,
|
||||
SecurityUpgradeFailure,
|
||||
)
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
@ -61,59 +65,6 @@ from .exceptions import (
|
||||
logger = logging.getLogger("libp2p.network.swarm")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""
|
||||
Configuration for retry logic with exponential backoff.
|
||||
|
||||
This configuration controls how connection attempts are retried when they fail.
|
||||
The retry mechanism uses exponential backoff with jitter to prevent thundering
|
||||
herd problems in distributed systems.
|
||||
|
||||
Attributes:
|
||||
max_retries: Maximum number of retry attempts before giving up.
|
||||
Default: 3 attempts
|
||||
initial_delay: Initial delay in seconds before the first retry.
|
||||
Default: 0.1 seconds (100ms)
|
||||
max_delay: Maximum delay cap in seconds to prevent excessive wait times.
|
||||
Default: 30.0 seconds
|
||||
backoff_multiplier: Multiplier for exponential backoff (each retry multiplies
|
||||
the delay by this factor). Default: 2.0 (doubles each time)
|
||||
jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays
|
||||
and prevent synchronized retries. Default: 0.1 (10% jitter)
|
||||
|
||||
"""
|
||||
|
||||
max_retries: int = 3
|
||||
initial_delay: float = 0.1
|
||||
max_delay: float = 30.0
|
||||
backoff_multiplier: float = 2.0
|
||||
jitter_factor: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""
|
||||
Configuration for multi-connection support.
|
||||
|
||||
This configuration controls how multiple connections per peer are managed,
|
||||
including connection limits, timeouts, and load balancing strategies.
|
||||
|
||||
Attributes:
|
||||
max_connections_per_peer: Maximum number of connections allowed to a single
|
||||
peer. Default: 3 connections
|
||||
connection_timeout: Timeout in seconds for establishing new connections.
|
||||
Default: 30.0 seconds
|
||||
load_balancing_strategy: Strategy for distributing streams across connections.
|
||||
Options: "round_robin" (default) or "least_loaded"
|
||||
|
||||
"""
|
||||
|
||||
max_connections_per_peer: int = 3
|
||||
connection_timeout: float = 30.0
|
||||
load_balancing_strategy: str = "round_robin" # or "least_loaded"
|
||||
|
||||
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
await network.get_manager().wait_finished()
|
||||
@ -126,8 +77,7 @@ class Swarm(Service, INetworkService):
|
||||
peerstore: IPeerStore
|
||||
upgrader: TransportUpgrader
|
||||
transport: ITransport
|
||||
# Enhanced: Support for multiple connections per peer
|
||||
connections: dict[ID, list[INetConn]] # Multiple connections per peer
|
||||
connections: dict[ID, list[INetConn]]
|
||||
listeners: dict[str, IListener]
|
||||
common_stream_handler: StreamHandlerFn
|
||||
listener_nursery: trio.Nursery | None
|
||||
@ -137,7 +87,7 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
# Enhanced: New configuration
|
||||
retry_config: RetryConfig
|
||||
connection_config: ConnectionConfig
|
||||
connection_config: ConnectionConfig | QUICTransportConfig
|
||||
_round_robin_index: dict[ID, int]
|
||||
|
||||
def __init__(
|
||||
@ -147,7 +97,7 @@ class Swarm(Service, INetworkService):
|
||||
upgrader: TransportUpgrader,
|
||||
transport: ITransport,
|
||||
retry_config: RetryConfig | None = None,
|
||||
connection_config: ConnectionConfig | None = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
):
|
||||
self.self_id = peer_id
|
||||
self.peerstore = peerstore
|
||||
@ -178,6 +128,11 @@ class Swarm(Service, INetworkService):
|
||||
# Create a nursery for listener tasks.
|
||||
self.listener_nursery = nursery
|
||||
self.event_listener_nursery_created.set()
|
||||
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
self.transport.set_background_nursery(nursery)
|
||||
self.transport.set_swarm(self)
|
||||
|
||||
try:
|
||||
await self.manager.wait_finished()
|
||||
finally:
|
||||
@ -370,6 +325,7 @@ class Swarm(Service, INetworkService):
|
||||
# Dial peer (connection to peer does not yet exist)
|
||||
# Transport dials peer (gets back a raw conn)
|
||||
try:
|
||||
addr = Multiaddr(f"{addr}/p2p/{peer_id}")
|
||||
raw_conn = await self.transport.dial(addr)
|
||||
except OpenConnectionError as error:
|
||||
logger.debug("fail to dial peer %s over base transport", peer_id)
|
||||
@ -377,6 +333,15 @@ class Swarm(Service, INetworkService):
|
||||
f"fail to open connection to peer {peer_id}"
|
||||
) from error
|
||||
|
||||
if isinstance(self.transport, QUICTransport) and isinstance(
|
||||
raw_conn, IMuxedConn
|
||||
):
|
||||
logger.info(
|
||||
"Skipping upgrade for QUIC, QUIC connections are already multiplexed"
|
||||
)
|
||||
swarm_conn = await self.add_conn(raw_conn)
|
||||
return swarm_conn
|
||||
|
||||
logger.debug("dialed peer %s over base transport", peer_id)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
@ -402,9 +367,7 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("upgraded mux for peer %s", peer_id)
|
||||
|
||||
swarm_conn = await self.add_conn(muxed_conn)
|
||||
|
||||
logger.debug("successfully dialed peer %s", peer_id)
|
||||
|
||||
return swarm_conn
|
||||
|
||||
async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn:
|
||||
@ -427,7 +390,6 @@ class Swarm(Service, INetworkService):
|
||||
:return: net stream instance
|
||||
"""
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
# Get existing connections or dial new ones
|
||||
connections = self.get_connections(peer_id)
|
||||
if not connections:
|
||||
@ -436,6 +398,10 @@ class Swarm(Service, INetworkService):
|
||||
# Load balancing strategy at interface level
|
||||
connection = self._select_connection(connections, peer_id)
|
||||
|
||||
if isinstance(self.transport, QUICTransport) and connection is not None:
|
||||
conn = cast(SwarmConn, connection)
|
||||
return await conn.new_stream()
|
||||
|
||||
try:
|
||||
net_stream = await connection.new_stream()
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
@ -516,6 +482,7 @@ class Swarm(Service, INetworkService):
|
||||
- Map multiaddr to listener
|
||||
"""
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
logger.debug("Starting to listen")
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
@ -527,6 +494,22 @@ class Swarm(Service, INetworkService):
|
||||
async def conn_handler(
|
||||
read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr
|
||||
) -> None:
|
||||
# No need to upgrade QUIC Connection
|
||||
if isinstance(self.transport, QUICTransport):
|
||||
try:
|
||||
quic_conn = cast(QUICConnection, read_write_closer)
|
||||
await self.add_conn(quic_conn)
|
||||
peer_id = quic_conn.peer_id
|
||||
logger.debug(
|
||||
f"successfully opened quic connection to peer {peer_id}"
|
||||
)
|
||||
# NOTE: This is a intentional barrier to prevent from the
|
||||
# handler exiting and closing the connection.
|
||||
await self.manager.wait_finished()
|
||||
except Exception:
|
||||
await read_write_closer.close()
|
||||
return
|
||||
|
||||
raw_conn = RawConnection(read_write_closer, False)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first
|
||||
@ -660,9 +643,10 @@ class Swarm(Service, INetworkService):
|
||||
muxed_conn,
|
||||
self,
|
||||
)
|
||||
|
||||
logger.debug("Swarm::add_conn | starting muxed connection")
|
||||
self.manager.run_task(muxed_conn.start)
|
||||
await muxed_conn.event_started.wait()
|
||||
logger.debug("Swarm::add_conn | starting swarm connection")
|
||||
self.manager.run_task(swarm_conn.start)
|
||||
await swarm_conn.event_started.wait()
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from builtins import AssertionError
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectCommunicator,
|
||||
)
|
||||
@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator):
|
||||
msg_bytes = encode_delim(msg_str.encode())
|
||||
try:
|
||||
await self.read_writer.write(msg_bytes)
|
||||
except IOException as error:
|
||||
# Handle for connection close during ongoing negotiation in QUIC
|
||||
except (IOException, AssertionError, ValueError) as error:
|
||||
raise MultiselectCommunicatorError(
|
||||
"fail to write to multiselect communicator"
|
||||
) from error
|
||||
|
||||
@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
Multiselect,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect_client import (
|
||||
@ -46,11 +47,17 @@ class MuxerMultistream:
|
||||
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
||||
multiselect: Multiselect
|
||||
multiselect_client: MultiselectClient
|
||||
negotiate_timeout: int
|
||||
|
||||
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
muxer_transports_by_protocol: TMuxerOptions,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
) -> None:
|
||||
self.transports = OrderedDict()
|
||||
self.multiselect = Multiselect()
|
||||
self.multistream_client = MultiselectClient()
|
||||
self.negotiate_timeout = negotiate_timeout
|
||||
for protocol, transport in muxer_transports_by_protocol.items():
|
||||
self.add_transport(protocol, transport)
|
||||
|
||||
@ -80,10 +87,12 @@ class MuxerMultistream:
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
if conn.is_initiator:
|
||||
protocol = await self.multiselect_client.select_one_of(
|
||||
tuple(self.transports.keys()), communicator
|
||||
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||
)
|
||||
else:
|
||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
||||
protocol, _ = await self.multiselect.negotiate(
|
||||
communicator, self.negotiate_timeout
|
||||
)
|
||||
if protocol is None:
|
||||
raise MultiselectError(
|
||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||
@ -93,7 +102,7 @@ class MuxerMultistream:
|
||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||
communicator = MultiselectCommunicator(conn)
|
||||
protocol = await self.multistream_client.select_one_of(
|
||||
tuple(self.transports.keys()), communicator
|
||||
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||
)
|
||||
transport_class = self.transports[protocol]
|
||||
if protocol == PROTOCOL_ID:
|
||||
|
||||
0
libp2p/transport/quic/__init__.py
Normal file
0
libp2p/transport/quic/__init__.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
345
libp2p/transport/quic/config.py
Normal file
@ -0,0 +1,345 @@
|
||||
"""
|
||||
Configuration classes for QUIC transport.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
import ssl
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.network.config import ConnectionConfig
|
||||
|
||||
|
||||
class QUICTransportKwargs(TypedDict, total=False):
|
||||
"""Type definition for kwargs accepted by new_transport function."""
|
||||
|
||||
# Connection settings
|
||||
idle_timeout: float
|
||||
max_datagram_size: int
|
||||
local_port: int | None
|
||||
|
||||
# Protocol version support
|
||||
enable_draft29: bool
|
||||
enable_v1: bool
|
||||
|
||||
# TLS settings
|
||||
verify_mode: ssl.VerifyMode
|
||||
alpn_protocols: list[str]
|
||||
|
||||
# Performance settings
|
||||
max_concurrent_streams: int
|
||||
connection_window: int
|
||||
stream_window: int
|
||||
|
||||
# Logging and debugging
|
||||
enable_qlog: bool
|
||||
qlog_dir: str | None
|
||||
|
||||
# Connection management
|
||||
max_connections: int
|
||||
connection_timeout: float
|
||||
|
||||
# Protocol identifiers
|
||||
PROTOCOL_QUIC_V1: TProtocol
|
||||
PROTOCOL_QUIC_DRAFT29: TProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class QUICTransportConfig(ConnectionConfig):
|
||||
"""Configuration for QUIC transport."""
|
||||
|
||||
# Connection settings
|
||||
idle_timeout: float = 30.0 # Seconds before an idle connection is closed.
|
||||
max_datagram_size: int = (
|
||||
1200 # Maximum size of UDP datagrams to avoid IP fragmentation.
|
||||
)
|
||||
local_port: int | None = (
|
||||
None # Local port to bind to. If None, a random port is chosen.
|
||||
)
|
||||
|
||||
# Protocol version support
|
||||
enable_draft29: bool = True # Enable QUIC draft-29 for compatibility
|
||||
enable_v1: bool = True # Enable QUIC v1 (RFC 9000)
|
||||
|
||||
# TLS settings
|
||||
verify_mode: ssl.VerifyMode = ssl.CERT_NONE
|
||||
alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"])
|
||||
|
||||
# Performance settings
|
||||
max_concurrent_streams: int = 100 # Maximum concurrent streams per connection
|
||||
connection_window: int = 1024 * 1024 # Connection flow control window
|
||||
stream_window: int = 64 * 1024 # Stream flow control window
|
||||
|
||||
# Logging and debugging
|
||||
enable_qlog: bool = False # Enable QUIC logging
|
||||
qlog_dir: str | None = None # Directory for QUIC logs
|
||||
|
||||
# Connection management
|
||||
max_connections: int = 1000 # Maximum number of connections
|
||||
connection_timeout: float = 10.0 # Connection establishment timeout
|
||||
|
||||
MAX_CONCURRENT_STREAMS: int = 1000
|
||||
"""Maximum number of concurrent streams per connection."""
|
||||
|
||||
MAX_INCOMING_STREAMS: int = 1000
|
||||
"""Maximum number of incoming streams per connection."""
|
||||
|
||||
CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0
|
||||
"""Timeout for connection handshake (seconds)."""
|
||||
|
||||
MAX_OUTGOING_STREAMS: int = 1000
|
||||
"""Maximum number of outgoing streams per connection."""
|
||||
|
||||
CONNECTION_CLOSE_TIMEOUT: int = 10
|
||||
"""Timeout for opening new connection (seconds)."""
|
||||
|
||||
# Stream timeouts
|
||||
STREAM_OPEN_TIMEOUT: float = 5.0
|
||||
"""Timeout for opening new streams (seconds)."""
|
||||
|
||||
STREAM_ACCEPT_TIMEOUT: float = 30.0
|
||||
"""Timeout for accepting incoming streams (seconds)."""
|
||||
|
||||
STREAM_READ_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream read operations (seconds)."""
|
||||
|
||||
STREAM_WRITE_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream write operations (seconds)."""
|
||||
|
||||
STREAM_CLOSE_TIMEOUT: float = 10.0
|
||||
"""Timeout for graceful stream close (seconds)."""
|
||||
|
||||
# Flow control configuration
|
||||
STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB
|
||||
"""Per-stream flow control window size."""
|
||||
|
||||
CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB
|
||||
"""Connection-wide flow control window size."""
|
||||
|
||||
# Buffer management
|
||||
MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB
|
||||
"""Maximum receive buffer size per stream."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
|
||||
"""Low watermark for stream receive buffer."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
|
||||
"""High watermark for stream receive buffer."""
|
||||
|
||||
# Stream lifecycle configuration
|
||||
ENABLE_STREAM_RESET_ON_ERROR: bool = True
|
||||
"""Whether to automatically reset streams on errors."""
|
||||
|
||||
STREAM_RESET_ERROR_CODE: int = 1
|
||||
"""Default error code for stream resets."""
|
||||
|
||||
ENABLE_STREAM_KEEP_ALIVE: bool = False
|
||||
"""Whether to enable stream keep-alive mechanisms."""
|
||||
|
||||
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
|
||||
"""Interval for stream keep-alive pings (seconds)."""
|
||||
|
||||
# Resource management
|
||||
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
|
||||
"""Whether to track stream resource usage."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
|
||||
"""Memory limit per individual stream."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
|
||||
"""Total memory limit for all streams per connection."""
|
||||
|
||||
# Concurrency and performance
|
||||
ENABLE_STREAM_BATCHING: bool = True
|
||||
"""Whether to batch multiple stream operations."""
|
||||
|
||||
STREAM_BATCH_SIZE: int = 10
|
||||
"""Number of streams to process in a batch."""
|
||||
|
||||
STREAM_PROCESSING_CONCURRENCY: int = 100
|
||||
"""Maximum concurrent stream processing tasks."""
|
||||
|
||||
# Debugging and monitoring
|
||||
ENABLE_STREAM_METRICS: bool = True
|
||||
"""Whether to collect stream metrics."""
|
||||
|
||||
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
|
||||
"""Whether to track stream lifecycle timelines."""
|
||||
|
||||
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
|
||||
"""Interval for collecting stream metrics (seconds)."""
|
||||
|
||||
# Error handling configuration
|
||||
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
|
||||
"""Number of retry attempts for recoverable stream errors."""
|
||||
|
||||
STREAM_ERROR_RETRY_DELAY: float = 1.0
|
||||
"""Initial delay between stream error retries (seconds)."""
|
||||
|
||||
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
|
||||
"""Backoff factor for stream error retries."""
|
||||
|
||||
# Protocol identifiers matching go-libp2p
|
||||
PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000
|
||||
PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration after initialization."""
|
||||
if not (self.enable_draft29 or self.enable_v1):
|
||||
raise ValueError("At least one QUIC version must be enabled")
|
||||
|
||||
if self.idle_timeout <= 0:
|
||||
raise ValueError("Idle timeout must be positive")
|
||||
|
||||
if self.max_datagram_size < 1200:
|
||||
raise ValueError("Max datagram size must be at least 1200 bytes")
|
||||
|
||||
# Validate timeouts
|
||||
timeout_fields = [
|
||||
"STREAM_OPEN_TIMEOUT",
|
||||
"STREAM_ACCEPT_TIMEOUT",
|
||||
"STREAM_READ_TIMEOUT",
|
||||
"STREAM_WRITE_TIMEOUT",
|
||||
"STREAM_CLOSE_TIMEOUT",
|
||||
]
|
||||
for timeout_field in timeout_fields:
|
||||
if getattr(self, timeout_field) <= 0:
|
||||
raise ValueError(f"{timeout_field} must be positive")
|
||||
|
||||
# Validate flow control windows
|
||||
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
|
||||
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
|
||||
|
||||
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
|
||||
raise ValueError(
|
||||
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
|
||||
)
|
||||
|
||||
# Validate buffer sizes
|
||||
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
|
||||
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
|
||||
|
||||
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
|
||||
"exceed MAX_STREAM_RECEIVE_BUFFER"
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
|
||||
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
|
||||
):
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
|
||||
)
|
||||
|
||||
# Validate memory limits
|
||||
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
|
||||
|
||||
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
|
||||
|
||||
expected_stream_memory = (
|
||||
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
|
||||
)
|
||||
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
|
||||
# Allow some headroom, but warn if configuration seems inconsistent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Stream memory configuration may be inconsistent: "
|
||||
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
|
||||
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
|
||||
"could exceed connection limit of"
|
||||
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
|
||||
)
|
||||
|
||||
def get_stream_config_dict(self) -> dict[str, Any]:
|
||||
"""Get stream-specific configuration as dictionary."""
|
||||
stream_config = {}
|
||||
for attr_name in dir(self):
|
||||
if attr_name.startswith(
|
||||
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
|
||||
):
|
||||
stream_config[attr_name.lower()] = getattr(self, attr_name)
|
||||
return stream_config
|
||||
|
||||
|
||||
# Additional configuration classes for specific stream features
|
||||
|
||||
|
||||
class QUICStreamFlowControlConfig:
|
||||
"""Configuration for QUIC stream flow control."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_window_size: int = 512 * 1024,
|
||||
max_window_size: int = 2 * 1024 * 1024,
|
||||
window_update_threshold: float = 0.5,
|
||||
enable_auto_tuning: bool = True,
|
||||
):
|
||||
self.initial_window_size = initial_window_size
|
||||
self.max_window_size = max_window_size
|
||||
self.window_update_threshold = window_update_threshold
|
||||
self.enable_auto_tuning = enable_auto_tuning
|
||||
|
||||
|
||||
def create_stream_config_for_use_case(
|
||||
use_case: Literal[
|
||||
"high_throughput", "low_latency", "many_streams", "memory_constrained"
|
||||
],
|
||||
) -> QUICTransportConfig:
|
||||
"""
|
||||
Create optimized stream configuration for specific use cases.
|
||||
|
||||
Args:
|
||||
use_case: One of "high_throughput", "low_latency", "many_streams","
|
||||
"memory_constrained"
|
||||
|
||||
Returns:
|
||||
Optimized QUICTransportConfig
|
||||
|
||||
"""
|
||||
base_config = QUICTransportConfig()
|
||||
|
||||
if use_case == "high_throughput":
|
||||
# Optimize for high throughput
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 200
|
||||
|
||||
elif use_case == "low_latency":
|
||||
# Optimize for low latency
|
||||
base_config.STREAM_OPEN_TIMEOUT = 1.0
|
||||
base_config.STREAM_READ_TIMEOUT = 5.0
|
||||
base_config.STREAM_WRITE_TIMEOUT = 5.0
|
||||
base_config.ENABLE_STREAM_BATCHING = False
|
||||
base_config.STREAM_BATCH_SIZE = 1
|
||||
|
||||
elif use_case == "many_streams":
|
||||
# Optimize for many concurrent streams
|
||||
base_config.MAX_CONCURRENT_STREAMS = 5000
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 500
|
||||
|
||||
elif use_case == "memory_constrained":
|
||||
# Optimize for low memory usage
|
||||
base_config.MAX_CONCURRENT_STREAMS = 100
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
|
||||
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 50
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown use case: {use_case}")
|
||||
|
||||
return base_config
|
||||
1487
libp2p/transport/quic/connection.py
Normal file
1487
libp2p/transport/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
391
libp2p/transport/quic/exceptions.py
Normal file
391
libp2p/transport/quic/exceptions.py
Normal file
@ -0,0 +1,391 @@
|
||||
"""
|
||||
QUIC Transport exceptions
|
||||
"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
class QUICError(Exception):
|
||||
"""Base exception for all QUIC transport errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
# Transport-level exceptions
|
||||
|
||||
|
||||
class QUICTransportError(QUICError):
|
||||
"""Base exception for QUIC transport operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICDialError(QUICTransportError):
|
||||
"""Error occurred during QUIC connection establishment."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICListenError(QUICTransportError):
|
||||
"""Error occurred during QUIC listener operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICSecurityError(QUICTransportError):
|
||||
"""Error related to QUIC security/TLS operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Connection-level exceptions
|
||||
|
||||
|
||||
class QUICConnectionError(QUICError):
|
||||
"""Base exception for QUIC connection operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionClosedError(QUICConnectionError):
|
||||
"""QUIC connection has been closed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionTimeoutError(QUICConnectionError):
|
||||
"""QUIC connection operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICHandshakeError(QUICConnectionError):
|
||||
"""Error during QUIC handshake process."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICPeerVerificationError(QUICConnectionError):
|
||||
"""Error verifying peer identity during handshake."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Stream-level exceptions
|
||||
|
||||
|
||||
class QUICStreamError(QUICError):
|
||||
"""Base exception for QUIC stream operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
):
|
||||
super().__init__(message, error_code)
|
||||
self.stream_id = stream_id
|
||||
|
||||
|
||||
class QUICStreamClosedError(QUICStreamError):
|
||||
"""Stream is closed and cannot be used for I/O operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamResetError(QUICStreamError):
|
||||
"""Stream was reset by local or remote peer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
reset_by_peer: bool = False,
|
||||
):
|
||||
super().__init__(message, stream_id, error_code)
|
||||
self.reset_by_peer = reset_by_peer
|
||||
|
||||
|
||||
class QUICStreamTimeoutError(QUICStreamError):
|
||||
"""Stream operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamBackpressureError(QUICStreamError):
|
||||
"""Stream write blocked due to flow control."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamLimitError(QUICStreamError):
|
||||
"""Stream limit reached (too many concurrent streams)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamStateError(QUICStreamError):
|
||||
"""Invalid operation for current stream state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
current_state: str | None = None,
|
||||
attempted_operation: str | None = None,
|
||||
):
|
||||
super().__init__(message, stream_id)
|
||||
self.current_state = current_state
|
||||
self.attempted_operation = attempted_operation
|
||||
|
||||
|
||||
# Flow control exceptions
|
||||
|
||||
|
||||
class QUICFlowControlError(QUICError):
|
||||
"""Base exception for flow control related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlViolationError(QUICFlowControlError):
|
||||
"""Flow control limits were violated."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlDeadlockError(QUICFlowControlError):
|
||||
"""Flow control deadlock detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Resource management exceptions
|
||||
|
||||
|
||||
class QUICResourceError(QUICError):
|
||||
"""Base exception for resource management errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICMemoryLimitError(QUICResourceError):
|
||||
"""Memory limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionLimitError(QUICResourceError):
|
||||
"""Connection limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Multiaddr and addressing exceptions
|
||||
|
||||
|
||||
class QUICAddressError(QUICError):
|
||||
"""Base exception for QUIC addressing errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICInvalidMultiaddrError(QUICAddressError):
|
||||
"""Invalid multiaddr format for QUIC transport."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICAddressResolutionError(QUICAddressError):
|
||||
"""Failed to resolve QUIC address."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICProtocolError(QUICError):
|
||||
"""Base exception for QUIC protocol errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICVersionNegotiationError(QUICProtocolError):
|
||||
"""QUIC version negotiation failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICUnsupportedVersionError(QUICProtocolError):
|
||||
"""Unsupported QUIC version."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Configuration exceptions
|
||||
|
||||
|
||||
class QUICConfigurationError(QUICError):
|
||||
"""Base exception for QUIC configuration errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICInvalidConfigError(QUICConfigurationError):
|
||||
"""Invalid QUIC configuration parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICCertificateError(QUICConfigurationError):
|
||||
"""Error with TLS certificate configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def map_quic_error_code(error_code: int) -> str:
|
||||
"""
|
||||
Map QUIC error codes to human-readable descriptions.
|
||||
Based on RFC 9000 Transport Error Codes.
|
||||
"""
|
||||
error_codes = {
|
||||
0x00: "NO_ERROR",
|
||||
0x01: "INTERNAL_ERROR",
|
||||
0x02: "CONNECTION_REFUSED",
|
||||
0x03: "FLOW_CONTROL_ERROR",
|
||||
0x04: "STREAM_LIMIT_ERROR",
|
||||
0x05: "STREAM_STATE_ERROR",
|
||||
0x06: "FINAL_SIZE_ERROR",
|
||||
0x07: "FRAME_ENCODING_ERROR",
|
||||
0x08: "TRANSPORT_PARAMETER_ERROR",
|
||||
0x09: "CONNECTION_ID_LIMIT_ERROR",
|
||||
0x0A: "PROTOCOL_VIOLATION",
|
||||
0x0B: "INVALID_TOKEN",
|
||||
0x0C: "APPLICATION_ERROR",
|
||||
0x0D: "CRYPTO_BUFFER_EXCEEDED",
|
||||
0x0E: "KEY_UPDATE_ERROR",
|
||||
0x0F: "AEAD_LIMIT_REACHED",
|
||||
0x10: "NO_VIABLE_PATH",
|
||||
}
|
||||
|
||||
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
|
||||
|
||||
|
||||
def create_stream_error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
) -> QUICStreamError:
|
||||
"""
|
||||
Factory function to create appropriate stream error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
|
||||
message: Error message
|
||||
stream_id: Stream identifier
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICStreamError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICStreamClosedError(message, stream_id, error_code)
|
||||
elif error_type == "reset":
|
||||
return QUICStreamResetError(message, stream_id, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICStreamTimeoutError(message, stream_id, error_code)
|
||||
elif error_type in ("backpressure", "flow_control"):
|
||||
return QUICStreamBackpressureError(message, stream_id, error_code)
|
||||
elif error_type in ("limit", "stream_limit"):
|
||||
return QUICStreamLimitError(message, stream_id, error_code)
|
||||
elif error_type == "state":
|
||||
return QUICStreamStateError(message, stream_id)
|
||||
else:
|
||||
return QUICStreamError(message, stream_id, error_code)
|
||||
|
||||
|
||||
def create_connection_error(
|
||||
error_type: str, message: str, error_code: int | None = None
|
||||
) -> QUICConnectionError:
|
||||
"""
|
||||
Factory function to create appropriate connection error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "timeout", "handshake", etc.)
|
||||
message: Error message
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICConnectionError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICConnectionClosedError(message, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICConnectionTimeoutError(message, error_code)
|
||||
elif error_type == "handshake":
|
||||
return QUICHandshakeError(message, error_code)
|
||||
elif error_type in ("peer_verification", "verification"):
|
||||
return QUICPeerVerificationError(message, error_code)
|
||||
else:
|
||||
return QUICConnectionError(message, error_code)
|
||||
|
||||
|
||||
class QUICErrorContext:
|
||||
"""
|
||||
Context manager for handling QUIC errors with automatic error mapping.
|
||||
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
|
||||
"""
|
||||
|
||||
def __init__(self, operation: str, component: str = "quic") -> None:
|
||||
self.operation = operation
|
||||
self.component = component
|
||||
|
||||
def __enter__(self) -> "QUICErrorContext":
|
||||
return self
|
||||
|
||||
# TODO: Fix types for exc_type
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> Literal[False]:
|
||||
if exc_type is None:
|
||||
return False
|
||||
|
||||
if exc_val is None:
|
||||
return False
|
||||
|
||||
# Map common aioquic exceptions to our exceptions
|
||||
if "ConnectionClosed" in str(exc_type):
|
||||
raise QUICConnectionClosedError(
|
||||
f"Connection closed during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "StreamReset" in str(exc_type):
|
||||
raise QUICStreamResetError(
|
||||
f"Stream reset during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "timeout" in str(exc_val).lower():
|
||||
if "stream" in self.component.lower():
|
||||
raise QUICStreamTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
else:
|
||||
raise QUICConnectionTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "flow control" in str(exc_val).lower():
|
||||
raise QUICStreamBackpressureError(
|
||||
f"Flow control error during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
|
||||
# Let other exceptions propagate
|
||||
return False
|
||||
1041
libp2p/transport/quic/listener.py
Normal file
1041
libp2p/transport/quic/listener.py
Normal file
File diff suppressed because it is too large
Load Diff
1165
libp2p/transport/quic/security.py
Normal file
1165
libp2p/transport/quic/security.py
Normal file
File diff suppressed because it is too large
Load Diff
656
libp2p/transport/quic/stream.py
Normal file
656
libp2p/transport/quic/stream.py
Normal file
@ -0,0 +1,656 @@
|
||||
"""
|
||||
QUIC Stream implementation
|
||||
Provides stream interface over QUIC's native multiplexing.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import trio
|
||||
|
||||
from .exceptions import (
|
||||
QUICStreamBackpressureError,
|
||||
QUICStreamClosedError,
|
||||
QUICStreamResetError,
|
||||
QUICStreamTimeoutError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import IMuxedStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
|
||||
from .connection import QUICConnection
|
||||
else:
|
||||
IMuxedStream = cast(type, object)
|
||||
TProtocol = cast(type, object)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamState(Enum):
|
||||
"""Stream lifecycle states following libp2p patterns."""
|
||||
|
||||
OPEN = "open"
|
||||
WRITE_CLOSED = "write_closed"
|
||||
READ_CLOSED = "read_closed"
|
||||
CLOSED = "closed"
|
||||
RESET = "reset"
|
||||
|
||||
|
||||
class StreamDirection(Enum):
|
||||
"""Stream direction for tracking initiator."""
|
||||
|
||||
INBOUND = "inbound"
|
||||
OUTBOUND = "outbound"
|
||||
|
||||
|
||||
class StreamTimeline:
|
||||
"""Track stream lifecycle events for debugging and monitoring."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.created_at = time.time()
|
||||
self.opened_at: float | None = None
|
||||
self.first_data_at: float | None = None
|
||||
self.closed_at: float | None = None
|
||||
self.reset_at: float | None = None
|
||||
self.error_code: int | None = None
|
||||
|
||||
def record_open(self) -> None:
|
||||
self.opened_at = time.time()
|
||||
|
||||
def record_first_data(self) -> None:
|
||||
if self.first_data_at is None:
|
||||
self.first_data_at = time.time()
|
||||
|
||||
def record_close(self) -> None:
|
||||
self.closed_at = time.time()
|
||||
|
||||
def record_reset(self, error_code: int) -> None:
|
||||
self.reset_at = time.time()
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class QUICStream(IMuxedStream):
|
||||
"""
|
||||
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||
|
||||
Based on patterns from go-libp2p and js-libp2p, this implementation:
|
||||
- Leverages QUIC's native multiplexing and flow control
|
||||
- Integrates with libp2p resource management
|
||||
- Provides comprehensive error handling with QUIC-specific codes
|
||||
- Supports bidirectional communication with independent close semantics
|
||||
- Implements proper stream lifecycle management
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: "QUICConnection",
|
||||
stream_id: int,
|
||||
direction: StreamDirection,
|
||||
remote_addr: tuple[str, int],
|
||||
resource_scope: Any | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize QUIC stream.
|
||||
|
||||
Args:
|
||||
connection: Parent QUIC connection
|
||||
stream_id: QUIC stream identifier
|
||||
direction: Stream direction (inbound/outbound)
|
||||
resource_scope: Resource manager scope for memory accounting
|
||||
remote_addr: Remote addr stream is connected to
|
||||
|
||||
"""
|
||||
self._connection = connection
|
||||
self._stream_id = stream_id
|
||||
self._direction = direction
|
||||
self._resource_scope = resource_scope
|
||||
|
||||
# libp2p interface compliance
|
||||
self._protocol: TProtocol | None = None
|
||||
self._metadata: dict[str, Any] = {}
|
||||
self._remote_addr = remote_addr
|
||||
|
||||
# Stream state management
|
||||
self._state = StreamState.OPEN
|
||||
self._state_lock = trio.Lock()
|
||||
|
||||
# Flow control and buffering
|
||||
self._receive_buffer = bytearray()
|
||||
self._receive_buffer_lock = trio.Lock()
|
||||
self._receive_event = trio.Event()
|
||||
self._backpressure_event = trio.Event()
|
||||
self._backpressure_event.set() # Initially no backpressure
|
||||
|
||||
# Close/reset state
|
||||
self._write_closed = False
|
||||
self._read_closed = False
|
||||
self._close_event = trio.Event()
|
||||
self._reset_error_code: int | None = None
|
||||
|
||||
# Lifecycle tracking
|
||||
self._timeline = StreamTimeline()
|
||||
self._timeline.record_open()
|
||||
|
||||
# Resource accounting
|
||||
self._memory_reserved = 0
|
||||
|
||||
# Stream constant configurations
|
||||
self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT
|
||||
self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT
|
||||
self.FLOW_CONTROL_WINDOW_SIZE = (
|
||||
connection._transport._config.STREAM_FLOW_CONTROL_WINDOW
|
||||
)
|
||||
self.MAX_RECEIVE_BUFFER_SIZE = (
|
||||
connection._transport._config.MAX_STREAM_RECEIVE_BUFFER
|
||||
)
|
||||
|
||||
if self._resource_scope:
|
||||
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
|
||||
|
||||
logger.debug(
|
||||
f"Created QUIC stream {stream_id} "
|
||||
f"({direction.value}, connection: {connection.remote_peer_id()})"
|
||||
)
|
||||
|
||||
# Properties for libp2p interface compliance
|
||||
|
||||
@property
|
||||
def protocol(self) -> TProtocol | None:
|
||||
"""Get the protocol identifier for this stream."""
|
||||
return self._protocol
|
||||
|
||||
@protocol.setter
|
||||
def protocol(self, protocol_id: TProtocol) -> None:
|
||||
"""Set the protocol identifier for this stream."""
|
||||
self._protocol = protocol_id
|
||||
self._metadata["protocol"] = protocol_id
|
||||
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str:
|
||||
"""Get stream ID as string for libp2p compatibility."""
|
||||
return str(self._stream_id)
|
||||
|
||||
@property
|
||||
def muxed_conn(self) -> "QUICConnection": # type: ignore
|
||||
"""Get the parent muxed connection."""
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def state(self) -> StreamState:
|
||||
"""Get current stream state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def direction(self) -> StreamDirection:
|
||||
"""Get stream direction."""
|
||||
return self._direction
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
"""Check if this stream was locally initiated."""
|
||||
return self._direction == StreamDirection.OUTBOUND
|
||||
|
||||
# Core stream operations
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read data from the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
n: Maximum number of bytes to read. If None or -1, read all available.
|
||||
|
||||
Returns:
|
||||
Data read from stream
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed
|
||||
QUICStreamResetError: Stream was reset
|
||||
QUICStreamTimeoutError: Read timeout exceeded
|
||||
|
||||
"""
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._read_closed:
|
||||
# Return any remaining buffered data, then EOF
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
return b""
|
||||
|
||||
# Wait for data with timeout
|
||||
timeout = self.READ_TIMEOUT
|
||||
try:
|
||||
with trio.move_on_after(timeout) as cancel_scope:
|
||||
while True:
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
|
||||
# Check if stream was closed while waiting
|
||||
if self._read_closed:
|
||||
return b""
|
||||
|
||||
# Wait for more data
|
||||
await self._receive_event.wait()
|
||||
self._receive_event = trio.Event() # Reset for next wait
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
|
||||
|
||||
return b""
|
||||
except QUICStreamResetError:
|
||||
# Stream was reset while reading
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from stream {self.stream_id}: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
data: Data to write
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed for writing
|
||||
QUICStreamBackpressureError: Flow control window exhausted
|
||||
QUICStreamResetError: Stream was reset
|
||||
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._write_closed:
|
||||
raise QUICStreamClosedError(
|
||||
f"Stream {self.stream_id} write side is closed"
|
||||
)
|
||||
|
||||
try:
|
||||
# Handle flow control backpressure
|
||||
await self._backpressure_event.wait()
|
||||
|
||||
# Send data through QUIC connection
|
||||
self._connection._quic.send_stream_data(self._stream_id, data)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._timeline.record_first_data()
|
||||
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to stream {self.stream_id}: {e}")
|
||||
# Convert QUIC-specific errors
|
||||
if "flow control" in str(e).lower():
|
||||
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the stream gracefully (both read and write sides).
|
||||
|
||||
This implements proper close semantics where both sides
|
||||
are closed and resources are cleaned up.
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
return
|
||||
|
||||
logger.debug(f"Closing stream {self.stream_id}")
|
||||
|
||||
# Close both sides
|
||||
if not self._write_closed:
|
||||
await self.close_write()
|
||||
if not self._read_closed:
|
||||
await self.close_read()
|
||||
|
||||
# Update state and cleanup
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.CLOSED
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_close()
|
||||
self._close_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} closed")
|
||||
|
||||
async def close_write(self) -> None:
|
||||
"""Close the write side of the stream."""
|
||||
if self._write_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
# Send FIN to close write side
|
||||
self._connection._quic.send_stream_data(
|
||||
self._stream_id, b"", end_stream=True
|
||||
)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} write side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def close_read(self) -> None:
|
||||
"""Close the read side of the stream."""
|
||||
if self._read_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
self._read_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up any pending reads
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} read side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def reset(self, error_code: int = 0) -> None:
|
||||
"""
|
||||
Reset the stream with the given error code.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code for the reset
|
||||
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Resetting stream {self.stream_id} with error code {error_code}"
|
||||
)
|
||||
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
try:
|
||||
# Send QUIC reset frame
|
||||
self._connection._quic.reset_stream(self._stream_id, error_code)
|
||||
await self._connection._transmit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
|
||||
finally:
|
||||
# Always cleanup resources
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if stream is completely closed."""
|
||||
return self._state in (StreamState.CLOSED, StreamState.RESET)
|
||||
|
||||
def is_reset(self) -> bool:
|
||||
"""Check if stream was reset."""
|
||||
return self._state == StreamState.RESET
|
||||
|
||||
def can_read(self) -> bool:
|
||||
"""Check if stream can be read from."""
|
||||
return not self._read_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
def can_write(self) -> bool:
|
||||
"""Check if stream can be written to."""
|
||||
return not self._write_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
||||
"""
|
||||
Handle data received from the QUIC connection.
|
||||
|
||||
Args:
|
||||
data: Received data
|
||||
end_stream: Whether this is the last data (FIN received)
|
||||
|
||||
"""
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
if data:
|
||||
async with self._receive_buffer_lock:
|
||||
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
|
||||
logger.warning(
|
||||
f"Stream {self.stream_id} receive buffer overflow, "
|
||||
f"dropping {len(data)} bytes"
|
||||
)
|
||||
return
|
||||
|
||||
self._receive_buffer.extend(data)
|
||||
self._timeline.record_first_data()
|
||||
|
||||
# Notify waiting readers
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
|
||||
|
||||
if end_stream:
|
||||
self._read_closed = True
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up readers to process remaining data and EOF
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received FIN")
|
||||
|
||||
async def handle_stop_sending(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle STOP_SENDING frame from remote peer.
|
||||
|
||||
When a STOP_SENDING frame is received, the peer is requesting that we
|
||||
stop sending data on this stream. We respond by resetting the stream.
|
||||
|
||||
Args:
|
||||
error_code: Error code from the STOP_SENDING frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})"
|
||||
)
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
# Wake up any pending write operations
|
||||
self._backpressure_event.set()
|
||||
|
||||
async with self._state_lock:
|
||||
if self.direction == StreamDirection.OUTBOUND:
|
||||
self._state = StreamState.CLOSED
|
||||
elif self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
# Only write side closed - add WRITE_CLOSED state if needed
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
# Send RESET_STREAM in response (QUIC protocol requirement)
|
||||
try:
|
||||
self._connection._quic.reset_stream(int(self.stream_id), error_code)
|
||||
await self._connection._transmit()
|
||||
logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not send RESET_STREAM for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def handle_reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle stream reset from remote peer.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code from reset frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} reset by peer with error code {error_code}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
# Wake up any pending operations
|
||||
self._receive_event.set()
|
||||
self._backpressure_event.set()
|
||||
|
||||
async def handle_flow_control_update(self, available_window: int) -> None:
|
||||
"""
|
||||
Handle flow control window updates.
|
||||
|
||||
Args:
|
||||
available_window: Available flow control window size
|
||||
|
||||
"""
|
||||
if available_window > 0:
|
||||
self._backpressure_event.set()
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} flow control".__add__(
|
||||
f"window updated: {available_window}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._backpressure_event = trio.Event() # Reset to blocking state
|
||||
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
|
||||
|
||||
def _extract_data_from_buffer(self, n: int) -> bytes:
|
||||
"""Extract data from receive buffer with specified limit."""
|
||||
if n == -1:
|
||||
# Read all available data
|
||||
data = bytes(self._receive_buffer)
|
||||
self._receive_buffer.clear()
|
||||
else:
|
||||
# Read up to n bytes
|
||||
data = bytes(self._receive_buffer[:n])
|
||||
self._receive_buffer = self._receive_buffer[n:]
|
||||
|
||||
return data
|
||||
|
||||
async def _handle_stream_error(self, error: Exception) -> None:
|
||||
"""Handle errors by resetting the stream."""
|
||||
logger.error(f"Stream {self.stream_id} error: {error}")
|
||||
await self.reset(error_code=1) # Generic error code
|
||||
|
||||
def _reserve_memory(self, size: int) -> None:
|
||||
"""Reserve memory with resource manager."""
|
||||
if self._resource_scope:
|
||||
try:
|
||||
self._resource_scope.reserve_memory(size)
|
||||
self._memory_reserved += size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to reserve memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
def _release_memory(self, size: int) -> None:
|
||||
"""Release memory with resource manager."""
|
||||
if self._resource_scope and size > 0:
|
||||
try:
|
||||
self._resource_scope.release_memory(size)
|
||||
self._memory_reserved = max(0, self._memory_reserved - size)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to release memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def _cleanup_resources(self) -> None:
|
||||
"""Clean up stream resources."""
|
||||
# Release all reserved memory
|
||||
if self._memory_reserved > 0:
|
||||
self._release_memory(self._memory_reserved)
|
||||
|
||||
# Clear receive buffer
|
||||
async with self._receive_buffer_lock:
|
||||
self._receive_buffer.clear()
|
||||
|
||||
# Remove from connection's stream registry
|
||||
self._connection._remove_stream(self._stream_id)
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} resources cleaned up")
|
||||
|
||||
# Abstact implementations
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int]:
|
||||
return self._remote_addr
|
||||
|
||||
async def __aenter__(self) -> "QUICStream":
|
||||
"""Enter the async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the stream."""
|
||||
logger.debug("Exiting the context and closing the stream")
|
||||
await self.close()
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
Set a deadline for the stream. QUIC does not support deadlines natively,
|
||||
so this method always returns False to indicate the operation is unsupported.
|
||||
|
||||
:param ttl: Time-to-live in seconds (ignored).
|
||||
:return: False, as deadlines are not supported.
|
||||
"""
|
||||
raise NotImplementedError("QUIC does not support setting read deadlines")
|
||||
|
||||
# String representation for debugging
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"QUICStream(id={self.stream_id}, "
|
||||
f"state={self._state.value}, "
|
||||
f"direction={self._direction.value}, "
|
||||
f"protocol={self._protocol})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"QUICStream({self.stream_id})"
|
||||
491
libp2p/transport/quic/transport.py
Normal file
491
libp2p/transport/quic/transport.py
Normal file
@ -0,0 +1,491 @@
|
||||
"""
|
||||
QUIC Transport implementation
|
||||
"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import ssl
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from aioquic.quic.configuration import (
|
||||
QuicConfiguration,
|
||||
)
|
||||
from aioquic.quic.connection import (
|
||||
QuicConnection as NativeQUICConnection,
|
||||
)
|
||||
from aioquic.quic.logger import QuicLogger
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
)
|
||||
from libp2p.custom_types import TProtocol, TQUICConnHandlerFn
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.transport.quic.security import QUICTLSSecurityConfig
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_client_config_from_base,
|
||||
create_server_config_from_base,
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
quic_multiaddr_to_endpoint,
|
||||
quic_version_to_wire_format,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.network.swarm import Swarm
|
||||
else:
|
||||
Swarm = cast(type, object)
|
||||
|
||||
from .config import (
|
||||
QUICTransportConfig,
|
||||
)
|
||||
from .connection import (
|
||||
QUICConnection,
|
||||
)
|
||||
from .exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
QUICSecurityError,
|
||||
)
|
||||
from .listener import (
|
||||
QUICListener,
|
||||
)
|
||||
from .security import (
|
||||
QUICTLSConfigManager,
|
||||
create_quic_security_transport,
|
||||
)
|
||||
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QUICTransport(ITransport):
|
||||
"""
|
||||
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, private_key: PrivateKey, config: QUICTransportConfig | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Initialize QUIC transport with security integration.
|
||||
|
||||
Args:
|
||||
private_key: libp2p private key for identity and TLS cert generation
|
||||
config: QUIC transport configuration options
|
||||
|
||||
"""
|
||||
self._private_key = private_key
|
||||
self._peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
self._config = config or QUICTransportConfig()
|
||||
|
||||
# Connection management
|
||||
self._connections: dict[str, QUICConnection] = {}
|
||||
self._listeners: list[QUICListener] = []
|
||||
|
||||
# Security manager for TLS integration
|
||||
self._security_manager = create_quic_security_transport(
|
||||
self._private_key, self._peer_id
|
||||
)
|
||||
|
||||
# QUIC configurations for different versions
|
||||
self._quic_configs: dict[TProtocol, QuicConfiguration] = {}
|
||||
self._setup_quic_configurations()
|
||||
|
||||
# Resource management
|
||||
self._closed = False
|
||||
self._nursery_manager = trio.CapacityLimiter(1)
|
||||
self._background_nursery: trio.Nursery | None = None
|
||||
|
||||
self._swarm: Swarm | None = None
|
||||
|
||||
logger.debug(
|
||||
f"Initialized QUIC transport with security for peer {self._peer_id}"
|
||||
)
|
||||
|
||||
def set_background_nursery(self, nursery: trio.Nursery) -> None:
|
||||
"""Set the nursery to use for background tasks (called by swarm)."""
|
||||
self._background_nursery = nursery
|
||||
logger.debug("Transport background nursery set")
|
||||
|
||||
def set_swarm(self, swarm: Swarm) -> None:
|
||||
"""Set the swarm for adding incoming connections."""
|
||||
self._swarm = swarm
|
||||
|
||||
def _setup_quic_configurations(self) -> None:
|
||||
"""Setup QUIC configurations."""
|
||||
try:
|
||||
# Get TLS configuration from security manager
|
||||
server_tls_config = self._security_manager.create_server_config()
|
||||
client_tls_config = self._security_manager.create_client_config()
|
||||
|
||||
# Base server configuration
|
||||
base_server_config = QuicConfiguration(
|
||||
is_client=False,
|
||||
alpn_protocols=get_alpn_protocols(),
|
||||
verify_mode=self._config.verify_mode,
|
||||
max_datagram_frame_size=self._config.max_datagram_size,
|
||||
idle_timeout=self._config.idle_timeout,
|
||||
)
|
||||
|
||||
# Base client configuration
|
||||
base_client_config = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=get_alpn_protocols(),
|
||||
verify_mode=self._config.verify_mode,
|
||||
max_datagram_frame_size=self._config.max_datagram_size,
|
||||
idle_timeout=self._config.idle_timeout,
|
||||
)
|
||||
|
||||
# Apply TLS configuration
|
||||
self._apply_tls_configuration(base_server_config, server_tls_config)
|
||||
self._apply_tls_configuration(base_client_config, client_tls_config)
|
||||
|
||||
# QUIC v1 (RFC 9000) configurations
|
||||
if self._config.enable_v1:
|
||||
quic_v1_server_config = create_server_config_from_base(
|
||||
base_server_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
quic_v1_client_config = create_client_config_from_base(
|
||||
base_client_config, self._security_manager, self._config
|
||||
)
|
||||
quic_v1_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
# Store both server and client configs for v1
|
||||
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = (
|
||||
quic_v1_server_config
|
||||
)
|
||||
self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = (
|
||||
quic_v1_client_config
|
||||
)
|
||||
|
||||
# QUIC draft-29 configurations for compatibility
|
||||
if self._config.enable_draft29:
|
||||
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
|
||||
draft29_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
|
||||
draft29_client_config = copy.copy(base_client_config)
|
||||
draft29_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
|
||||
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = (
|
||||
draft29_server_config
|
||||
)
|
||||
self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = (
|
||||
draft29_client_config
|
||||
)
|
||||
|
||||
logger.debug("QUIC configurations initialized with libp2p TLS security")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(
|
||||
f"Failed to setup QUIC TLS configurations: {e}"
|
||||
) from e
|
||||
|
||||
def _apply_tls_configuration(
|
||||
self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig
|
||||
) -> None:
|
||||
"""
|
||||
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
||||
|
||||
Args:
|
||||
config: QuicConfiguration to update
|
||||
tls_config: TLS configuration dictionary from security manager
|
||||
|
||||
"""
|
||||
try:
|
||||
config.certificate = tls_config.certificate
|
||||
config.private_key = tls_config.private_key
|
||||
config.certificate_chain = tls_config.certificate_chain
|
||||
config.alpn_protocols = tls_config.alpn_protocols
|
||||
config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.debug("Successfully applied TLS configuration to QUIC config")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
|
||||
|
||||
async def dial(
|
||||
self,
|
||||
maddr: multiaddr.Multiaddr,
|
||||
) -> QUICConnection:
|
||||
"""
|
||||
Dial a remote peer using QUIC transport with security verification.
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1)
|
||||
peer_id: Expected peer ID for verification
|
||||
nursery: Nursery to execute the background tasks
|
||||
|
||||
Returns:
|
||||
Raw connection interface to the remote peer
|
||||
|
||||
Raises:
|
||||
QUICDialError: If dialing fails
|
||||
QUICSecurityError: If security verification fails
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise QUICDialError("Transport is closed")
|
||||
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}")
|
||||
|
||||
try:
|
||||
# Extract connection details from multiaddr
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
remote_peer_id = maddr.get_peer_id()
|
||||
if remote_peer_id is not None:
|
||||
remote_peer_id = ID.from_base58(remote_peer_id)
|
||||
|
||||
if remote_peer_id is None:
|
||||
logger.error("Unable to derive peer id from multiaddr")
|
||||
raise QUICDialError("Unable to derive peer id from multiaddr")
|
||||
quic_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
# Get appropriate QUIC client configuration
|
||||
config_key = TProtocol(f"{quic_version}_client")
|
||||
logger.debug("config_key", config_key, self._quic_configs.keys())
|
||||
config = self._quic_configs.get(config_key)
|
||||
if not config:
|
||||
raise QUICDialError(f"Unsupported QUIC version: {quic_version}")
|
||||
|
||||
config.is_client = True
|
||||
config.quic_logger = QuicLogger()
|
||||
|
||||
# Ensure client certificate is properly set for mutual authentication
|
||||
if not config.certificate or not config.private_key:
|
||||
logger.warning(
|
||||
"Client config missing certificate - applying TLS config"
|
||||
)
|
||||
client_tls_config = self._security_manager.create_client_config()
|
||||
self._apply_tls_configuration(config, client_tls_config)
|
||||
|
||||
# Debug log to verify certificate is present
|
||||
logger.info(
|
||||
f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})"
|
||||
)
|
||||
|
||||
logger.debug("Starting QUIC Connection")
|
||||
# Create QUIC connection using aioquic's sans-IO core
|
||||
native_quic_connection = NativeQUICConnection(configuration=config)
|
||||
|
||||
# Create trio-based QUIC connection wrapper with security
|
||||
connection = QUICConnection(
|
||||
quic_connection=native_quic_connection,
|
||||
remote_addr=(host, port),
|
||||
remote_peer_id=remote_peer_id,
|
||||
local_peer_id=self._peer_id,
|
||||
is_initiator=True,
|
||||
maddr=maddr,
|
||||
transport=self,
|
||||
security_manager=self._security_manager,
|
||||
)
|
||||
logger.debug("QUIC Connection Created")
|
||||
|
||||
if self._background_nursery is None:
|
||||
logger.error("No nursery set to execute background tasks")
|
||||
raise QUICDialError("No nursery found to execute tasks")
|
||||
|
||||
await connection.connect(self._background_nursery)
|
||||
|
||||
# Store connection for management
|
||||
conn_id = f"{host}:{port}"
|
||||
self._connections[conn_id] = connection
|
||||
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dial QUIC connection to {maddr}: {e}")
|
||||
raise QUICDialError(f"Dial failed: {e}") from e
|
||||
|
||||
async def _verify_peer_identity(
|
||||
self, connection: QUICConnection, expected_peer_id: ID
|
||||
) -> None:
|
||||
"""
|
||||
Verify remote peer identity after TLS handshake.
|
||||
|
||||
Args:
|
||||
connection: The established QUIC connection
|
||||
expected_peer_id: Expected peer ID
|
||||
|
||||
Raises:
|
||||
QUICSecurityError: If peer verification fails
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get peer certificate from the connection
|
||||
peer_certificate = await connection.get_peer_certificate()
|
||||
|
||||
if not peer_certificate:
|
||||
raise QUICSecurityError("No peer certificate available")
|
||||
|
||||
# Verify peer identity using security manager
|
||||
verified_peer_id = self._security_manager.verify_peer_identity(
|
||||
peer_certificate, expected_peer_id
|
||||
)
|
||||
|
||||
if verified_peer_id != expected_peer_id:
|
||||
raise QUICSecurityError(
|
||||
"Peer ID verification failed: expected "
|
||||
f"{expected_peer_id}, got {verified_peer_id}"
|
||||
)
|
||||
|
||||
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||
logger.debug(f"Peer identity verified: {verified_peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Peer identity verification failed: {e}") from e
|
||||
|
||||
def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener:
|
||||
"""
|
||||
Create a QUIC listener with integrated security.
|
||||
|
||||
Args:
|
||||
handler_function: Function to handle new connections
|
||||
|
||||
Returns:
|
||||
QUIC listener instance
|
||||
|
||||
Raises:
|
||||
QUICListenError: If transport is closed
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise QUICListenError("Transport is closed")
|
||||
|
||||
# Get server configurations for the listener
|
||||
server_configs = {
|
||||
version: config
|
||||
for version, config in self._quic_configs.items()
|
||||
if version.endswith("_server")
|
||||
}
|
||||
|
||||
listener = QUICListener(
|
||||
transport=self,
|
||||
handler_function=handler_function,
|
||||
quic_configs=server_configs,
|
||||
config=self._config,
|
||||
security_manager=self._security_manager,
|
||||
)
|
||||
|
||||
self._listeners.append(listener)
|
||||
logger.debug("Created QUIC listener with security")
|
||||
return listener
|
||||
|
||||
def can_dial(self, maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""
|
||||
Check if this transport can dial the given multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to check
|
||||
|
||||
Returns:
|
||||
True if this transport can dial the address
|
||||
|
||||
"""
|
||||
return is_quic_multiaddr(maddr)
|
||||
|
||||
def protocols(self) -> list[TProtocol]:
|
||||
"""
|
||||
Get supported protocol identifiers.
|
||||
|
||||
Returns:
|
||||
List of supported protocol strings
|
||||
|
||||
"""
|
||||
protocols = [QUIC_V1_PROTOCOL]
|
||||
if self._config.enable_draft29:
|
||||
protocols.append(QUIC_DRAFT29_PROTOCOL)
|
||||
return protocols
|
||||
|
||||
def listen_order(self) -> int:
|
||||
"""
|
||||
Get the listen order priority for this transport.
|
||||
Matches go-libp2p's ListenOrder = 1 for QUIC.
|
||||
|
||||
Returns:
|
||||
Priority order for listening (lower = higher priority)
|
||||
|
||||
"""
|
||||
return 1
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the transport and cleanup resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
logger.debug("Closing QUIC transport")
|
||||
|
||||
# Close all active connections and listeners concurrently using trio nursery
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Close all connections
|
||||
for connection in self._connections.values():
|
||||
nursery.start_soon(connection.close)
|
||||
|
||||
# Close all listeners
|
||||
for listener in self._listeners:
|
||||
nursery.start_soon(listener.close)
|
||||
|
||||
self._connections.clear()
|
||||
self._listeners.clear()
|
||||
|
||||
logger.debug("QUIC transport closed")
|
||||
|
||||
async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None:
|
||||
"""Clean up a terminated connection from all listeners."""
|
||||
try:
|
||||
for listener in self._listeners:
|
||||
await listener._remove_connection_by_object(connection)
|
||||
logger.debug(
|
||||
"✅ TRANSPORT: Cleaned up terminated connection from all listeners"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ TRANSPORT: Error cleaning up terminated connection: {e}")
|
||||
|
||||
def get_stats(self) -> dict[str, int | list[str] | object]:
|
||||
"""Get transport statistics including security info."""
|
||||
return {
|
||||
"active_connections": len(self._connections),
|
||||
"active_listeners": len(self._listeners),
|
||||
"supported_protocols": self.protocols(),
|
||||
"local_peer_id": str(self._peer_id),
|
||||
"security_enabled": True,
|
||||
"tls_configured": True,
|
||||
}
|
||||
|
||||
def get_security_manager(self) -> QUICTLSConfigManager:
|
||||
"""
|
||||
Get the security manager for this transport.
|
||||
|
||||
Returns:
|
||||
The QUIC TLS configuration manager
|
||||
|
||||
"""
|
||||
return self._security_manager
|
||||
|
||||
def get_listener_socket(self) -> trio.socket.SocketType | None:
|
||||
"""Get the socket from the first active listener."""
|
||||
for listener in self._listeners:
|
||||
if listener.is_listening() and listener._socket:
|
||||
return listener._socket
|
||||
return None
|
||||
466
libp2p/transport/quic/utils.py
Normal file
466
libp2p/transport/quic/utils.py
Normal file
@ -0,0 +1,466 @@
|
||||
"""
|
||||
Multiaddr utilities for QUIC transport - Module 4.
|
||||
Essential utilities required for QUIC transport implementation.
|
||||
Based on go-libp2p and js-libp2p QUIC implementations.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
import multiaddr
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||
|
||||
from .config import QUICTransportConfig
|
||||
from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Protocol constants
|
||||
QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1
|
||||
QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29
|
||||
UDP_PROTOCOL = "udp"
|
||||
IP4_PROTOCOL = "ip4"
|
||||
IP6_PROTOCOL = "ip6"
|
||||
|
||||
SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server"
|
||||
CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client"
|
||||
|
||||
SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server"
|
||||
CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client"
|
||||
|
||||
CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = {
|
||||
SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000
|
||||
CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000
|
||||
SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||
CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29
|
||||
}
|
||||
|
||||
# QUIC version to wire format mappings (required for aioquic)
|
||||
QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = {
|
||||
QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000
|
||||
QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29
|
||||
}
|
||||
|
||||
# ALPN protocols for libp2p over QUIC
|
||||
LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"]
|
||||
|
||||
|
||||
def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool:
|
||||
"""
|
||||
Check if a multiaddr represents a QUIC address.
|
||||
|
||||
Valid QUIC multiaddrs:
|
||||
- /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
- /ip4/127.0.0.1/udp/4001/quic
|
||||
- /ip6/::1/udp/4001/quic-v1
|
||||
- /ip6/::1/udp/4001/quic
|
||||
|
||||
Args:
|
||||
maddr: Multiaddr to check
|
||||
|
||||
Returns:
|
||||
True if the multiaddr represents a QUIC address
|
||||
|
||||
"""
|
||||
try:
|
||||
addr_str = str(maddr)
|
||||
|
||||
# Check for required components
|
||||
has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str
|
||||
has_udp = f"/{UDP_PROTOCOL}/" in addr_str
|
||||
has_quic = (
|
||||
f"/{QUIC_V1_PROTOCOL}" in addr_str
|
||||
or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str
|
||||
or "/quic" in addr_str
|
||||
)
|
||||
|
||||
return has_ip and has_udp and has_quic
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]:
|
||||
"""
|
||||
Extract host and port from a QUIC multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Tuple of (host, port)
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}")
|
||||
|
||||
try:
|
||||
host = None
|
||||
port = None
|
||||
|
||||
# Try to get IPv4 address
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try to get IPv6 address if IPv4 not found
|
||||
if host is None:
|
||||
try:
|
||||
host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get UDP port
|
||||
try:
|
||||
port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore
|
||||
port = int(port_str)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if host is None or port is None:
|
||||
raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}")
|
||||
|
||||
return host, port
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(
|
||||
f"Failed to parse QUIC multiaddr {maddr}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol:
|
||||
"""
|
||||
Determine QUIC version from multiaddr.
|
||||
|
||||
Args:
|
||||
maddr: QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
QUIC version identifier ("quic-v1" or "quic")
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol
|
||||
|
||||
"""
|
||||
try:
|
||||
addr_str = str(maddr)
|
||||
|
||||
if f"/{QUIC_V1_PROTOCOL}" in addr_str:
|
||||
return QUIC_V1_PROTOCOL # RFC 9000
|
||||
elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str:
|
||||
return QUIC_DRAFT29_PROTOCOL # draft-29
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(
|
||||
f"Failed to determine QUIC version from {maddr}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def create_quic_multiaddr(
|
||||
host: str, port: int, version: str = "quic-v1"
|
||||
) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Create a QUIC multiaddr from host, port, and version.
|
||||
|
||||
Args:
|
||||
host: IP address (IPv4 or IPv6)
|
||||
port: UDP port number
|
||||
version: QUIC version ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
QUIC multiaddr
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If invalid parameters provided
|
||||
|
||||
"""
|
||||
try:
|
||||
# Determine IP version
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
ip_proto = IP4_PROTOCOL
|
||||
else:
|
||||
ip_proto = IP6_PROTOCOL
|
||||
except ValueError:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}")
|
||||
|
||||
# Validate port
|
||||
if not (0 <= port <= 65535):
|
||||
raise QUICInvalidMultiaddrError(f"Invalid port: {port}")
|
||||
|
||||
# Validate and normalize QUIC version
|
||||
if version == "quic-v1" or version == "/quic-v1":
|
||||
quic_proto = QUIC_V1_PROTOCOL
|
||||
elif version == "quic" or version == "/quic":
|
||||
quic_proto = QUIC_DRAFT29_PROTOCOL
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
|
||||
|
||||
# Construct multiaddr
|
||||
addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}"
|
||||
return multiaddr.Multiaddr(addr_str)
|
||||
|
||||
except Exception as e:
|
||||
raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e
|
||||
|
||||
|
||||
def quic_version_to_wire_format(version: TProtocol) -> int:
|
||||
"""
|
||||
Convert QUIC version string to wire format integer for aioquic.
|
||||
|
||||
Args:
|
||||
version: QUIC version string ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
Wire format version number
|
||||
|
||||
Raises:
|
||||
QUICUnsupportedVersionError: If version is not supported
|
||||
|
||||
"""
|
||||
wire_version = QUIC_VERSION_MAPPINGS.get(version)
|
||||
if wire_version is None:
|
||||
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||
|
||||
return wire_version
|
||||
|
||||
|
||||
def custom_quic_version_to_wire_format(version: TProtocol) -> int:
|
||||
"""
|
||||
Convert QUIC version string to wire format integer for aioquic.
|
||||
|
||||
Args:
|
||||
version: QUIC version string ("quic-v1" or "quic")
|
||||
|
||||
Returns:
|
||||
Wire format version number
|
||||
|
||||
Raises:
|
||||
QUICUnsupportedVersionError: If version is not supported
|
||||
|
||||
"""
|
||||
wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version)
|
||||
if wire_version is None:
|
||||
raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}")
|
||||
|
||||
return wire_version
|
||||
|
||||
|
||||
def get_alpn_protocols() -> list[str]:
|
||||
"""
|
||||
Get ALPN protocols for libp2p over QUIC.
|
||||
|
||||
Returns:
|
||||
List of ALPN protocol identifiers
|
||||
|
||||
"""
|
||||
return LIBP2P_ALPN_PROTOCOLS.copy()
|
||||
|
||||
|
||||
def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr:
|
||||
"""
|
||||
Normalize a QUIC multiaddr to canonical form.
|
||||
|
||||
Args:
|
||||
maddr: Input QUIC multiaddr
|
||||
|
||||
Returns:
|
||||
Normalized multiaddr
|
||||
|
||||
Raises:
|
||||
QUICInvalidMultiaddrError: If not a valid QUIC multiaddr
|
||||
|
||||
"""
|
||||
if not is_quic_multiaddr(maddr):
|
||||
raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}")
|
||||
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
return create_quic_multiaddr(host, port, version)
|
||||
|
||||
|
||||
def create_server_config_from_base(
|
||||
base_config: QuicConfiguration,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
transport_config: QUICTransportConfig | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""
|
||||
Create a server configuration without using deepcopy.
|
||||
Manually copies attributes while handling cryptography objects properly.
|
||||
"""
|
||||
try:
|
||||
# Create new server configuration from scratch
|
||||
server_config = QuicConfiguration(is_client=False)
|
||||
server_config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Copy basic configuration attributes (these are safe to copy)
|
||||
copyable_attrs = [
|
||||
"alpn_protocols",
|
||||
"verify_mode",
|
||||
"max_datagram_frame_size",
|
||||
"idle_timeout",
|
||||
"max_concurrent_streams",
|
||||
"supported_versions",
|
||||
"max_data",
|
||||
"max_stream_data",
|
||||
"stateless_retry",
|
||||
"quantum_readiness_test",
|
||||
]
|
||||
|
||||
for attr in copyable_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(server_config, attr, value)
|
||||
|
||||
# Handle cryptography objects - these need direct reference, not copying
|
||||
crypto_attrs = [
|
||||
"certificate",
|
||||
"private_key",
|
||||
"certificate_chain",
|
||||
"ca_certs",
|
||||
]
|
||||
|
||||
for attr in crypto_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(server_config, attr, value)
|
||||
|
||||
# Apply security manager configuration if available
|
||||
if security_manager:
|
||||
try:
|
||||
server_tls_config = security_manager.create_server_config()
|
||||
|
||||
# Override with security manager's TLS configuration
|
||||
if server_tls_config.certificate:
|
||||
server_config.certificate = server_tls_config.certificate
|
||||
if server_tls_config.private_key:
|
||||
server_config.private_key = server_tls_config.private_key
|
||||
if server_tls_config.certificate_chain:
|
||||
server_config.certificate_chain = (
|
||||
server_tls_config.certificate_chain
|
||||
)
|
||||
if server_tls_config.alpn_protocols:
|
||||
server_config.alpn_protocols = server_tls_config.alpn_protocols
|
||||
server_tls_config.request_client_certificate = True
|
||||
if getattr(server_tls_config, "request_client_certificate", False):
|
||||
server_config._libp2p_request_client_cert = True # type: ignore
|
||||
else:
|
||||
logger.error(
|
||||
"🔧 Failed to set request_client_certificate in server config"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
|
||||
# Set transport-specific defaults if provided
|
||||
if transport_config:
|
||||
if server_config.idle_timeout == 0:
|
||||
server_config.idle_timeout = getattr(
|
||||
transport_config, "idle_timeout", 30.0
|
||||
)
|
||||
if server_config.max_datagram_frame_size is None:
|
||||
server_config.max_datagram_frame_size = getattr(
|
||||
transport_config, "max_datagram_size", 1200
|
||||
)
|
||||
# Ensure we have ALPN protocols
|
||||
if not server_config.alpn_protocols:
|
||||
server_config.alpn_protocols = ["libp2p"]
|
||||
|
||||
logger.debug("Successfully created server config without deepcopy")
|
||||
return server_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create server config: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_client_config_from_base(
|
||||
base_config: QuicConfiguration,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
transport_config: QUICTransportConfig | None = None,
|
||||
) -> QuicConfiguration:
|
||||
"""
|
||||
Create a client configuration without using deepcopy.
|
||||
"""
|
||||
try:
|
||||
# Create new client configuration from scratch
|
||||
client_config = QuicConfiguration(is_client=True)
|
||||
client_config.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Copy basic configuration attributes
|
||||
copyable_attrs = [
|
||||
"alpn_protocols",
|
||||
"verify_mode",
|
||||
"max_datagram_frame_size",
|
||||
"idle_timeout",
|
||||
"max_concurrent_streams",
|
||||
"supported_versions",
|
||||
"max_data",
|
||||
"max_stream_data",
|
||||
"quantum_readiness_test",
|
||||
]
|
||||
|
||||
for attr in copyable_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Handle cryptography objects - these need direct reference, not copying
|
||||
crypto_attrs = [
|
||||
"certificate",
|
||||
"private_key",
|
||||
"certificate_chain",
|
||||
"ca_certs",
|
||||
]
|
||||
|
||||
for attr in crypto_attrs:
|
||||
if hasattr(base_config, attr):
|
||||
value = getattr(base_config, attr)
|
||||
if value is not None:
|
||||
setattr(client_config, attr, value)
|
||||
|
||||
# Apply security manager configuration if available
|
||||
if security_manager:
|
||||
try:
|
||||
client_tls_config = security_manager.create_client_config()
|
||||
|
||||
# Override with security manager's TLS configuration
|
||||
if client_tls_config.certificate:
|
||||
client_config.certificate = client_tls_config.certificate
|
||||
if client_tls_config.private_key:
|
||||
client_config.private_key = client_tls_config.private_key
|
||||
if client_tls_config.certificate_chain:
|
||||
client_config.certificate_chain = (
|
||||
client_tls_config.certificate_chain
|
||||
)
|
||||
if client_tls_config.alpn_protocols:
|
||||
client_config.alpn_protocols = client_tls_config.alpn_protocols
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to apply security manager config: {e}")
|
||||
|
||||
# Ensure we have ALPN protocols
|
||||
if not client_config.alpn_protocols:
|
||||
client_config.alpn_protocols = ["libp2p"]
|
||||
|
||||
logger.debug("Successfully created client config without deepcopy")
|
||||
return client_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create client config: {e}")
|
||||
raise
|
||||
@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectClientError,
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.protocol_muxer.multiselect import (
|
||||
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
)
|
||||
from libp2p.security.exceptions import (
|
||||
HandshakeFailure,
|
||||
)
|
||||
@ -37,9 +40,12 @@ class TransportUpgrader:
|
||||
self,
|
||||
secure_transports_by_protocol: TSecurityOptions,
|
||||
muxer_transports_by_protocol: TMuxerOptions,
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
):
|
||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
||||
self.muxer_multistream = MuxerMultistream(
|
||||
muxer_transports_by_protocol, negotiate_timeout
|
||||
)
|
||||
|
||||
async def upgrade_security(
|
||||
self,
|
||||
|
||||
1
newsfragments/763.feature.rst
Normal file
1
newsfragments/763.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Add QUIC transport support for faster, more efficient peer-to-peer connections with native stream multiplexing.
|
||||
1
newsfragments/896.bugfix.rst
Normal file
1
newsfragments/896.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly
|
||||
1
newsfragments/927.bugfix.rst
Normal file
1
newsfragments/927.bugfix.rst
Normal file
@ -0,0 +1 @@
|
||||
Fix multiaddr dependency to use the last py-multiaddr commit hash to resolve installation issues
|
||||
1
newsfragments/934.misc.rst
Normal file
1
newsfragments/934.misc.rst
Normal file
@ -0,0 +1 @@
|
||||
Updated multiaddr dependency from git repository to pip package version 0.0.11.
|
||||
@ -16,13 +16,14 @@ maintainers = [
|
||||
{ name = "Dave Grantham", email = "dwg@linuxprogrammer.org" },
|
||||
]
|
||||
dependencies = [
|
||||
"aioquic>=1.2.0",
|
||||
"base58>=1.0.3",
|
||||
"coincurve>=10.0.0",
|
||||
"coincurve==21.0.0",
|
||||
"exceptiongroup>=1.2.0; python_version < '3.11'",
|
||||
"fastecdsa==2.3.2; sys_platform != 'win32'",
|
||||
"grpcio>=1.41.0",
|
||||
"lru-dict>=1.1.6",
|
||||
# "multiaddr>=0.0.9",
|
||||
"multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6",
|
||||
"multiaddr>=0.0.11",
|
||||
"mypy-protobuf>=3.0.0",
|
||||
"noiseprotocol>=0.3.0",
|
||||
"protobuf>=4.25.0,<5.0.0",
|
||||
@ -32,7 +33,6 @@ dependencies = [
|
||||
"rpcudp>=3.0.0",
|
||||
"trio-typing>=0.0.4",
|
||||
"trio>=0.26.0",
|
||||
"fastecdsa==2.3.2; sys_platform != 'win32'",
|
||||
"zeroconf (>=0.147.0,<0.148.0)",
|
||||
]
|
||||
classifiers = [
|
||||
@ -52,6 +52,7 @@ Homepage = "https://github.com/libp2p/py-libp2p"
|
||||
[project.scripts]
|
||||
chat-demo = "examples.chat.chat:main"
|
||||
echo-demo = "examples.echo.echo:main"
|
||||
echo-quic-demo="examples.echo.echo_quic:main"
|
||||
ping-demo = "examples.ping.ping:main"
|
||||
identify-demo = "examples.identify.identify:main"
|
||||
identify-push-demo = "examples.identify_push.identify_push_demo:run_main"
|
||||
@ -77,6 +78,7 @@ dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-xdist>=2.4.0",
|
||||
"pytest-trio>=0.5.2",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"factory-boy>=2.12.0,<3.0.0",
|
||||
"ruff>=0.11.10",
|
||||
"pyrefly (>=0.17.1,<0.18.0)",
|
||||
@ -88,11 +90,12 @@ docs = [
|
||||
"tomli; python_version < '3.11'",
|
||||
]
|
||||
test = [
|
||||
"factory-boy>=2.12.0,<3.0.0",
|
||||
"p2pclient==0.2.0",
|
||||
"pytest>=7.0.0",
|
||||
"pytest-xdist>=2.4.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-trio>=0.5.2",
|
||||
"factory-boy>=2.12.0,<3.0.0",
|
||||
"pytest-xdist>=2.4.0",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
@ -282,4 +285,5 @@ project_excludes = [
|
||||
"**/*pb2.py",
|
||||
"**/*.pyi",
|
||||
".venv/**",
|
||||
"./tests/interop/nim_libp2p",
|
||||
]
|
||||
|
||||
@ -250,10 +250,13 @@ def test_new_swarm_tcp_multiaddr_supported():
|
||||
assert isinstance(swarm.transport, TCP)
|
||||
|
||||
|
||||
def test_new_swarm_quic_multiaddr_raises():
|
||||
def test_new_swarm_quic_multiaddr_supported():
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
|
||||
with pytest.raises(ValueError, match="QUIC not yet supported"):
|
||||
new_swarm(listen_addrs=[addr])
|
||||
swarm = new_swarm(listen_addrs=[addr])
|
||||
assert isinstance(swarm, Swarm)
|
||||
assert isinstance(swarm.transport, QUICTransport)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
|
||||
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
@ -0,0 +1,108 @@
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.custom_types import (
|
||||
TMuxerClass,
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.protocol_muxer.exceptions import (
|
||||
MultiselectError,
|
||||
)
|
||||
from libp2p.stream_muxer.muxer_multistream import (
|
||||
MuxerMultistream,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_muxer_timeout_configuration():
|
||||
"""Test that muxer respects timeout configuration."""
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=1)
|
||||
assert muxer.negotiate_timeout == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_transport_passes_timeout_to_multiselect():
|
||||
"""Test that timeout is passed to multiselect client in select_transport."""
|
||||
# Mock dependencies
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.is_initiator = False
|
||||
|
||||
# Mock MultiselectClient
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=10)
|
||||
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
|
||||
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Call select_transport
|
||||
await muxer.select_transport(mock_conn)
|
||||
|
||||
# Verify that select_one_of was called with the correct timeout
|
||||
args, _ = muxer.multiselect.negotiate.call_args
|
||||
assert args[1] == 10
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_new_conn_passes_timeout_to_multistream_client():
|
||||
"""Test that timeout is passed to multistream client in new_conn."""
|
||||
# Mock dependencies
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.is_initiator = True
|
||||
mock_peer_id = ID(b"test_peer")
|
||||
mock_communicator = MagicMock()
|
||||
|
||||
# Mock MultistreamClient and transports
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
|
||||
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||
|
||||
# Call new_conn
|
||||
await muxer.new_conn(mock_conn, mock_peer_id)
|
||||
|
||||
# Verify that select_one_of was called with the correct timeout
|
||||
muxer.multistream_client.select_one_of(
|
||||
tuple(muxer.transports.keys()), mock_communicator, 30
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_select_transport_no_protocol_selected():
|
||||
"""
|
||||
Test that select_transport raises MultiselectError when no protocol is selected.
|
||||
"""
|
||||
# Mock dependencies
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.is_initiator = False
|
||||
|
||||
# Mock Multiselect to return None
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
|
||||
|
||||
# Expect MultiselectError to be raised
|
||||
with pytest.raises(MultiselectError, match="no protocol selected"):
|
||||
await muxer.select_transport(mock_conn)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_add_transport_updates_precedence():
|
||||
"""Test that adding a transport updates protocol precedence."""
|
||||
# Mock transport classes
|
||||
mock_transport1 = MagicMock(spec=TMuxerClass)
|
||||
mock_transport2 = MagicMock(spec=TMuxerClass)
|
||||
|
||||
# Initialize muxer and add transports
|
||||
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||
muxer.add_transport(TProtocol("proto2"), mock_transport2)
|
||||
|
||||
# Verify transport order
|
||||
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
|
||||
|
||||
# Re-add proto1 to check if it moves to the end
|
||||
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||
assert list(muxer.transports.keys()) == ["proto2", "proto1"]
|
||||
0
tests/core/transport/quic/test_concurrency.py
Normal file
0
tests/core/transport/quic/test_concurrency.py
Normal file
553
tests/core/transport/quic/test_connection.py
Normal file
553
tests/core/transport/quic/test_connection.py
Normal file
@ -0,0 +1,553 @@
|
||||
"""
|
||||
Enhanced tests for QUIC connection functionality - Module 3.
|
||||
Tests all new features including advanced stream management, resource management,
|
||||
error handling, and concurrent operations.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICConnectionClosedError,
|
||||
QUICConnectionError,
|
||||
QUICConnectionTimeoutError,
|
||||
QUICPeerVerificationError,
|
||||
QUICStreamLimitError,
|
||||
QUICStreamTimeoutError,
|
||||
)
|
||||
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||
from libp2p.transport.quic.stream import QUICStream, StreamDirection
|
||||
|
||||
|
||||
class MockResourceScope:
|
||||
"""Mock resource scope for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_reserved = 0
|
||||
|
||||
def reserve_memory(self, size):
|
||||
self.memory_reserved += size
|
||||
|
||||
def release_memory(self, size):
|
||||
self.memory_reserved = max(0, self.memory_reserved - size)
|
||||
|
||||
|
||||
class TestQUICConnection:
|
||||
"""Test suite for QUIC connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_connection(self):
|
||||
"""Create mock aioquic QuicConnection."""
|
||||
mock = Mock()
|
||||
mock.next_event.return_value = None
|
||||
mock.datagrams_to_send.return_value = []
|
||||
mock.get_timer.return_value = None
|
||||
mock.connect = Mock()
|
||||
mock.close = Mock()
|
||||
mock.send_stream_data = Mock()
|
||||
mock.reset_stream = Mock()
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_transport(self):
|
||||
mock = Mock()
|
||||
mock._config = QUICTransportConfig()
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_resource_scope(self):
|
||||
"""Create mock resource scope."""
|
||||
return MockResourceScope()
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(
|
||||
self,
|
||||
mock_quic_connection: Mock,
|
||||
mock_quic_transport: Mock,
|
||||
mock_resource_scope: MockResourceScope,
|
||||
):
|
||||
"""Create test QUIC connection with enhanced features."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
mock_security_manager = Mock()
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=None,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=mock_quic_transport,
|
||||
resource_scope=mock_resource_scope,
|
||||
security_manager=mock_security_manager,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def server_connection(self, mock_quic_connection, mock_resource_scope):
|
||||
"""Create server-side QUIC connection."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=False,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
resource_scope=mock_resource_scope,
|
||||
)
|
||||
|
||||
# Basic functionality tests
|
||||
|
||||
def test_connection_initialization_enhanced(
|
||||
self, quic_connection, mock_resource_scope
|
||||
):
|
||||
"""Test enhanced connection initialization."""
|
||||
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
|
||||
assert quic_connection.is_initiator is True
|
||||
assert not quic_connection.is_closed
|
||||
assert not quic_connection.is_established
|
||||
assert len(quic_connection._streams) == 0
|
||||
assert quic_connection._resource_scope == mock_resource_scope
|
||||
assert quic_connection._outbound_stream_count == 0
|
||||
assert quic_connection._inbound_stream_count == 0
|
||||
assert len(quic_connection._stream_accept_queue) == 0
|
||||
|
||||
def test_stream_id_calculation_enhanced(self):
|
||||
"""Test enhanced stream ID calculation for client/server."""
|
||||
# Client connection (initiator)
|
||||
client_conn = QUICConnection(
|
||||
quic_connection=Mock(),
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=None,
|
||||
local_peer_id=Mock(),
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
assert client_conn._next_stream_id == 0 # Client starts with 0
|
||||
|
||||
# Server connection (not initiator)
|
||||
server_conn = QUICConnection(
|
||||
quic_connection=Mock(),
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=None,
|
||||
local_peer_id=Mock(),
|
||||
is_initiator=False,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
assert server_conn._next_stream_id == 1 # Server starts with 1
|
||||
|
||||
def test_incoming_stream_detection_enhanced(self, quic_connection):
|
||||
"""Test enhanced incoming stream detection logic."""
|
||||
# For client (initiator), odd stream IDs are incoming
|
||||
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
|
||||
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
|
||||
|
||||
# Stream management tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_basic(self, quic_connection):
|
||||
"""Test basic stream opening."""
|
||||
quic_connection._started = True
|
||||
|
||||
stream = await quic_connection.open_stream()
|
||||
|
||||
assert isinstance(stream, QUICStream)
|
||||
assert stream.stream_id == "0"
|
||||
assert stream.direction == StreamDirection.OUTBOUND
|
||||
assert 0 in quic_connection._streams
|
||||
assert quic_connection._outbound_stream_count == 1
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_limit_reached(self, quic_connection):
|
||||
"""Test stream limit enforcement."""
|
||||
quic_connection._started = True
|
||||
quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS
|
||||
|
||||
with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"):
|
||||
await quic_connection.open_stream()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_timeout(self, quic_connection: QUICConnection):
|
||||
"""Test stream opening timeout."""
|
||||
quic_connection._started = True
|
||||
return
|
||||
|
||||
# Mock the stream ID lock to simulate slow operation
|
||||
async def slow_acquire():
|
||||
await trio.sleep(10) # Longer than timeout
|
||||
|
||||
with patch.object(
|
||||
quic_connection._stream_lock, "acquire", side_effect=slow_acquire
|
||||
):
|
||||
with pytest.raises(
|
||||
QUICStreamTimeoutError, match="Stream creation timed out"
|
||||
):
|
||||
await quic_connection.open_stream(timeout=0.1)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_basic(self, quic_connection):
|
||||
"""Test basic stream acceptance."""
|
||||
# Create a mock inbound stream
|
||||
mock_stream = Mock(spec=QUICStream)
|
||||
mock_stream.stream_id = "1"
|
||||
|
||||
# Add to accept queue
|
||||
quic_connection._stream_accept_queue.append(mock_stream)
|
||||
quic_connection._stream_accept_event.set()
|
||||
|
||||
accepted_stream = await quic_connection.accept_stream(timeout=0.1)
|
||||
|
||||
assert accepted_stream == mock_stream
|
||||
assert len(quic_connection._stream_accept_queue) == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_timeout(self, quic_connection):
|
||||
"""Test stream acceptance timeout."""
|
||||
with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"):
|
||||
await quic_connection.accept_stream(timeout=0.1)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_on_closed_connection(self, quic_connection):
|
||||
"""Test stream acceptance on closed connection."""
|
||||
await quic_connection.close()
|
||||
|
||||
with pytest.raises(QUICConnectionClosedError, match="Connection is closed"):
|
||||
await quic_connection.accept_stream()
|
||||
|
||||
# Stream handler tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_handler_setting(self, quic_connection):
|
||||
"""Test setting stream handler."""
|
||||
|
||||
async def mock_handler(stream):
|
||||
pass
|
||||
|
||||
quic_connection.set_stream_handler(mock_handler)
|
||||
assert quic_connection._stream_handler == mock_handler
|
||||
|
||||
# Connection lifecycle tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_client(self, quic_connection):
|
||||
"""Test client connection start."""
|
||||
with patch.object(
|
||||
quic_connection, "_initiate_connection", new_callable=AsyncMock
|
||||
) as mock_initiate:
|
||||
await quic_connection.start()
|
||||
|
||||
assert quic_connection._started
|
||||
mock_initiate.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_server(self, server_connection):
|
||||
"""Test server connection start."""
|
||||
await server_connection.start()
|
||||
|
||||
assert server_connection._started
|
||||
assert server_connection._established
|
||||
assert server_connection._connected_event.is_set()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_already_started(self, quic_connection):
|
||||
"""Test starting already started connection."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Should not raise error, just log warning
|
||||
await quic_connection.start()
|
||||
assert quic_connection._started
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_closed(self, quic_connection):
|
||||
"""Test starting closed connection."""
|
||||
quic_connection._closed = True
|
||||
|
||||
with pytest.raises(
|
||||
QUICConnectionError, match="Cannot start a closed connection"
|
||||
):
|
||||
await quic_connection.start()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_connect_with_nursery(
|
||||
self, quic_connection: QUICConnection
|
||||
):
|
||||
"""Test connection establishment with nursery."""
|
||||
quic_connection._started = True
|
||||
quic_connection._established = True
|
||||
quic_connection._connected_event.set()
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||
) as mock_start_tasks:
|
||||
with patch.object(
|
||||
quic_connection,
|
||||
"_verify_peer_identity_with_security",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_verify:
|
||||
async with trio.open_nursery() as nursery:
|
||||
await quic_connection.connect(nursery)
|
||||
|
||||
assert quic_connection._nursery == nursery
|
||||
mock_start_tasks.assert_called_once()
|
||||
mock_verify.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.slow
|
||||
async def test_connection_connect_timeout(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test connection establishment timeout."""
|
||||
quic_connection._started = True
|
||||
# Don't set connected event to simulate timeout
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
with pytest.raises(
|
||||
QUICConnectionTimeoutError, match="Connection handshake timed out"
|
||||
):
|
||||
await quic_connection.connect(nursery)
|
||||
|
||||
# Resource management tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_removal_resource_cleanup(
|
||||
self, quic_connection: QUICConnection, mock_resource_scope
|
||||
):
|
||||
"""Test stream removal and resource cleanup."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create a stream
|
||||
stream = await quic_connection.open_stream()
|
||||
|
||||
# Remove the stream
|
||||
quic_connection._remove_stream(int(stream.stream_id))
|
||||
|
||||
assert int(stream.stream_id) not in quic_connection._streams
|
||||
# Note: Count updates is async, so we can't test it directly here
|
||||
|
||||
# Error handling tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_error_handling(self, quic_connection) -> None:
|
||||
"""Test connection error handling."""
|
||||
error = Exception("Test error")
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "close", new_callable=AsyncMock
|
||||
) as mock_close:
|
||||
await quic_connection._handle_connection_error(error)
|
||||
mock_close.assert_called_once()
|
||||
|
||||
# Statistics and monitoring tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_stats_enhanced(self, quic_connection) -> None:
|
||||
"""Test enhanced connection statistics."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create some streams
|
||||
_stream1 = await quic_connection.open_stream()
|
||||
_stream2 = await quic_connection.open_stream()
|
||||
|
||||
stats = quic_connection.get_stream_stats()
|
||||
|
||||
expected_keys = [
|
||||
"total_streams",
|
||||
"outbound_streams",
|
||||
"inbound_streams",
|
||||
"max_streams",
|
||||
"stream_utilization",
|
||||
"stats",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in stats
|
||||
|
||||
assert stats["total_streams"] == 2
|
||||
assert stats["outbound_streams"] == 2
|
||||
assert stats["inbound_streams"] == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_active_streams(self, quic_connection) -> None:
|
||||
"""Test getting active streams."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create streams
|
||||
stream1 = await quic_connection.open_stream()
|
||||
stream2 = await quic_connection.open_stream()
|
||||
|
||||
active_streams = quic_connection.get_active_streams()
|
||||
|
||||
assert len(active_streams) == 2
|
||||
assert stream1 in active_streams
|
||||
assert stream2 in active_streams
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_streams_by_protocol(self, quic_connection) -> None:
|
||||
"""Test getting streams by protocol."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create streams with different protocols
|
||||
stream1 = await quic_connection.open_stream()
|
||||
stream1.protocol = "/test/1.0.0"
|
||||
|
||||
stream2 = await quic_connection.open_stream()
|
||||
stream2.protocol = "/other/1.0.0"
|
||||
|
||||
test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0")
|
||||
other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0")
|
||||
|
||||
assert len(test_streams) == 1
|
||||
assert len(other_streams) == 1
|
||||
assert stream1 in test_streams
|
||||
assert stream2 in other_streams
|
||||
|
||||
# Enhanced close tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_close_enhanced(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test enhanced connection close with stream cleanup."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create some streams
|
||||
_stream1 = await quic_connection.open_stream()
|
||||
_stream2 = await quic_connection.open_stream()
|
||||
|
||||
await quic_connection.close()
|
||||
|
||||
assert quic_connection.is_closed
|
||||
assert len(quic_connection._streams) == 0
|
||||
|
||||
# Concurrent operations tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_concurrent_stream_operations(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test concurrent stream operations."""
|
||||
quic_connection._started = True
|
||||
|
||||
async def create_stream():
|
||||
return await quic_connection.open_stream()
|
||||
|
||||
# Create multiple streams concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(10):
|
||||
nursery.start_soon(create_stream)
|
||||
|
||||
# Wait a bit for all to start
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Should have created streams without conflicts
|
||||
assert quic_connection._outbound_stream_count == 10
|
||||
assert len(quic_connection._streams) == 10
|
||||
|
||||
# Connection properties tests
|
||||
|
||||
def test_connection_properties(self, quic_connection: QUICConnection) -> None:
|
||||
"""Test connection property accessors."""
|
||||
assert quic_connection.multiaddr() == quic_connection._maddr
|
||||
assert quic_connection.local_peer_id() == quic_connection._local_peer_id
|
||||
assert quic_connection.remote_peer_id() == quic_connection._remote_peer_id
|
||||
|
||||
# IRawConnection interface tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_raw_connection_write(self, quic_connection: QUICConnection) -> None:
|
||||
"""Test raw connection write interface."""
|
||||
quic_connection._started = True
|
||||
|
||||
with patch.object(quic_connection, "open_stream") as mock_open:
|
||||
mock_stream = AsyncMock()
|
||||
mock_open.return_value = mock_stream
|
||||
|
||||
await quic_connection.write(b"test data")
|
||||
|
||||
mock_open.assert_called_once()
|
||||
mock_stream.write.assert_called_once_with(b"test data")
|
||||
mock_stream.close_write.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_raw_connection_read_not_implemented(
|
||||
self, quic_connection: QUICConnection
|
||||
) -> None:
|
||||
"""Test raw connection read raises NotImplementedError."""
|
||||
with pytest.raises(NotImplementedError):
|
||||
await quic_connection.read()
|
||||
|
||||
# Mock verification helpers
|
||||
|
||||
def test_mock_resource_scope_functionality(self, mock_resource_scope) -> None:
|
||||
"""Test mock resource scope works correctly."""
|
||||
assert mock_resource_scope.memory_reserved == 0
|
||||
|
||||
mock_resource_scope.reserve_memory(1000)
|
||||
assert mock_resource_scope.memory_reserved == 1000
|
||||
|
||||
mock_resource_scope.reserve_memory(500)
|
||||
assert mock_resource_scope.memory_reserved == 1500
|
||||
|
||||
mock_resource_scope.release_memory(600)
|
||||
assert mock_resource_scope.memory_reserved == 900
|
||||
|
||||
mock_resource_scope.release_memory(2000) # Should not go negative
|
||||
assert mock_resource_scope.memory_reserved == 0
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_invalid_certificate_verification():
|
||||
key_pair1 = create_new_key_pair()
|
||||
key_pair2 = create_new_key_pair()
|
||||
|
||||
peer_id1 = ID.from_pubkey(key_pair1.public_key)
|
||||
peer_id2 = ID.from_pubkey(key_pair2.public_key)
|
||||
|
||||
manager = QUICTLSConfigManager(
|
||||
libp2p_private_key=key_pair1.private_key, peer_id=peer_id1
|
||||
)
|
||||
|
||||
# Match the certificate against a different peer_id
|
||||
with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"):
|
||||
manager.verify_peer_identity(manager.tls_config.certificate, peer_id2)
|
||||
|
||||
from cryptography.hazmat.primitives.serialization import Encoding
|
||||
|
||||
# --- Corrupt the certificate by tampering the DER bytes ---
|
||||
cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER)
|
||||
corrupted_bytes = bytearray(cert_bytes)
|
||||
|
||||
# Flip some random bytes in the middle of the certificate
|
||||
corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
# This will still parse (structurally valid), but the signature
|
||||
# or fingerprint will break
|
||||
corrupted_cert = x509.load_der_x509_certificate(
|
||||
bytes(corrupted_bytes), backend=default_backend()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
QUICPeerVerificationError, match="Certificate verification failed"
|
||||
):
|
||||
manager.verify_peer_identity(corrupted_cert, peer_id1)
|
||||
624
tests/core/transport/quic/test_connection_id.py
Normal file
624
tests/core/transport/quic/test_connection_id.py
Normal file
@ -0,0 +1,624 @@
|
||||
"""
|
||||
QUIC Connection ID Management Tests
|
||||
|
||||
This test module covers comprehensive testing of QUIC connection ID functionality
|
||||
including generation, rotation, retirement, and validation according to RFC 9000.
|
||||
|
||||
Tests are organized into:
|
||||
1. Basic Connection ID Management
|
||||
2. Connection ID Rotation and Updates
|
||||
3. Connection ID Retirement
|
||||
4. Error Conditions and Edge Cases
|
||||
5. Integration Tests with Real Connections
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from aioquic.buffer import Buffer
|
||||
|
||||
# Import aioquic components for low-level testing
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
from aioquic.quic.connection import QuicConnection, QuicConnectionId
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
|
||||
class ConnectionIdTestHelper:
|
||||
"""Helper class for connection ID testing utilities."""
|
||||
|
||||
@staticmethod
|
||||
def generate_connection_id(length: int = 8) -> bytes:
|
||||
"""Generate a random connection ID of specified length."""
|
||||
return secrets.token_bytes(length)
|
||||
|
||||
@staticmethod
|
||||
def create_quic_connection_id(cid: bytes, sequence: int = 0) -> QuicConnectionId:
|
||||
"""Create a QuicConnectionId object."""
|
||||
return QuicConnectionId(
|
||||
cid=cid,
|
||||
sequence_number=sequence,
|
||||
stateless_reset_token=secrets.token_bytes(16),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_connection_ids_from_connection(conn: QUICConnection) -> dict[str, Any]:
|
||||
"""Extract connection ID information from a QUIC connection."""
|
||||
quic = conn._quic
|
||||
return {
|
||||
"host_cids": [cid.cid.hex() for cid in getattr(quic, "_host_cids", [])],
|
||||
"peer_cid": getattr(quic, "_peer_cid", None),
|
||||
"peer_cid_available": [
|
||||
cid.cid.hex() for cid in getattr(quic, "_peer_cid_available", [])
|
||||
],
|
||||
"retire_connection_ids": getattr(quic, "_retire_connection_ids", []),
|
||||
"host_cid_seq": getattr(quic, "_host_cid_seq", 0),
|
||||
}
|
||||
|
||||
|
||||
class TestBasicConnectionIdManagement:
|
||||
"""Test basic connection ID management functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_connection(self):
|
||||
"""Create a mock QUIC connection with connection ID support."""
|
||||
mock_quic = Mock(spec=QuicConnection)
|
||||
mock_quic._host_cids = []
|
||||
mock_quic._host_cid_seq = 0
|
||||
mock_quic._peer_cid = None
|
||||
mock_quic._peer_cid_available = []
|
||||
mock_quic._retire_connection_ids = []
|
||||
mock_quic._configuration = Mock()
|
||||
mock_quic._configuration.connection_id_length = 8
|
||||
mock_quic._remote_active_connection_id_limit = 8
|
||||
return mock_quic
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(self, mock_quic_connection):
|
||||
"""Create a QUICConnection instance for testing."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
|
||||
def test_connection_id_initialization(self, quic_connection):
|
||||
"""Test that connection ID tracking is properly initialized."""
|
||||
# Check that connection ID tracking structures are initialized
|
||||
assert hasattr(quic_connection, "_available_connection_ids")
|
||||
assert hasattr(quic_connection, "_current_connection_id")
|
||||
assert hasattr(quic_connection, "_retired_connection_ids")
|
||||
assert hasattr(quic_connection, "_connection_id_sequence_numbers")
|
||||
|
||||
# Initial state should be empty
|
||||
assert len(quic_connection._available_connection_ids) == 0
|
||||
assert quic_connection._current_connection_id is None
|
||||
assert len(quic_connection._retired_connection_ids) == 0
|
||||
assert len(quic_connection._connection_id_sequence_numbers) == 0
|
||||
|
||||
def test_connection_id_stats_tracking(self, quic_connection):
|
||||
"""Test connection ID statistics are properly tracked."""
|
||||
stats = quic_connection.get_connection_id_stats()
|
||||
|
||||
# Check that all expected stats are present
|
||||
expected_keys = [
|
||||
"available_connection_ids",
|
||||
"current_connection_id",
|
||||
"retired_connection_ids",
|
||||
"connection_ids_issued",
|
||||
"connection_ids_retired",
|
||||
"connection_id_changes",
|
||||
"available_cid_list",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in stats
|
||||
|
||||
# Initial values should be zero/empty
|
||||
assert stats["available_connection_ids"] == 0
|
||||
assert stats["current_connection_id"] is None
|
||||
assert stats["retired_connection_ids"] == 0
|
||||
assert stats["connection_ids_issued"] == 0
|
||||
assert stats["connection_ids_retired"] == 0
|
||||
assert stats["connection_id_changes"] == 0
|
||||
assert stats["available_cid_list"] == []
|
||||
|
||||
def test_current_connection_id_getter(self, quic_connection):
|
||||
"""Test getting current connection ID."""
|
||||
# Initially no connection ID
|
||||
assert quic_connection.get_current_connection_id() is None
|
||||
|
||||
# Set a connection ID
|
||||
test_cid = ConnectionIdTestHelper.generate_connection_id()
|
||||
quic_connection._current_connection_id = test_cid
|
||||
|
||||
assert quic_connection.get_current_connection_id() == test_cid
|
||||
|
||||
def test_connection_id_generation(self):
|
||||
"""Test connection ID generation utilities."""
|
||||
# Test default length
|
||||
cid1 = ConnectionIdTestHelper.generate_connection_id()
|
||||
assert len(cid1) == 8
|
||||
assert isinstance(cid1, bytes)
|
||||
|
||||
# Test custom length
|
||||
cid2 = ConnectionIdTestHelper.generate_connection_id(16)
|
||||
assert len(cid2) == 16
|
||||
|
||||
# Test uniqueness
|
||||
cid3 = ConnectionIdTestHelper.generate_connection_id()
|
||||
assert cid1 != cid3
|
||||
|
||||
|
||||
class TestConnectionIdRotationAndUpdates:
|
||||
"""Test connection ID rotation and update mechanisms."""
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
"""Create transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def server_key(self):
|
||||
"""Generate server private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def client_key(self):
|
||||
"""Generate client private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
def test_connection_id_replenishment(self):
|
||||
"""Test connection ID replenishment mechanism."""
|
||||
# Create a real QuicConnection to test replenishment
|
||||
config = QuicConfiguration(is_client=True)
|
||||
config.connection_id_length = 8
|
||||
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Initial state - should have some host connection IDs
|
||||
initial_count = len(quic_conn._host_cids)
|
||||
assert initial_count > 0
|
||||
|
||||
# Remove some connection IDs to trigger replenishment
|
||||
while len(quic_conn._host_cids) > 2:
|
||||
quic_conn._host_cids.pop()
|
||||
|
||||
# Trigger replenishment
|
||||
quic_conn._replenish_connection_ids()
|
||||
|
||||
# Should have replenished up to the limit
|
||||
assert len(quic_conn._host_cids) >= initial_count
|
||||
|
||||
# All connection IDs should have unique sequence numbers
|
||||
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
assert len(sequences) == len(set(sequences))
|
||||
|
||||
def test_connection_id_sequence_numbers(self):
|
||||
"""Test connection ID sequence number management."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Get initial sequence number
|
||||
initial_seq = quic_conn._host_cid_seq
|
||||
|
||||
# Trigger replenishment to generate new connection IDs
|
||||
quic_conn._replenish_connection_ids()
|
||||
|
||||
# Sequence numbers should increment
|
||||
assert quic_conn._host_cid_seq > initial_seq
|
||||
|
||||
# All host connection IDs should have sequential numbers
|
||||
sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
sequences.sort()
|
||||
|
||||
# Check for proper sequence
|
||||
for i in range(len(sequences) - 1):
|
||||
assert sequences[i + 1] > sequences[i]
|
||||
|
||||
def test_connection_id_limits(self):
|
||||
"""Test connection ID limit enforcement."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
config.connection_id_length = 8
|
||||
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Set a reasonable limit
|
||||
quic_conn._remote_active_connection_id_limit = 4
|
||||
|
||||
# Replenish connection IDs
|
||||
quic_conn._replenish_connection_ids()
|
||||
|
||||
# Should not exceed the limit
|
||||
assert len(quic_conn._host_cids) <= quic_conn._remote_active_connection_id_limit
|
||||
|
||||
|
||||
class TestConnectionIdRetirement:
|
||||
"""Test connection ID retirement functionality."""
|
||||
|
||||
def test_connection_id_retirement_basic(self):
|
||||
"""Test basic connection ID retirement."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Create a test connection ID to retire
|
||||
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=1
|
||||
)
|
||||
|
||||
# Add it to peer connection IDs
|
||||
quic_conn._peer_cid_available.append(test_cid)
|
||||
quic_conn._peer_cid_sequence_numbers.add(1)
|
||||
|
||||
# Retire the connection ID
|
||||
quic_conn._retire_peer_cid(test_cid)
|
||||
|
||||
# Should be added to retirement list
|
||||
assert 1 in quic_conn._retire_connection_ids
|
||||
|
||||
def test_connection_id_retirement_limits(self):
|
||||
"""Test connection ID retirement limits."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Fill up retirement list near the limit
|
||||
max_retirements = 32 # Based on aioquic's default limit
|
||||
|
||||
for i in range(max_retirements):
|
||||
quic_conn._retire_connection_ids.append(i)
|
||||
|
||||
# Should be at limit
|
||||
assert len(quic_conn._retire_connection_ids) == max_retirements
|
||||
|
||||
def test_connection_id_retirement_events(self):
|
||||
"""Test that retirement generates proper events."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Create and add a host connection ID
|
||||
test_cid = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=5
|
||||
)
|
||||
quic_conn._host_cids.append(test_cid)
|
||||
|
||||
# Create a retirement frame buffer
|
||||
from aioquic.buffer import Buffer
|
||||
|
||||
buf = Buffer(capacity=16)
|
||||
buf.push_uint_var(5) # sequence number to retire
|
||||
buf.seek(0)
|
||||
|
||||
# Process retirement (this should generate an event)
|
||||
try:
|
||||
quic_conn._handle_retire_connection_id_frame(
|
||||
Mock(), # context
|
||||
0x19, # RETIRE_CONNECTION_ID frame type
|
||||
buf,
|
||||
)
|
||||
|
||||
# Check that connection ID was removed
|
||||
remaining_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
assert 5 not in remaining_sequences
|
||||
|
||||
except Exception:
|
||||
# May fail due to missing context, but that's okay for this test
|
||||
pass
|
||||
|
||||
|
||||
class TestConnectionIdErrorConditions:
|
||||
"""Test error conditions and edge cases in connection ID handling."""
|
||||
|
||||
def test_invalid_connection_id_length(self):
|
||||
"""Test handling of invalid connection ID lengths."""
|
||||
# Connection IDs must be 1-20 bytes according to RFC 9000
|
||||
|
||||
# Test too short (0 bytes) - this should be handled gracefully
|
||||
empty_cid = b""
|
||||
assert len(empty_cid) == 0
|
||||
|
||||
# Test too long (>20 bytes)
|
||||
long_cid = secrets.token_bytes(21)
|
||||
assert len(long_cid) == 21
|
||||
|
||||
# Test valid lengths
|
||||
for length in range(1, 21):
|
||||
valid_cid = secrets.token_bytes(length)
|
||||
assert len(valid_cid) == length
|
||||
|
||||
def test_duplicate_sequence_numbers(self):
|
||||
"""Test handling of duplicate sequence numbers."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Create two connection IDs with same sequence number
|
||||
cid1 = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=10
|
||||
)
|
||||
cid2 = ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), sequence=10
|
||||
)
|
||||
|
||||
# Add first connection ID
|
||||
quic_conn._peer_cid_available.append(cid1)
|
||||
quic_conn._peer_cid_sequence_numbers.add(10)
|
||||
|
||||
# Adding second with same sequence should be handled appropriately
|
||||
# (The implementation should prevent duplicates)
|
||||
if 10 not in quic_conn._peer_cid_sequence_numbers:
|
||||
quic_conn._peer_cid_available.append(cid2)
|
||||
quic_conn._peer_cid_sequence_numbers.add(10)
|
||||
|
||||
# Should only have one entry for sequence 10
|
||||
sequences = [cid.sequence_number for cid in quic_conn._peer_cid_available]
|
||||
assert sequences.count(10) <= 1
|
||||
|
||||
def test_retire_unknown_connection_id(self):
|
||||
"""Test retiring an unknown connection ID."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Try to create a buffer to retire unknown sequence number
|
||||
buf = Buffer(capacity=16)
|
||||
buf.push_uint_var(999) # Unknown sequence number
|
||||
buf.seek(0)
|
||||
|
||||
# This should raise an error when processed
|
||||
# (Testing the error condition, not the full processing)
|
||||
unknown_sequence = 999
|
||||
known_sequences = [cid.sequence_number for cid in quic_conn._host_cids]
|
||||
|
||||
assert unknown_sequence not in known_sequences
|
||||
|
||||
def test_retire_current_connection_id(self):
|
||||
"""Test that retiring current connection ID is prevented."""
|
||||
config = QuicConfiguration(is_client=True)
|
||||
quic_conn = QuicConnection(configuration=config)
|
||||
|
||||
# Get current connection ID if available
|
||||
if quic_conn._host_cids:
|
||||
current_cid = quic_conn._host_cids[0]
|
||||
current_sequence = current_cid.sequence_number
|
||||
|
||||
# Trying to retire current connection ID should be prevented
|
||||
# This is tested by checking the sequence number logic
|
||||
assert current_sequence >= 0
|
||||
|
||||
|
||||
class TestConnectionIdIntegration:
|
||||
"""Integration tests for connection ID functionality with real connections."""
|
||||
|
||||
@pytest.fixture
|
||||
def server_config(self):
|
||||
"""Server transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def client_config(self):
|
||||
"""Client transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def server_key(self):
|
||||
"""Generate server private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def client_key(self):
|
||||
"""Generate client private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_id_exchange_during_handshake(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test connection ID exchange during connection handshake."""
|
||||
# This test would require a full connection setup
|
||||
# For now, we test the setup components
|
||||
|
||||
server_transport = QUICTransport(server_key, server_config)
|
||||
client_transport = QUICTransport(client_key, client_config)
|
||||
|
||||
# Verify transports are created with proper configuration
|
||||
assert server_transport._config == server_config
|
||||
assert client_transport._config == client_config
|
||||
|
||||
# Test that connection ID tracking is available
|
||||
# (Integration with actual networking would require more setup)
|
||||
|
||||
def test_connection_id_extraction_utilities(self):
|
||||
"""Test connection ID extraction utilities."""
|
||||
# Create a mock connection with some connection IDs
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
mock_quic = Mock()
|
||||
mock_quic._host_cids = [
|
||||
ConnectionIdTestHelper.create_quic_connection_id(
|
||||
ConnectionIdTestHelper.generate_connection_id(), i
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
mock_quic._peer_cid = None
|
||||
mock_quic._peer_cid_available = []
|
||||
mock_quic._retire_connection_ids = []
|
||||
mock_quic._host_cid_seq = 3
|
||||
|
||||
quic_conn = QUICConnection(
|
||||
quic_connection=mock_quic,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
|
||||
# Extract connection ID information
|
||||
cid_info = ConnectionIdTestHelper.extract_connection_ids_from_connection(
|
||||
quic_conn
|
||||
)
|
||||
|
||||
# Verify extraction works
|
||||
assert "host_cids" in cid_info
|
||||
assert "peer_cid" in cid_info
|
||||
assert "peer_cid_available" in cid_info
|
||||
assert "retire_connection_ids" in cid_info
|
||||
assert "host_cid_seq" in cid_info
|
||||
|
||||
# Check values
|
||||
assert len(cid_info["host_cids"]) == 3
|
||||
assert cid_info["host_cid_seq"] == 3
|
||||
assert cid_info["peer_cid"] is None
|
||||
assert len(cid_info["peer_cid_available"]) == 0
|
||||
assert len(cid_info["retire_connection_ids"]) == 0
|
||||
|
||||
|
||||
class TestConnectionIdStatistics:
|
||||
"""Test connection ID statistics and monitoring."""
|
||||
|
||||
@pytest.fixture
|
||||
def connection_with_stats(self):
|
||||
"""Create a connection with connection ID statistics."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
mock_quic = Mock()
|
||||
mock_quic._host_cids = []
|
||||
mock_quic._peer_cid = None
|
||||
mock_quic._peer_cid_available = []
|
||||
mock_quic._retire_connection_ids = []
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
remote_peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
)
|
||||
|
||||
def test_connection_id_stats_initialization(self, connection_with_stats):
|
||||
"""Test that connection ID statistics are properly initialized."""
|
||||
stats = connection_with_stats._stats
|
||||
|
||||
# Check that connection ID stats are present
|
||||
assert "connection_ids_issued" in stats
|
||||
assert "connection_ids_retired" in stats
|
||||
assert "connection_id_changes" in stats
|
||||
|
||||
# Initial values should be zero
|
||||
assert stats["connection_ids_issued"] == 0
|
||||
assert stats["connection_ids_retired"] == 0
|
||||
assert stats["connection_id_changes"] == 0
|
||||
|
||||
def test_connection_id_stats_update(self, connection_with_stats):
|
||||
"""Test updating connection ID statistics."""
|
||||
conn = connection_with_stats
|
||||
|
||||
# Add some connection IDs to tracking
|
||||
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(3)]
|
||||
|
||||
for cid in test_cids:
|
||||
conn._available_connection_ids.add(cid)
|
||||
|
||||
# Update stats (this would normally be done by the implementation)
|
||||
conn._stats["connection_ids_issued"] = len(test_cids)
|
||||
|
||||
# Verify stats
|
||||
stats = conn.get_connection_id_stats()
|
||||
assert stats["connection_ids_issued"] == 3
|
||||
assert stats["available_connection_ids"] == 3
|
||||
|
||||
def test_connection_id_list_representation(self, connection_with_stats):
|
||||
"""Test connection ID list representation in stats."""
|
||||
conn = connection_with_stats
|
||||
|
||||
# Add some connection IDs
|
||||
test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(2)]
|
||||
|
||||
for cid in test_cids:
|
||||
conn._available_connection_ids.add(cid)
|
||||
|
||||
# Get stats
|
||||
stats = conn.get_connection_id_stats()
|
||||
|
||||
# Check that CID list is properly formatted
|
||||
assert "available_cid_list" in stats
|
||||
assert len(stats["available_cid_list"]) == 2
|
||||
|
||||
# All entries should be hex strings
|
||||
for cid_hex in stats["available_cid_list"]:
|
||||
assert isinstance(cid_hex, str)
|
||||
assert len(cid_hex) == 16 # 8 bytes = 16 hex chars
|
||||
|
||||
|
||||
# Performance and stress tests
|
||||
class TestConnectionIdPerformance:
|
||||
"""Test connection ID performance and stress scenarios."""
|
||||
|
||||
def test_connection_id_generation_performance(self):
|
||||
"""Test connection ID generation performance."""
|
||||
start_time = time.time()
|
||||
|
||||
# Generate many connection IDs
|
||||
cids = []
|
||||
for _ in range(1000):
|
||||
cid = ConnectionIdTestHelper.generate_connection_id()
|
||||
cids.append(cid)
|
||||
|
||||
end_time = time.time()
|
||||
generation_time = end_time - start_time
|
||||
|
||||
# Should be reasonably fast (less than 1 second for 1000 IDs)
|
||||
assert generation_time < 1.0
|
||||
|
||||
# All should be unique
|
||||
assert len(set(cids)) == len(cids)
|
||||
|
||||
def test_connection_id_tracking_memory(self):
|
||||
"""Test memory usage of connection ID tracking."""
|
||||
conn_ids = set()
|
||||
|
||||
# Add many connection IDs
|
||||
for _ in range(1000):
|
||||
cid = ConnectionIdTestHelper.generate_connection_id()
|
||||
conn_ids.add(cid)
|
||||
|
||||
# Verify they're all stored
|
||||
assert len(conn_ids) == 1000
|
||||
|
||||
# Clean up
|
||||
conn_ids.clear()
|
||||
assert len(conn_ids) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests if executed directly
|
||||
pytest.main([__file__, "-v"])
|
||||
418
tests/core/transport/quic/test_integration.py
Normal file
418
tests/core/transport/quic/test_integration.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""
|
||||
Basic QUIC Echo Test
|
||||
|
||||
Simple test to verify the basic QUIC flow:
|
||||
1. Client connects to server
|
||||
2. Client sends data
|
||||
3. Server receives data and echoes back
|
||||
4. Client receives the echo
|
||||
|
||||
This test focuses on identifying where the accept_stream issue occurs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID
|
||||
from libp2p import new_host
|
||||
from libp2p.abc import INetStream
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.transport.quic.config import QUICTransportConfig
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
from libp2p.transport.quic.utils import create_quic_multiaddr
|
||||
|
||||
# Set up logging to see what's happening
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestBasicQUICFlow:
|
||||
"""Test basic QUIC client-server communication flow."""
|
||||
|
||||
@pytest.fixture
|
||||
def server_key(self):
|
||||
"""Generate server key pair."""
|
||||
return create_new_key_pair()
|
||||
|
||||
@pytest.fixture
|
||||
def client_key(self):
|
||||
"""Generate client key pair."""
|
||||
return create_new_key_pair()
|
||||
|
||||
@pytest.fixture
|
||||
def server_config(self):
|
||||
"""Simple server configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=10,
|
||||
max_connections=5,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def client_config(self):
|
||||
"""Simple client configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0,
|
||||
connection_timeout=5.0,
|
||||
max_concurrent_streams=5,
|
||||
)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_basic_echo_flow(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test basic client-server echo flow with detailed logging."""
|
||||
print("\n=== BASIC QUIC ECHO TEST ===")
|
||||
|
||||
# Create server components
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
|
||||
# Track test state
|
||||
server_received_data = None
|
||||
server_connection_established = False
|
||||
echo_sent = False
|
||||
|
||||
async def echo_server_handler(connection: QUICConnection) -> None:
|
||||
"""Simple echo server handler with detailed logging."""
|
||||
nonlocal server_received_data, server_connection_established, echo_sent
|
||||
|
||||
print("🔗 SERVER: Connection handler called")
|
||||
server_connection_established = True
|
||||
|
||||
try:
|
||||
print("📡 SERVER: Waiting for incoming stream...")
|
||||
|
||||
# Accept stream with timeout and detailed logging
|
||||
print("📡 SERVER: Calling accept_stream...")
|
||||
stream = await connection.accept_stream(timeout=5.0)
|
||||
|
||||
if stream is None:
|
||||
print("❌ SERVER: accept_stream returned None")
|
||||
return
|
||||
|
||||
print(f"✅ SERVER: Stream accepted! Stream ID: {stream.stream_id}")
|
||||
|
||||
# Read data from the stream
|
||||
print("📖 SERVER: Reading data from stream...")
|
||||
server_data = await stream.read(1024)
|
||||
|
||||
if not server_data:
|
||||
print("❌ SERVER: No data received from stream")
|
||||
return
|
||||
|
||||
server_received_data = server_data.decode("utf-8", errors="ignore")
|
||||
print(f"📨 SERVER: Received data: '{server_received_data}'")
|
||||
|
||||
# Echo the data back
|
||||
echo_message = f"ECHO: {server_received_data}"
|
||||
print(f"📤 SERVER: Sending echo: '{echo_message}'")
|
||||
|
||||
await stream.write(echo_message.encode())
|
||||
echo_sent = True
|
||||
print("✅ SERVER: Echo sent successfully")
|
||||
|
||||
# Close the stream
|
||||
await stream.close()
|
||||
print("🔒 SERVER: Stream closed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ SERVER: Error in handler: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Create listener
|
||||
listener = server_transport.create_listener(echo_server_handler)
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
# Variables to track client state
|
||||
client_connected = False
|
||||
client_sent_data = False
|
||||
client_received_echo = None
|
||||
|
||||
try:
|
||||
print("🚀 Starting server...")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start server listener
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success, "Failed to start server listener"
|
||||
|
||||
# Get server address
|
||||
server_addrs = listener.get_addrs()
|
||||
server_addr = multiaddr.Multiaddr(
|
||||
f"{server_addrs[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
|
||||
)
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Give server a moment to be ready
|
||||
await trio.sleep(0.1)
|
||||
|
||||
print("🚀 Starting client...")
|
||||
|
||||
# Create client transport
|
||||
client_transport = QUICTransport(client_key.private_key, client_config)
|
||||
client_transport.set_background_nursery(nursery)
|
||||
|
||||
try:
|
||||
# Connect to server
|
||||
print(f"📞 CLIENT: Connecting to {server_addr}")
|
||||
connection = await client_transport.dial(server_addr)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected to server")
|
||||
|
||||
# Open a stream
|
||||
print("📤 CLIENT: Opening stream...")
|
||||
stream = await connection.open_stream()
|
||||
print(f"✅ CLIENT: Stream opened with ID: {stream.stream_id}")
|
||||
|
||||
# Send test data
|
||||
test_message = "Hello QUIC Server!"
|
||||
print(f"📨 CLIENT: Sending message: '{test_message}'")
|
||||
await stream.write(test_message.encode())
|
||||
client_sent_data = True
|
||||
print("✅ CLIENT: Message sent")
|
||||
|
||||
# Read echo response
|
||||
print("📖 CLIENT: Waiting for echo response...")
|
||||
response_data = await stream.read(1024)
|
||||
|
||||
if response_data:
|
||||
client_received_echo = response_data.decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
print(f"📬 CLIENT: Received echo: '{client_received_echo}'")
|
||||
else:
|
||||
print("❌ CLIENT: No echo response received")
|
||||
|
||||
print("🔒 CLIENT: Closing connection")
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
|
||||
print("🔒 CLIENT: Closing transport")
|
||||
await client_transport.close()
|
||||
print("🔒 CLIENT: Transport closed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ CLIENT: Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
print("🔒 CLIENT: Transport closed")
|
||||
|
||||
# Give everything time to complete
|
||||
await trio.sleep(0.5)
|
||||
|
||||
# Cancel nursery to stop server
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
if not listener._closed:
|
||||
await listener.close()
|
||||
await server_transport.close()
|
||||
|
||||
# Verify the flow worked
|
||||
print("\n📊 TEST RESULTS:")
|
||||
print(f" Server connection established: {server_connection_established}")
|
||||
print(f" Client connected: {client_connected}")
|
||||
print(f" Client sent data: {client_sent_data}")
|
||||
print(f" Server received data: '{server_received_data}'")
|
||||
print(f" Echo sent by server: {echo_sent}")
|
||||
print(f" Client received echo: '{client_received_echo}'")
|
||||
|
||||
# Test assertions
|
||||
assert server_connection_established, "Server connection handler was not called"
|
||||
assert client_connected, "Client failed to connect"
|
||||
assert client_sent_data, "Client failed to send data"
|
||||
assert server_received_data == "Hello QUIC Server!", (
|
||||
f"Server received wrong data: '{server_received_data}'"
|
||||
)
|
||||
assert echo_sent, "Server failed to send echo"
|
||||
assert client_received_echo == "ECHO: Hello QUIC Server!", (
|
||||
f"Client received wrong echo: '{client_received_echo}'"
|
||||
)
|
||||
|
||||
print("✅ BASIC ECHO TEST PASSED!")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_server_accept_stream_timeout(
|
||||
self, server_key, client_key, server_config, client_config
|
||||
):
|
||||
"""Test what happens when server accept_stream times out."""
|
||||
print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===")
|
||||
|
||||
server_transport = QUICTransport(server_key.private_key, server_config)
|
||||
|
||||
accept_stream_called = False
|
||||
accept_stream_timeout = False
|
||||
|
||||
async def timeout_test_handler(connection: QUICConnection) -> None:
|
||||
"""Handler that tests accept_stream timeout."""
|
||||
nonlocal accept_stream_called, accept_stream_timeout
|
||||
|
||||
print("🔗 SERVER: Connection established, testing accept_stream timeout")
|
||||
accept_stream_called = True
|
||||
|
||||
try:
|
||||
print("📡 SERVER: Calling accept_stream with 2 second timeout...")
|
||||
stream = await connection.accept_stream(timeout=2.0)
|
||||
print(f"✅ SERVER: accept_stream returned: {stream}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⏰ SERVER: accept_stream timed out or failed: {e}")
|
||||
accept_stream_timeout = True
|
||||
|
||||
listener = server_transport.create_listener(timeout_test_handler)
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
client_connected = False
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start server
|
||||
server_transport.set_background_nursery(nursery)
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
|
||||
server_addr = multiaddr.Multiaddr(
|
||||
f"{listener.get_addrs()[0]}/p2p/{ID.from_pubkey(server_key.public_key)}"
|
||||
)
|
||||
print(f"🔧 SERVER: Listening on {server_addr}")
|
||||
|
||||
# Create client but DON'T open a stream
|
||||
async with trio.open_nursery() as client_nursery:
|
||||
client_transport = QUICTransport(
|
||||
client_key.private_key, client_config
|
||||
)
|
||||
client_transport.set_background_nursery(client_nursery)
|
||||
|
||||
try:
|
||||
print("📞 CLIENT: Connecting (but NOT opening stream)...")
|
||||
connection = await client_transport.dial(server_addr)
|
||||
client_connected = True
|
||||
print("✅ CLIENT: Connected (no stream opened)")
|
||||
|
||||
# Wait for server timeout
|
||||
await trio.sleep(3.0)
|
||||
|
||||
await connection.close()
|
||||
print("🔒 CLIENT: Connection closed")
|
||||
|
||||
finally:
|
||||
await client_transport.close()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
finally:
|
||||
await listener.close()
|
||||
await server_transport.close()
|
||||
|
||||
print("\n📊 TIMEOUT TEST RESULTS:")
|
||||
print(f" Client connected: {client_connected}")
|
||||
print(f" accept_stream called: {accept_stream_called}")
|
||||
print(f" accept_stream timeout: {accept_stream_timeout}")
|
||||
|
||||
assert client_connected, "Client should have connected"
|
||||
assert accept_stream_called, "accept_stream should have been called"
|
||||
assert accept_stream_timeout, (
|
||||
"accept_stream should have timed out when no stream was opened"
|
||||
)
|
||||
|
||||
print("✅ TIMEOUT TEST PASSED!")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_yamux_stress_ping():
|
||||
STREAM_COUNT = 100
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
latencies = []
|
||||
failures = []
|
||||
|
||||
# === Server Setup ===
|
||||
server_host = new_host(listen_addrs=[listen_addr])
|
||||
|
||||
async def handle_ping(stream: INetStream) -> None:
|
||||
try:
|
||||
while True:
|
||||
payload = await stream.read(PING_LENGTH)
|
||||
if not payload:
|
||||
break
|
||||
await stream.write(payload)
|
||||
except Exception:
|
||||
await stream.reset()
|
||||
|
||||
server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||
|
||||
async with server_host.run(listen_addrs=[listen_addr]):
|
||||
# Give server time to start
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# === Client Setup ===
|
||||
destination = str(server_host.get_addrs()[0])
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
|
||||
client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
client_host = new_host(listen_addrs=[client_listen_addr])
|
||||
|
||||
async with client_host.run(listen_addrs=[client_listen_addr]):
|
||||
await client_host.connect(info)
|
||||
|
||||
async def ping_stream(i: int):
|
||||
stream = None
|
||||
try:
|
||||
start = trio.current_time()
|
||||
stream = await client_host.new_stream(
|
||||
info.peer_id, [PING_PROTOCOL_ID]
|
||||
)
|
||||
|
||||
await stream.write(b"\x01" * PING_LENGTH)
|
||||
|
||||
with trio.fail_after(5):
|
||||
response = await stream.read(PING_LENGTH)
|
||||
|
||||
if response == b"\x01" * PING_LENGTH:
|
||||
latency_ms = int((trio.current_time() - start) * 1000)
|
||||
latencies.append(latency_ms)
|
||||
print(f"[Ping #{i}] Latency: {latency_ms} ms")
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
print(f"[Ping #{i}] Failed: {e}")
|
||||
failures.append(i)
|
||||
if stream:
|
||||
await stream.reset()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(STREAM_COUNT):
|
||||
nursery.start_soon(ping_stream, i)
|
||||
|
||||
# === Result Summary ===
|
||||
print("\n📊 Ping Stress Test Summary")
|
||||
print(f"Total Streams Launched: {STREAM_COUNT}")
|
||||
print(f"Successful Pings: {len(latencies)}")
|
||||
print(f"Failed Pings: {len(failures)}")
|
||||
if failures:
|
||||
print(f"❌ Failed stream indices: {failures}")
|
||||
|
||||
# === Assertions ===
|
||||
assert len(latencies) == STREAM_COUNT, (
|
||||
f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}"
|
||||
)
|
||||
assert all(isinstance(x, int) and x >= 0 for x in latencies), (
|
||||
"Invalid latencies"
|
||||
)
|
||||
|
||||
avg_latency = sum(latencies) / len(latencies)
|
||||
print(f"✅ Average Latency: {avg_latency:.2f} ms")
|
||||
assert avg_latency < 1000
|
||||
150
tests/core/transport/quic/test_listener.py
Normal file
150
tests/core/transport/quic/test_listener.py
Normal file
@ -0,0 +1,150 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICListenError,
|
||||
)
|
||||
from libp2p.transport.quic.listener import QUICListener
|
||||
from libp2p.transport.quic.transport import (
|
||||
QUICTransport,
|
||||
QUICTransportConfig,
|
||||
)
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_quic_multiaddr,
|
||||
)
|
||||
|
||||
|
||||
class TestQUICListener:
|
||||
"""Test suite for QUIC listener functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def private_key(self):
|
||||
"""Generate test private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
"""Generate test transport configuration."""
|
||||
return QUICTransportConfig(idle_timeout=10.0)
|
||||
|
||||
@pytest.fixture
|
||||
def transport(self, private_key, transport_config):
|
||||
"""Create test transport instance."""
|
||||
return QUICTransport(private_key, transport_config)
|
||||
|
||||
@pytest.fixture
|
||||
def connection_handler(self):
|
||||
"""Mock connection handler."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def listener(self, transport, connection_handler):
|
||||
"""Create test listener."""
|
||||
return transport.create_listener(connection_handler)
|
||||
|
||||
def test_listener_creation(self, transport, connection_handler):
|
||||
"""Test listener creation."""
|
||||
listener = transport.create_listener(connection_handler)
|
||||
|
||||
assert isinstance(listener, QUICListener)
|
||||
assert listener._transport == transport
|
||||
assert listener._handler == connection_handler
|
||||
assert not listener._listening
|
||||
assert not listener._closed
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_invalid_multiaddr(self, listener: QUICListener):
|
||||
"""Test listener with invalid multiaddr."""
|
||||
async with trio.open_nursery() as nursery:
|
||||
invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
|
||||
with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"):
|
||||
await listener.listen(invalid_addr, nursery)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_basic_lifecycle(self, listener: QUICListener):
|
||||
"""Test basic listener lifecycle."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Start listening
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
assert listener.is_listening()
|
||||
|
||||
# Check bound addresses
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) == 1
|
||||
|
||||
# Check stats
|
||||
stats = listener.get_stats()
|
||||
assert stats["is_listening"] is True
|
||||
assert stats["active_connections"] == 0
|
||||
assert stats["pending_connections"] == 0
|
||||
|
||||
# Sender Cancel Signal
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
await listener.close()
|
||||
assert not listener.is_listening()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_double_listen(self, listener: QUICListener):
|
||||
"""Test that double listen raises error."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic")
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
await trio.sleep(0.01)
|
||||
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) > 0
|
||||
async with trio.open_nursery() as nursery2:
|
||||
with pytest.raises(QUICListenError, match="Already listening"):
|
||||
await listener.listen(listen_addr, nursery2)
|
||||
nursery2.cancel_scope.cancel()
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
finally:
|
||||
await listener.close()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_port_binding(self, listener: QUICListener):
|
||||
"""Test listener port binding and cleanup."""
|
||||
listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic")
|
||||
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
success = await listener.listen(listen_addr, nursery)
|
||||
assert success
|
||||
await trio.sleep(0.5)
|
||||
|
||||
addrs = listener.get_addrs()
|
||||
assert len(addrs) > 0
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
finally:
|
||||
await listener.close()
|
||||
|
||||
# By the time we get here, the listener and its tasks have been fully
|
||||
# shut down, allowing the nursery to exit without hanging.
|
||||
print("TEST COMPLETED SUCCESSFULLY.")
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_listener_stats_tracking(self, listener):
|
||||
"""Test listener statistics tracking."""
|
||||
initial_stats = listener.get_stats()
|
||||
|
||||
# All counters should start at 0
|
||||
assert initial_stats["connections_accepted"] == 0
|
||||
assert initial_stats["connections_rejected"] == 0
|
||||
assert initial_stats["bytes_received"] == 0
|
||||
assert initial_stats["packets_processed"] == 0
|
||||
123
tests/core/transport/quic/test_transport.py
Normal file
123
tests/core/transport/quic/test_transport.py
Normal file
@ -0,0 +1,123 @@
|
||||
from unittest.mock import (
|
||||
Mock,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICDialError,
|
||||
QUICListenError,
|
||||
)
|
||||
from libp2p.transport.quic.transport import (
|
||||
QUICTransport,
|
||||
QUICTransportConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestQUICTransport:
|
||||
"""Test suite for QUIC transport using trio."""
|
||||
|
||||
@pytest.fixture
|
||||
def private_key(self):
|
||||
"""Generate test private key."""
|
||||
return create_new_key_pair().private_key
|
||||
|
||||
@pytest.fixture
|
||||
def transport_config(self):
|
||||
"""Generate test transport configuration."""
|
||||
return QUICTransportConfig(
|
||||
idle_timeout=10.0, enable_draft29=True, enable_v1=True
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig):
|
||||
"""Create test transport instance."""
|
||||
return QUICTransport(private_key, transport_config)
|
||||
|
||||
def test_transport_initialization(self, transport):
|
||||
"""Test transport initialization."""
|
||||
assert transport._private_key is not None
|
||||
assert transport._peer_id is not None
|
||||
assert not transport._closed
|
||||
assert len(transport._quic_configs) >= 1
|
||||
|
||||
def test_supported_protocols(self, transport):
|
||||
"""Test supported protocol identifiers."""
|
||||
protocols = transport.protocols()
|
||||
# TODO: Update when quic-v1 compatible
|
||||
# assert "quic-v1" in protocols
|
||||
assert "quic" in protocols # draft-29
|
||||
|
||||
def test_can_dial_quic_addresses(self, transport: QUICTransport):
|
||||
"""Test multiaddr compatibility checking."""
|
||||
import multiaddr
|
||||
|
||||
# Valid QUIC addresses
|
||||
valid_addrs = [
|
||||
# TODO: Update Multiaddr package to accept quic-v1
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
multiaddr.Multiaddr(
|
||||
f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}"
|
||||
),
|
||||
]
|
||||
|
||||
for addr in valid_addrs:
|
||||
assert transport.can_dial(addr)
|
||||
|
||||
# Invalid addresses
|
||||
invalid_addrs = [
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/4001"),
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001"),
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/ws"),
|
||||
]
|
||||
|
||||
for addr in invalid_addrs:
|
||||
assert not transport.can_dial(addr)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_transport_lifecycle(self, transport):
|
||||
"""Test transport lifecycle management using trio."""
|
||||
assert not transport._closed
|
||||
|
||||
await transport.close()
|
||||
assert transport._closed
|
||||
|
||||
# Should be safe to close multiple times
|
||||
await transport.close()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_dial_closed_transport(self, transport: QUICTransport) -> None:
|
||||
"""Test dialing with closed transport raises error."""
|
||||
import multiaddr
|
||||
|
||||
await transport.close()
|
||||
|
||||
with pytest.raises(QUICDialError, match="Transport is closed"):
|
||||
await transport.dial(
|
||||
multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
)
|
||||
|
||||
def test_create_listener_closed_transport(self, transport: QUICTransport) -> None:
|
||||
"""Test creating listener with closed transport raises error."""
|
||||
transport._closed = True
|
||||
|
||||
with pytest.raises(QUICListenError, match="Transport is closed"):
|
||||
transport.create_listener(Mock())
|
||||
321
tests/core/transport/quic/test_utils.py
Normal file
321
tests/core/transport/quic/test_utils.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""
|
||||
Test suite for QUIC multiaddr utilities.
|
||||
Focused tests covering essential functionality required for QUIC transport.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICInvalidMultiaddrError,
|
||||
QUICUnsupportedVersionError,
|
||||
)
|
||||
from libp2p.transport.quic.utils import (
|
||||
create_quic_multiaddr,
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
normalize_quic_multiaddr,
|
||||
quic_multiaddr_to_endpoint,
|
||||
quic_version_to_wire_format,
|
||||
)
|
||||
|
||||
|
||||
class TestIsQuicMultiaddr:
|
||||
"""Test QUIC multiaddr detection."""
|
||||
|
||||
def test_valid_quic_v1_multiaddrs(self):
|
||||
"""Test valid QUIC v1 multiaddrs are detected."""
|
||||
valid_addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip4/192.168.1.1/udp/8080/quic-v1",
|
||||
"/ip6/::1/udp/4001/quic-v1",
|
||||
"/ip6/2001:db8::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in valid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
def test_valid_quic_draft29_multiaddrs(self):
|
||||
"""Test valid QUIC draft-29 multiaddrs are detected."""
|
||||
valid_addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip4/10.0.0.1/udp/9000/quic",
|
||||
"/ip6/::1/udp/4001/quic",
|
||||
"/ip6/fe80::1/udp/6000/quic",
|
||||
]
|
||||
|
||||
for addr_str in valid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC"
|
||||
|
||||
def test_invalid_multiaddrs(self):
|
||||
"""Test non-QUIC multiaddrs are not detected."""
|
||||
invalid_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC
|
||||
"/ip4/127.0.0.1/udp/4001", # UDP without QUIC
|
||||
"/ip4/127.0.0.1/udp/4001/ws", # WebSocket
|
||||
"/ip4/127.0.0.1/quic-v1", # Missing UDP
|
||||
"/udp/4001/quic-v1", # Missing IP
|
||||
"/dns4/example.com/tcp/443/tls", # Completely different
|
||||
]
|
||||
|
||||
for addr_str in invalid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
|
||||
|
||||
|
||||
class TestQuicMultiaddrToEndpoint:
|
||||
"""Test endpoint extraction from QUIC multiaddrs."""
|
||||
|
||||
def test_ipv4_extraction(self):
|
||||
"""Test IPv4 host/port extraction."""
|
||||
test_cases = [
|
||||
("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)),
|
||||
("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)),
|
||||
("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)),
|
||||
]
|
||||
|
||||
for addr_str, expected in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = quic_multiaddr_to_endpoint(maddr)
|
||||
assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
def test_ipv6_extraction(self):
|
||||
"""Test IPv6 host/port extraction."""
|
||||
test_cases = [
|
||||
("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)),
|
||||
("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)),
|
||||
]
|
||||
|
||||
for addr_str, expected in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = quic_multiaddr_to_endpoint(maddr)
|
||||
assert result == expected, f"Failed for {addr_str}"
|
||||
|
||||
def test_invalid_multiaddr_raises_error(self):
|
||||
"""Test invalid multiaddrs raise appropriate errors."""
|
||||
invalid_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/4001", # Not QUIC
|
||||
"/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol
|
||||
]
|
||||
|
||||
for addr_str in invalid_addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
quic_multiaddr_to_endpoint(maddr)
|
||||
|
||||
|
||||
class TestMultiaddrToQuicVersion:
|
||||
"""Test QUIC version extraction."""
|
||||
|
||||
def test_quic_v1_detection(self):
|
||||
"""Test QUIC v1 version detection."""
|
||||
addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip6/::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}"
|
||||
|
||||
def test_quic_draft29_detection(self):
|
||||
"""Test QUIC draft-29 version detection."""
|
||||
addrs = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip6/::1/udp/5000/quic",
|
||||
]
|
||||
|
||||
for addr_str in addrs:
|
||||
maddr = Multiaddr(addr_str)
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
assert version == "quic", f"Should detect quic for {addr_str}"
|
||||
|
||||
def test_non_quic_raises_error(self):
|
||||
"""Test non-QUIC multiaddrs raise error."""
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
multiaddr_to_quic_version(maddr)
|
||||
|
||||
|
||||
class TestCreateQuicMultiaddr:
|
||||
"""Test QUIC multiaddr creation."""
|
||||
|
||||
def test_ipv4_creation(self):
|
||||
"""Test IPv4 QUIC multiaddr creation."""
|
||||
test_cases = [
|
||||
("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"),
|
||||
("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"),
|
||||
("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"),
|
||||
]
|
||||
|
||||
for host, port, version, expected in test_cases:
|
||||
result = create_quic_multiaddr(host, port, version)
|
||||
assert str(result) == expected
|
||||
|
||||
def test_ipv6_creation(self):
|
||||
"""Test IPv6 QUIC multiaddr creation."""
|
||||
test_cases = [
|
||||
("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"),
|
||||
("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"),
|
||||
]
|
||||
|
||||
for host, port, version, expected in test_cases:
|
||||
result = create_quic_multiaddr(host, port, version)
|
||||
assert str(result) == expected
|
||||
|
||||
def test_default_version(self):
|
||||
"""Test default version is quic-v1."""
|
||||
result = create_quic_multiaddr("127.0.0.1", 4001)
|
||||
expected = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
assert str(result) == expected
|
||||
|
||||
def test_invalid_inputs_raise_errors(self):
|
||||
"""Test invalid inputs raise appropriate errors."""
|
||||
# Invalid IP
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("invalid-ip", 4001)
|
||||
|
||||
# Invalid port
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", 70000)
|
||||
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", -1)
|
||||
|
||||
# Invalid version
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
create_quic_multiaddr("127.0.0.1", 4001, "invalid-version")
|
||||
|
||||
|
||||
class TestQuicVersionToWireFormat:
|
||||
"""Test QUIC version to wire format conversion."""
|
||||
|
||||
def test_supported_versions(self):
|
||||
"""Test supported version conversions."""
|
||||
test_cases = [
|
||||
("quic-v1", 0x00000001), # RFC 9000
|
||||
("quic", 0xFF00001D), # draft-29
|
||||
]
|
||||
|
||||
for version, expected_wire in test_cases:
|
||||
result = quic_version_to_wire_format(TProtocol(version))
|
||||
assert result == expected_wire, f"Failed for version {version}"
|
||||
|
||||
def test_unsupported_version_raises_error(self):
|
||||
"""Test unsupported versions raise error."""
|
||||
with pytest.raises(QUICUnsupportedVersionError):
|
||||
quic_version_to_wire_format(TProtocol("unsupported-version"))
|
||||
|
||||
|
||||
class TestGetAlpnProtocols:
|
||||
"""Test ALPN protocol retrieval."""
|
||||
|
||||
def test_returns_libp2p_protocols(self):
|
||||
"""Test returns expected libp2p ALPN protocols."""
|
||||
protocols = get_alpn_protocols()
|
||||
assert protocols == ["libp2p"]
|
||||
assert isinstance(protocols, list)
|
||||
|
||||
def test_returns_copy(self):
|
||||
"""Test returns a copy, not the original list."""
|
||||
protocols1 = get_alpn_protocols()
|
||||
protocols2 = get_alpn_protocols()
|
||||
|
||||
# Modify one list
|
||||
protocols1.append("test")
|
||||
|
||||
# Other list should be unchanged
|
||||
assert protocols2 == ["libp2p"]
|
||||
|
||||
|
||||
class TestNormalizeQuicMultiaddr:
|
||||
"""Test QUIC multiaddr normalization."""
|
||||
|
||||
def test_already_normalized(self):
|
||||
"""Test already normalized multiaddrs pass through."""
|
||||
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
maddr = Multiaddr(addr_str)
|
||||
|
||||
result = normalize_quic_multiaddr(maddr)
|
||||
assert str(result) == addr_str
|
||||
|
||||
def test_normalize_different_versions(self):
|
||||
"""Test normalization works for different QUIC versions."""
|
||||
test_cases = [
|
||||
"/ip4/127.0.0.1/udp/4001/quic-v1",
|
||||
"/ip4/127.0.0.1/udp/4001/quic",
|
||||
"/ip6/::1/udp/5000/quic-v1",
|
||||
]
|
||||
|
||||
for addr_str in test_cases:
|
||||
maddr = Multiaddr(addr_str)
|
||||
result = normalize_quic_multiaddr(maddr)
|
||||
|
||||
# Should be valid QUIC multiaddr
|
||||
assert is_quic_multiaddr(result)
|
||||
|
||||
# Should be parseable
|
||||
host, port = quic_multiaddr_to_endpoint(result)
|
||||
version = multiaddr_to_quic_version(result)
|
||||
|
||||
# Should match original
|
||||
orig_host, orig_port = quic_multiaddr_to_endpoint(maddr)
|
||||
orig_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
assert host == orig_host
|
||||
assert port == orig_port
|
||||
assert version == orig_version
|
||||
|
||||
def test_non_quic_raises_error(self):
|
||||
"""Test non-QUIC multiaddrs raise error."""
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001")
|
||||
with pytest.raises(QUICInvalidMultiaddrError):
|
||||
normalize_quic_multiaddr(maddr)
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for utility functions working together."""
|
||||
|
||||
def test_round_trip_conversion(self):
|
||||
"""Test creating and parsing multiaddrs works correctly."""
|
||||
test_cases = [
|
||||
("127.0.0.1", 4001, "quic-v1"),
|
||||
("::1", 5000, "quic"),
|
||||
("192.168.1.100", 8080, "quic-v1"),
|
||||
]
|
||||
|
||||
for host, port, version in test_cases:
|
||||
# Create multiaddr
|
||||
maddr = create_quic_multiaddr(host, port, version)
|
||||
|
||||
# Should be detected as QUIC
|
||||
assert is_quic_multiaddr(maddr)
|
||||
|
||||
# Should extract original values
|
||||
extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr)
|
||||
extracted_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
assert extracted_host == host
|
||||
assert extracted_port == port
|
||||
assert extracted_version == version
|
||||
|
||||
# Should normalize to same value
|
||||
normalized = normalize_quic_multiaddr(maddr)
|
||||
assert str(normalized) == str(maddr)
|
||||
|
||||
def test_wire_format_integration(self):
|
||||
"""Test wire format conversion works with version detection."""
|
||||
addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1"
|
||||
maddr = Multiaddr(addr_str)
|
||||
|
||||
# Extract version and convert to wire format
|
||||
version = multiaddr_to_quic_version(maddr)
|
||||
wire_format = quic_version_to_wire_format(version)
|
||||
|
||||
# Should be QUIC v1 wire format
|
||||
assert wire_format == 0x00000001
|
||||
27
tests/core/transport/test_upgrader.py
Normal file
27
tests/core/transport/test_upgrader.py
Normal file
@ -0,0 +1,27 @@
|
||||
import pytest
|
||||
|
||||
from libp2p.custom_types import (
|
||||
TMuxerOptions,
|
||||
TSecurityOptions,
|
||||
)
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_transport_upgrader_security_and_muxer_initialization():
|
||||
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
|
||||
secure_transports: TSecurityOptions = {}
|
||||
muxer_transports: TMuxerOptions = {}
|
||||
negotiate_timeout = 15
|
||||
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
|
||||
)
|
||||
|
||||
# Verify security multistream initialization
|
||||
assert upgrader.security_multistream.transports == secure_transports
|
||||
# Verify muxer multistream initialization and timeout
|
||||
assert upgrader.muxer_multistream.transports == muxer_transports
|
||||
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout
|
||||
6
tests/examples/test_quic_echo_example.py
Normal file
6
tests/examples/test_quic_echo_example.py
Normal file
@ -0,0 +1,6 @@
|
||||
def test_echo_quic_example():
|
||||
"""Test that the QUIC echo example can be imported and has required functions."""
|
||||
from examples.echo import echo_quic
|
||||
|
||||
assert hasattr(echo_quic, "main")
|
||||
assert hasattr(echo_quic, "run")
|
||||
8
tests/interop/nim_libp2p/.gitignore
vendored
Normal file
8
tests/interop/nim_libp2p/.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
nimble.develop
|
||||
nimble.paths
|
||||
|
||||
*.nimble
|
||||
nim-libp2p/
|
||||
|
||||
nim_echo_server
|
||||
config.nims
|
||||
119
tests/interop/nim_libp2p/conftest.py
Normal file
119
tests/interop/nim_libp2p/conftest.py
Normal file
@ -0,0 +1,119 @@
|
||||
import fcntl
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_nim_available():
|
||||
"""Check if nim compiler is available."""
|
||||
return shutil.which("nim") is not None and shutil.which("nimble") is not None
|
||||
|
||||
|
||||
def check_nim_binary_built():
|
||||
"""Check if nim echo server binary is built."""
|
||||
current_dir = Path(__file__).parent
|
||||
binary_path = current_dir / "nim_echo_server"
|
||||
return binary_path.exists() and binary_path.stat().st_size > 0
|
||||
|
||||
|
||||
def run_nim_setup_with_lock():
|
||||
"""Run nim setup with file locking to prevent parallel execution."""
|
||||
current_dir = Path(__file__).parent
|
||||
lock_file = current_dir / ".setup_lock"
|
||||
setup_script = current_dir / "scripts" / "setup_nim_echo.sh"
|
||||
|
||||
if not setup_script.exists():
|
||||
raise RuntimeError(f"Setup script not found: {setup_script}")
|
||||
|
||||
# Try to acquire lock
|
||||
try:
|
||||
with open(lock_file, "w") as f:
|
||||
# Non-blocking lock attempt
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
|
||||
# Double-check binary doesn't exist (another worker might have built it)
|
||||
if check_nim_binary_built():
|
||||
logger.info("Binary already exists, skipping setup")
|
||||
return
|
||||
|
||||
logger.info("Acquired setup lock, running nim-libp2p setup...")
|
||||
|
||||
# Make setup script executable and run it
|
||||
setup_script.chmod(0o755)
|
||||
result = subprocess.run(
|
||||
[str(setup_script)],
|
||||
cwd=current_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5 minute timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Setup failed (exit {result.returncode}):\n"
|
||||
f"stdout: {result.stdout}\n"
|
||||
f"stderr: {result.stderr}"
|
||||
)
|
||||
|
||||
# Verify binary was built
|
||||
if not check_nim_binary_built():
|
||||
raise RuntimeError("nim_echo_server binary not found after setup")
|
||||
|
||||
logger.info("nim-libp2p setup completed successfully")
|
||||
|
||||
except BlockingIOError:
|
||||
# Another worker is running setup, wait for it to complete
|
||||
logger.info("Another worker is running setup, waiting...")
|
||||
|
||||
# Wait for setup to complete (check every 2 seconds, max 5 minutes)
|
||||
for _ in range(150): # 150 * 2 = 300 seconds = 5 minutes
|
||||
if check_nim_binary_built():
|
||||
logger.info("Setup completed by another worker")
|
||||
return
|
||||
time.sleep(2)
|
||||
|
||||
raise TimeoutError("Timed out waiting for setup to complete")
|
||||
|
||||
finally:
|
||||
# Clean up lock file
|
||||
try:
|
||||
lock_file.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="function") # Changed to function scope
|
||||
def nim_echo_binary():
|
||||
"""Get nim echo server binary path."""
|
||||
current_dir = Path(__file__).parent
|
||||
binary_path = current_dir / "nim_echo_server"
|
||||
|
||||
if not binary_path.exists():
|
||||
pytest.skip(
|
||||
"nim_echo_server binary not found. "
|
||||
"Run setup script: ./scripts/setup_nim_echo.sh"
|
||||
)
|
||||
|
||||
return binary_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def nim_server(nim_echo_binary):
|
||||
"""Start and stop nim echo server for tests."""
|
||||
# Import here to avoid circular imports
|
||||
# pyrefly: ignore
|
||||
from test_echo_interop import NimEchoServer
|
||||
|
||||
server = NimEchoServer(nim_echo_binary)
|
||||
|
||||
try:
|
||||
peer_id, listen_addr = await server.start()
|
||||
yield server, peer_id, listen_addr
|
||||
finally:
|
||||
await server.stop()
|
||||
108
tests/interop/nim_libp2p/nim_echo_server.nim
Normal file
108
tests/interop/nim_libp2p/nim_echo_server.nim
Normal file
@ -0,0 +1,108 @@
|
||||
{.used.}
|
||||
|
||||
import chronos
|
||||
import stew/byteutils
|
||||
import libp2p
|
||||
|
||||
##
|
||||
# Simple Echo Protocol Implementation for py-libp2p Interop Testing
|
||||
##
|
||||
const EchoCodec = "/echo/1.0.0"
|
||||
|
||||
type EchoProto = ref object of LPProtocol
|
||||
|
||||
proc new(T: typedesc[EchoProto]): T =
|
||||
proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} =
|
||||
try:
|
||||
echo "Echo server: Received connection from ", conn.peerId
|
||||
|
||||
# Read and echo messages in a loop
|
||||
while not conn.atEof:
|
||||
try:
|
||||
# Read length-prefixed message using nim-libp2p's readLp
|
||||
let message = await conn.readLp(1024 * 1024) # Max 1MB
|
||||
if message.len == 0:
|
||||
echo "Echo server: Empty message, closing connection"
|
||||
break
|
||||
|
||||
let messageStr = string.fromBytes(message)
|
||||
echo "Echo server: Received (", message.len, " bytes): ", messageStr
|
||||
|
||||
# Echo back using writeLp
|
||||
await conn.writeLp(message)
|
||||
echo "Echo server: Echoed message back"
|
||||
|
||||
except CatchableError as e:
|
||||
echo "Echo server: Error processing message: ", e.msg
|
||||
break
|
||||
|
||||
except CancelledError as e:
|
||||
echo "Echo server: Connection cancelled"
|
||||
raise e
|
||||
except CatchableError as e:
|
||||
echo "Echo server: Exception in handler: ", e.msg
|
||||
finally:
|
||||
echo "Echo server: Connection closed"
|
||||
await conn.close()
|
||||
|
||||
return T.new(codecs = @[EchoCodec], handler = handle)
|
||||
|
||||
##
|
||||
# Create QUIC-enabled switch
|
||||
##
|
||||
proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch =
|
||||
var switch = SwitchBuilder
|
||||
.new()
|
||||
.withRng(rng)
|
||||
.withAddress(ma)
|
||||
.withQuicTransport()
|
||||
.build()
|
||||
result = switch
|
||||
|
||||
##
|
||||
# Main server
|
||||
##
|
||||
proc main() {.async.} =
|
||||
let
|
||||
rng = newRng()
|
||||
localAddr = MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet()
|
||||
echoProto = EchoProto.new()
|
||||
|
||||
echo "=== Nim Echo Server for py-libp2p Interop ==="
|
||||
|
||||
# Create switch
|
||||
let switch = createSwitch(localAddr, rng)
|
||||
switch.mount(echoProto)
|
||||
|
||||
# Start server
|
||||
await switch.start()
|
||||
|
||||
# Print connection info
|
||||
echo "Peer ID: ", $switch.peerInfo.peerId
|
||||
echo "Listening on:"
|
||||
for addr in switch.peerInfo.addrs:
|
||||
echo " ", $addr, "/p2p/", $switch.peerInfo.peerId
|
||||
echo "Protocol: ", EchoCodec
|
||||
echo "Ready for py-libp2p connections!"
|
||||
echo ""
|
||||
|
||||
# Keep running
|
||||
try:
|
||||
await sleepAsync(100.hours)
|
||||
except CancelledError:
|
||||
echo "Shutting down..."
|
||||
finally:
|
||||
await switch.stop()
|
||||
|
||||
# Graceful shutdown handler
|
||||
proc signalHandler() {.noconv.} =
|
||||
echo "\nShutdown signal received"
|
||||
quit(0)
|
||||
|
||||
when isMainModule:
|
||||
setControlCHook(signalHandler)
|
||||
try:
|
||||
waitFor(main())
|
||||
except CatchableError as e:
|
||||
echo "Error: ", e.msg
|
||||
quit(1)
|
||||
74
tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
Executable file
74
tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
Executable file
@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env bash
|
||||
# tests/interop/nim_libp2p/scripts/setup_nim_echo.sh
|
||||
# Cache-aware setup that skips installation if packages exist
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="${SCRIPT_DIR}/.."
|
||||
|
||||
# Colors
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
||||
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||
|
||||
main() {
|
||||
log_info "Setting up nim echo server for interop testing..."
|
||||
|
||||
# Check if nim is available
|
||||
if ! command -v nim &> /dev/null || ! command -v nimble &> /dev/null; then
|
||||
log_error "Nim not found. Please install nim first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd "${PROJECT_DIR}"
|
||||
|
||||
# Create logs directory
|
||||
mkdir -p logs
|
||||
|
||||
# Check if binary already exists
|
||||
if [[ -f "nim_echo_server" ]]; then
|
||||
log_info "nim_echo_server already exists, skipping build"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Check if libp2p is already installed (cache-aware)
|
||||
if nimble list -i | grep -q "libp2p"; then
|
||||
log_info "libp2p already installed, skipping installation"
|
||||
else
|
||||
log_info "Installing nim-libp2p globally..."
|
||||
nimble install -y libp2p
|
||||
fi
|
||||
|
||||
log_info "Building nim echo server..."
|
||||
# Compile the echo server
|
||||
nim c \
|
||||
-d:release \
|
||||
-d:chronicles_log_level=INFO \
|
||||
-d:libp2p_quic_support \
|
||||
-d:chronos_event_loop=iocp \
|
||||
-d:ssl \
|
||||
--opt:speed \
|
||||
--mm:orc \
|
||||
--verbosity:1 \
|
||||
-o:nim_echo_server \
|
||||
nim_echo_server.nim
|
||||
|
||||
# Verify binary was created
|
||||
if [[ -f "nim_echo_server" ]]; then
|
||||
log_info "✅ nim_echo_server built successfully"
|
||||
log_info "Binary size: $(ls -lh nim_echo_server | awk '{print $5}')"
|
||||
else
|
||||
log_error "❌ Failed to build nim_echo_server"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "🎉 Setup complete!"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
195
tests/interop/nim_libp2p/test_echo_interop.py
Normal file
195
tests/interop/nim_libp2p/test_echo_interop.py
Normal file
@ -0,0 +1,195 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.utils.varint import encode_varint_prefixed, read_varint_prefixed_bytes
|
||||
|
||||
# Configuration
|
||||
PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
TEST_TIMEOUT = 30
|
||||
SERVER_START_TIMEOUT = 10.0
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NimEchoServer:
|
||||
"""Simple nim echo server manager."""
|
||||
|
||||
def __init__(self, binary_path: Path):
|
||||
self.binary_path = binary_path
|
||||
self.process: None | subprocess.Popen = None
|
||||
self.peer_id = None
|
||||
self.listen_addr = None
|
||||
|
||||
async def start(self):
|
||||
"""Start nim echo server and get connection info."""
|
||||
logger.info(f"Starting nim echo server: {self.binary_path}")
|
||||
|
||||
self.process = subprocess.Popen(
|
||||
[str(self.binary_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Parse output for connection info
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < SERVER_START_TIMEOUT:
|
||||
if self.process and self.process.poll() and self.process.stdout:
|
||||
output = self.process.stdout.read()
|
||||
raise RuntimeError(f"Server exited early: {output}")
|
||||
|
||||
reader = self.process.stdout if self.process else None
|
||||
if reader:
|
||||
line = reader.readline().strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
logger.info(f"Server: {line}")
|
||||
|
||||
if line.startswith("Peer ID:"):
|
||||
self.peer_id = line.split(":", 1)[1].strip()
|
||||
|
||||
elif "/quic-v1/p2p/" in line and self.peer_id:
|
||||
if line.strip().startswith("/"):
|
||||
self.listen_addr = line.strip()
|
||||
logger.info(f"Server ready: {self.listen_addr}")
|
||||
return self.peer_id, self.listen_addr
|
||||
|
||||
await self.stop()
|
||||
raise TimeoutError(f"Server failed to start within {SERVER_START_TIMEOUT}s")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the server."""
|
||||
if self.process:
|
||||
logger.info("Stopping nim echo server...")
|
||||
try:
|
||||
self.process.terminate()
|
||||
self.process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
self.process = None
|
||||
|
||||
|
||||
async def run_echo_test(server_addr: str, messages: list[str]):
|
||||
"""Test echo protocol against nim server with proper timeout handling."""
|
||||
# Create py-libp2p QUIC client with shorter timeouts
|
||||
|
||||
host = new_host(
|
||||
enable_quic=True,
|
||||
key_pair=create_new_key_pair(),
|
||||
)
|
||||
|
||||
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/udp/0/quic-v1")
|
||||
responses = []
|
||||
|
||||
try:
|
||||
async with host.run(listen_addrs=[listen_addr]):
|
||||
logger.info(f"Connecting to nim server: {server_addr}")
|
||||
|
||||
# Connect to nim server
|
||||
maddr = multiaddr.Multiaddr(server_addr)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
await host.connect(info)
|
||||
|
||||
# Create stream
|
||||
stream = await host.new_stream(info.peer_id, [PROTOCOL_ID])
|
||||
logger.info("Stream created")
|
||||
|
||||
# Test each message
|
||||
for i, message in enumerate(messages, 1):
|
||||
logger.info(f"Testing message {i}: {message}")
|
||||
|
||||
# Send with varint length prefix
|
||||
data = message.encode("utf-8")
|
||||
prefixed_data = encode_varint_prefixed(data)
|
||||
await stream.write(prefixed_data)
|
||||
|
||||
# Read response
|
||||
response_data = await read_varint_prefixed_bytes(stream)
|
||||
response = response_data.decode("utf-8")
|
||||
|
||||
logger.info(f"Got echo: {response}")
|
||||
responses.append(response)
|
||||
|
||||
# Verify echo
|
||||
assert message == response, (
|
||||
f"Echo failed: sent {message!r}, got {response!r}"
|
||||
)
|
||||
|
||||
await stream.close()
|
||||
logger.info("✅ All messages echoed correctly")
|
||||
|
||||
finally:
|
||||
await host.close()
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.timeout(TEST_TIMEOUT)
|
||||
async def test_basic_echo_interop(nim_server):
|
||||
"""Test basic echo functionality between py-libp2p and nim-libp2p."""
|
||||
server, peer_id, listen_addr = nim_server
|
||||
|
||||
test_messages = [
|
||||
"Hello from py-libp2p!",
|
||||
"QUIC transport working",
|
||||
"Echo test successful!",
|
||||
"Unicode: Ñoël, 测试, Ψυχή",
|
||||
]
|
||||
|
||||
logger.info(f"Testing against nim server: {peer_id}")
|
||||
|
||||
# Run test with timeout
|
||||
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
|
||||
responses = await run_echo_test(listen_addr, test_messages)
|
||||
|
||||
# Verify all messages echoed correctly
|
||||
assert len(responses) == len(test_messages)
|
||||
for sent, received in zip(test_messages, responses):
|
||||
assert sent == received
|
||||
|
||||
logger.info("✅ Basic echo interop test passed!")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
@pytest.mark.timeout(TEST_TIMEOUT)
|
||||
async def test_large_message_echo(nim_server):
|
||||
"""Test echo with larger messages."""
|
||||
server, peer_id, listen_addr = nim_server
|
||||
|
||||
large_messages = [
|
||||
"x" * 1024,
|
||||
"y" * 5000,
|
||||
]
|
||||
|
||||
logger.info("Testing large message echo...")
|
||||
|
||||
# Run test with timeout
|
||||
with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup
|
||||
responses = await run_echo_test(listen_addr, large_messages)
|
||||
|
||||
assert len(responses) == len(large_messages)
|
||||
for sent, received in zip(large_messages, responses):
|
||||
assert sent == received
|
||||
|
||||
logger.info("✅ Large message echo test passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user