diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index ef963f80..56d6a0bc 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -36,10 +36,48 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} + + - name: Install Nim for interop testing + if: matrix.toxenv == 'interop' + run: | + echo "Installing Nim for nim-libp2p interop testing..." + curl -sSf https://nim-lang.org/choosenim/init.sh | sh -s -- -y --firstInstall + echo "$HOME/.nimble/bin" >> $GITHUB_PATH + echo "$HOME/.choosenim/toolchains/nim-stable/bin" >> $GITHUB_PATH + + - name: Cache nimble packages + if: matrix.toxenv == 'interop' + uses: actions/cache@v4 + with: + path: | + ~/.nimble + ~/.choosenim/toolchains/*/lib + key: ${{ runner.os }}-nimble-${{ hashFiles('**/nim_echo_server.nim') }} + restore-keys: | + ${{ runner.os }}-nimble- + + - name: Build nim interop binaries + if: matrix.toxenv == 'interop' + run: | + export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH" + cd tests/interop/nim_libp2p + ./scripts/setup_nim_echo.sh + - run: | python -m pip install --upgrade pip python -m pip install tox - - run: | + + - name: Run Tests or Generate Docs + run: | + if [[ "${{ matrix.toxenv }}" == 'docs' ]]; then + export TOXENV=docs + else + export TOXENV=py${{ matrix.python }}-${{ matrix.toxenv }} + fi + # Set PATH for nim commands during tox + if [[ "${{ matrix.toxenv }}" == 'interop' ]]; then + export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH" + fi python -m tox run -r windows: @@ -65,5 +103,5 @@ jobs: if [[ "${{ matrix.toxenv }}" == "wheel" ]]; then python -m tox run -e windows-wheel else - python -m tox run -e py311-${{ matrix.toxenv }} + python -m tox run -e py${{ matrix.python-version }}-${{ matrix.toxenv }} fi diff --git a/.gitignore b/.gitignore index fd2c8231..11e75cda 100644 --- a/.gitignore +++ b/.gitignore @@ -178,6 +178,10 @@ env.bak/ #lockfiles uv.lock poetry.lock +tests/interop/js_libp2p/js_node/node_modules/ +tests/interop/js_libp2p/js_node/package-lock.json +tests/interop/js_libp2p/js_node/src/node_modules/ +tests/interop/js_libp2p/js_node/src/package-lock.json # Sphinx documentation build _build/ diff --git a/README.md b/README.md index 77166429..f87fbea6 100644 --- a/README.md +++ b/README.md @@ -61,12 +61,12 @@ ______________________________________________________________________ ### Discovery -| **Discovery** | **Status** | **Source** | -| -------------------- | :--------: | :--------------------------------------------------------------------------------: | -| **`bootstrap`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) | -| **`random-walk`** | 🌱 | | -| **`mdns-discovery`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) | -| **`rendezvous`** | 🌱 | | +| **Discovery** | **Status** | **Source** | +| -------------------- | :--------: | :----------------------------------------------------------------------------------: | +| **`bootstrap`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) | +| **`random-walk`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/random_walk) | +| **`mdns-discovery`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) | +| **`rendezvous`** | 🌱 | | ______________________________________________________________________ diff --git a/docs/examples.circuit_relay.rst b/docs/examples.circuit_relay.rst index 2a14c3c5..055aafdf 100644 --- a/docs/examples.circuit_relay.rst +++ b/docs/examples.circuit_relay.rst @@ -36,12 +36,14 @@ Create a file named ``relay_node.py`` with the following content: from libp2p.relay.circuit_v2.transport import CircuitV2Transport from libp2p.relay.circuit_v2.config import RelayConfig from libp2p.tools.async_service import background_trio_service + from libp2p.utils import get_wildcard_address logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("relay_node") async def run_relay(): - listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9000") + # Use wildcard address to listen on all interfaces + listen_addr = get_wildcard_address(9000) host = new_host() config = RelayConfig( @@ -107,6 +109,7 @@ Create a file named ``destination_node.py`` with the following content: from libp2p.relay.circuit_v2.config import RelayConfig from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.tools.async_service import background_trio_service + from libp2p.utils import get_wildcard_address logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("destination_node") @@ -139,7 +142,8 @@ Create a file named ``destination_node.py`` with the following content: Run a simple destination node that accepts connections. This is a simplified version that doesn't use the relay functionality. """ - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/9001") + # Create a libp2p host - use wildcard address to listen on all interfaces + listen_addr = get_wildcard_address(9001) host = new_host() # Configure as a relay receiver (stop) @@ -252,14 +256,15 @@ Create a file named ``source_node.py`` with the following content: from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.tools.async_service import background_trio_service from libp2p.relay.circuit_v2.discovery import RelayInfo + from libp2p.utils import get_wildcard_address # Configure logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("source_node") async def run_source(relay_peer_id=None, destination_peer_id=None): - # Create a libp2p host - listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9002") + # Create a libp2p host - use wildcard address to listen on all interfaces + listen_addr = get_wildcard_address(9002) host = new_host() # Configure as a relay client @@ -428,7 +433,7 @@ Running the Example Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx ================================================== - Listening on: [] + Listening on: [] Protocol service started Relay service started successfully Relay limits: RelayLimits(duration=3600, data=10485760, max_circuit_conns=8, max_reservations=4) @@ -447,7 +452,7 @@ Running the Example Use this ID in the source node: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s ================================================== - Listening on: [] + Listening on: [] Registered echo protocol handler Protocol service started Transport created @@ -469,7 +474,7 @@ Running the Example $ python source_node.py Source node started with ID: QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3 - Listening on: [] + Listening on: [] Protocol service started No relay peer ID provided. Please enter the relay\'s peer ID: Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx diff --git a/docs/examples.echo_quic.rst b/docs/examples.echo_quic.rst new file mode 100644 index 00000000..0e3313df --- /dev/null +++ b/docs/examples.echo_quic.rst @@ -0,0 +1,43 @@ +QUIC Echo Demo +============== + +This example demonstrates a simple ``echo`` protocol using **QUIC transport**. + +QUIC provides built-in TLS security and stream multiplexing over UDP, making it an excellent transport choice for libp2p applications. + +.. code-block:: console + + $ python -m pip install libp2p + Collecting libp2p + ... + Successfully installed libp2p-x.x.x + $ echo-quic-demo + Run this from the same folder in another console: + + echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmAsbxRR1HiGJRNVPQLNMeNsBCsXT3rDjoYBQzgzNpM5mJ + + Waiting for incoming connection... + +Copy the line that starts with ``echo-quic-demo -p 8001``, open a new terminal in the same +folder and paste it in: + +.. code-block:: console + + $ echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu + + I am 16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu + STARTING CLIENT CONNECTION PROCESS + CLIENT CONNECTED TO SERVER + Sent: hi, there! + Got: ECHO: hi, there! + +**Key differences from TCP Echo:** + +- Uses UDP instead of TCP: ``/udp/8000`` instead of ``/tcp/8000`` +- Includes QUIC protocol identifier: ``/quic-v1`` in the multiaddr +- Built-in TLS security (no separate security transport needed) +- Native stream multiplexing over a single QUIC connection + +.. literalinclude:: ../examples/echo/echo_quic.py + :language: python + :linenos: diff --git a/docs/examples.identify.rst b/docs/examples.identify.rst index 9623f112..ba3e13c4 100644 --- a/docs/examples.identify.rst +++ b/docs/examples.identify.rst @@ -12,7 +12,7 @@ This example demonstrates how to use the libp2p ``identify`` protocol. $ identify-demo First host listening. Run this from another console: - identify-demo -p 8889 -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM + identify-demo -p 8889 -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM Waiting for incoming identify request... @@ -21,13 +21,13 @@ folder and paste it in: .. code-block:: console - $ identify-demo -p 8889 -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM - dialer (host_b) listening on /ip4/0.0.0.0/tcp/8889 + $ identify-demo -p 8889 -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM + dialer (host_b) listening on /ip4/127.0.0.1/tcp/8889 Second host connecting to peer: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM Starting identify protocol... Identify response: Public Key (Base64): CAASpgIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDC6c/oNPP9X13NDQ3Xrlp3zOj+ErXIWb/A4JGwWchiDBwMhMslEX3ct8CqI0BqUYKuwdFjowqqopOJ3cS2MlqtGaiP6Dg9bvGqSDoD37BpNaRVNcebRxtB0nam9SQy3PYLbHAmz0vR4ToSiL9OLRORnGOxCtHBuR8ZZ5vS0JEni8eQMpNa7IuXwyStnuty/QjugOZudBNgYSr8+9gH722KTjput5IRL7BrpIdd4HNXGVRm4b9BjNowvHu404x3a/ifeNblpy/FbYyFJEW0looygKF7hpRHhRbRKIDZt2BqOfT1sFkbqsHE85oY859+VMzP61YELgvGwai2r7KcjkW/AgMBAAE= - Listen Addresses: ['/ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM'] + Listen Addresses: ['/ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM'] Protocols: ['/ipfs/id/1.0.0', '/ipfs/ping/1.0.0'] Observed Address: ['/ip4/127.0.0.1/tcp/38082'] Protocol Version: ipfs/0.1.0 diff --git a/docs/examples.identify_push.rst b/docs/examples.identify_push.rst index 5b217d38..614d37bd 100644 --- a/docs/examples.identify_push.rst +++ b/docs/examples.identify_push.rst @@ -34,11 +34,11 @@ There is also a more interactive version of the example which runs as separate l ==== Starting Identify-Push Listener on port 8888 ==== Listener host ready! - Listening on: /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM + Listening on: /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM Peer ID: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM Run dialer with command: - identify-push-listener-dialer-demo -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM + identify-push-listener-dialer-demo -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM Waiting for incoming connections... (Ctrl+C to exit) @@ -47,12 +47,12 @@ folder and paste it in: .. code-block:: console - $ identify-push-listener-dialer-demo -d /ip4/0.0.0.0/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM + $ identify-push-listener-dialer-demo -d /ip4/127.0.0.1/tcp/8888/p2p/QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM ==== Starting Identify-Push Dialer on port 8889 ==== Dialer host ready! - Listening on: /ip4/0.0.0.0/tcp/8889/p2p/QmZyXwVuTaBcDeRsSkJpOpWrSt + Listening on: /ip4/127.0.0.1/tcp/8889/p2p/QmZyXwVuTaBcDeRsSkJpOpWrSt Connecting to peer: QmUiN4R3fNrCoQugGgmmb3v35neMEjKFNrsbNGVDsRHWpM Successfully connected to listener! diff --git a/docs/examples.pubsub.rst b/docs/examples.pubsub.rst index f3a8500f..8990e5c0 100644 --- a/docs/examples.pubsub.rst +++ b/docs/examples.pubsub.rst @@ -15,7 +15,7 @@ This example demonstrates how to create a chat application using libp2p's PubSub 2025-04-06 23:59:17,471 - pubsub-demo - INFO - Your selected topic is: pubsub-chat 2025-04-06 23:59:17,472 - pubsub-demo - INFO - Using random available port: 33269 2025-04-06 23:59:17,490 - pubsub-demo - INFO - Node started with peer ID: QmcJnocH1d1tz3Zp4MotVDjNfNFawXHw2dpB9tMYGTXJp7 - 2025-04-06 23:59:17,490 - pubsub-demo - INFO - Listening on: /ip4/0.0.0.0/tcp/33269 + 2025-04-06 23:59:17,490 - pubsub-demo - INFO - Listening on: /ip4/127.0.0.1/tcp/33269 2025-04-06 23:59:17,490 - pubsub-demo - INFO - Initializing PubSub and GossipSub... 2025-04-06 23:59:17,491 - pubsub-demo - INFO - Pubsub and GossipSub services started. 2025-04-06 23:59:17,491 - pubsub-demo - INFO - Pubsub ready. @@ -35,7 +35,7 @@ Copy the line that starts with ``pubsub-demo -d``, open a new terminal and paste 2025-04-07 00:00:59,846 - pubsub-demo - INFO - Your selected topic is: pubsub-chat 2025-04-07 00:00:59,846 - pubsub-demo - INFO - Using random available port: 51977 2025-04-07 00:00:59,864 - pubsub-demo - INFO - Node started with peer ID: QmYQKCm95Ut1aXsjHmWVYqdaVbno1eKTYC8KbEVjqUaKaQ - 2025-04-07 00:00:59,864 - pubsub-demo - INFO - Listening on: /ip4/0.0.0.0/tcp/51977 + 2025-04-07 00:00:59,864 - pubsub-demo - INFO - Listening on: /ip4/127.0.0.1/tcp/51977 2025-04-07 00:00:59,864 - pubsub-demo - INFO - Initializing PubSub and GossipSub... 2025-04-07 00:00:59,864 - pubsub-demo - INFO - Pubsub and GossipSub services started. 2025-04-07 00:00:59,865 - pubsub-demo - INFO - Pubsub ready. diff --git a/docs/examples.random_walk.rst b/docs/examples.random_walk.rst index baa3f81f..ea9ea220 100644 --- a/docs/examples.random_walk.rst +++ b/docs/examples.random_walk.rst @@ -23,7 +23,7 @@ The Random Walk implementation performs the following key operations: 2025-08-12 19:51:25,424 - random-walk-example - INFO - Mode: server, Port: 0 Demo interval: 30s 2025-08-12 19:51:25,426 - random-walk-example - INFO - Starting server node on port 45123 2025-08-12 19:51:25,426 - random-walk-example - INFO - Node peer ID: 16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef - 2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/0.0.0.0/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef + 2025-08-12 19:51:25,426 - random-walk-example - INFO - Node address: /ip4/127.0.0.1/tcp/45123/p2p/16Uiu2HAm7EsNv5vvjPAehGAVfChjYjD63ZHyWogQRdzntSbAg9ef 2025-08-12 19:51:25,427 - random-walk-example - INFO - Initial routing table size: 0 2025-08-12 19:51:25,427 - random-walk-example - INFO - DHT service started in SERVER mode 2025-08-12 19:51:25,430 - libp2p.discovery.random_walk.rt_refresh_manager - INFO - RT Refresh Manager started diff --git a/docs/examples.rst b/docs/examples.rst index 74864cbe..9f149ad0 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -9,6 +9,7 @@ Examples examples.identify_push examples.chat examples.echo + examples.echo_quic examples.ping examples.pubsub examples.circuit_relay diff --git a/docs/getting_started.rst b/docs/getting_started.rst index a8303ce0..b5de85bc 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -28,6 +28,11 @@ For Python, the most common transport is TCP. Here's how to set up a basic TCP t .. literalinclude:: ../examples/doc-examples/example_transport.py :language: python +Also, QUIC is a modern transport protocol that provides built-in TLS security and stream multiplexing over UDP: + +.. literalinclude:: ../examples/doc-examples/example_quic_transport.py + :language: python + Connection Encryption ^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/libp2p.transport.quic.rst b/docs/libp2p.transport.quic.rst new file mode 100644 index 00000000..b7b4b561 --- /dev/null +++ b/docs/libp2p.transport.quic.rst @@ -0,0 +1,77 @@ +libp2p.transport.quic package +============================= + +Submodules +---------- + +libp2p.transport.quic.config module +----------------------------------- + +.. automodule:: libp2p.transport.quic.config + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.connection module +--------------------------------------- + +.. automodule:: libp2p.transport.quic.connection + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.exceptions module +--------------------------------------- + +.. automodule:: libp2p.transport.quic.exceptions + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.listener module +------------------------------------- + +.. automodule:: libp2p.transport.quic.listener + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.security module +------------------------------------- + +.. automodule:: libp2p.transport.quic.security + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.stream module +----------------------------------- + +.. automodule:: libp2p.transport.quic.stream + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.transport module +-------------------------------------- + +.. automodule:: libp2p.transport.quic.transport + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.utils module +---------------------------------- + +.. automodule:: libp2p.transport.quic.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.transport.quic + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.transport.rst b/docs/libp2p.transport.rst index 0d92c48f..2a468143 100644 --- a/docs/libp2p.transport.rst +++ b/docs/libp2p.transport.rst @@ -9,6 +9,11 @@ Subpackages libp2p.transport.tcp +.. toctree:: + :maxdepth: 4 + + libp2p.transport.quic + Submodules ---------- diff --git a/examples/advanced/network_discover.py b/examples/advanced/network_discover.py index 87b44ddf..945ed12c 100644 --- a/examples/advanced/network_discover.py +++ b/examples/advanced/network_discover.py @@ -14,11 +14,26 @@ try: expand_wildcard_address, get_available_interfaces, get_optimal_binding_address, + get_wildcard_address, ) except ImportError: - # Fallbacks if utilities are missing + # Fallbacks if utilities are missing - use minimal network discovery + import socket + def get_available_interfaces(port: int, protocol: str = "tcp"): - return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")] + # Try to get local network interfaces, fallback to loopback + addrs = [] + try: + # Get hostname IP (better than hardcoded localhost) + hostname = socket.gethostname() + local_ip = socket.gethostbyname(hostname) + if local_ip != "127.0.0.1": + addrs.append(Multiaddr(f"/ip4/{local_ip}/{protocol}/{port}")) + except Exception: + pass + # Always include loopback as fallback + addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")) + return addrs def expand_wildcard_address(addr: Multiaddr, port: int | None = None): if port is None: @@ -27,6 +42,15 @@ except ImportError: return [Multiaddr(addr_str + f"/{port}")] def get_optimal_binding_address(port: int, protocol: str = "tcp"): + # Try to get a non-loopback address first + interfaces = get_available_interfaces(port, protocol) + for addr in interfaces: + if "127.0.0.1" not in str(addr): + return addr + # Fallback to loopback if no other interfaces found + return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}") + + def get_wildcard_address(port: int, protocol: str = "tcp"): return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") @@ -37,7 +61,10 @@ def main() -> None: for a in interfaces: print(f" - {a}") - wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + # Demonstrate wildcard address as a feature + wildcard_v4 = get_wildcard_address(port) + print(f"\nWildcard address (feature): {wildcard_v4}") + expanded_v4 = expand_wildcard_address(wildcard_v4) print("\nExpanded IPv4 wildcard:") for a in expanded_v4: diff --git a/examples/bootstrap/bootstrap.py b/examples/bootstrap/bootstrap.py index af7d08cc..70ac3b0a 100644 --- a/examples/bootstrap/bootstrap.py +++ b/examples/bootstrap/bootstrap.py @@ -2,7 +2,6 @@ import argparse import logging import secrets -import multiaddr import trio from libp2p import new_host @@ -54,18 +53,26 @@ BOOTSTRAP_PEERS = [ async def run(port: int, bootstrap_addrs: list[str]) -> None: """Run the bootstrap discovery example.""" + from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, + get_optimal_binding_address, + ) + + if port <= 0: + port = find_free_port() + # Generate key pair secret = secrets.token_bytes(32) key_pair = create_new_key_pair(secret) - # Create listen address - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + # Create listen addresses for all available interfaces + listen_addrs = get_available_interfaces(port) # Register peer discovery handler peerDiscovery.register_peer_discovered_handler(on_peer_discovery) logger.info("šŸš€ Starting Bootstrap Discovery Example") - logger.info(f"šŸ“ Listening on: {listen_addr}") logger.info(f"🌐 Bootstrap peers: {len(bootstrap_addrs)}") print("\n" + "=" * 60) @@ -80,7 +87,22 @@ async def run(port: int, bootstrap_addrs: list[str]) -> None: host = new_host(key_pair=key_pair, bootstrap=bootstrap_addrs) try: - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + + logger.info("Listener ready, listening on:") + print("Listener ready, listening on:") + for addr in all_addrs: + logger.info(f"{addr}") + print(f"{addr}") + + # Display optimal address for reference + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}" + logger.info(f"Optimal address: {optimal_addr_with_peer}") + print(f"Optimal address: {optimal_addr_with_peer}") + # Keep running and log peer discovery events await trio.sleep_forever() except KeyboardInterrupt: @@ -98,7 +120,7 @@ def main() -> None: Usage: python bootstrap.py -p 8000 python bootstrap.py -p 8001 --custom-bootstrap \\ - "/ip4/127.0.0.1/tcp/8000/p2p/QmYourPeerID" + "/ip4/[HOST_IP]/tcp/8000/p2p/QmYourPeerID" """ parser = argparse.ArgumentParser( diff --git a/examples/chat/chat.py b/examples/chat/chat.py index 05a9b918..80b627e5 100755 --- a/examples/chat/chat.py +++ b/examples/chat/chat.py @@ -1,4 +1,5 @@ import argparse +import logging import sys import multiaddr @@ -17,6 +18,11 @@ from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + PROTOCOL_ID = TProtocol("/chat/1.0.0") MAX_READ_LEN = 2**32 - 1 @@ -40,9 +46,18 @@ async def write_data(stream: INetStream) -> None: async def run(port: int, destination: str) -> None: - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, + get_optimal_binding_address, + ) + + if port <= 0: + port = find_free_port() + + listen_addrs = get_available_interfaces(port) host = new_host() - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) @@ -54,10 +69,19 @@ async def run(port: int, destination: str) -> None: host.set_stream_handler(PROTOCOL_ID, stream_handler) + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + + print("Listener ready, listening on:\n") + for addr in all_addrs: + print(f"{addr}") + + # Use optimal address for the client command + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}" print( - "Run this from the same folder in another console:\n\n" - f"chat-demo " - f"-d {host.get_addrs()[0]}\n" + f"\nRun this from the same folder in another console:\n\n" + f"chat-demo -d {optimal_addr_with_peer}\n" ) print("Waiting for incoming connection...") @@ -86,7 +110,7 @@ def main() -> None: where is the multiaddress of the previous listener host. """ example_maddr = ( - "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" + "/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) parser.add_argument("-p", "--port", default=0, type=int, help="source port number") diff --git a/examples/doc-examples/example_encryption_insecure.py b/examples/doc-examples/example_encryption_insecure.py index c1536808..859ab295 100644 --- a/examples/doc-examples/example_encryption_insecure.py +++ b/examples/doc-examples/example_encryption_insecure.py @@ -1,6 +1,5 @@ import secrets -import multiaddr import trio from libp2p import ( @@ -9,9 +8,10 @@ from libp2p import ( from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) -from libp2p.security.insecure.transport import ( - PLAINTEXT_PROTOCOL_ID, - InsecureTransport, +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, ) @@ -38,17 +38,19 @@ async def main(): # Create a host with the key pair and insecure transport host = new_host(key_pair=key_pair, sec_opt=security_options) - # Configure the listening address + # Configure the listening address using the new paradigm port = 8000 - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Start the host - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print( "libp2p has started with insecure transport " "(not recommended for production)" ) print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_addr}") # Keep the host running await trio.sleep_forever() diff --git a/examples/doc-examples/example_encryption_noise.py b/examples/doc-examples/example_encryption_noise.py index a2a4318c..4138354f 100644 --- a/examples/doc-examples/example_encryption_noise.py +++ b/examples/doc-examples/example_encryption_noise.py @@ -1,6 +1,5 @@ import secrets -import multiaddr import trio from libp2p import ( @@ -13,6 +12,10 @@ from libp2p.security.noise.transport import ( PROTOCOL_ID as NOISE_PROTOCOL_ID, Transport as NoiseTransport, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) async def main(): @@ -39,14 +42,16 @@ async def main(): # Create a host with the key pair and Noise security host = new_host(key_pair=key_pair, sec_opt=security_options) - # Configure the listening address + # Configure the listening address using the new paradigm port = 8000 - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Start the host - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print("libp2p has started with Noise encryption") print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_addr}") # Keep the host running await trio.sleep_forever() diff --git a/examples/doc-examples/example_encryption_secio.py b/examples/doc-examples/example_encryption_secio.py index 603ad6ea..b90c28bb 100644 --- a/examples/doc-examples/example_encryption_secio.py +++ b/examples/doc-examples/example_encryption_secio.py @@ -1,6 +1,5 @@ import secrets -import multiaddr import trio from libp2p import ( @@ -13,6 +12,10 @@ from libp2p.security.secio.transport import ( ID as SECIO_PROTOCOL_ID, Transport as SecioTransport, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) async def main(): @@ -32,14 +35,16 @@ async def main(): # Create a host with the key pair and SECIO security host = new_host(key_pair=key_pair, sec_opt=security_options) - # Configure the listening address + # Configure the listening address using the new paradigm port = 8000 - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Start the host - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print("libp2p has started with SECIO encryption") print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_addr}") # Keep the host running await trio.sleep_forever() diff --git a/examples/doc-examples/example_multiplexer.py b/examples/doc-examples/example_multiplexer.py index 0d6f2662..63a29fc5 100644 --- a/examples/doc-examples/example_multiplexer.py +++ b/examples/doc-examples/example_multiplexer.py @@ -1,6 +1,5 @@ import secrets -import multiaddr import trio from libp2p import ( @@ -13,6 +12,10 @@ from libp2p.security.noise.transport import ( PROTOCOL_ID as NOISE_PROTOCOL_ID, Transport as NoiseTransport, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) async def main(): @@ -39,14 +42,16 @@ async def main(): # Create a host with the key pair, Noise security, and mplex multiplexer host = new_host(key_pair=key_pair, sec_opt=security_options) - # Configure the listening address + # Configure the listening address using the new paradigm port = 8000 - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Start the host - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print("libp2p has started with Noise encryption and mplex multiplexing") print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_addr}") # Keep the host running await trio.sleep_forever() diff --git a/examples/doc-examples/example_net_stream.py b/examples/doc-examples/example_net_stream.py index d8842bea..edd2ac90 100644 --- a/examples/doc-examples/example_net_stream.py +++ b/examples/doc-examples/example_net_stream.py @@ -38,6 +38,10 @@ from libp2p.network.stream.net_stream import ( from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) PROTOCOL_ID = TProtocol("/echo/1.0.0") @@ -173,7 +177,9 @@ async def run_enhanced_demo( """ Run enhanced echo demo with NetStream state management. """ - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + # Use the new address paradigm + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Generate or use provided key if seed: @@ -185,7 +191,7 @@ async def run_enhanced_demo( host = new_host(key_pair=create_new_key_pair(secret)) - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print(f"Host ID: {host.get_id().to_string()}") print("=" * 60) @@ -196,10 +202,12 @@ async def run_enhanced_demo( # type: ignore: Stream is type of NetStream host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler) + # Use optimal address for client command + optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}" print( "Run client from another console:\n" f"python3 example_net_stream.py " - f"-d {host.get_addrs()[0]}\n" + f"-d {optimal_addr_with_peer}\n" ) print("Waiting for connections...") print("Press Ctrl+C to stop server") @@ -226,7 +234,7 @@ async def run_enhanced_demo( def main() -> None: example_maddr = ( - "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" + "/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser( diff --git a/examples/doc-examples/example_peer_discovery.py b/examples/doc-examples/example_peer_discovery.py index 7ceec375..a85796c0 100644 --- a/examples/doc-examples/example_peer_discovery.py +++ b/examples/doc-examples/example_peer_discovery.py @@ -1,6 +1,6 @@ import secrets -import multiaddr +from multiaddr import Multiaddr import trio from libp2p import ( @@ -16,6 +16,10 @@ from libp2p.security.noise.transport import ( PROTOCOL_ID as NOISE_PROTOCOL_ID, Transport as NoiseTransport, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) async def main(): @@ -42,14 +46,16 @@ async def main(): # Create a host with the key pair, Noise security, and mplex multiplexer host = new_host(key_pair=key_pair, sec_opt=security_options) - # Configure the listening address + # Configure the listening address using the new paradigm port = 8000 - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Start the host - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print("libp2p has started") print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_addr}") # Connect to bootstrap peers manually bootstrap_list = [ @@ -61,7 +67,7 @@ async def main(): for addr in bootstrap_list: try: - peer_info = info_from_p2p_addr(multiaddr.Multiaddr(addr)) + peer_info = info_from_p2p_addr(Multiaddr(addr)) await host.connect(peer_info) print(f"Connected to {peer_info.peer_id.to_string()}") except Exception as e: diff --git a/examples/doc-examples/example_quic_transport.py b/examples/doc-examples/example_quic_transport.py new file mode 100644 index 00000000..15fef1a3 --- /dev/null +++ b/examples/doc-examples/example_quic_transport.py @@ -0,0 +1,49 @@ +import secrets + +import trio + +from libp2p import ( + new_host, +) +from libp2p.crypto.secp256k1 import ( + create_new_key_pair, +) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) + + +async def main(): + # Create a key pair for the host + secret = secrets.token_bytes(32) + key_pair = create_new_key_pair(secret) + + # Create a host with the key pair + host = new_host(key_pair=key_pair, enable_quic=True) + + # Configure the listening address using the new paradigm + port = 8000 + listen_addrs = get_available_interfaces(port, protocol="udp") + # Convert TCP addresses to QUIC-v1 addresses + quic_addrs = [] + for addr in listen_addrs: + addr_str = str(addr).replace("/tcp/", "/udp/") + "/quic-v1" + from multiaddr import Multiaddr + + quic_addrs.append(Multiaddr(addr_str)) + + optimal_addr = get_optimal_binding_address(port, protocol="udp") + optimal_quic_str = str(optimal_addr).replace("/tcp/", "/udp/") + "/quic-v1" + + # Start the host + async with host.run(listen_addrs=quic_addrs): + print("libp2p has started with QUIC transport") + print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_quic_str}") + # Keep the host running + await trio.sleep_forever() + + +# Run the async function +trio.run(main) diff --git a/examples/doc-examples/example_running.py b/examples/doc-examples/example_running.py index a0169931..2f495979 100644 --- a/examples/doc-examples/example_running.py +++ b/examples/doc-examples/example_running.py @@ -1,6 +1,5 @@ import secrets -import multiaddr import trio from libp2p import ( @@ -13,6 +12,10 @@ from libp2p.security.noise.transport import ( PROTOCOL_ID as NOISE_PROTOCOL_ID, Transport as NoiseTransport, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) async def main(): @@ -39,14 +42,16 @@ async def main(): # Create a host with the key pair, Noise security, and mplex multiplexer host = new_host(key_pair=key_pair, sec_opt=security_options) - # Configure the listening address + # Configure the listening address using the new paradigm port = 8000 - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Start the host - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print("libp2p has started") print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_addr}") # Keep the host running await trio.sleep_forever() diff --git a/examples/doc-examples/example_transport.py b/examples/doc-examples/example_transport.py index e981fa7d..9d29d457 100644 --- a/examples/doc-examples/example_transport.py +++ b/examples/doc-examples/example_transport.py @@ -1,6 +1,5 @@ import secrets -import multiaddr import trio from libp2p import ( @@ -9,6 +8,10 @@ from libp2p import ( from libp2p.crypto.secp256k1 import ( create_new_key_pair, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, +) async def main(): @@ -19,14 +22,16 @@ async def main(): # Create a host with the key pair host = new_host(key_pair=key_pair) - # Configure the listening address + # Configure the listening address using the new paradigm port = 8000 - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) + optimal_addr = get_optimal_binding_address(port) # Start the host - async with host.run(listen_addrs=[listen_addr]): + async with host.run(listen_addrs=listen_addrs): print("libp2p has started with TCP transport") print("libp2p is listening on:", host.get_addrs()) + print(f"Optimal address: {optimal_addr}") # Keep the host running await trio.sleep_forever() diff --git a/examples/doc-examples/multiple_connections_example.py b/examples/doc-examples/multiple_connections_example.py index f0738283..20a7fd86 100644 --- a/examples/doc-examples/multiple_connections_example.py +++ b/examples/doc-examples/multiple_connections_example.py @@ -7,6 +7,7 @@ This example shows how to: 2. Use different load balancing strategies 3. Access multiple connections through the new API 4. Maintain backward compatibility +5. Use the new address paradigm for network configuration """ import logging @@ -15,6 +16,7 @@ import trio from libp2p import new_swarm from libp2p.network.swarm import ConnectionConfig, RetryConfig +from libp2p.utils import get_available_interfaces, get_optimal_binding_address # Set up logging logging.basicConfig(level=logging.INFO) @@ -103,10 +105,45 @@ async def example_backward_compatibility() -> None: logger.info("Backward compatibility example completed") +async def example_network_address_paradigm() -> None: + """Example of using the new address paradigm with multiple connections.""" + logger.info("Demonstrating network address paradigm...") + + # Get available interfaces using the new paradigm + port = 8000 # Example port + available_interfaces = get_available_interfaces(port) + logger.info(f"Available interfaces: {available_interfaces}") + + # Get optimal binding address + optimal_address = get_optimal_binding_address(port) + logger.info(f"Optimal binding address: {optimal_address}") + + # Create connection config for multiple connections with network awareness + connection_config = ConnectionConfig( + max_connections_per_peer=3, load_balancing_strategy="round_robin" + ) + + # Create swarm with address paradigm + swarm = new_swarm(connection_config=connection_config) + + logger.info("Network address paradigm features:") + logger.info(" - get_available_interfaces() for interface discovery") + logger.info(" - get_optimal_binding_address() for smart address selection") + logger.info(" - Multiple connections with proper network binding") + + await swarm.close() + logger.info("Network address paradigm example completed") + + async def example_production_ready_config() -> None: """Example of production-ready configuration.""" logger.info("Creating swarm with production-ready configuration...") + # Get optimal network configuration using the new paradigm + port = 8001 # Example port + optimal_address = get_optimal_binding_address(port) + logger.info(f"Using optimal binding address: {optimal_address}") + # Production-ready retry configuration retry_config = RetryConfig( max_retries=3, # Reasonable retry limit @@ -156,6 +193,9 @@ async def main() -> None: await example_backward_compatibility() logger.info("-" * 30) + await example_network_address_paradigm() + logger.info("-" * 30) + await example_production_ready_config() logger.info("-" * 30) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 19e98377..f95c9add 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -1,4 +1,5 @@ import argparse +import logging import random import secrets @@ -26,8 +27,14 @@ from libp2p.peer.peerinfo import ( from libp2p.utils.address_validation import ( find_free_port, get_available_interfaces, + get_optimal_binding_address, ) +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + PROTOCOL_ID = TProtocol("/echo/1.0.0") MAX_READ_LEN = 2**32 - 1 @@ -76,9 +83,13 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: for addr in listen_addr: print(f"{addr}/p2p/{peer_id}") + # Get optimal address for display + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = f"{optimal_addr}/p2p/{peer_id}" + print( "\nRun this from the same folder in another console:\n\n" - f"echo-demo -d {host.get_addrs()[0]}\n" + f"echo-demo -d {optimal_addr_with_peer}\n" ) print("Waiting for incoming connections...") await trio.sleep_forever() @@ -114,7 +125,7 @@ def main() -> None: where is the multiaddress of the previous listener host. """ example_maddr = ( - "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" + "/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) parser.add_argument("-p", "--port", default=0, type=int, help="source port number") diff --git a/examples/echo/echo_quic.py b/examples/echo/echo_quic.py new file mode 100644 index 00000000..87618fbb --- /dev/null +++ b/examples/echo/echo_quic.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +QUIC Echo Example - Fixed version with proper client/server separation + +This program demonstrates a simple echo protocol using QUIC transport where a peer +listens for connections and copies back any input received on a stream. + +Fixed to properly separate client and server modes - clients don't start listeners. +""" + +import argparse +import logging + +from multiaddr import Multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.network.stream.net_stream import INetStream +from libp2p.peer.peerinfo import info_from_p2p_addr + +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + +PROTOCOL_ID = TProtocol("/echo/1.0.0") + + +async def _echo_stream_handler(stream: INetStream) -> None: + try: + msg = await stream.read() + await stream.write(msg) + await stream.close() + except Exception as e: + print(f"Echo handler error: {e}") + try: + await stream.close() + except: # noqa: E722 + pass + + +async def run_server(port: int, seed: int | None = None) -> None: + """Run echo server with QUIC transport.""" + from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, + get_optimal_binding_address, + ) + + if port <= 0: + port = find_free_port() + + # For QUIC, we need UDP addresses - use the new address paradigm + tcp_addrs = get_available_interfaces(port) + # Convert TCP addresses to QUIC addresses + quic_addrs = [] + for addr in tcp_addrs: + addr_str = str(addr).replace("/tcp/", "/udp/") + "/quic" + quic_addrs.append(Multiaddr(addr_str)) + + if seed: + import random + + random.seed(seed) + secret_number = random.getrandbits(32 * 8) + secret = secret_number.to_bytes(length=32, byteorder="big") + else: + import secrets + + secret = secrets.token_bytes(32) + + # Create host with QUIC transport + host = new_host( + enable_quic=True, + key_pair=create_new_key_pair(secret), + ) + + # Server mode: start listener + async with host.run(listen_addrs=quic_addrs): + try: + print(f"I am {host.get_id().to_string()}") + host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + + print("Listener ready, listening on:") + for addr in all_addrs: + print(f"{addr}") + + # Use optimal address for the client command + optimal_tcp = get_optimal_binding_address(port) + optimal_quic_str = str(optimal_tcp).replace("/tcp/", "/udp/") + "/quic" + peer_id = host.get_id().to_string() + optimal_quic_with_peer = f"{optimal_quic_str}/p2p/{peer_id}" + print( + f"\nRun this from the same folder in another console:\n\n" + f"python3 ./examples/echo/echo_quic.py -d {optimal_quic_with_peer}\n" + ) + print("Waiting for incoming QUIC connections...") + await trio.sleep_forever() + except KeyboardInterrupt: + print("Closing server gracefully...") + await host.close() + return + + +async def run_client(destination: str, seed: int | None = None) -> None: + """Run echo client with QUIC transport.""" + if seed: + import random + + random.seed(seed) + secret_number = random.getrandbits(32 * 8) + secret = secret_number.to_bytes(length=32, byteorder="big") + else: + import secrets + + secret = secrets.token_bytes(32) + + # Create host with QUIC transport + host = new_host( + enable_quic=True, + key_pair=create_new_key_pair(secret), + ) + + # Client mode: NO listener, just connect + async with host.run(listen_addrs=[]): # Empty listen_addrs for client + print(f"I am {host.get_id().to_string()}") + + maddr = Multiaddr(destination) + info = info_from_p2p_addr(maddr) + + # Connect to server + print("STARTING CLIENT CONNECTION PROCESS") + await host.connect(info) + print("CLIENT CONNECTED TO SERVER") + + # Start a stream with the destination + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + + msg = b"hi, there!\n" + + await stream.write(msg) + response = await stream.read() + + print(f"Sent: {msg.decode('utf-8')}") + print(f"Got: {response.decode('utf-8')}") + await stream.close() + await host.disconnect(info.peer_id) + + +async def run(port: int, destination: str, seed: int | None = None) -> None: + """ + Run echo server or client with QUIC transport. + + Fixed version that properly separates client and server modes. + """ + if not destination: # Server mode + await run_server(port, seed) + else: # Client mode + await run_client(destination, seed) + + +def main() -> None: + """Main function - help text updated for QUIC.""" + description = """ + This program demonstrates a simple echo protocol using QUIC + transport where a peer listens for connections and copies back + any input received on a stream. + + QUIC provides built-in TLS security and stream multiplexing over UDP. + + To use it, first run 'echo-quic-demo -p ', where is + the UDP port number. Then, run another host with , + 'echo-quic-demo -d ' + where is the QUIC multiaddress of the previous listener host. + """ + + example_maddr = "/ip4/[HOST_IP]/udp/8000/quic/p2p/QmQn4SwGkDZKkUEpBRBv" + + parser = argparse.ArgumentParser(description=description) + parser.add_argument("-p", "--port", default=0, type=int, help="UDP port number") + parser.add_argument( + "-d", + "--destination", + type=str, + help=f"destination multiaddr string, e.g. {example_maddr}", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + help="provide a seed to the random number generator", + ) + args = parser.parse_args() + + try: + trio.run(run, args.port, args.destination, args.seed) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/examples/identify/identify.py b/examples/identify/identify.py index 98980f99..327ea4d6 100644 --- a/examples/identify/identify.py +++ b/examples/identify/identify.py @@ -20,6 +20,11 @@ from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + logger = logging.getLogger("libp2p.identity.identify-example") @@ -58,11 +63,19 @@ def print_identify_response(identify_response: Identify): async def run(port: int, destination: str, use_varint_format: bool = True) -> None: - localhost_ip = "0.0.0.0" + from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + ) if not destination: # Create first host (listener) - listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}") + if port <= 0: + from libp2p.utils.address_validation import find_free_port + + port = find_free_port() + + listen_addrs = get_available_interfaces(port) host_a = new_host() # Set up identify handler with specified format @@ -73,25 +86,49 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No host_a.set_stream_handler(IDENTIFY_PROTOCOL_ID, identify_handler) async with ( - host_a.run(listen_addrs=[listen_addr]), + host_a.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery, ): # Start the peer-store cleanup task nursery.start_soon(host_a.get_peerstore().start_cleanup_task, 60) - # Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client - # connections - server_addr = str(host_a.get_addrs()[0]) - client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") + # Get all available addresses with peer ID + all_addrs = host_a.get_addrs() - format_name = "length-prefixed" if use_varint_format else "raw protobuf" - format_flag = "--raw-format" if not use_varint_format else "" - print( - f"First host listening (using {format_name} format). " - f"Run this from another console:\n\n" - f"identify-demo {format_flag} -d {client_addr}\n" - ) - print("Waiting for incoming identify request...") + if use_varint_format: + format_name = "length-prefixed" + print(f"First host listening (using {format_name} format).") + print("Listener ready, listening on:\n") + for addr in all_addrs: + print(f"{addr}") + + # Use optimal address for the client command + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = ( + f"{optimal_addr}/p2p/{host_a.get_id().to_string()}" + ) + print( + f"\nRun this from the same folder in another console:\n\n" + f"identify-demo -d {optimal_addr_with_peer}\n" + ) + print("Waiting for incoming identify request...") + else: + format_name = "raw protobuf" + print(f"First host listening (using {format_name} format).") + print("Listener ready, listening on:\n") + for addr in all_addrs: + print(f"{addr}") + + # Use optimal address for the client command + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = ( + f"{optimal_addr}/p2p/{host_a.get_id().to_string()}" + ) + print( + f"\nRun this from the same folder in another console:\n\n" + f"identify-demo -d {optimal_addr_with_peer}\n" + ) + print("Waiting for incoming identify request...") # Add a custom handler to show connection events async def custom_identify_handler(stream): @@ -134,11 +171,20 @@ async def run(port: int, destination: str, use_varint_format: bool = True) -> No else: # Create second host (dialer) - listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}") + from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, + get_optimal_binding_address, + ) + + if port <= 0: + port = find_free_port() + + listen_addrs = get_available_interfaces(port) host_b = new_host() async with ( - host_b.run(listen_addrs=[listen_addr]), + host_b.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery, ): # Start the peer-store cleanup task @@ -234,7 +280,7 @@ def main() -> None: """ example_maddr = ( - "/ip4/127.0.0.1/tcp/8888/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" + "/ip4/[HOST_IP]/tcp/8888/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) @@ -258,7 +304,7 @@ def main() -> None: # Determine format: use varint (length-prefixed) if --raw-format is specified, # otherwise use raw protobuf format (old format) - use_varint_format = args.raw_format + use_varint_format = not args.raw_format try: if args.destination: diff --git a/examples/identify_push/identify_push_demo.py b/examples/identify_push/identify_push_demo.py index ccd8b29d..98e1e937 100644 --- a/examples/identify_push/identify_push_demo.py +++ b/examples/identify_push/identify_push_demo.py @@ -36,6 +36,9 @@ from libp2p.identity.identify_push import ( from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, +) # Configure logging logger = logging.getLogger(__name__) @@ -207,13 +210,13 @@ async def main() -> None: ID_PUSH, create_custom_identify_push_handler(host_2, "Host 2") ) - # Start listening on random ports using the run context manager - listen_addr_1 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") - listen_addr_2 = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0") + # Start listening on available interfaces using random ports + listen_addrs_1 = get_available_interfaces(0) # 0 for random port + listen_addrs_2 = get_available_interfaces(0) # 0 for random port async with ( - host_1.run([listen_addr_1]), - host_2.run([listen_addr_2]), + host_1.run(listen_addrs_1), + host_2.run(listen_addrs_2), trio.open_nursery() as nursery, ): # Start the peer-store cleanup task diff --git a/examples/identify_push/identify_push_listener_dialer.py b/examples/identify_push/identify_push_listener_dialer.py index c23e62bb..079457a2 100644 --- a/examples/identify_push/identify_push_listener_dialer.py +++ b/examples/identify_push/identify_push_listener_dialer.py @@ -14,7 +14,7 @@ Usage: python identify_push_listener_dialer.py # Then in another console, run as a dialer (default port 8889): - python identify_push_listener_dialer.py -d /ip4/127.0.0.1/tcp/8888/p2p/PEER_ID + python identify_push_listener_dialer.py -d /ip4/[HOST_IP]/tcp/8888/p2p/PEER_ID (where PEER_ID is the peer ID displayed by the listener) """ @@ -56,6 +56,11 @@ from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + # Configure logging logger = logging.getLogger("libp2p.identity.identify-push-example") @@ -194,6 +199,11 @@ async def run_listener( port: int, use_varint_format: bool = True, raw_format_flag: bool = False ) -> None: """Run a host in listener mode.""" + from libp2p.utils.address_validation import find_free_port, get_available_interfaces + + if port <= 0: + port = find_free_port() + format_name = "length-prefixed" if use_varint_format else "raw protobuf" print( f"\n==== Starting Identify-Push Listener on port {port} " @@ -215,26 +225,33 @@ async def run_listener( custom_identify_push_handler_for(host, use_varint_format=use_varint_format), ) - # Start listening - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + # Start listening on all available interfaces + listen_addrs = get_available_interfaces(port) try: - async with host.run([listen_addr]): - addr = host.get_addrs()[0] + async with host.run(listen_addrs): + all_addrs = host.get_addrs() logger.info("Listener host ready!") print("Listener host ready!") - logger.info(f"Listening on: {addr}") - print(f"Listening on: {addr}") + logger.info("Listener ready, listening on:") + print("Listener ready, listening on:") + for addr in all_addrs: + logger.info(f"{addr}") + print(f"{addr}") logger.info(f"Peer ID: {host.get_id().pretty()}") print(f"Peer ID: {host.get_id().pretty()}") - print("\nRun dialer with command:") + # Use the first address as the default for the dialer command + default_addr = all_addrs[0] + print("\nRun this from the same folder in another console:") if raw_format_flag: - print(f"identify-push-listener-dialer-demo -d {addr} --raw-format") + print( + f"identify-push-listener-dialer-demo -d {default_addr} --raw-format" + ) else: - print(f"identify-push-listener-dialer-demo -d {addr}") + print(f"identify-push-listener-dialer-demo -d {default_addr}") print("\nWaiting for incoming identify/push requests... (Ctrl+C to exit)") # Keep running until interrupted @@ -274,10 +291,12 @@ async def run_dialer( identify_push_handler_for(host, use_varint_format=use_varint_format), ) - # Start listening on a different port - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + # Start listening on available interfaces + from libp2p.utils.address_validation import get_available_interfaces - async with host.run([listen_addr]): + listen_addrs = get_available_interfaces(port) + + async with host.run(listen_addrs): logger.info("Dialer host ready!") print("Dialer host ready!") diff --git a/examples/kademlia/kademlia.py b/examples/kademlia/kademlia.py index faaa66be..cf4b2988 100644 --- a/examples/kademlia/kademlia.py +++ b/examples/kademlia/kademlia.py @@ -150,26 +150,43 @@ async def run_node( key_pair = create_new_key_pair(secrets.token_bytes(32)) host = new_host(key_pair=key_pair) - listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}") - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + ) + + listen_addrs = get_available_interfaces(port) + + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) peer_id = host.get_id().pretty() - addr_str = f"/ip4/127.0.0.1/tcp/{port}/p2p/{peer_id}" + + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + + logger.info("Listener ready, listening on:") + for addr in all_addrs: + logger.info(f"{addr}") + + # Use optimal address for the bootstrap command + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}" + bootstrap_cmd = f"--bootstrap {optimal_addr_with_peer}" + logger.info("To connect to this node, use: %s", bootstrap_cmd) + await connect_to_bootstrap_nodes(host, bootstrap_nodes) dht = KadDHT(host, dht_mode) # take all peer ids from the host and add them to the dht for peer_id in host.get_peerstore().peer_ids(): await dht.routing_table.add_peer(peer_id) logger.info(f"Connected to bootstrap nodes: {host.get_connected_peers()}") - bootstrap_cmd = f"--bootstrap {addr_str}" - logger.info("To connect to this node, use: %s", bootstrap_cmd) # Save server address in server mode if dht_mode == DHTMode.SERVER: - save_server_addr(addr_str) + save_server_addr(str(optimal_addr_with_peer)) # Start the DHT service async with background_trio_service(dht): diff --git a/examples/mDNS/mDNS.py b/examples/mDNS/mDNS.py index d3f11b56..9f0cf74b 100644 --- a/examples/mDNS/mDNS.py +++ b/examples/mDNS/mDNS.py @@ -2,7 +2,6 @@ import argparse import logging import secrets -import multiaddr import trio from libp2p import ( @@ -14,6 +13,11 @@ from libp2p.crypto.secp256k1 import ( ) from libp2p.discovery.events.peerDiscovery import peerDiscovery +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + logger = logging.getLogger("libp2p.discovery.mdns") logger.setLevel(logging.INFO) handler = logging.StreamHandler() @@ -22,34 +26,43 @@ handler.setFormatter( ) logger.addHandler(handler) -# Set root logger to DEBUG to capture all logs from dependencies -logging.getLogger().setLevel(logging.DEBUG) - def onPeerDiscovery(peerinfo: PeerInfo): logger.info(f"Discovered: {peerinfo.peer_id}") async def run(port: int) -> None: + from libp2p.utils.address_validation import find_free_port, get_available_interfaces + + if port <= 0: + port = find_free_port() + secret = secrets.token_bytes(32) key_pair = create_new_key_pair(secret) - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) peerDiscovery.register_peer_discovered_handler(onPeerDiscovery) - print( - "Run this from the same folder in another console to " - "start another peer on a different port:\n\n" - "mdns-demo -p \n" - ) - print("Waiting for mDNS peer discovery events...\n") - logger.info("Starting peer Discovery") host = new_host(key_pair=key_pair, enable_mDNS=True) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + + print("Listener ready, listening on:") + for addr in all_addrs: + print(f"{addr}") + + print( + "\nRun this from the same folder in another console to " + "start another peer on a different port:\n\n" + "mdns-demo -p \n" + ) + print("Waiting for mDNS peer discovery events...\n") + await trio.sleep_forever() diff --git a/examples/ping/ping.py b/examples/ping/ping.py index d1a5daae..f62689aa 100644 --- a/examples/ping/ping.py +++ b/examples/ping/ping.py @@ -1,4 +1,5 @@ import argparse +import logging import multiaddr import trio @@ -16,6 +17,11 @@ from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +# Configure minimal logging +logging.basicConfig(level=logging.WARNING) +logging.getLogger("multiaddr").setLevel(logging.WARNING) +logging.getLogger("libp2p").setLevel(logging.WARNING) + PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0") PING_LENGTH = 32 RESP_TIMEOUT = 60 @@ -55,20 +61,38 @@ async def send_ping(stream: INetStream) -> None: async def run(port: int, destination: str) -> None: - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") - host = new_host(listen_addrs=[listen_addr]) + from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, + get_optimal_binding_address, + ) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + if port <= 0: + port = find_free_port() + + listen_addrs = get_available_interfaces(port) + host = new_host(listen_addrs=listen_addrs) + + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) if not destination: host.set_stream_handler(PING_PROTOCOL_ID, handle_ping) + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + + print("Listener ready, listening on:\n") + for addr in all_addrs: + print(f"{addr}") + + # Use optimal address for the client command + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}" print( - "Run this from the same folder in another console:\n\n" - f"ping-demo " - f"-d {host.get_addrs()[0]}\n" + f"\nRun this from the same folder in another console:\n\n" + f"ping-demo -d {optimal_addr_with_peer}\n" ) print("Waiting for incoming connection...") @@ -94,7 +118,7 @@ def main() -> None: """ example_maddr = ( - "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" + "/ip4/[HOST_IP]/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" ) parser = argparse.ArgumentParser(description=description) diff --git a/examples/pubsub/pubsub.py b/examples/pubsub/pubsub.py index 41545658..adb3a1d0 100644 --- a/examples/pubsub/pubsub.py +++ b/examples/pubsub/pubsub.py @@ -102,14 +102,16 @@ async def monitor_peer_topics(pubsub, nursery, termination_event): async def run(topic: str, destination: str | None, port: int | None) -> None: - # Initialize network settings - localhost_ip = "127.0.0.1" + from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + ) if port is None or port == 0: port = find_free_port() logger.info(f"Using random available port: {port}") - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + listen_addrs = get_available_interfaces(port) # Create a new libp2p host host = new_host( @@ -138,12 +140,11 @@ async def run(topic: str, destination: str | None, port: int | None) -> None: pubsub = Pubsub(host, gossipsub) termination_event = trio.Event() # Event to signal termination - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) logger.info(f"Node started with peer ID: {host.get_id()}") - logger.info(f"Listening on: {listen_addr}") logger.info("Initializing PubSub and GossipSub...") async with background_trio_service(pubsub): async with background_trio_service(gossipsub): @@ -157,10 +158,21 @@ async def run(topic: str, destination: str | None, port: int | None) -> None: if not destination: # Server mode + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + + logger.info("Listener ready, listening on:") + for addr in all_addrs: + logger.info(f"{addr}") + + # Use optimal address for the client command + optimal_addr = get_optimal_binding_address(port) + optimal_addr_with_peer = ( + f"{optimal_addr}/p2p/{host.get_id().to_string()}" + ) logger.info( - "Run this script in another console with:\n" - f"pubsub-demo " - f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id()}\n" + f"\nRun this from the same folder in another console:\n\n" + f"pubsub-demo -d {optimal_addr_with_peer}\n" ) logger.info("Waiting for peers...") @@ -182,11 +194,6 @@ async def run(topic: str, destination: str | None, port: int | None) -> None: f"Connecting to peer: {info.peer_id} " f"using protocols: {protocols_in_maddr}" ) - logger.info( - "Run this script in another console with:\n" - f"pubsub-demo " - f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id()}\n" - ) try: await host.connect(info) logger.info(f"Connected to peer: {info.peer_id}") diff --git a/examples/random_walk/random_walk.py b/examples/random_walk/random_walk.py index 845ccd57..d2278b16 100644 --- a/examples/random_walk/random_walk.py +++ b/examples/random_walk/random_walk.py @@ -16,7 +16,6 @@ import random import secrets import sys -from multiaddr import Multiaddr import trio from libp2p import new_host @@ -130,16 +129,24 @@ async def run_node(port: int, mode: str, demo_interval: int = 30) -> None: # Create host and DHT key_pair = create_new_key_pair(secrets.token_bytes(32)) host = new_host(key_pair=key_pair, bootstrap=DEFAULT_BOOTSTRAP_NODES) - listen_addr = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + from libp2p.utils.address_validation import get_available_interfaces + + listen_addrs = get_available_interfaces(port) + + async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start maintenance tasks nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) nursery.start_soon(maintain_connections, host) peer_id = host.get_id().pretty() logger.info(f"Node peer ID: {peer_id}") - logger.info(f"Node address: /ip4/0.0.0.0/tcp/{port}/p2p/{peer_id}") + + # Get all available addresses with peer ID + all_addrs = host.get_addrs() + logger.info("Listener ready, listening on:") + for addr in all_addrs: + logger.info(f"{addr}") # Create and start DHT with Random Walk enabled dht = KadDHT(host, dht_mode, enable_random_walk=True) diff --git a/examples/test_tcp_data_transfer.py b/examples/test_tcp_data_transfer.py new file mode 100644 index 00000000..634386bd --- /dev/null +++ b/examples/test_tcp_data_transfer.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +""" +TCP P2P Data Transfer Test + +This test proves that TCP peer-to-peer data transfer works correctly in libp2p. +This serves as a baseline to compare with WebSocket tests. +""" + +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p import create_yamux_muxer_option, new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport + +# Test protocol for data exchange +TCP_DATA_PROTOCOL = TProtocol("/test/tcp-data-exchange/1.0.0") + + +async def create_tcp_host_pair(): + """Create a pair of hosts configured for TCP communication.""" + # Create key pairs + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + + # Create security options (using plaintext for simplicity) + def security_options(kp): + return { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=kp, secure_bytes_provider=None, peerstore=None + ) + } + + # Host A (listener) - TCP transport (default) + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options(key_pair_a), + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")], + ) + + # Host B (dialer) - TCP transport (default) + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options(key_pair_b), + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")], + ) + + return host_a, host_b + + +@pytest.mark.trio +async def test_tcp_basic_connection(): + """Test basic TCP connection establishment.""" + host_a, host_b = await create_tcp_host_pair() + + connection_established = False + + async def connection_handler(stream): + nonlocal connection_established + connection_established = True + await stream.close() + + host_a.set_stream_handler(TCP_DATA_PROTOCOL, connection_handler) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert listen_addrs, "Host A should have listen addresses" + + # Extract TCP address + tcp_addr = None + for addr in listen_addrs: + if "/tcp/" in str(addr) and "/ws" not in str(addr): + tcp_addr = addr + break + + assert tcp_addr, f"No TCP address found in {listen_addrs}" + print(f"šŸ”— Host A listening on: {tcp_addr}") + + # Create peer info for host A + peer_info = info_from_p2p_addr(tcp_addr) + + # Host B connects to host A + await host_b.connect(peer_info) + print("āœ… TCP connection established") + + # Open a stream to test the connection + stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL]) + await stream.close() + + # Wait a bit for the handler to be called + await trio.sleep(0.1) + + assert connection_established, "TCP connection handler should have been called" + print("āœ… TCP basic connection test successful!") + + +@pytest.mark.trio +async def test_tcp_data_transfer(): + """Test TCP peer-to-peer data transfer.""" + host_a, host_b = await create_tcp_host_pair() + + # Test data + test_data = b"Hello TCP P2P Data Transfer! This is a test message." + received_data = None + transfer_complete = trio.Event() + + async def data_handler(stream): + nonlocal received_data + try: + # Read the incoming data + received_data = await stream.read(len(test_data)) + # Echo it back to confirm successful transfer + await stream.write(received_data) + await stream.close() + transfer_complete.set() + except Exception as e: + print(f"Handler error: {e}") + transfer_complete.set() + + host_a.set_stream_handler(TCP_DATA_PROTOCOL, data_handler) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert listen_addrs, "Host A should have listen addresses" + + # Extract TCP address + tcp_addr = None + for addr in listen_addrs: + if "/tcp/" in str(addr) and "/ws" not in str(addr): + tcp_addr = addr + break + + assert tcp_addr, f"No TCP address found in {listen_addrs}" + print(f"šŸ”— Host A listening on: {tcp_addr}") + + # Create peer info for host A + peer_info = info_from_p2p_addr(tcp_addr) + + # Host B connects to host A + await host_b.connect(peer_info) + print("āœ… TCP connection established") + + # Open a stream for data transfer + stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL]) + print("āœ… TCP stream opened") + + # Send test data + await stream.write(test_data) + print(f"šŸ“¤ Sent data: {test_data}") + + # Read echoed data back + echoed_data = await stream.read(len(test_data)) + print(f"šŸ“„ Received echo: {echoed_data}") + + await stream.close() + + # Wait for transfer to complete + with trio.fail_after(5.0): # 5 second timeout + await transfer_complete.wait() + + # Verify data transfer + assert received_data == test_data, ( + f"Data mismatch: {received_data} != {test_data}" + ) + assert echoed_data == test_data, f"Echo mismatch: {echoed_data} != {test_data}" + + print("āœ… TCP P2P data transfer successful!") + print(f" Original: {test_data}") + print(f" Received: {received_data}") + print(f" Echoed: {echoed_data}") + + +@pytest.mark.trio +async def test_tcp_large_data_transfer(): + """Test TCP with larger data payloads.""" + host_a, host_b = await create_tcp_host_pair() + + # Large test data (10KB) + test_data = b"TCP Large Data Test! " * 500 # ~10KB + received_data = None + transfer_complete = trio.Event() + + async def large_data_handler(stream): + nonlocal received_data + try: + # Read data in chunks + chunks = [] + total_received = 0 + expected_size = len(test_data) + + while total_received < expected_size: + chunk = await stream.read(min(1024, expected_size - total_received)) + if not chunk: + break + chunks.append(chunk) + total_received += len(chunk) + + received_data = b"".join(chunks) + + # Send back confirmation + await stream.write(b"RECEIVED_OK") + await stream.close() + transfer_complete.set() + except Exception as e: + print(f"Large data handler error: {e}") + transfer_complete.set() + + host_a.set_stream_handler(TCP_DATA_PROTOCOL, large_data_handler) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert listen_addrs, "Host A should have listen addresses" + + # Extract TCP address + tcp_addr = None + for addr in listen_addrs: + if "/tcp/" in str(addr) and "/ws" not in str(addr): + tcp_addr = addr + break + + assert tcp_addr, f"No TCP address found in {listen_addrs}" + print(f"šŸ”— Host A listening on: {tcp_addr}") + print(f"šŸ“Š Test data size: {len(test_data)} bytes") + + # Create peer info for host A + peer_info = info_from_p2p_addr(tcp_addr) + + # Host B connects to host A + await host_b.connect(peer_info) + print("āœ… TCP connection established") + + # Open a stream for data transfer + stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL]) + print("āœ… TCP stream opened") + + # Send large test data in chunks + chunk_size = 1024 + sent_bytes = 0 + for i in range(0, len(test_data), chunk_size): + chunk = test_data[i : i + chunk_size] + await stream.write(chunk) + sent_bytes += len(chunk) + if sent_bytes % (chunk_size * 4) == 0: # Progress every 4KB + print(f"šŸ“¤ Sent {sent_bytes}/{len(test_data)} bytes") + + print(f"šŸ“¤ Sent all {len(test_data)} bytes") + + # Read confirmation + confirmation = await stream.read(1024) + print(f"šŸ“„ Received confirmation: {confirmation}") + + await stream.close() + + # Wait for transfer to complete + with trio.fail_after(10.0): # 10 second timeout for large data + await transfer_complete.wait() + + # Verify data transfer + assert received_data is not None, "No data was received" + assert received_data == test_data, ( + "Large data transfer failed:" + + f" sizes {len(received_data)} != {len(test_data)}" + ) + assert confirmation == b"RECEIVED_OK", f"Confirmation failed: {confirmation}" + + print("āœ… TCP large data transfer successful!") + print(f" Data size: {len(test_data)} bytes") + print(f" Received: {len(received_data)} bytes") + print(f" Match: {received_data == test_data}") + + +@pytest.mark.trio +async def test_tcp_bidirectional_transfer(): + """Test bidirectional data transfer over TCP.""" + host_a, host_b = await create_tcp_host_pair() + + # Test data + data_a_to_b = b"Message from Host A to Host B via TCP" + data_b_to_a = b"Response from Host B to Host A via TCP" + + received_on_a = None + received_on_b = None + transfer_complete_a = trio.Event() + transfer_complete_b = trio.Event() + + async def handler_a(stream): + nonlocal received_on_a + try: + # Read data from B + received_on_a = await stream.read(len(data_b_to_a)) + print(f"šŸ…°ļø Host A received: {received_on_a}") + await stream.close() + transfer_complete_a.set() + except Exception as e: + print(f"Handler A error: {e}") + transfer_complete_a.set() + + async def handler_b(stream): + nonlocal received_on_b + try: + # Read data from A + received_on_b = await stream.read(len(data_a_to_b)) + print(f"šŸ…±ļø Host B received: {received_on_b}") + await stream.close() + transfer_complete_b.set() + except Exception as e: + print(f"Handler B error: {e}") + transfer_complete_b.set() + + # Set up handlers on both hosts + protocol_a_to_b = TProtocol("/test/tcp-a-to-b/1.0.0") + protocol_b_to_a = TProtocol("/test/tcp-b-to-a/1.0.0") + + host_a.set_stream_handler(protocol_b_to_a, handler_a) + host_b.set_stream_handler(protocol_a_to_b, handler_b) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + ): + # Get addresses + addrs_a = host_a.get_addrs() + addrs_b = host_b.get_addrs() + + assert addrs_a and addrs_b, "Both hosts should have addresses" + + # Extract TCP addresses + tcp_addr_a = next( + ( + addr + for addr in addrs_a + if "/tcp/" in str(addr) and "/ws" not in str(addr) + ), + None, + ) + tcp_addr_b = next( + ( + addr + for addr in addrs_b + if "/tcp/" in str(addr) and "/ws" not in str(addr) + ), + None, + ) + + assert tcp_addr_a and tcp_addr_b, ( + f"TCP addresses not found: A={addrs_a}, B={addrs_b}" + ) + print(f"šŸ”— Host A listening on: {tcp_addr_a}") + print(f"šŸ”— Host B listening on: {tcp_addr_b}") + + # Create peer infos + peer_info_a = info_from_p2p_addr(tcp_addr_a) + peer_info_b = info_from_p2p_addr(tcp_addr_b) + + # Establish connections + await host_b.connect(peer_info_a) + await host_a.connect(peer_info_b) + print("āœ… Bidirectional TCP connections established") + + # Send data A -> B + stream_a_to_b = await host_a.new_stream(peer_info_b.peer_id, [protocol_a_to_b]) + await stream_a_to_b.write(data_a_to_b) + print(f"šŸ“¤ A->B: {data_a_to_b}") + await stream_a_to_b.close() + + # Send data B -> A + stream_b_to_a = await host_b.new_stream(peer_info_a.peer_id, [protocol_b_to_a]) + await stream_b_to_a.write(data_b_to_a) + print(f"šŸ“¤ B->A: {data_b_to_a}") + await stream_b_to_a.close() + + # Wait for both transfers to complete + with trio.fail_after(5.0): + await transfer_complete_a.wait() + await transfer_complete_b.wait() + + # Verify bidirectional transfer + assert received_on_a == data_b_to_a, f"A received wrong data: {received_on_a}" + assert received_on_b == data_a_to_b, f"B received wrong data: {received_on_b}" + + print("āœ… TCP bidirectional data transfer successful!") + print(f" A->B: {data_a_to_b}") + print(f" B->A: {data_b_to_a}") + print(f" āœ“ A got: {received_on_a}") + print(f" āœ“ B got: {received_on_b}") + + +if __name__ == "__main__": + # Run tests directly + import logging + + logging.basicConfig(level=logging.INFO) + + print("🧪 Running TCP P2P Data Transfer Tests") + print("=" * 50) + + async def run_all_tcp_tests(): + try: + print("\n1. Testing basic TCP connection...") + await test_tcp_basic_connection() + except Exception as e: + print(f"āŒ Basic TCP connection test failed: {e}") + return + + try: + print("\n2. Testing TCP data transfer...") + await test_tcp_data_transfer() + except Exception as e: + print(f"āŒ TCP data transfer test failed: {e}") + return + + try: + print("\n3. Testing TCP large data transfer...") + await test_tcp_large_data_transfer() + except Exception as e: + print(f"āŒ TCP large data transfer test failed: {e}") + return + + try: + print("\n4. Testing TCP bidirectional transfer...") + await test_tcp_bidirectional_transfer() + except Exception as e: + print(f"āŒ TCP bidirectional transfer test failed: {e}") + return + + print("\n" + "=" * 50) + print("šŸ TCP P2P Tests Complete - All Tests PASSED!") + + trio.run(run_all_tcp_tests) diff --git a/examples/transport_integration_demo.py b/examples/transport_integration_demo.py new file mode 100644 index 00000000..424979e9 --- /dev/null +++ b/examples/transport_integration_demo.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +""" +Demo script showing the new transport integration capabilities in py-libp2p. + +This script demonstrates: +1. How to use the transport registry +2. How to create transports dynamically based on multiaddrs +3. How to register custom transports +4. How the new system automatically selects the right transport +""" + +import asyncio +import logging +from pathlib import Path +import sys + +# Add the libp2p directory to the path so we can import it +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import multiaddr + +from libp2p.transport import ( + create_transport, + create_transport_for_multiaddr, + get_supported_transport_protocols, + get_transport_registry, + register_transport, +) +from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.upgrader import TransportUpgrader + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def demo_transport_registry(): + """Demonstrate the transport registry functionality.""" + print("šŸ”§ Transport Registry Demo") + print("=" * 50) + + # Get the global registry + registry = get_transport_registry() + + # Show supported protocols + supported = get_supported_transport_protocols() + print(f"Supported transport protocols: {supported}") + + # Show registered transports + print("\nRegistered transports:") + for protocol in supported: + transport_class = registry.get_transport(protocol) + class_name = transport_class.__name__ if transport_class else "None" + print(f" {protocol}: {class_name}") + + print() + + +def demo_transport_factory(): + """Demonstrate the transport factory functions.""" + print("šŸ­ Transport Factory Demo") + print("=" * 50) + + # Create a dummy upgrader for WebSocket transport + upgrader = TransportUpgrader({}, {}) + + # Create transports using the factory function + try: + tcp_transport = create_transport("tcp") + print(f"āœ… Created TCP transport: {type(tcp_transport).__name__}") + + ws_transport = create_transport("ws", upgrader) + print(f"āœ… Created WebSocket transport: {type(ws_transport).__name__}") + + except Exception as e: + print(f"āŒ Error creating transport: {e}") + + print() + + +def demo_multiaddr_transport_selection(): + """Demonstrate automatic transport selection based on multiaddrs.""" + print("šŸŽÆ Multiaddr Transport Selection Demo") + print("=" * 50) + + # Create a dummy upgrader + upgrader = TransportUpgrader({}, {}) + + # Test different multiaddr types + test_addrs = [ + "/ip4/127.0.0.1/tcp/8080", + "/ip4/127.0.0.1/tcp/8080/ws", + "/ip6/::1/tcp/8080/ws", + "/dns4/example.com/tcp/443/ws", + ] + + for addr_str in test_addrs: + try: + maddr = multiaddr.Multiaddr(addr_str) + transport = create_transport_for_multiaddr(maddr, upgrader) + + if transport: + print(f"āœ… {addr_str} -> {type(transport).__name__}") + else: + print(f"āŒ {addr_str} -> No transport found") + + except Exception as e: + print(f"āŒ {addr_str} -> Error: {e}") + + print() + + +def demo_custom_transport_registration(): + """Demonstrate how to register custom transports.""" + print("šŸ”§ Custom Transport Registration Demo") + print("=" * 50) + + # Show current supported protocols + print(f"Before registration: {get_supported_transport_protocols()}") + + # Register a custom transport (using TCP as an example) + class CustomTCPTransport(TCP): + """Custom TCP transport for demonstration.""" + + def __init__(self): + super().__init__() + self.custom_flag = True + + # Register the custom transport + register_transport("custom_tcp", CustomTCPTransport) + + # Show updated supported protocols + print(f"After registration: {get_supported_transport_protocols()}") + + # Test creating the custom transport + try: + custom_transport = create_transport("custom_tcp") + print(f"āœ… Created custom transport: {type(custom_transport).__name__}") + # Check if it has the custom flag (type-safe way) + if hasattr(custom_transport, "custom_flag"): + flag_value = getattr(custom_transport, "custom_flag", "Not found") + print(f" Custom flag: {flag_value}") + else: + print(" Custom flag: Not found") + except Exception as e: + print(f"āŒ Error creating custom transport: {e}") + + print() + + +def demo_integration_with_libp2p(): + """Demonstrate how the new system integrates with libp2p.""" + print("šŸš€ Libp2p Integration Demo") + print("=" * 50) + + print("The new transport system integrates seamlessly with libp2p:") + print() + print("1. āœ… Automatic transport selection based on multiaddr") + print("2. āœ… Support for WebSocket (/ws) protocol") + print("3. āœ… Fallback to TCP for backward compatibility") + print("4. āœ… Easy registration of new transport protocols") + print("5. āœ… No changes needed to existing libp2p code") + print() + + print("Example usage in libp2p:") + print(" # This will automatically use WebSocket transport") + print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080/ws'])") + print() + print(" # This will automatically use TCP transport") + print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080'])") + print() + + print() + + +async def main(): + """Run all demos.""" + print("šŸŽ‰ Py-libp2p Transport Integration Demo") + print("=" * 60) + print() + + # Run all demos + demo_transport_registry() + demo_transport_factory() + demo_multiaddr_transport_selection() + demo_custom_transport_registration() + demo_integration_with_libp2p() + + print("šŸŽÆ Summary of New Features:") + print("=" * 40) + print("āœ… Transport Registry: Central registry for all transport implementations") + print("āœ… Dynamic Transport Selection: Automatic selection based on multiaddr") + print("āœ… WebSocket Support: Full /ws protocol support") + print("āœ… Extensible Architecture: Easy to add new transport protocols") + print("āœ… Backward Compatibility: Existing TCP code continues to work") + print("āœ… Factory Functions: Simple API for creating transports") + print() + print("šŸš€ The transport system is now ready for production use!") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nšŸ‘‹ Demo interrupted by user") + except Exception as e: + print(f"\nāŒ Demo failed with error: {e}") + import traceback + + traceback.print_exc() diff --git a/examples/websocket/test_tcp_echo.py b/examples/websocket/test_tcp_echo.py new file mode 100644 index 00000000..20728bf6 --- /dev/null +++ b/examples/websocket/test_tcp_echo.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +""" +Simple TCP echo demo to verify basic libp2p functionality. +""" + +import argparse +import logging +import traceback + +import multiaddr +import trio + +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.host.basic_host import BasicHost +from libp2p.network.swarm import Swarm +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.peer.peerstore import PeerStore +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +from libp2p.stream_muxer.yamux.yamux import Yamux +from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.upgrader import TransportUpgrader + +# Enable debug logging +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger("libp2p.tcp-example") + +# Simple echo protocol +ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") + + +async def echo_handler(stream): + """Simple echo handler that echoes back any data received.""" + try: + data = await stream.read(1024) + if data: + message = data.decode("utf-8", errors="replace") + print(f"šŸ“„ Received: {message}") + print(f"šŸ“¤ Echoing back: {message}") + await stream.write(data) + await stream.close() + except Exception as e: + logger.error(f"Echo handler error: {e}") + await stream.close() + + +def create_tcp_host(): + """Create a host with TCP transport.""" + # Create key pair and peer store + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + peer_store = PeerStore() + peer_store.add_key_pair(peer_id, key_pair) + + # Create transport upgrader with plaintext security + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + + # Create TCP transport + transport = TCP() + + # Create swarm and host + swarm = Swarm(peer_id, peer_store, upgrader, transport) + host = BasicHost(swarm) + + return host + + +async def run(port: int, destination: str) -> None: + localhost_ip = "0.0.0.0" + + if not destination: + # Create first host (listener) with TCP transport + listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}") + + try: + host = create_tcp_host() + logger.debug("Created TCP host") + + # Set up echo handler + host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler) + + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): + # Start the peer-store cleanup task + nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + + # Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client + # connections + addrs = host.get_addrs() + logger.debug(f"Host addresses: {addrs}") + if not addrs: + print("āŒ Error: No addresses found for the host") + return + + server_addr = str(addrs[0]) + client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") + + print("🌐 TCP Server Started Successfully!") + print("=" * 50) + print(f"šŸ“ Server Address: {client_addr}") + print("šŸ”§ Protocol: /echo/1.0.0") + print("šŸš€ Transport: TCP") + print() + print("šŸ“‹ To test the connection, run this in another terminal:") + print(f" python test_tcp_echo.py -d {client_addr}") + print() + print("ā³ Waiting for incoming TCP connections...") + print("─" * 50) + + await trio.sleep_forever() + + except Exception as e: + print(f"āŒ Error creating TCP server: {e}") + traceback.print_exc() + return + + else: + # Create second host (dialer) with TCP transport + listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}") + + try: + # Create a single host for client operations + host = create_tcp_host() + + # Start the host for client operations + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): + # Start the peer-store cleanup task + nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + print("šŸ”Œ TCP Client Starting...") + print("=" * 40) + print(f"šŸŽÆ Target Peer: {info.peer_id}") + print(f"šŸ“ Target Address: {destination}") + print() + + try: + print("šŸ”— Connecting to TCP server...") + await host.connect(info) + print("āœ… Successfully connected to TCP server!") + except Exception as e: + error_msg = str(e) + print("\nāŒ Connection Failed!") + print(f" Peer ID: {info.peer_id}") + print(f" Address: {destination}") + print(f" Error: {error_msg}") + return + + # Create a stream and send test data + try: + stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID]) + except Exception as e: + print(f"āŒ Failed to create stream: {e}") + return + + try: + print("šŸš€ Starting Echo Protocol Test...") + print("─" * 40) + + # Send test data + test_message = b"Hello TCP Transport!" + print(f"šŸ“¤ Sending message: {test_message.decode('utf-8')}") + await stream.write(test_message) + + # Read response + print("ā³ Waiting for server response...") + response = await stream.read(1024) + print(f"šŸ“„ Received response: {response.decode('utf-8')}") + + await stream.close() + + print("─" * 40) + if response == test_message: + print("šŸŽ‰ Echo test successful!") + print("āœ… TCP transport is working perfectly!") + else: + print("āŒ Echo test failed!") + + except Exception as e: + print(f"Echo protocol error: {e}") + traceback.print_exc() + + print("āœ… TCP demo completed successfully!") + + except Exception as e: + print(f"āŒ Error creating TCP client: {e}") + traceback.print_exc() + return + + +def main() -> None: + description = "Simple TCP echo demo for libp2p" + parser = argparse.ArgumentParser(description=description) + parser.add_argument("-p", "--port", default=0, type=int, help="source port number") + parser.add_argument( + "-d", "--destination", type=str, help="destination multiaddr string" + ) + + args = parser.parse_args() + + try: + trio.run(run, args.port, args.destination) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + main() diff --git a/examples/websocket/test_websocket_transport.py b/examples/websocket/test_websocket_transport.py new file mode 100644 index 00000000..86353ef9 --- /dev/null +++ b/examples/websocket/test_websocket_transport.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +Simple test script to verify WebSocket transport functionality. +""" + +import asyncio +import logging +from pathlib import Path +import sys + +# Add the libp2p directory to the path so we can import it +sys.path.insert(0, str(Path(__file__).parent)) + +import multiaddr + +from libp2p.transport import create_transport, create_transport_for_multiaddr +from libp2p.transport.upgrader import TransportUpgrader + +# Set up logging +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +async def test_websocket_transport(): + """Test basic WebSocket transport functionality.""" + print("🧪 Testing WebSocket Transport Functionality") + print("=" * 50) + + # Create a dummy upgrader + upgrader = TransportUpgrader({}, {}) + + # Test creating WebSocket transport + try: + ws_transport = create_transport("ws", upgrader) + print(f"āœ… WebSocket transport created: {type(ws_transport).__name__}") + + # Test creating transport from multiaddr + ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader) + print( + f"āœ… WebSocket transport from multiaddr: " + f"{type(ws_transport_from_maddr).__name__}" + ) + + # Test creating listener + handler_called = False + + async def test_handler(conn): + nonlocal handler_called + handler_called = True + print(f"āœ… Connection handler called with: {type(conn).__name__}") + await conn.close() + + listener = ws_transport.create_listener(test_handler) + print(f"āœ… WebSocket listener created: {type(listener).__name__}") + + # Test that the transport can be used + print( + f"āœ… WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}" + ) + print( + f"āœ… WebSocket transport supports listening: " + f"{hasattr(ws_transport, 'create_listener')}" + ) + + print("\nšŸŽÆ WebSocket Transport Test Results:") + print("āœ… Transport creation: PASS") + print("āœ… Multiaddr parsing: PASS") + print("āœ… Listener creation: PASS") + print("āœ… Interface compliance: PASS") + + except Exception as e: + print(f"āŒ WebSocket transport test failed: {e}") + import traceback + + traceback.print_exc() + return False + + return True + + +async def test_transport_registry(): + """Test the transport registry functionality.""" + print("\nšŸ”§ Testing Transport Registry") + print("=" * 30) + + from libp2p.transport import ( + get_supported_transport_protocols, + get_transport_registry, + ) + + registry = get_transport_registry() + supported = get_supported_transport_protocols() + + print(f"Supported protocols: {supported}") + + # Test getting transports + for protocol in supported: + transport_class = registry.get_transport(protocol) + class_name = transport_class.__name__ if transport_class else "None" + print(f" {protocol}: {class_name}") + + # Test creating transports through registry + upgrader = TransportUpgrader({}, {}) + + for protocol in supported: + try: + transport = registry.create_transport(protocol, upgrader) + if transport: + print(f"āœ… {protocol}: Created successfully") + else: + print(f"āŒ {protocol}: Failed to create") + except Exception as e: + print(f"āŒ {protocol}: Error - {e}") + + +async def main(): + """Run all tests.""" + print("šŸš€ WebSocket Transport Integration Test Suite") + print("=" * 60) + print() + + # Run tests + success = await test_websocket_transport() + await test_transport_registry() + + print("\n" + "=" * 60) + if success: + print("šŸŽ‰ All tests passed! WebSocket transport is working correctly.") + else: + print("āŒ Some tests failed. Check the output above for details.") + + print("\nšŸš€ WebSocket transport is ready for use in py-libp2p!") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nšŸ‘‹ Test interrupted by user") + except Exception as e: + print(f"\nāŒ Test failed with error: {e}") + import traceback + + traceback.print_exc() diff --git a/examples/websocket/websocket_demo.py b/examples/websocket/websocket_demo.py new file mode 100644 index 00000000..bd13a881 --- /dev/null +++ b/examples/websocket/websocket_demo.py @@ -0,0 +1,448 @@ +import argparse +import logging +import signal +import sys +import traceback + +import multiaddr +import trio + +from libp2p.abc import INotifee +from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.host.basic_host import BasicHost +from libp2p.network.swarm import Swarm +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.peer.peerstore import PeerStore +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, +) +from libp2p.stream_muxer.yamux.yamux import Yamux +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.transport import WebsocketTransport + +# Enable debug logging +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger("libp2p.websocket-example") + + +# Suppress KeyboardInterrupt by handling SIGINT directly +def signal_handler(signum, frame): + print("āœ… Clean exit completed.") + sys.exit(0) + + +signal.signal(signal.SIGINT, signal_handler) + +# Simple echo protocol +ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") + + +async def echo_handler(stream): + """Simple echo handler that echoes back any data received.""" + try: + data = await stream.read(1024) + if data: + message = data.decode("utf-8", errors="replace") + print(f"šŸ“„ Received: {message}") + print(f"šŸ“¤ Echoing back: {message}") + await stream.write(data) + await stream.close() + except Exception as e: + logger.error(f"Echo handler error: {e}") + await stream.close() + + +def create_websocket_host(listen_addrs=None, use_plaintext=False): + """Create a host with WebSocket transport.""" + # Create key pair and peer store + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + peer_store = PeerStore() + peer_store.add_key_pair(peer_id, key_pair) + + if use_plaintext: + # Create transport upgrader with plaintext security + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + else: + # Create separate Ed25519 key for Noise protocol + noise_key_pair = create_ed25519_key_pair() + + # Create Noise transport + noise_transport = NoiseTransport( + libp2p_keypair=key_pair, + noise_privkey=noise_key_pair.private_key, + early_data=None, + with_noise_pipes=False, + ) + + # Create transport upgrader with Noise security + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(NOISE_PROTOCOL_ID): noise_transport + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + + # Create WebSocket transport + transport = WebsocketTransport(upgrader) + + # Create swarm and host + swarm = Swarm(peer_id, peer_store, upgrader, transport) + host = BasicHost(swarm) + + return host + + +async def run(port: int, destination: str, use_plaintext: bool = False) -> None: + localhost_ip = "0.0.0.0" + + if not destination: + # Create first host (listener) with WebSocket transport + listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws") + + try: + host = create_websocket_host(use_plaintext=use_plaintext) + logger.debug(f"Created host with use_plaintext={use_plaintext}") + + # Set up echo handler + host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler) + + # Add connection event handlers for debugging + class DebugNotifee(INotifee): + async def opened_stream(self, network, stream): + pass + + async def closed_stream(self, network, stream): + pass + + async def connected(self, network, conn): + print( + f"šŸ”— New libp2p connection established: " + f"{conn.muxed_conn.peer_id}" + ) + if hasattr(conn.muxed_conn, "get_security_protocol"): + security = conn.muxed_conn.get_security_protocol() + else: + security = "Unknown" + + print(f" Security: {security}") + + async def disconnected(self, network, conn): + print(f"šŸ”Œ libp2p connection closed: {conn.muxed_conn.peer_id}") + + async def listen(self, network, multiaddr): + pass + + async def listen_close(self, network, multiaddr): + pass + + host.get_network().register_notifee(DebugNotifee()) + + # Create a cancellation token for clean shutdown + cancel_scope = trio.CancelScope() + + async def signal_handler(): + with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as ( + signal_receiver + ): + async for sig in signal_receiver: + print(f"\nšŸ›‘ Received signal {sig}") + print("āœ… Shutting down WebSocket server...") + cancel_scope.cancel() + return + + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): + # Start the peer-store cleanup task + nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + + # Start the signal handler + nursery.start_soon(signal_handler) + + # Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client + # connections + addrs = host.get_addrs() + logger.debug(f"Host addresses: {addrs}") + if not addrs: + print("āŒ Error: No addresses found for the host") + print("Debug: host.get_addrs() returned empty list") + return + + server_addr = str(addrs[0]) + client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") + + print("🌐 WebSocket Server Started Successfully!") + print("=" * 50) + print(f"šŸ“ Server Address: {client_addr}") + print("šŸ”§ Protocol: /echo/1.0.0") + print("šŸš€ Transport: WebSocket (/ws)") + print() + print("šŸ“‹ To test the connection, run this in another terminal:") + plaintext_flag = " --plaintext" if use_plaintext else "" + print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}") + print() + print("ā³ Waiting for incoming WebSocket connections...") + print("─" * 50) + + # Add a custom handler to show connection events + async def custom_echo_handler(stream): + peer_id = stream.muxed_conn.peer_id + print("\nšŸ”— New WebSocket Connection!") + print(f" Peer ID: {peer_id}") + print(" Protocol: /echo/1.0.0") + + # Show remote address in multiaddr format + try: + remote_address = stream.get_remote_address() + if remote_address: + print(f" Remote: {remote_address}") + except Exception: + print(" Remote: Unknown") + + print(" ─" * 40) + + # Call the original handler + await echo_handler(stream) + + print(" ─" * 40) + print(f"āœ… Echo request completed for peer: {peer_id}") + print() + + # Replace the handler with our custom one + host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler) + + # Wait indefinitely or until cancelled + with cancel_scope: + await trio.sleep_forever() + + except Exception as e: + print(f"āŒ Error creating WebSocket server: {e}") + traceback.print_exc() + return + + else: + # Create second host (dialer) with WebSocket transport + listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws") + + try: + # Create a single host for client operations + host = create_websocket_host(use_plaintext=use_plaintext) + + # Start the host for client operations + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): + # Start the peer-store cleanup task + nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + + # Add connection event handlers for debugging + class ClientDebugNotifee(INotifee): + async def opened_stream(self, network, stream): + pass + + async def closed_stream(self, network, stream): + pass + + async def connected(self, network, conn): + print( + f"šŸ”— Client: libp2p connection established: " + f"{conn.muxed_conn.peer_id}" + ) + + async def disconnected(self, network, conn): + print( + f"šŸ”Œ Client: libp2p connection closed: " + f"{conn.muxed_conn.peer_id}" + ) + + async def listen(self, network, multiaddr): + pass + + async def listen_close(self, network, multiaddr): + pass + + host.get_network().register_notifee(ClientDebugNotifee()) + + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + print("šŸ”Œ WebSocket Client Starting...") + print("=" * 40) + print(f"šŸŽÆ Target Peer: {info.peer_id}") + print(f"šŸ“ Target Address: {destination}") + print() + + try: + print("šŸ”— Connecting to WebSocket server...") + print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}") + await host.connect(info) + print("āœ… Successfully connected to WebSocket server!") + except Exception as e: + error_msg = str(e) + print("\nāŒ Connection Failed!") + print(f" Peer ID: {info.peer_id}") + print(f" Address: {destination}") + print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}") + print(f" Error: {error_msg}") + print(f" Error type: {type(e).__name__}") + + # Add more detailed error information for debugging + if hasattr(e, "__cause__") and e.__cause__: + print(f" Root cause: {e.__cause__}") + print(f" Root cause type: {type(e.__cause__).__name__}") + + print() + print("šŸ’” Troubleshooting:") + print(" • Make sure the WebSocket server is running") + print(" • Check that the server address is correct") + print(" • Verify the server is listening on the right port") + print( + " • Ensure both client and server use the same sec protocol" + ) + if not use_plaintext: + print(" • Noise over WebSocket may have compatibility issues") + return + + # Create a stream and send test data + try: + stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID]) + except Exception as e: + print(f"āŒ Failed to create stream: {e}") + return + + try: + print("šŸš€ Starting Echo Protocol Test...") + print("─" * 40) + + # Send test data + test_message = b"Hello WebSocket Transport!" + print(f"šŸ“¤ Sending message: {test_message.decode('utf-8')}") + await stream.write(test_message) + + # Read response + print("ā³ Waiting for server response...") + response = await stream.read(1024) + print(f"šŸ“„ Received response: {response.decode('utf-8')}") + + await stream.close() + + print("─" * 40) + if response == test_message: + print("šŸŽ‰ Echo test successful!") + print("āœ… WebSocket transport is working perfectly!") + print("āœ… Client completed successfully, exiting.") + else: + print("āŒ Echo test failed!") + print(" Response doesn't match sent data.") + print(f" Sent: {test_message}") + print(f" Received: {response}") + + except Exception as e: + error_msg = str(e) + print(f"Echo protocol error: {error_msg}") + traceback.print_exc() + finally: + # Ensure stream is closed + try: + if stream: + # Check if stream has is_closed method and use it + has_is_closed = hasattr(stream, "is_closed") and callable( + getattr(stream, "is_closed") + ) + if has_is_closed: + # type: ignore[attr-defined] + if not await stream.is_closed(): + await stream.close() + else: + # Fallback: just try to close the stream + await stream.close() + except Exception: + pass + + # host.run() context manager handles cleanup automatically + print() + print("šŸŽ‰ WebSocket Demo Completed Successfully!") + print("=" * 50) + print("āœ… WebSocket transport is working perfectly!") + print("āœ… Echo protocol communication successful!") + print("āœ… libp2p integration verified!") + print() + print("šŸš€ Your WebSocket transport is ready for production use!") + + # Add a small delay to ensure all cleanup is complete + await trio.sleep(0.1) + + except Exception as e: + print(f"āŒ Error creating WebSocket client: {e}") + traceback.print_exc() + return + + +def main() -> None: + description = """ + This program demonstrates the libp2p WebSocket transport. + First run + 'python websocket_demo.py -p [--plaintext]' to start a WebSocket server. + Then run + 'python websocket_demo.py -d [--plaintext]' + where is the multiaddress shown by the server. + + By default, this example uses Noise encryption for secure communication. + Use --plaintext for testing with unencrypted communication + (not recommended for production). + """ + + example_maddr = ( + "/ip4/127.0.0.1/tcp/8888/ws/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" + ) + + parser = argparse.ArgumentParser(description=description) + parser.add_argument("-p", "--port", default=0, type=int, help="source port number") + parser.add_argument( + "-d", + "--destination", + type=str, + help=f"destination multiaddr string, e.g. {example_maddr}", + ) + parser.add_argument( + "--plaintext", + action="store_true", + help=( + "use plaintext security instead of Noise encryption " + "(not recommended for production)" + ), + ) + + args = parser.parse_args() + + # Determine security mode: use Noise by default, + # plaintext if --plaintext is specified + use_plaintext = args.plaintext + + try: + trio.run(run, args.port, args.destination, use_plaintext) + except KeyboardInterrupt: + # This is expected when Ctrl+C is pressed + # The signal handler already printed the shutdown message + print("āœ… Clean exit completed.") + return + except Exception as e: + print(f"āŒ Unexpected error: {e}") + return + + +if __name__ == "__main__": + main() diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 1fbb7a62..11378aca 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,5 +1,12 @@ """Libp2p Python implementation.""" +import logging +import ssl + +from libp2p.transport.quic.utils import is_quic_multiaddr +from typing import Any +from libp2p.transport.quic.transport import QUICTransport +from libp2p.transport.quic.config import QUICTransportConfig from collections.abc import ( Mapping, Sequence, @@ -18,6 +25,7 @@ from libp2p.abc import ( IPeerRouting, IPeerStore, ISecureTransport, + ITransport, ) from libp2p.crypto.keys import ( KeyPair, @@ -38,10 +46,12 @@ from libp2p.host.routed_host import ( RoutedHost, ) from libp2p.network.swarm import ( - ConnectionConfig, - RetryConfig, Swarm, ) +from libp2p.network.config import ( + ConnectionConfig, + RetryConfig +) from libp2p.peer.id import ( ID, ) @@ -72,6 +82,10 @@ from libp2p.transport.tcp.tcp import ( from libp2p.transport.upgrader import ( TransportUpgrader, ) +from libp2p.transport.transport_registry import ( + create_transport_for_multiaddr, + get_supported_transport_protocols, +) from libp2p.utils.logging import ( setup_logging, ) @@ -87,6 +101,7 @@ MUXER_YAMUX = "YAMUX" MUXER_MPLEX = "MPLEX" DEFAULT_NEGOTIATE_TIMEOUT = 5 +logger = logging.getLogger(__name__) def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None: """ @@ -162,9 +177,13 @@ def new_swarm( peerstore_opt: IPeerStore | None = None, muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, + enable_quic: bool = False, retry_config: Optional["RetryConfig"] = None, - connection_config: Optional["ConnectionConfig"] = None, + connection_config: ConnectionConfig | QUICTransportConfig | None = None, + tls_client_config: ssl.SSLContext | None = None, + tls_server_config: ssl.SSLContext | None = None, ) -> INetworkService: + logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}") """ Create a swarm instance based on the parameters. @@ -174,6 +193,8 @@ def new_swarm( :param peerstore_opt: optional peerstore :param muxer_preference: optional explicit muxer preference :param listen_addrs: optional list of multiaddrs to listen on + :param enable_quic: enable quic for transport + :param quic_transport_opt: options for transport :return: return a default swarm instance Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer @@ -186,16 +207,48 @@ def new_swarm( id_opt = generate_peer_id_from(key_pair) + transport: TCP | QUICTransport | ITransport + quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None + if listen_addrs is None: - transport = TCP() - else: - addr = listen_addrs[0] - if addr.__contains__("tcp"): - transport = TCP() - elif addr.__contains__("quic"): - raise ValueError("QUIC not yet supported") + if enable_quic: + transport = QUICTransport(key_pair.private_key, config=quic_transport_opt) else: - raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}") + transport = TCP() + else: + # Use transport registry to select the appropriate transport + from libp2p.transport.transport_registry import create_transport_for_multiaddr + + # Create a temporary upgrader for transport selection + # We'll create the real upgrader later with the proper configuration + temp_upgrader = TransportUpgrader( + secure_transports_by_protocol={}, + muxer_transports_by_protocol={} + ) + + addr = listen_addrs[0] + logger.debug(f"new_swarm: Creating transport for address: {addr}") + transport_maybe = create_transport_for_multiaddr( + addr, + temp_upgrader, + private_key=key_pair.private_key, + config=quic_transport_opt, + tls_client_config=tls_client_config, + tls_server_config=tls_server_config + ) + + if transport_maybe is None: + raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}") + + transport = transport_maybe + logger.debug(f"new_swarm: Created transport: {type(transport)}") + + # If enable_quic is True but we didn't get a QUIC transport, force QUIC + if enable_quic and not isinstance(transport, QUICTransport): + logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})") + transport = QUICTransport(key_pair.private_key, config=quic_transport_opt) + + logger.debug(f"new_swarm: Final transport type: {type(transport)}") # Generate X25519 keypair for Noise noise_key_pair = create_new_x25519_key_pair() @@ -236,6 +289,7 @@ def new_swarm( muxer_transports_by_protocol=muxer_transports_by_protocol, ) + peerstore = peerstore_opt or PeerStore() # Store our key pair in peerstore peerstore.add_key_pair(id_opt, key_pair) @@ -261,6 +315,10 @@ def new_host( enable_mDNS: bool = False, bootstrap: list[str] | None = None, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, + enable_quic: bool = False, + quic_transport_opt: QUICTransportConfig | None = None, + tls_client_config: ssl.SSLContext | None = None, + tls_server_config: ssl.SSLContext | None = None, ) -> IHost: """ Create a new libp2p host based on the given parameters. @@ -274,15 +332,27 @@ def new_host( :param listen_addrs: optional list of multiaddrs to listen on :param enable_mDNS: whether to enable mDNS discovery :param bootstrap: optional list of bootstrap peer addresses as strings + :param enable_quic: optinal choice to use QUIC for transport + :param quic_transport_opt: optional configuration for quic transport + :param tls_client_config: optional TLS client configuration for WebSocket transport + :param tls_server_config: optional TLS server configuration for WebSocket transport :return: return a host instance """ + + if not enable_quic and quic_transport_opt is not None: + logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config") + swarm = new_swarm( + enable_quic=enable_quic, key_pair=key_pair, muxer_opt=muxer_opt, sec_opt=sec_opt, peerstore_opt=peerstore_opt, muxer_preference=muxer_preference, listen_addrs=listen_addrs, + connection_config=quic_transport_opt if enable_quic else None, + tls_client_config=tls_client_config, + tls_server_config=tls_server_config ) if disc_opt is not None: diff --git a/libp2p/custom_types.py b/libp2p/custom_types.py index 0b844133..d8e1a1d9 100644 --- a/libp2p/custom_types.py +++ b/libp2p/custom_types.py @@ -5,17 +5,17 @@ from collections.abc import ( ) from typing import TYPE_CHECKING, NewType, Union, cast +from libp2p.transport.quic.stream import QUICStream + if TYPE_CHECKING: - from libp2p.abc import ( - IMuxedConn, - INetStream, - ISecureTransport, - ) + from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport + from libp2p.transport.quic.connection import QUICConnection else: IMuxedConn = cast(type, object) INetStream = cast(type, object) ISecureTransport = cast(type, object) - + IMuxedStream = cast(type, object) + QUICConnection = cast(type, object) from libp2p.io.abc import ( ReadWriteCloser, @@ -37,3 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] UnsubscribeFn = Callable[[], Awaitable[None]] +TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]] +TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]] +MessageID = NewType("MessageID", str) diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index e370a3de..6b7eb1d3 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -213,7 +213,6 @@ class BasicHost(IHost): self, peer_id: ID, protocol_ids: Sequence[TProtocol], - negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ) -> INetStream: """ :param peer_id: peer_id that host is connecting @@ -227,7 +226,7 @@ class BasicHost(IHost): selected_protocol = await self.multiselect_client.select_one_of( list(protocol_ids), MultiselectCommunicator(net_stream), - negotitate_timeout, + self.negotiate_timeout, ) except MultiselectClientError as error: logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) diff --git a/libp2p/network/config.py b/libp2p/network/config.py new file mode 100644 index 00000000..e0fad33c --- /dev/null +++ b/libp2p/network/config.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass + + +@dataclass +class RetryConfig: + """ + Configuration for retry logic with exponential backoff. + + This configuration controls how connection attempts are retried when they fail. + The retry mechanism uses exponential backoff with jitter to prevent thundering + herd problems in distributed systems. + + Attributes: + max_retries: Maximum number of retry attempts before giving up. + Default: 3 attempts + initial_delay: Initial delay in seconds before the first retry. + Default: 0.1 seconds (100ms) + max_delay: Maximum delay cap in seconds to prevent excessive wait times. + Default: 30.0 seconds + backoff_multiplier: Multiplier for exponential backoff (each retry multiplies + the delay by this factor). Default: 2.0 (doubles each time) + jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays + and prevent synchronized retries. Default: 0.1 (10% jitter) + + """ + + max_retries: int = 3 + initial_delay: float = 0.1 + max_delay: float = 30.0 + backoff_multiplier: float = 2.0 + jitter_factor: float = 0.1 + + +@dataclass +class ConnectionConfig: + """ + Configuration for multi-connection support. + + This configuration controls how multiple connections per peer are managed, + including connection limits, timeouts, and load balancing strategies. + + Attributes: + max_connections_per_peer: Maximum number of connections allowed to a single + peer. Default: 3 connections + connection_timeout: Timeout in seconds for establishing new connections. + Default: 30.0 seconds + load_balancing_strategy: Strategy for distributing streams across connections. + Options: "round_robin" (default) or "least_loaded" + + """ + + max_connections_per_peer: int = 3 + connection_timeout: float = 30.0 + load_balancing_strategy: str = "round_robin" # or "least_loaded" + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + if not ( + self.load_balancing_strategy == "round_robin" + or self.load_balancing_strategy == "least_loaded" + ): + raise ValueError( + "Load balancing strategy can only be 'round_robin' or 'least_loaded'" + ) + + if self.max_connections_per_peer < 1: + raise ValueError("Max connection per peer should be atleast 1") + + if self.connection_timeout < 0: + raise ValueError("Connection timeout should be positive") diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index b54fdda4..49daab9c 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import ( MuxedStreamError, MuxedStreamReset, ) +from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError from .exceptions import ( StreamClosed, @@ -170,7 +171,7 @@ class NetStream(INetStream): elif self.__stream_state == StreamState.OPEN: self.__stream_state = StreamState.CLOSE_READ raise StreamEOF() from error - except MuxedStreamReset as error: + except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error: async with self._state_lock: if self.__stream_state in [ StreamState.OPEN, @@ -199,7 +200,12 @@ class NetStream(INetStream): try: await self.muxed_stream.write(data) - except (MuxedStreamClosed, MuxedStreamError) as error: + except ( + MuxedStreamClosed, + MuxedStreamError, + QUICStreamClosedError, + QUICStreamResetError, + ) as error: async with self._state_lock: if self.__stream_state == StreamState.OPEN: self.__stream_state = StreamState.CLOSE_WRITE diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 5a3ce7bb..94d9c7a3 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -2,9 +2,9 @@ from collections.abc import ( Awaitable, Callable, ) -from dataclasses import dataclass import logging import random +from typing import cast from multiaddr import ( Multiaddr, @@ -27,6 +27,7 @@ from libp2p.custom_types import ( from libp2p.io.abc import ( ReadWriteCloser, ) +from libp2p.network.config import ConnectionConfig, RetryConfig from libp2p.peer.id import ( ID, ) @@ -41,6 +42,9 @@ from libp2p.transport.exceptions import ( OpenConnectionError, SecurityUpgradeFailure, ) +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.transport import QUICTransport from libp2p.transport.upgrader import ( TransportUpgrader, ) @@ -61,59 +65,6 @@ from .exceptions import ( logger = logging.getLogger("libp2p.network.swarm") -@dataclass -class RetryConfig: - """ - Configuration for retry logic with exponential backoff. - - This configuration controls how connection attempts are retried when they fail. - The retry mechanism uses exponential backoff with jitter to prevent thundering - herd problems in distributed systems. - - Attributes: - max_retries: Maximum number of retry attempts before giving up. - Default: 3 attempts - initial_delay: Initial delay in seconds before the first retry. - Default: 0.1 seconds (100ms) - max_delay: Maximum delay cap in seconds to prevent excessive wait times. - Default: 30.0 seconds - backoff_multiplier: Multiplier for exponential backoff (each retry multiplies - the delay by this factor). Default: 2.0 (doubles each time) - jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays - and prevent synchronized retries. Default: 0.1 (10% jitter) - - """ - - max_retries: int = 3 - initial_delay: float = 0.1 - max_delay: float = 30.0 - backoff_multiplier: float = 2.0 - jitter_factor: float = 0.1 - - -@dataclass -class ConnectionConfig: - """ - Configuration for multi-connection support. - - This configuration controls how multiple connections per peer are managed, - including connection limits, timeouts, and load balancing strategies. - - Attributes: - max_connections_per_peer: Maximum number of connections allowed to a single - peer. Default: 3 connections - connection_timeout: Timeout in seconds for establishing new connections. - Default: 30.0 seconds - load_balancing_strategy: Strategy for distributing streams across connections. - Options: "round_robin" (default) or "least_loaded" - - """ - - max_connections_per_peer: int = 3 - connection_timeout: float = 30.0 - load_balancing_strategy: str = "round_robin" # or "least_loaded" - - def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn: async def stream_handler(stream: INetStream) -> None: await network.get_manager().wait_finished() @@ -126,8 +77,7 @@ class Swarm(Service, INetworkService): peerstore: IPeerStore upgrader: TransportUpgrader transport: ITransport - # Enhanced: Support for multiple connections per peer - connections: dict[ID, list[INetConn]] # Multiple connections per peer + connections: dict[ID, list[INetConn]] listeners: dict[str, IListener] common_stream_handler: StreamHandlerFn listener_nursery: trio.Nursery | None @@ -137,7 +87,7 @@ class Swarm(Service, INetworkService): # Enhanced: New configuration retry_config: RetryConfig - connection_config: ConnectionConfig + connection_config: ConnectionConfig | QUICTransportConfig _round_robin_index: dict[ID, int] def __init__( @@ -147,7 +97,7 @@ class Swarm(Service, INetworkService): upgrader: TransportUpgrader, transport: ITransport, retry_config: RetryConfig | None = None, - connection_config: ConnectionConfig | None = None, + connection_config: ConnectionConfig | QUICTransportConfig | None = None, ): self.self_id = peer_id self.peerstore = peerstore @@ -178,6 +128,11 @@ class Swarm(Service, INetworkService): # Create a nursery for listener tasks. self.listener_nursery = nursery self.event_listener_nursery_created.set() + + if isinstance(self.transport, QUICTransport): + self.transport.set_background_nursery(nursery) + self.transport.set_swarm(self) + try: await self.manager.wait_finished() finally: @@ -370,6 +325,7 @@ class Swarm(Service, INetworkService): # Dial peer (connection to peer does not yet exist) # Transport dials peer (gets back a raw conn) try: + addr = Multiaddr(f"{addr}/p2p/{peer_id}") raw_conn = await self.transport.dial(addr) except OpenConnectionError as error: logger.debug("fail to dial peer %s over base transport", peer_id) @@ -377,6 +333,15 @@ class Swarm(Service, INetworkService): f"fail to open connection to peer {peer_id}" ) from error + if isinstance(self.transport, QUICTransport) and isinstance( + raw_conn, IMuxedConn + ): + logger.info( + "Skipping upgrade for QUIC, QUIC connections are already multiplexed" + ) + swarm_conn = await self.add_conn(raw_conn) + return swarm_conn + logger.debug("dialed peer %s over base transport", peer_id) # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure @@ -402,9 +367,7 @@ class Swarm(Service, INetworkService): logger.debug("upgraded mux for peer %s", peer_id) swarm_conn = await self.add_conn(muxed_conn) - logger.debug("successfully dialed peer %s", peer_id) - return swarm_conn async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn: @@ -427,7 +390,6 @@ class Swarm(Service, INetworkService): :return: net stream instance """ logger.debug("attempting to open a stream to peer %s", peer_id) - # Get existing connections or dial new ones connections = self.get_connections(peer_id) if not connections: @@ -436,6 +398,10 @@ class Swarm(Service, INetworkService): # Load balancing strategy at interface level connection = self._select_connection(connections, peer_id) + if isinstance(self.transport, QUICTransport) and connection is not None: + conn = cast(SwarmConn, connection) + return await conn.new_stream() + try: net_stream = await connection.new_stream() logger.debug("successfully opened a stream to peer %s", peer_id) @@ -515,18 +481,38 @@ class Swarm(Service, INetworkService): - Call listener listen with the multiaddr - Map multiaddr to listener """ + logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}") # We need to wait until `self.listener_nursery` is created. + logger.debug("Starting to listen") await self.event_listener_nursery_created.wait() success_count = 0 for maddr in multiaddrs: + logger.debug(f"Swarm.listen processing multiaddr: {maddr}") if str(maddr) in self.listeners: + logger.debug(f"Swarm.listen: listener already exists for {maddr}") success_count += 1 continue async def conn_handler( read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr ) -> None: + # No need to upgrade QUIC Connection + if isinstance(self.transport, QUICTransport): + try: + quic_conn = cast(QUICConnection, read_write_closer) + await self.add_conn(quic_conn) + peer_id = quic_conn.peer_id + logger.debug( + f"successfully opened quic connection to peer {peer_id}" + ) + # NOTE: This is a intentional barrier to prevent from the + # handler exiting and closing the connection. + await self.manager.wait_finished() + except Exception: + await read_write_closer.close() + return + raw_conn = RawConnection(read_write_closer, False) # Per, https://discuss.libp2p.io/t/multistream-security/130, we first @@ -562,13 +548,18 @@ class Swarm(Service, INetworkService): try: # Success + logger.debug(f"Swarm.listen: creating listener for {maddr}") listener = self.transport.create_listener(conn_handler) + logger.debug(f"Swarm.listen: listener created for {maddr}") self.listeners[str(maddr)] = listener # TODO: `listener.listen` is not bounded with nursery. If we want to be # I/O agnostic, we should change the API. if self.listener_nursery is None: raise SwarmException("swarm instance hasn't been run") + assert self.listener_nursery is not None # For type checker + logger.debug(f"Swarm.listen: calling listener.listen for {maddr}") await listener.listen(maddr, self.listener_nursery) + logger.debug(f"Swarm.listen: listener.listen completed for {maddr}") # Call notifiers since event occurred await self.notify_listen(maddr) @@ -660,9 +651,10 @@ class Swarm(Service, INetworkService): muxed_conn, self, ) - + logger.debug("Swarm::add_conn | starting muxed connection") self.manager.run_task(muxed_conn.start) await muxed_conn.event_started.wait() + logger.debug("Swarm::add_conn | starting swarm connection") self.manager.run_task(swarm_conn.start) await swarm_conn.event_started.wait() diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index 98a8129c..dff5b339 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -1,3 +1,5 @@ +from builtins import AssertionError + from libp2p.abc import ( IMultiselectCommunicator, ) @@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator): msg_bytes = encode_delim(msg_str.encode()) try: await self.read_writer.write(msg_bytes) - except IOException as error: + # Handle for connection close during ongoing negotiation in QUIC + except (IOException, AssertionError, ValueError) as error: raise MultiselectCommunicatorError( "fail to write to multiselect communicator" ) from error diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index a4c8c463..f0e84641 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -1,6 +1,3 @@ -from ast import ( - literal_eval, -) from collections import ( defaultdict, ) @@ -22,6 +19,7 @@ from libp2p.abc import ( IPubsubRouter, ) from libp2p.custom_types import ( + MessageID, TProtocol, ) from libp2p.peer.id import ( @@ -56,6 +54,10 @@ from .pb import ( from .pubsub import ( Pubsub, ) +from .utils import ( + parse_message_id_safe, + safe_parse_message_id, +) PROTOCOL_ID = TProtocol("/meshsub/1.0.0") PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0") @@ -306,7 +308,8 @@ class GossipSub(IPubsubRouter, Service): floodsub_peers: set[ID] = { peer_id for peer_id in self.pubsub.peer_topics[topic] - if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID + if peer_id in self.peer_protocol + and self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID } send_to.update(floodsub_peers) @@ -794,8 +797,8 @@ class GossipSub(IPubsubRouter, Service): # Add all unknown message ids (ids that appear in ihave_msg but not in # seen_seqnos) to list of messages we want to request - msg_ids_wanted: list[str] = [ - msg_id + msg_ids_wanted: list[MessageID] = [ + parse_message_id_safe(msg_id) for msg_id in ihave_msg.messageIDs if msg_id not in seen_seqnos_and_peers ] @@ -811,9 +814,9 @@ class GossipSub(IPubsubRouter, Service): Forwards all request messages that are present in mcache to the requesting peer. """ - # FIXME: Update type of message ID - # FIXME: Find a better way to parse the msg ids - msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs] + msg_ids: list[tuple[bytes, bytes]] = [ + safe_parse_message_id(msg) for msg in iwant_msg.messageIDs + ] msgs_to_forward: list[rpc_pb2.Message] = [] for msg_id_iwant in msg_ids: # Check if the wanted message ID is present in mcache diff --git a/libp2p/pubsub/utils.py b/libp2p/pubsub/utils.py index 6686ba69..6beaccc5 100644 --- a/libp2p/pubsub/utils.py +++ b/libp2p/pubsub/utils.py @@ -1,6 +1,10 @@ +import ast import logging from libp2p.abc import IHost +from libp2p.custom_types import ( + MessageID, +) from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ID from libp2p.pubsub.pb.rpc_pb2 import RPC @@ -48,3 +52,29 @@ def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool: logger.error("Failed to update the Certified-Addr-Book: %s", e) return False return True + + +def parse_message_id_safe(msg_id_str: str) -> MessageID: + """Safely handle message ID as string.""" + return MessageID(msg_id_str) + + +def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]: + """ + Safely parse message ID using ast.literal_eval with validation. + :param msg_id_str: String representation of message ID + :return: Tuple of (seqno, from_id) as bytes + :raises ValueError: If parsing fails + """ + try: + parsed = ast.literal_eval(msg_id_str) + if not isinstance(parsed, tuple) or len(parsed) != 2: + raise ValueError("Invalid message ID format") + + seqno, from_id = parsed + if not isinstance(seqno, bytes) or not isinstance(from_id, bytes): + raise ValueError("Message ID components must be bytes") + + return (seqno, from_id) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid message ID format: {e}") diff --git a/libp2p/relay/circuit_v2/config.py b/libp2p/relay/circuit_v2/config.py index 3315c74f..d56839e0 100644 --- a/libp2p/relay/circuit_v2/config.py +++ b/libp2p/relay/circuit_v2/config.py @@ -9,6 +9,7 @@ from dataclasses import ( dataclass, field, ) +from enum import Flag, auto from libp2p.peer.peerinfo import ( PeerInfo, @@ -18,29 +19,118 @@ from .resources import ( RelayLimits, ) +DEFAULT_MIN_RELAYS = 3 +DEFAULT_MAX_RELAYS = 20 +DEFAULT_DISCOVERY_INTERVAL = 300 # seconds +DEFAULT_RESERVATION_TTL = 3600 # seconds +DEFAULT_MAX_CIRCUIT_DURATION = 3600 # seconds +DEFAULT_MAX_CIRCUIT_BYTES = 1024 * 1024 * 1024 # 1GB + +DEFAULT_MAX_CIRCUIT_CONNS = 8 +DEFAULT_MAX_RESERVATIONS = 4 + +MAX_RESERVATIONS_PER_IP = 8 +MAX_CIRCUITS_PER_IP = 16 +RESERVATION_RATE_PER_IP = 4 # per minute +CIRCUIT_RATE_PER_IP = 8 # per minute +MAX_CIRCUITS_TOTAL = 64 +MAX_RESERVATIONS_TOTAL = 32 +MAX_BANDWIDTH_PER_CIRCUIT = 1024 * 1024 # 1MB/s +MAX_BANDWIDTH_TOTAL = 10 * 1024 * 1024 # 10MB/s + +MIN_RELAY_SCORE = 0.5 +MAX_RELAY_LATENCY = 1.0 # seconds +ENABLE_AUTO_RELAY = True +AUTO_RELAY_TIMEOUT = 30 # seconds +MAX_AUTO_RELAY_ATTEMPTS = 3 +RESERVATION_REFRESH_THRESHOLD = 0.8 # Refresh at 80% of TTL +MAX_CONCURRENT_RESERVATIONS = 2 + +# Timeout constants for different components +DEFAULT_DISCOVERY_STREAM_TIMEOUT = 10 # seconds +DEFAULT_PEER_PROTOCOL_TIMEOUT = 5 # seconds +DEFAULT_PROTOCOL_READ_TIMEOUT = 15 # seconds +DEFAULT_PROTOCOL_WRITE_TIMEOUT = 15 # seconds +DEFAULT_PROTOCOL_CLOSE_TIMEOUT = 10 # seconds +DEFAULT_DCUTR_READ_TIMEOUT = 30 # seconds +DEFAULT_DCUTR_WRITE_TIMEOUT = 30 # seconds +DEFAULT_DIAL_TIMEOUT = 10 # seconds + + +@dataclass +class TimeoutConfig: + """Timeout configuration for different Circuit Relay v2 components.""" + + # Discovery timeouts + discovery_stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT + peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT + + # Core protocol timeouts + protocol_read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT + protocol_write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT + protocol_close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT + + # DCUtR timeouts + dcutr_read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT + dcutr_write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT + dial_timeout: int = DEFAULT_DIAL_TIMEOUT + + +# Relay roles enum +class RelayRole(Flag): + """ + Bit-flag enum that captures the three possible relay capabilities. + + A node can combine multiple roles using bit-wise OR, for example:: + + RelayRole.HOP | RelayRole.STOP + """ + + HOP = auto() # Act as a relay for others ("hop") + STOP = auto() # Accept relayed connections ("stop") + CLIENT = auto() # Dial through existing relays ("client") + @dataclass class RelayConfig: """Configuration for Circuit Relay v2.""" - # Role configuration - enable_hop: bool = False # Whether to act as a relay (hop) - enable_stop: bool = True # Whether to accept relayed connections (stop) - enable_client: bool = True # Whether to use relays for dialing + # Role configuration (bit-flags) + roles: RelayRole = RelayRole.STOP | RelayRole.CLIENT # Resource limits limits: RelayLimits | None = None # Discovery configuration bootstrap_relays: list[PeerInfo] = field(default_factory=list) - min_relays: int = 3 - max_relays: int = 20 - discovery_interval: int = 300 # seconds + min_relays: int = DEFAULT_MIN_RELAYS + max_relays: int = DEFAULT_MAX_RELAYS + discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL # Connection configuration - reservation_ttl: int = 3600 # seconds - max_circuit_duration: int = 3600 # seconds - max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB + reservation_ttl: int = DEFAULT_RESERVATION_TTL + max_circuit_duration: int = DEFAULT_MAX_CIRCUIT_DURATION + max_circuit_bytes: int = DEFAULT_MAX_CIRCUIT_BYTES + + # Timeout configuration + timeouts: TimeoutConfig = field(default_factory=TimeoutConfig) + + # --------------------------------------------------------------------- + # Backwards-compat boolean helpers. Existing code that still accesses + # ``cfg.enable_hop, cfg.enable_stop, cfg.enable_client`` will continue to work. + # --------------------------------------------------------------------- + + @property + def enable_hop(self) -> bool: # pragma: no cover – helper + return bool(self.roles & RelayRole.HOP) + + @property + def enable_stop(self) -> bool: # pragma: no cover – helper + return bool(self.roles & RelayRole.STOP) + + @property + def enable_client(self) -> bool: # pragma: no cover – helper + return bool(self.roles & RelayRole.CLIENT) def __post_init__(self) -> None: """Initialize default values.""" @@ -48,8 +138,8 @@ class RelayConfig: self.limits = RelayLimits( duration=self.max_circuit_duration, data=self.max_circuit_bytes, - max_circuit_conns=8, - max_reservations=4, + max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS, + max_reservations=DEFAULT_MAX_RESERVATIONS, ) @@ -58,20 +148,20 @@ class HopConfig: """Configuration specific to relay (hop) nodes.""" # Resource limits per IP - max_reservations_per_ip: int = 8 - max_circuits_per_ip: int = 16 + max_reservations_per_ip: int = MAX_RESERVATIONS_PER_IP + max_circuits_per_ip: int = MAX_CIRCUITS_PER_IP # Rate limiting - reservation_rate_per_ip: int = 4 # per minute - circuit_rate_per_ip: int = 8 # per minute + reservation_rate_per_ip: int = RESERVATION_RATE_PER_IP + circuit_rate_per_ip: int = CIRCUIT_RATE_PER_IP # Resource quotas - max_circuits_total: int = 64 - max_reservations_total: int = 32 + max_circuits_total: int = MAX_CIRCUITS_TOTAL + max_reservations_total: int = MAX_RESERVATIONS_TOTAL # Bandwidth limits - max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s - max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s + max_bandwidth_per_circuit: int = MAX_BANDWIDTH_PER_CIRCUIT + max_bandwidth_total: int = MAX_BANDWIDTH_TOTAL @dataclass @@ -79,14 +169,14 @@ class ClientConfig: """Configuration specific to relay clients.""" # Relay selection - min_relay_score: float = 0.5 - max_relay_latency: float = 1.0 # seconds + min_relay_score: float = MIN_RELAY_SCORE + max_relay_latency: float = MAX_RELAY_LATENCY # Auto-relay settings - enable_auto_relay: bool = True - auto_relay_timeout: int = 30 # seconds - max_auto_relay_attempts: int = 3 + enable_auto_relay: bool = ENABLE_AUTO_RELAY + auto_relay_timeout: int = AUTO_RELAY_TIMEOUT + max_auto_relay_attempts: int = MAX_AUTO_RELAY_ATTEMPTS # Reservation management - reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL - max_concurrent_reservations: int = 2 + reservation_refresh_threshold: float = RESERVATION_REFRESH_THRESHOLD + max_concurrent_reservations: int = MAX_CONCURRENT_RESERVATIONS diff --git a/libp2p/relay/circuit_v2/dcutr.py b/libp2p/relay/circuit_v2/dcutr.py index 48ba1a3f..1328ac49 100644 --- a/libp2p/relay/circuit_v2/dcutr.py +++ b/libp2p/relay/circuit_v2/dcutr.py @@ -29,6 +29,11 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.relay.circuit_v2.config import ( + DEFAULT_DCUTR_READ_TIMEOUT, + DEFAULT_DCUTR_WRITE_TIMEOUT, + DEFAULT_DIAL_TIMEOUT, +) from libp2p.relay.circuit_v2.nat import ( ReachabilityChecker, ) @@ -47,11 +52,7 @@ PROTOCOL_ID = TProtocol("/libp2p/dcutr") # Maximum message size for DCUtR (4KiB as per spec) MAX_MESSAGE_SIZE = 4 * 1024 -# Timeouts -STREAM_READ_TIMEOUT = 30 # seconds -STREAM_WRITE_TIMEOUT = 30 # seconds -DIAL_TIMEOUT = 10 # seconds - +# DCUtR protocol constants # Maximum number of hole punch attempts per peer MAX_HOLE_PUNCH_ATTEMPTS = 5 @@ -70,7 +71,13 @@ class DCUtRProtocol(Service): hole punching, after they have established an initial connection through a relay. """ - def __init__(self, host: IHost): + def __init__( + self, + host: IHost, + read_timeout: int = DEFAULT_DCUTR_READ_TIMEOUT, + write_timeout: int = DEFAULT_DCUTR_WRITE_TIMEOUT, + dial_timeout: int = DEFAULT_DIAL_TIMEOUT, + ): """ Initialize the DCUtR protocol. @@ -78,10 +85,19 @@ class DCUtRProtocol(Service): ---------- host : IHost The libp2p host this protocol is running on + read_timeout : int + Timeout for stream read operations, in seconds + write_timeout : int + Timeout for stream write operations, in seconds + dial_timeout : int + Timeout for dial operations, in seconds """ super().__init__() self.host = host + self.read_timeout = read_timeout + self.write_timeout = write_timeout + self.dial_timeout = dial_timeout self.event_started = trio.Event() self._hole_punch_attempts: dict[ID, int] = {} self._direct_connections: set[ID] = set() @@ -161,7 +177,7 @@ class DCUtRProtocol(Service): try: # Read the CONNECT message - with trio.fail_after(STREAM_READ_TIMEOUT): + with trio.fail_after(self.read_timeout): msg_bytes = await stream.read(MAX_MESSAGE_SIZE) # Parse the message @@ -196,7 +212,7 @@ class DCUtRProtocol(Service): response.type = HolePunch.CONNECT response.ObsAddrs.extend(our_addrs) - with trio.fail_after(STREAM_WRITE_TIMEOUT): + with trio.fail_after(self.write_timeout): await stream.write(response.SerializeToString()) logger.debug( @@ -206,7 +222,7 @@ class DCUtRProtocol(Service): ) # Wait for SYNC message - with trio.fail_after(STREAM_READ_TIMEOUT): + with trio.fail_after(self.read_timeout): sync_bytes = await stream.read(MAX_MESSAGE_SIZE) # Parse the SYNC message @@ -300,7 +316,7 @@ class DCUtRProtocol(Service): connect_msg.ObsAddrs.extend(our_addrs) start_time = time.time() - with trio.fail_after(STREAM_WRITE_TIMEOUT): + with trio.fail_after(self.write_timeout): await stream.write(connect_msg.SerializeToString()) logger.debug( @@ -310,7 +326,7 @@ class DCUtRProtocol(Service): ) # Receive the peer's CONNECT message - with trio.fail_after(STREAM_READ_TIMEOUT): + with trio.fail_after(self.read_timeout): resp_bytes = await stream.read(MAX_MESSAGE_SIZE) # Calculate RTT @@ -349,7 +365,7 @@ class DCUtRProtocol(Service): sync_msg = HolePunch() sync_msg.type = HolePunch.SYNC - with trio.fail_after(STREAM_WRITE_TIMEOUT): + with trio.fail_after(self.write_timeout): await stream.write(sync_msg.SerializeToString()) logger.debug("Sent SYNC message to %s", peer_id) @@ -468,7 +484,7 @@ class DCUtRProtocol(Service): peer_info = PeerInfo(peer_id, [addr]) # Try to connect with timeout - with trio.fail_after(DIAL_TIMEOUT): + with trio.fail_after(self.dial_timeout): await self.host.connect(peer_info) logger.info("Successfully connected to %s at %s", peer_id, addr) diff --git a/libp2p/relay/circuit_v2/discovery.py b/libp2p/relay/circuit_v2/discovery.py index a35eacdc..50ee8d90 100644 --- a/libp2p/relay/circuit_v2/discovery.py +++ b/libp2p/relay/circuit_v2/discovery.py @@ -31,6 +31,11 @@ from libp2p.tools.async_service import ( Service, ) +from .config import ( + DEFAULT_DISCOVERY_INTERVAL, + DEFAULT_DISCOVERY_STREAM_TIMEOUT, + DEFAULT_PEER_PROTOCOL_TIMEOUT, +) from .pb.circuit_pb2 import ( HopMessage, ) @@ -43,10 +48,8 @@ from .protocol_buffer import ( logger = logging.getLogger("libp2p.relay.circuit_v2.discovery") -# Constants +# Discovery constants MAX_RELAYS_TO_TRACK = 10 -DEFAULT_DISCOVERY_INTERVAL = 60 # seconds -STREAM_TIMEOUT = 10 # seconds # Extended interfaces for type checking @@ -86,6 +89,8 @@ class RelayDiscovery(Service): auto_reserve: bool = False, discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL, max_relays: int = MAX_RELAYS_TO_TRACK, + stream_timeout: int = DEFAULT_DISCOVERY_STREAM_TIMEOUT, + peer_protocol_timeout: int = DEFAULT_PEER_PROTOCOL_TIMEOUT, ) -> None: """ Initialize the discovery service. @@ -100,6 +105,10 @@ class RelayDiscovery(Service): How often to run discovery, in seconds max_relays : int Maximum number of relays to track + stream_timeout : int + Timeout for stream operations during discovery, in seconds + peer_protocol_timeout : int + Timeout for checking peer protocol support, in seconds """ super().__init__() @@ -107,6 +116,8 @@ class RelayDiscovery(Service): self.auto_reserve = auto_reserve self.discovery_interval = discovery_interval self.max_relays = max_relays + self.stream_timeout = stream_timeout + self.peer_protocol_timeout = peer_protocol_timeout self._discovered_relays: dict[ID, RelayInfo] = {} self._protocol_cache: dict[ ID, set[str] @@ -165,8 +176,8 @@ class RelayDiscovery(Service): self._discovered_relays[peer_id].last_seen = time.time() continue - # Check if peer supports the relay protocol - with trio.move_on_after(5): # Don't wait too long for protocol info + # Don't wait too long for protocol info + with trio.move_on_after(self.peer_protocol_timeout): if await self._supports_relay_protocol(peer_id): await self._add_relay(peer_id) @@ -264,7 +275,7 @@ class RelayDiscovery(Service): async def _check_via_direct_connection(self, peer_id: ID) -> bool | None: """Check protocol support via direct connection.""" try: - with trio.fail_after(STREAM_TIMEOUT): + with trio.fail_after(self.stream_timeout): stream = await self.host.new_stream(peer_id, [PROTOCOL_ID]) if stream: await stream.close() @@ -370,7 +381,7 @@ class RelayDiscovery(Service): # Open a stream to the relay with timeout try: - with trio.fail_after(STREAM_TIMEOUT): + with trio.fail_after(self.stream_timeout): stream = await self.host.new_stream(peer_id, [PROTOCOL_ID]) if not stream: logger.error("Failed to open stream to relay %s", peer_id) @@ -386,7 +397,7 @@ class RelayDiscovery(Service): peer=self.host.get_id().to_bytes(), ) - with trio.fail_after(STREAM_TIMEOUT): + with trio.fail_after(self.stream_timeout): await stream.write(request.SerializeToString()) # Wait for response diff --git a/libp2p/relay/circuit_v2/protocol.py b/libp2p/relay/circuit_v2/protocol.py index 1cf76efa..a6a80c20 100644 --- a/libp2p/relay/circuit_v2/protocol.py +++ b/libp2p/relay/circuit_v2/protocol.py @@ -5,6 +5,7 @@ This module implements the Circuit Relay v2 protocol as specified in: https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md """ +from enum import Enum, auto import logging import time from typing import ( @@ -37,6 +38,15 @@ from libp2p.tools.async_service import ( Service, ) +from .config import ( + DEFAULT_MAX_CIRCUIT_BYTES, + DEFAULT_MAX_CIRCUIT_CONNS, + DEFAULT_MAX_CIRCUIT_DURATION, + DEFAULT_MAX_RESERVATIONS, + DEFAULT_PROTOCOL_CLOSE_TIMEOUT, + DEFAULT_PROTOCOL_READ_TIMEOUT, + DEFAULT_PROTOCOL_WRITE_TIMEOUT, +) from .pb.circuit_pb2 import ( HopMessage, Limit, @@ -58,18 +68,22 @@ logger = logging.getLogger("libp2p.relay.circuit_v2") PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0") STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop") + +# Direction enum for data piping +class Pipe(Enum): + SRC_TO_DST = auto() + DST_TO_SRC = auto() + + # Default limits for relay resources DEFAULT_RELAY_LIMITS = RelayLimits( - duration=60 * 60, # 1 hour - data=1024 * 1024 * 1024, # 1GB - max_circuit_conns=8, - max_reservations=4, + duration=DEFAULT_MAX_CIRCUIT_DURATION, + data=DEFAULT_MAX_CIRCUIT_BYTES, + max_circuit_conns=DEFAULT_MAX_CIRCUIT_CONNS, + max_reservations=DEFAULT_MAX_RESERVATIONS, ) -# Stream operation timeouts -STREAM_READ_TIMEOUT = 15 # seconds -STREAM_WRITE_TIMEOUT = 15 # seconds -STREAM_CLOSE_TIMEOUT = 10 # seconds +# Stream operation constants MAX_READ_RETRIES = 5 # Maximum number of read retries @@ -113,6 +127,9 @@ class CircuitV2Protocol(Service): host: IHost, limits: RelayLimits | None = None, allow_hop: bool = False, + read_timeout: int = DEFAULT_PROTOCOL_READ_TIMEOUT, + write_timeout: int = DEFAULT_PROTOCOL_WRITE_TIMEOUT, + close_timeout: int = DEFAULT_PROTOCOL_CLOSE_TIMEOUT, ) -> None: """ Initialize a Circuit Relay v2 protocol instance. @@ -125,11 +142,20 @@ class CircuitV2Protocol(Service): Resource limits for the relay allow_hop : bool Whether to allow this node to act as a relay + read_timeout : int + Timeout for stream read operations, in seconds + write_timeout : int + Timeout for stream write operations, in seconds + close_timeout : int + Timeout for stream close operations, in seconds """ self.host = host self.limits = limits or DEFAULT_RELAY_LIMITS self.allow_hop = allow_hop + self.read_timeout = read_timeout + self.write_timeout = write_timeout + self.close_timeout = close_timeout self.resource_manager = RelayResourceManager(self.limits) self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {} self.event_started = trio.Event() @@ -174,7 +200,7 @@ class CircuitV2Protocol(Service): return try: - with trio.fail_after(STREAM_CLOSE_TIMEOUT): + with trio.fail_after(self.close_timeout): await stream.close() except Exception: try: @@ -216,7 +242,7 @@ class CircuitV2Protocol(Service): while retries < max_retries: try: - with trio.fail_after(STREAM_READ_TIMEOUT): + with trio.fail_after(self.read_timeout): # Try reading with timeout logger.debug( "Attempting to read from stream (attempt %d/%d)", @@ -293,7 +319,7 @@ class CircuitV2Protocol(Service): # First, handle the read timeout gracefully try: with trio.fail_after( - STREAM_READ_TIMEOUT * 2 + self.read_timeout * 2 ): # Double the timeout for reading msg_bytes = await stream.read() if not msg_bytes: @@ -414,7 +440,7 @@ class CircuitV2Protocol(Service): """ try: # Read the incoming message with timeout - with trio.fail_after(STREAM_READ_TIMEOUT): + with trio.fail_after(self.read_timeout): msg_bytes = await stream.read() stop_msg = StopMessage() stop_msg.ParseFromString(msg_bytes) @@ -458,8 +484,20 @@ class CircuitV2Protocol(Service): # Start relaying data async with trio.open_nursery() as nursery: - nursery.start_soon(self._relay_data, src_stream, stream, peer_id) - nursery.start_soon(self._relay_data, stream, src_stream, peer_id) + nursery.start_soon( + self._relay_data, + src_stream, + stream, + peer_id, + Pipe.SRC_TO_DST, + ) + nursery.start_soon( + self._relay_data, + stream, + src_stream, + peer_id, + Pipe.DST_TO_SRC, + ) except trio.TooSlowError: logger.error("Timeout reading from stop stream") @@ -509,7 +547,7 @@ class CircuitV2Protocol(Service): ttl = self.resource_manager.reserve(peer_id) # Send reservation success response - with trio.fail_after(STREAM_WRITE_TIMEOUT): + with trio.fail_after(self.write_timeout): status = create_status( code=StatusCode.OK, message="Reservation accepted" ) @@ -560,7 +598,7 @@ class CircuitV2Protocol(Service): # Always close the stream when done with reservation if cast(INetStreamWithExtras, stream).is_open(): try: - with trio.fail_after(STREAM_CLOSE_TIMEOUT): + with trio.fail_after(self.close_timeout): await stream.close() except Exception as close_err: logger.error("Error closing stream: %s", str(close_err)) @@ -596,7 +634,7 @@ class CircuitV2Protocol(Service): self._active_relays[peer_id] = (stream, None) # Try to connect to the destination with timeout - with trio.fail_after(STREAM_READ_TIMEOUT): + with trio.fail_after(self.read_timeout): dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID]) if not dst_stream: raise ConnectionError("Could not connect to destination") @@ -648,8 +686,20 @@ class CircuitV2Protocol(Service): # Start relaying data async with trio.open_nursery() as nursery: - nursery.start_soon(self._relay_data, stream, dst_stream, peer_id) - nursery.start_soon(self._relay_data, dst_stream, stream, peer_id) + nursery.start_soon( + self._relay_data, + stream, + dst_stream, + peer_id, + Pipe.SRC_TO_DST, + ) + nursery.start_soon( + self._relay_data, + dst_stream, + stream, + peer_id, + Pipe.DST_TO_SRC, + ) except (trio.TooSlowError, ConnectionError) as e: logger.error("Error establishing relay connection: %s", str(e)) @@ -685,6 +735,7 @@ class CircuitV2Protocol(Service): src_stream: INetStream, dst_stream: INetStream, peer_id: ID, + direction: Pipe, ) -> None: """ Relay data between two streams. @@ -698,24 +749,27 @@ class CircuitV2Protocol(Service): peer_id : ID ID of the peer being relayed + direction : Pipe + Direction of data flow (``Pipe.SRC_TO_DST`` or ``Pipe.DST_TO_SRC``) + """ try: while True: # Read data with retries data = await self._read_stream_with_retry(src_stream) if not data: - logger.info("Source stream closed/reset") + logger.info("%s closed/reset", direction.name) break # Write data with timeout try: - with trio.fail_after(STREAM_WRITE_TIMEOUT): + with trio.fail_after(self.write_timeout): await dst_stream.write(data) except trio.TooSlowError: - logger.error("Timeout writing to destination stream") + logger.error("Timeout writing in %s", direction.name) break except Exception as e: - logger.error("Error writing to destination stream: %s", str(e)) + logger.error("Error writing in %s: %s", direction.name, str(e)) break # Update resource usage @@ -744,7 +798,7 @@ class CircuitV2Protocol(Service): """Send a status message.""" try: logger.debug("Sending status message with code %s: %s", code, message) - with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout + with trio.fail_after(self.write_timeout * 2): # Double the timeout # Create a proto Status directly pb_status = PbStatus() pb_status.code = cast( @@ -782,7 +836,7 @@ class CircuitV2Protocol(Service): """Send a status message on a STOP stream.""" try: logger.debug("Sending stop status message with code %s: %s", code, message) - with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout + with trio.fail_after(self.write_timeout * 2): # Double the timeout # Create a proto Status directly pb_status = PbStatus() pb_status.code = cast( diff --git a/libp2p/relay/circuit_v2/resources.py b/libp2p/relay/circuit_v2/resources.py index 4da67ec6..d621990d 100644 --- a/libp2p/relay/circuit_v2/resources.py +++ b/libp2p/relay/circuit_v2/resources.py @@ -8,6 +8,7 @@ including reservations and connection limits. from dataclasses import ( dataclass, ) +from enum import Enum, auto import hashlib import os import time @@ -19,6 +20,18 @@ from libp2p.peer.id import ( # Import the protobuf definitions from .pb.circuit_pb2 import Reservation as PbReservation +RANDOM_BYTES_LENGTH = 16 # 128 bits of randomness +TIMESTAMP_MULTIPLIER = 1000000 # To convert seconds to microseconds + + +# Reservation status enum +class ReservationStatus(Enum): + """Lifecycle status of a relay reservation.""" + + ACTIVE = auto() + EXPIRED = auto() + REJECTED = auto() + @dataclass class RelayLimits: @@ -68,8 +81,8 @@ class Reservation: # - Peer ID to bind it to the specific peer # - Timestamp for uniqueness # - Hash everything for a fixed size output - random_bytes = os.urandom(16) # 128 bits of randomness - timestamp = str(int(self.created_at * 1000000)).encode() + random_bytes = os.urandom(RANDOM_BYTES_LENGTH) + timestamp = str(int(self.created_at * TIMESTAMP_MULTIPLIER)).encode() peer_bytes = self.peer_id.to_bytes() # Combine all elements and hash them @@ -84,6 +97,15 @@ class Reservation: """Check if the reservation has expired.""" return time.time() > self.expires_at + # Expose a friendly status enum + + @property + def status(self) -> ReservationStatus: + """Return the current status as a ``ReservationStatus`` enum.""" + return ( + ReservationStatus.EXPIRED if self.is_expired() else ReservationStatus.ACTIVE + ) + def can_accept_connection(self) -> bool: """Check if a new connection can be accepted.""" return ( diff --git a/libp2p/relay/circuit_v2/transport.py b/libp2p/relay/circuit_v2/transport.py index 8ac43d99..23454b89 100644 --- a/libp2p/relay/circuit_v2/transport.py +++ b/libp2p/relay/circuit_v2/transport.py @@ -89,6 +89,8 @@ class CircuitV2Transport(ITransport): auto_reserve=config.enable_client, discovery_interval=config.discovery_interval, max_relays=config.max_relays, + stream_timeout=config.timeouts.discovery_stream_timeout, + peer_protocol_timeout=config.timeouts.peer_protocol_timeout, ) self.relay_counter = 0 # for round robin load balancing diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index a24b6c74..18fbbcd5 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -1,3 +1,4 @@ +import logging from typing import ( cast, ) @@ -15,6 +16,8 @@ from libp2p.io.msgio import ( FixedSizeLenMsgReadWriter, ) +logger = logging.getLogger(__name__) + SIZE_NOISE_MESSAGE_LEN = 2 MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 SIZE_NOISE_MESSAGE_BODY_LEN = 2 @@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): self.noise_state = noise_state async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None: + logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes") data_encrypted = self.encrypt(msg) if prefix_encoded: # Manually add the prefix if needed data_encrypted = self.prefix + data_encrypted + logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes") await self.read_writer.write_msg(data_encrypted) + logger.debug("Noise write_msg: write completed successfully") async def read_msg(self, prefix_encoded: bool = False) -> bytes: + logger.debug("Noise read_msg: reading encrypted message") noise_msg_encrypted = await self.read_writer.read_msg() + logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes") if prefix_encoded: - return self.decrypt(noise_msg_encrypted[len(self.prefix) :]) + result = self.decrypt(noise_msg_encrypted[len(self.prefix) :]) else: - return self.decrypt(noise_msg_encrypted) + result = self.decrypt(noise_msg_encrypted) + logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes") + return result async def close(self) -> None: await self.read_writer.close() diff --git a/libp2p/security/noise/messages.py b/libp2p/security/noise/messages.py index 309b24b0..f7e2dceb 100644 --- a/libp2p/security/noise/messages.py +++ b/libp2p/security/noise/messages.py @@ -1,6 +1,7 @@ from dataclasses import ( dataclass, ) +import logging from libp2p.crypto.keys import ( PrivateKey, @@ -12,6 +13,8 @@ from libp2p.crypto.serialization import ( from .pb import noise_pb2 as noise_pb +logger = logging.getLogger(__name__) + SIGNED_DATA_PREFIX = "noise-libp2p-static-key:" @@ -48,6 +51,8 @@ def make_handshake_payload_sig( id_privkey: PrivateKey, noise_static_pubkey: PublicKey ) -> bytes: data = make_data_to_be_signed(noise_static_pubkey) + logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}") + logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}") return id_privkey.sign(data) @@ -60,4 +65,27 @@ def verify_handshake_payload_sig( 2. signed by the private key corresponding to `id_pubkey` """ expected_data = make_data_to_be_signed(noise_static_pubkey) - return payload.id_pubkey.verify(expected_data, payload.id_sig) + logger.debug( + f"verify_handshake_payload_sig: payload.id_pubkey type: " + f"{type(payload.id_pubkey)}" + ) + logger.debug( + f"verify_handshake_payload_sig: noise_static_pubkey type: " + f"{type(noise_static_pubkey)}" + ) + logger.debug( + f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}" + ) + logger.debug( + f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}" + ) + logger.debug( + f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}" + ) + try: + result = payload.id_pubkey.verify(expected_data, payload.id_sig) + logger.debug(f"verify_handshake_payload_sig: verification result: {result}") + return result + except Exception as e: + logger.error(f"verify_handshake_payload_sig: verification exception: {e}") + return False diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index 00f51d06..d51332a4 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -2,6 +2,7 @@ from abc import ( ABC, abstractmethod, ) +import logging from cryptography.hazmat.primitives import ( serialization, @@ -46,6 +47,8 @@ from .messages import ( verify_handshake_payload_sig, ) +logger = logging.getLogger(__name__) + class IPattern(ABC): @abstractmethod @@ -95,6 +98,7 @@ class PatternXX(BasePattern): self.early_data = early_data async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: + logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}") noise_state = self.create_noise_state() noise_state.set_as_responder() noise_state.start_handshake() @@ -107,15 +111,22 @@ class PatternXX(BasePattern): read_writer = NoiseHandshakeReadWriter(conn, noise_state) # Consume msg#1. + logger.debug("Noise XX handshake_inbound: reading msg#1") await read_writer.read_msg() + logger.debug("Noise XX handshake_inbound: read msg#1 successfully") # Send msg#2, which should include our handshake payload. + logger.debug("Noise XX handshake_inbound: preparing msg#2") our_payload = self.make_handshake_payload() msg_2 = our_payload.serialize() + logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)") await read_writer.write_msg(msg_2) + logger.debug("Noise XX handshake_inbound: sent msg#2 successfully") # Receive and consume msg#3. + logger.debug("Noise XX handshake_inbound: reading msg#3") msg_3 = await read_writer.read_msg() + logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)") peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3) if handshake_state.rs is None: @@ -147,6 +158,7 @@ class PatternXX(BasePattern): async def handshake_outbound( self, conn: IRawConnection, remote_peer: ID ) -> ISecureConn: + logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}") noise_state = self.create_noise_state() read_writer = NoiseHandshakeReadWriter(conn, noise_state) @@ -159,11 +171,15 @@ class PatternXX(BasePattern): raise NoiseStateError("Handshake state is not initialized") # Send msg#1, which is *not* encrypted. + logger.debug("Noise XX handshake_outbound: sending msg#1") msg_1 = b"" await read_writer.write_msg(msg_1) + logger.debug("Noise XX handshake_outbound: sent msg#1 successfully") # Read msg#2 from the remote, which contains the public key of the peer. + logger.debug("Noise XX handshake_outbound: reading msg#2") msg_2 = await read_writer.read_msg() + logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)") peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2) if handshake_state.rs is None: @@ -174,8 +190,27 @@ class PatternXX(BasePattern): ) remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs) + logger.debug( + f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}" + ) + logger.debug( + f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}" + ) + id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex() + logger.debug( + f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: " + f"{id_pubkey_repr}" + ) if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey): + logger.error( + f"Noise XX handshake_outbound: signature verification failed for peer " + f"{remote_peer}" + ) raise InvalidSignature + logger.debug( + f"Noise XX handshake_outbound: signature verification successful for peer " + f"{remote_peer}" + ) remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey) if remote_peer_id_from_pubkey != remote_peer: raise PeerIDMismatchesPubkey( diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e8d0561d..150ae9dd 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,5 +1,3 @@ -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager from types import ( TracebackType, ) @@ -15,6 +13,7 @@ from libp2p.abc import ( from libp2p.stream_muxer.exceptions import ( MuxedConnUnavailable, ) +from libp2p.stream_muxer.rw_lock import ReadWriteLock from .constants import ( HeaderTags, @@ -34,72 +33,6 @@ if TYPE_CHECKING: ) -class ReadWriteLock: - """ - A read-write lock that allows multiple concurrent readers - or one exclusive writer, implemented using Trio primitives. - """ - - def __init__(self) -> None: - self._readers = 0 - self._readers_lock = trio.Lock() # Protects access to _readers count - self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time - - async def acquire_read(self) -> None: - """Acquire a read lock. Multiple readers can hold it simultaneously.""" - try: - async with self._readers_lock: - if self._readers == 0: - await self._writer_lock.acquire() - self._readers += 1 - except trio.Cancelled: - raise - - async def release_read(self) -> None: - """Release a read lock.""" - async with self._readers_lock: - if self._readers == 1: - self._writer_lock.release() - self._readers -= 1 - - async def acquire_write(self) -> None: - """Acquire an exclusive write lock.""" - try: - await self._writer_lock.acquire() - except trio.Cancelled: - raise - - def release_write(self) -> None: - """Release the exclusive write lock.""" - self._writer_lock.release() - - @asynccontextmanager - async def read_lock(self) -> AsyncGenerator[None, None]: - """Context manager for acquiring and releasing a read lock safely.""" - acquire = False - try: - await self.acquire_read() - acquire = True - yield - finally: - if acquire: - with trio.CancelScope() as scope: - scope.shield = True - await self.release_read() - - @asynccontextmanager - async def write_lock(self) -> AsyncGenerator[None, None]: - """Context manager for acquiring and releasing a write lock safely.""" - acquire = False - try: - await self.acquire_write() - acquire = True - yield - finally: - if acquire: - self.release_write() - - class MplexStream(IMuxedStream): """ reference: https://github.com/libp2p/go-mplex/blob/master/stream.go diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index ef90fac0..2d206141 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import ( MultiselectError, ) from libp2p.protocol_muxer.multiselect import ( + DEFAULT_NEGOTIATE_TIMEOUT, Multiselect, ) from libp2p.protocol_muxer.multiselect_client import ( @@ -46,11 +47,17 @@ class MuxerMultistream: transports: "OrderedDict[TProtocol, TMuxerClass]" multiselect: Multiselect multiselect_client: MultiselectClient + negotiate_timeout: int - def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None: + def __init__( + self, + muxer_transports_by_protocol: TMuxerOptions, + negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, + ) -> None: self.transports = OrderedDict() self.multiselect = Multiselect() self.multistream_client = MultiselectClient() + self.negotiate_timeout = negotiate_timeout for protocol, transport in muxer_transports_by_protocol.items(): self.add_transport(protocol, transport) @@ -80,10 +87,12 @@ class MuxerMultistream: communicator = MultiselectCommunicator(conn) if conn.is_initiator: protocol = await self.multiselect_client.select_one_of( - tuple(self.transports.keys()), communicator + tuple(self.transports.keys()), communicator, self.negotiate_timeout ) else: - protocol, _ = await self.multiselect.negotiate(communicator) + protocol, _ = await self.multiselect.negotiate( + communicator, self.negotiate_timeout + ) if protocol is None: raise MultiselectError( "Fail to negotiate a stream muxer protocol: no protocol selected" @@ -93,7 +102,7 @@ class MuxerMultistream: async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: communicator = MultiselectCommunicator(conn) protocol = await self.multistream_client.select_one_of( - tuple(self.transports.keys()), communicator + tuple(self.transports.keys()), communicator, self.negotiate_timeout ) transport_class = self.transports[protocol] if protocol == PROTOCOL_ID: diff --git a/libp2p/stream_muxer/rw_lock.py b/libp2p/stream_muxer/rw_lock.py new file mode 100644 index 00000000..7910a144 --- /dev/null +++ b/libp2p/stream_muxer/rw_lock.py @@ -0,0 +1,70 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import trio + + +class ReadWriteLock: + """ + A read-write lock that allows multiple concurrent readers + or one exclusive writer, implemented using Trio primitives. + """ + + def __init__(self) -> None: + self._readers = 0 + self._readers_lock = trio.Lock() # Protects access to _readers count + self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time + + async def acquire_read(self) -> None: + """Acquire a read lock. Multiple readers can hold it simultaneously.""" + try: + async with self._readers_lock: + if self._readers == 0: + await self._writer_lock.acquire() + self._readers += 1 + except trio.Cancelled: + raise + + async def release_read(self) -> None: + """Release a read lock.""" + async with self._readers_lock: + if self._readers == 1: + self._writer_lock.release() + self._readers -= 1 + + async def acquire_write(self) -> None: + """Acquire an exclusive write lock.""" + try: + await self._writer_lock.acquire() + except trio.Cancelled: + raise + + def release_write(self) -> None: + """Release the exclusive write lock.""" + self._writer_lock.release() + + @asynccontextmanager + async def read_lock(self) -> AsyncGenerator[None, None]: + """Context manager for acquiring and releasing a read lock safely.""" + acquire = False + try: + await self.acquire_read() + acquire = True + yield + finally: + if acquire: + with trio.CancelScope() as scope: + scope.shield = True + await self.release_read() + + @asynccontextmanager + async def write_lock(self) -> AsyncGenerator[None, None]: + """Context manager for acquiring and releasing a write lock safely.""" + acquire = False + try: + await self.acquire_write() + acquire = True + yield + finally: + if acquire: + self.release_write() diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index b2711e1a..bb84a5db 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import ( MuxedStreamError, MuxedStreamReset, ) +from libp2p.stream_muxer.rw_lock import ReadWriteLock # Configure logger for this module logger = logging.getLogger("libp2p.stream_muxer.yamux") @@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream): self.send_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE self.window_lock = trio.Lock() + self.rw_lock = ReadWriteLock() + self.close_lock = trio.Lock() async def __aenter__(self) -> "YamuxStream": """Enter the async context manager.""" @@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream): await self.close() async def write(self, data: bytes) -> None: - if self.send_closed: - raise MuxedStreamError("Stream is closed for sending") + async with self.rw_lock.write_lock(): + if self.send_closed: + raise MuxedStreamError("Stream is closed for sending") - # Flow control: Check if we have enough send window - total_len = len(data) - sent = 0 - logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") - while sent < total_len: - # Wait for available window with timeout - timeout = False - async with self.window_lock: - if self.send_window == 0: - logger.debug( - f"Stream {self.stream_id}: Window is zero, waiting for update" + # Flow control: Check if we have enough send window + total_len = len(data) + sent = 0 + logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") + while sent < total_len: + # Wait for available window with timeout + timeout = False + async with self.window_lock: + if self.send_window == 0: + logger.debug( + f"Stream {self.stream_id}: " + "Window is zero, waiting for update" + ) + # Release lock and wait with timeout + self.window_lock.release() + # To avoid re-acquiring the lock immediately, + with trio.move_on_after(5.0) as cancel_scope: + while self.send_window == 0 and not self.closed: + await trio.sleep(0.01) + # If we timed out, cancel the scope + timeout = cancel_scope.cancelled_caught + # Re-acquire lock + await self.window_lock.acquire() + + # If we timed out waiting for window update, raise an error + if timeout: + raise MuxedStreamError( + "Timed out waiting for window update after 5 seconds." + ) + + if self.closed: + raise MuxedStreamError("Stream is closed") + + # Calculate how much we can send now + to_send = min(self.send_window, total_len - sent) + chunk = data[sent : sent + to_send] + self.send_window -= to_send + + # Send the data + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) ) - # Release lock and wait with timeout - self.window_lock.release() - # To avoid re-acquiring the lock immediately, - with trio.move_on_after(5.0) as cancel_scope: - while self.send_window == 0 and not self.closed: - await trio.sleep(0.01) - # If we timed out, cancel the scope - timeout = cancel_scope.cancelled_caught - # Re-acquire lock - await self.window_lock.acquire() - - # If we timed out waiting for window update, raise an error - if timeout: - raise MuxedStreamError( - "Timed out waiting for window update after 5 seconds." - ) - - if self.closed: - raise MuxedStreamError("Stream is closed") - - # Calculate how much we can send now - to_send = min(self.send_window, total_len - sent) - chunk = data[sent : sent + to_send] - self.send_window -= to_send - - # Send the data - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk) - ) - await self.conn.secured_conn.write(header + chunk) - sent += to_send + await self.conn.secured_conn.write(header + chunk) + sent += to_send async def send_window_update(self, increment: int, skip_lock: bool = False) -> None: """ @@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream): return data async def close(self) -> None: - if not self.send_closed: - logger.debug(f"Half-closing stream {self.stream_id} (local end)") - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 - ) - await self.conn.secured_conn.write(header) - self.send_closed = True + async with self.close_lock: + if not self.send_closed: + logger.debug(f"Half-closing stream {self.stream_id} (local end)") + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0 + ) + await self.conn.secured_conn.write(header) + self.send_closed = True - # Only set fully closed if both directions are closed - if self.send_closed and self.recv_closed: - self.closed = True - else: - # Stream is half-closed but not fully closed - self.closed = False + # Only set fully closed if both directions are closed + if self.send_closed and self.recv_closed: + self.closed = True + else: + # Stream is half-closed but not fully closed + self.closed = False async def reset(self) -> None: if not self.closed: - logger.debug(f"Resetting stream {self.stream_id}") - header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 - ) - await self.conn.secured_conn.write(header) - self.closed = True - self.reset_received = True # Mark as reset + async with self.close_lock: + logger.debug(f"Resetting stream {self.stream_id}") + header = struct.pack( + YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0 + ) + await self.conn.secured_conn.write(header) + self.closed = True + self.reset_received = True # Mark as reset def set_deadline(self, ttl: int) -> bool: """ diff --git a/libp2p/transport/__init__.py b/libp2p/transport/__init__.py index e69de29b..ebc587e5 100644 --- a/libp2p/transport/__init__.py +++ b/libp2p/transport/__init__.py @@ -0,0 +1,57 @@ +from typing import Any + +from .tcp.tcp import TCP +from .websocket.transport import WebsocketTransport +from .transport_registry import ( + TransportRegistry, + create_transport_for_multiaddr, + get_transport_registry, + register_transport, + get_supported_transport_protocols, +) +from .upgrader import TransportUpgrader +from libp2p.abc import ITransport + +def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport: + """ + Convenience function to create a transport instance. + + :param protocol: The transport protocol ("tcp", "ws", "wss", or custom) + :param upgrader: Optional transport upgrader (required for WebSocket) + :param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config) + :return: Transport instance + """ + # First check if it's a built-in protocol + if protocol in ["ws", "wss"]: + if upgrader is None: + raise ValueError(f"WebSocket transport requires an upgrader") + return WebsocketTransport( + upgrader, + tls_client_config=kwargs.get("tls_client_config"), + tls_server_config=kwargs.get("tls_server_config"), + handshake_timeout=kwargs.get("handshake_timeout", 15.0) + ) + elif protocol == "tcp": + return TCP() + else: + # Check if it's a custom registered transport + registry = get_transport_registry() + transport_class = registry.get_transport(protocol) + if transport_class: + transport = registry.create_transport(protocol, upgrader, **kwargs) + if transport is None: + raise ValueError(f"Failed to create transport for protocol: {protocol}") + return transport + else: + raise ValueError(f"Unsupported transport protocol: {protocol}") + +__all__ = [ + "TCP", + "WebsocketTransport", + "TransportRegistry", + "create_transport_for_multiaddr", + "create_transport", + "get_transport_registry", + "register_transport", + "get_supported_transport_protocols", +] diff --git a/libp2p/transport/quic/__init__.py b/libp2p/transport/quic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libp2p/transport/quic/config.py b/libp2p/transport/quic/config.py new file mode 100644 index 00000000..e0c87adf --- /dev/null +++ b/libp2p/transport/quic/config.py @@ -0,0 +1,345 @@ +""" +Configuration classes for QUIC transport. +""" + +from dataclasses import ( + dataclass, + field, +) +import ssl +from typing import Any, Literal, TypedDict + +from libp2p.custom_types import TProtocol +from libp2p.network.config import ConnectionConfig + + +class QUICTransportKwargs(TypedDict, total=False): + """Type definition for kwargs accepted by new_transport function.""" + + # Connection settings + idle_timeout: float + max_datagram_size: int + local_port: int | None + + # Protocol version support + enable_draft29: bool + enable_v1: bool + + # TLS settings + verify_mode: ssl.VerifyMode + alpn_protocols: list[str] + + # Performance settings + max_concurrent_streams: int + connection_window: int + stream_window: int + + # Logging and debugging + enable_qlog: bool + qlog_dir: str | None + + # Connection management + max_connections: int + connection_timeout: float + + # Protocol identifiers + PROTOCOL_QUIC_V1: TProtocol + PROTOCOL_QUIC_DRAFT29: TProtocol + + +@dataclass +class QUICTransportConfig(ConnectionConfig): + """Configuration for QUIC transport.""" + + # Connection settings + idle_timeout: float = 30.0 # Seconds before an idle connection is closed. + max_datagram_size: int = ( + 1200 # Maximum size of UDP datagrams to avoid IP fragmentation. + ) + local_port: int | None = ( + None # Local port to bind to. If None, a random port is chosen. + ) + + # Protocol version support + enable_draft29: bool = True # Enable QUIC draft-29 for compatibility + enable_v1: bool = True # Enable QUIC v1 (RFC 9000) + + # TLS settings + verify_mode: ssl.VerifyMode = ssl.CERT_NONE + alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"]) + + # Performance settings + max_concurrent_streams: int = 100 # Maximum concurrent streams per connection + connection_window: int = 1024 * 1024 # Connection flow control window + stream_window: int = 64 * 1024 # Stream flow control window + + # Logging and debugging + enable_qlog: bool = False # Enable QUIC logging + qlog_dir: str | None = None # Directory for QUIC logs + + # Connection management + max_connections: int = 1000 # Maximum number of connections + connection_timeout: float = 10.0 # Connection establishment timeout + + MAX_CONCURRENT_STREAMS: int = 1000 + """Maximum number of concurrent streams per connection.""" + + MAX_INCOMING_STREAMS: int = 1000 + """Maximum number of incoming streams per connection.""" + + CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0 + """Timeout for connection handshake (seconds).""" + + MAX_OUTGOING_STREAMS: int = 1000 + """Maximum number of outgoing streams per connection.""" + + CONNECTION_CLOSE_TIMEOUT: int = 10 + """Timeout for opening new connection (seconds).""" + + # Stream timeouts + STREAM_OPEN_TIMEOUT: float = 5.0 + """Timeout for opening new streams (seconds).""" + + STREAM_ACCEPT_TIMEOUT: float = 30.0 + """Timeout for accepting incoming streams (seconds).""" + + STREAM_READ_TIMEOUT: float = 30.0 + """Default timeout for stream read operations (seconds).""" + + STREAM_WRITE_TIMEOUT: float = 30.0 + """Default timeout for stream write operations (seconds).""" + + STREAM_CLOSE_TIMEOUT: float = 10.0 + """Timeout for graceful stream close (seconds).""" + + # Flow control configuration + STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB + """Per-stream flow control window size.""" + + CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB + """Connection-wide flow control window size.""" + + # Buffer management + MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB + """Maximum receive buffer size per stream.""" + + STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB + """Low watermark for stream receive buffer.""" + + STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB + """High watermark for stream receive buffer.""" + + # Stream lifecycle configuration + ENABLE_STREAM_RESET_ON_ERROR: bool = True + """Whether to automatically reset streams on errors.""" + + STREAM_RESET_ERROR_CODE: int = 1 + """Default error code for stream resets.""" + + ENABLE_STREAM_KEEP_ALIVE: bool = False + """Whether to enable stream keep-alive mechanisms.""" + + STREAM_KEEP_ALIVE_INTERVAL: float = 30.0 + """Interval for stream keep-alive pings (seconds).""" + + # Resource management + ENABLE_STREAM_RESOURCE_TRACKING: bool = True + """Whether to track stream resource usage.""" + + STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB + """Memory limit per individual stream.""" + + STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB + """Total memory limit for all streams per connection.""" + + # Concurrency and performance + ENABLE_STREAM_BATCHING: bool = True + """Whether to batch multiple stream operations.""" + + STREAM_BATCH_SIZE: int = 10 + """Number of streams to process in a batch.""" + + STREAM_PROCESSING_CONCURRENCY: int = 100 + """Maximum concurrent stream processing tasks.""" + + # Debugging and monitoring + ENABLE_STREAM_METRICS: bool = True + """Whether to collect stream metrics.""" + + ENABLE_STREAM_TIMELINE_TRACKING: bool = True + """Whether to track stream lifecycle timelines.""" + + STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0 + """Interval for collecting stream metrics (seconds).""" + + # Error handling configuration + STREAM_ERROR_RETRY_ATTEMPTS: int = 3 + """Number of retry attempts for recoverable stream errors.""" + + STREAM_ERROR_RETRY_DELAY: float = 1.0 + """Initial delay between stream error retries (seconds).""" + + STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0 + """Backoff factor for stream error retries.""" + + # Protocol identifiers matching go-libp2p + PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000 + PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29 + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + if not (self.enable_draft29 or self.enable_v1): + raise ValueError("At least one QUIC version must be enabled") + + if self.idle_timeout <= 0: + raise ValueError("Idle timeout must be positive") + + if self.max_datagram_size < 1200: + raise ValueError("Max datagram size must be at least 1200 bytes") + + # Validate timeouts + timeout_fields = [ + "STREAM_OPEN_TIMEOUT", + "STREAM_ACCEPT_TIMEOUT", + "STREAM_READ_TIMEOUT", + "STREAM_WRITE_TIMEOUT", + "STREAM_CLOSE_TIMEOUT", + ] + for timeout_field in timeout_fields: + if getattr(self, timeout_field) <= 0: + raise ValueError(f"{timeout_field} must be positive") + + # Validate flow control windows + if self.STREAM_FLOW_CONTROL_WINDOW <= 0: + raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive") + + if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW: + raise ValueError( + "CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW" + ) + + # Validate buffer sizes + if self.MAX_STREAM_RECEIVE_BUFFER <= 0: + raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive") + + if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER: + raise ValueError( + "STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__( + "exceed MAX_STREAM_RECEIVE_BUFFER" + ) + ) + + if ( + self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK + >= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK + ): + raise ValueError( + "STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK" + ) + + # Validate memory limits + if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0: + raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive") + + if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0: + raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive") + + expected_stream_memory = ( + self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM + ) + if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2: + # Allow some headroom, but warn if configuration seems inconsistent + import logging + + logger = logging.getLogger(__name__) + logger.warning( + "Stream memory configuration may be inconsistent: " + f"{self.MAX_CONCURRENT_STREAMS} streams Ɨ" + "{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes " + "could exceed connection limit of" + f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes" + ) + + def get_stream_config_dict(self) -> dict[str, Any]: + """Get stream-specific configuration as dictionary.""" + stream_config = {} + for attr_name in dir(self): + if attr_name.startswith( + ("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW") + ): + stream_config[attr_name.lower()] = getattr(self, attr_name) + return stream_config + + +# Additional configuration classes for specific stream features + + +class QUICStreamFlowControlConfig: + """Configuration for QUIC stream flow control.""" + + def __init__( + self, + initial_window_size: int = 512 * 1024, + max_window_size: int = 2 * 1024 * 1024, + window_update_threshold: float = 0.5, + enable_auto_tuning: bool = True, + ): + self.initial_window_size = initial_window_size + self.max_window_size = max_window_size + self.window_update_threshold = window_update_threshold + self.enable_auto_tuning = enable_auto_tuning + + +def create_stream_config_for_use_case( + use_case: Literal[ + "high_throughput", "low_latency", "many_streams", "memory_constrained" + ], +) -> QUICTransportConfig: + """ + Create optimized stream configuration for specific use cases. + + Args: + use_case: One of "high_throughput", "low_latency", "many_streams"," + "memory_constrained" + + Returns: + Optimized QUICTransportConfig + + """ + base_config = QUICTransportConfig() + + if use_case == "high_throughput": + # Optimize for high throughput + base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB + base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB + base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB + base_config.STREAM_PROCESSING_CONCURRENCY = 200 + + elif use_case == "low_latency": + # Optimize for low latency + base_config.STREAM_OPEN_TIMEOUT = 1.0 + base_config.STREAM_READ_TIMEOUT = 5.0 + base_config.STREAM_WRITE_TIMEOUT = 5.0 + base_config.ENABLE_STREAM_BATCHING = False + base_config.STREAM_BATCH_SIZE = 1 + + elif use_case == "many_streams": + # Optimize for many concurrent streams + base_config.MAX_CONCURRENT_STREAMS = 5000 + base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB + base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB + base_config.STREAM_PROCESSING_CONCURRENCY = 500 + + elif use_case == "memory_constrained": + # Optimize for low memory usage + base_config.MAX_CONCURRENT_STREAMS = 100 + base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB + base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB + base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB + base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB + base_config.STREAM_PROCESSING_CONCURRENCY = 50 + + else: + raise ValueError(f"Unknown use case: {use_case}") + + return base_config diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py new file mode 100644 index 00000000..fb4cff4a --- /dev/null +++ b/libp2p/transport/quic/connection.py @@ -0,0 +1,1489 @@ +""" +QUIC Connection implementation. +Manages bidirectional QUIC connections with integrated stream multiplexing. +""" + +from collections import defaultdict +from collections.abc import Awaitable, Callable +import logging +import socket +import time +from typing import TYPE_CHECKING, Any, Optional + +from aioquic.quic import events +from aioquic.quic.connection import QuicConnection +from aioquic.quic.events import QuicEvent +from cryptography import x509 +import multiaddr +import trio + +from libp2p.abc import IMuxedConn, IRawConnection +from libp2p.custom_types import TQUICStreamHandlerFn +from libp2p.peer.id import ID +from libp2p.stream_muxer.exceptions import MuxedConnUnavailable + +from .exceptions import ( + QUICConnectionClosedError, + QUICConnectionError, + QUICConnectionTimeoutError, + QUICErrorContext, + QUICPeerVerificationError, + QUICStreamError, + QUICStreamLimitError, + QUICStreamTimeoutError, +) +from .stream import QUICStream, StreamDirection + +if TYPE_CHECKING: + from .security import QUICTLSConfigManager + from .transport import QUICTransport + +logger = logging.getLogger(__name__) + + +class QUICConnection(IRawConnection, IMuxedConn): + """ + QUIC connection implementing both raw connection and muxed connection interfaces. + + Uses aioquic's sans-IO core with trio for native async support. + QUIC natively provides stream multiplexing, so this connection acts as both + a raw connection (for transport layer) and muxed connection (for upper layers). + + Features: + - Native QUIC stream multiplexing + - Integrated libp2p TLS security with peer identity verification + - Resource-aware stream management + - Comprehensive error handling + - Flow control integration + - Connection migration support + - Performance monitoring + - COMPLETE connection ID management (fixes the original issue) + """ + + def __init__( + self, + quic_connection: QuicConnection, + remote_addr: tuple[str, int], + remote_peer_id: ID | None, + local_peer_id: ID, + is_initiator: bool, + maddr: multiaddr.Multiaddr, + transport: "QUICTransport", + security_manager: Optional["QUICTLSConfigManager"] = None, + resource_scope: Any | None = None, + listener_socket: trio.socket.SocketType | None = None, + ): + """ + Initialize QUIC connection with security integration. + + Args: + quic_connection: aioquic QuicConnection instance + remote_addr: Remote peer address + remote_peer_id: Remote peer ID (may be None initially) + local_peer_id: Local peer ID + is_initiator: Whether this is the connection initiator + maddr: Multiaddr for this connection + transport: Parent QUIC transport + security_manager: Security manager for TLS/certificate handling + resource_scope: Resource manager scope for tracking + listener_socket: Socket of listener to transmit data + + """ + self._quic = quic_connection + self._remote_addr = remote_addr + self._remote_peer_id = remote_peer_id + self._local_peer_id = local_peer_id + self.peer_id = remote_peer_id or local_peer_id + self._is_initiator = is_initiator + self._maddr = maddr + self._transport = transport + self._security_manager = security_manager + self._resource_scope = resource_scope + + # Trio networking - socket may be provided by listener + self._socket = listener_socket if listener_socket else None + self._owns_socket = listener_socket is None + self._connected_event = trio.Event() + self._closed_event = trio.Event() + + self._streams: dict[int, QUICStream] = {} + self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups + self._next_stream_id: int = self._calculate_initial_stream_id() + self._stream_handler: TQUICStreamHandlerFn | None = None + + # Single lock for all stream operations + self._stream_lock = trio.Lock() + + # Stream counting and limits + self._outbound_stream_count = 0 + self._inbound_stream_count = 0 + + # Stream acceptance for incoming streams + self._stream_accept_queue: list[QUICStream] = [] + self._stream_accept_event = trio.Event() + + # Connection state + self._closed: bool = False + self._established = False + self._started = False + self._handshake_completed = False + self._peer_verified = False + + # Security state + self._peer_certificate: x509.Certificate | None = None + self._handshake_events: list[events.HandshakeCompleted] = [] + + # Background task management + self._background_tasks_started = False + self._nursery: trio.Nursery | None = None + self._event_processing_task: Any | None = None + self.on_close: Callable[[], Awaitable[None]] | None = None + self.event_started = trio.Event() + + self._available_connection_ids: set[bytes] = set() + self._current_connection_id: bytes | None = None + self._retired_connection_ids: set[bytes] = set() + self._connection_id_sequence_numbers: set[int] = set() + + # Event processing control with batching + self._event_processing_active = False + self._event_batch: list[events.QuicEvent] = [] + self._event_batch_size = 10 + self._last_event_time = 0.0 + + # Set quic connection configuration + self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT + self.MAX_INCOMING_STREAMS = transport._config.MAX_INCOMING_STREAMS + self.MAX_OUTGOING_STREAMS = transport._config.MAX_OUTGOING_STREAMS + self.CONNECTION_HANDSHAKE_TIMEOUT = ( + transport._config.CONNECTION_HANDSHAKE_TIMEOUT + ) + self.MAX_CONCURRENT_STREAMS = transport._config.MAX_CONCURRENT_STREAMS + + # Performance and monitoring + self._connection_start_time = time.time() + self._stats = { + "streams_opened": 0, + "streams_accepted": 0, + "streams_closed": 0, + "streams_reset": 0, + "bytes_sent": 0, + "bytes_received": 0, + "packets_sent": 0, + "packets_received": 0, + "connection_ids_issued": 0, + "connection_ids_retired": 0, + "connection_id_changes": 0, + } + + logger.debug( + f"Created QUIC connection to {remote_peer_id} " + f"(initiator: {is_initiator}, addr: {remote_addr}, " + "security: {security_manager is not None})" + ) + + def _calculate_initial_stream_id(self) -> int: + """ + Calculate the initial stream ID based on QUIC specification. + + QUIC stream IDs: + - Client-initiated bidirectional: 0, 4, 8, 12, ... + - Server-initiated bidirectional: 1, 5, 9, 13, ... + - Client-initiated unidirectional: 2, 6, 10, 14, ... + - Server-initiated unidirectional: 3, 7, 11, 15, ... + + For libp2p, we primarily use bidirectional streams. + """ + if self._is_initiator: + return 0 + else: + return 1 + + @property + def is_initiator(self) -> bool: # type: ignore + """Check if this connection is the initiator.""" + return self._is_initiator + + @property + def is_closed(self) -> bool: + """Check if connection is closed.""" + return self._closed + + @property + def is_established(self) -> bool: + """Check if connection is established (handshake completed).""" + return self._established and self._handshake_completed + + @property + def is_started(self) -> bool: + """Check if connection has been started.""" + return self._started + + @property + def is_peer_verified(self) -> bool: + """Check if peer identity has been verified.""" + return self._peer_verified + + def multiaddr(self) -> multiaddr.Multiaddr: + """Get the multiaddr for this connection.""" + return self._maddr + + def local_peer_id(self) -> ID: + """Get the local peer ID.""" + return self._local_peer_id + + def remote_peer_id(self) -> ID | None: + """Get the remote peer ID.""" + return self._remote_peer_id + + def get_connection_id_stats(self) -> dict[str, Any]: + """Get connection ID statistics and current state.""" + return { + "available_connection_ids": len(self._available_connection_ids), + "current_connection_id": self._current_connection_id.hex() + if self._current_connection_id + else None, + "retired_connection_ids": len(self._retired_connection_ids), + "connection_ids_issued": self._stats["connection_ids_issued"], + "connection_ids_retired": self._stats["connection_ids_retired"], + "connection_id_changes": self._stats["connection_id_changes"], + "available_cid_list": [cid.hex() for cid in self._available_connection_ids], + } + + def get_current_connection_id(self) -> bytes | None: + """Get the current connection ID.""" + return self._current_connection_id + + # Fast stream lookup with caching + def _get_stream_fast(self, stream_id: int) -> QUICStream | None: + """Get stream with caching for performance.""" + # Try cache first + stream = self._stream_cache.get(stream_id) + if stream is not None: + return stream + + # Fallback to main dict + stream = self._streams.get(stream_id) + if stream is not None: + self._stream_cache[stream_id] = stream + + return stream + + # Connection lifecycle methods + + async def start(self) -> None: + """ + Start the connection and its background tasks. + + This method implements the IMuxedConn.start() interface. + It should be called to begin processing connection events. + """ + if self._started: + logger.warning("Connection already started") + return + + if self._closed: + raise QUICConnectionError("Cannot start a closed connection") + + self._started = True + self.event_started.set() + logger.debug(f"Starting QUIC connection to {self._remote_peer_id}") + + try: + # If this is a client connection, we need to establish the connection + if self._is_initiator: + await self._initiate_connection() + else: + # For server connections, we're already connected via the listener + self._established = True + self._connected_event.set() + + logger.debug(f"QUIC connection to {self._remote_peer_id} started") + + except Exception as e: + logger.error(f"Failed to start connection: {e}") + raise QUICConnectionError(f"Connection start failed: {e}") from e + + async def _initiate_connection(self) -> None: + """Initiate client-side connection, reusing listener socket if available.""" + try: + with QUICErrorContext("connection_initiation", "connection"): + if not self._socket: + logger.debug("Creating new socket for outbound connection") + self._socket = trio.socket.socket( + family=socket.AF_INET, type=socket.SOCK_DGRAM + ) + + await self._socket.bind(("0.0.0.0", 0)) + + self._quic.connect(self._remote_addr, now=time.time()) + + # Send initial packet(s) + await self._transmit() + + logger.debug(f"Initiated QUIC connection to {self._remote_addr}") + + except Exception as e: + logger.error(f"Failed to initiate connection: {e}") + raise QUICConnectionError(f"Connection initiation failed: {e}") from e + + async def connect(self, nursery: trio.Nursery) -> None: + """ + Establish the QUIC connection using trio nursery for background tasks. + + Args: + nursery: Trio nursery for managing connection background tasks + + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + self._nursery = nursery + + try: + with QUICErrorContext("connection_establishment", "connection"): + # Start the connection if not already started + logger.debug("STARTING TO CONNECT") + if not self._started: + await self.start() + + # Start background event processing + if not self._background_tasks_started: + logger.debug("STARTING BACKGROUND TASK") + await self._start_background_tasks() + else: + logger.debug("BACKGROUND TASK ALREADY STARTED") + + # Wait for handshake completion with timeout + with trio.move_on_after( + self.CONNECTION_HANDSHAKE_TIMEOUT + ) as cancel_scope: + await self._connected_event.wait() + + if cancel_scope.cancelled_caught: + raise QUICConnectionTimeoutError( + "Connection handshake timed out after" + f"{self.CONNECTION_HANDSHAKE_TIMEOUT}s" + ) + + logger.debug( + "QUICConnection: Verifying peer identity with security manager" + ) + # Verify peer identity using security manager + peer_id = await self._verify_peer_identity_with_security() + + if peer_id: + self.peer_id = peer_id + + logger.debug(f"QUICConnection {id(self)}: Peer identity verified") + self._established = True + logger.debug(f"QUIC connection established with {self._remote_peer_id}") + + except Exception as e: + logger.error(f"Failed to establish connection: {e}") + await self.close() + raise + + async def _start_background_tasks(self) -> None: + """Start background tasks for connection management.""" + if self._background_tasks_started or not self._nursery: + return + + self._background_tasks_started = True + + if self._is_initiator: + self._nursery.start_soon(async_fn=self._client_packet_receiver) + + self._nursery.start_soon(async_fn=self._event_processing_loop) + self._nursery.start_soon(async_fn=self._periodic_maintenance) + + logger.debug("Started background tasks for QUIC connection") + + async def _event_processing_loop(self) -> None: + """Main event processing loop for the connection.""" + logger.debug( + f"Started QUIC event processing loop for connection id: {id(self)} " + f"and local peer id {str(self.local_peer_id())}" + ) + + try: + while not self._closed: + # Batch process events + await self._process_quic_events_batched() + + # Handle timer events + await self._handle_timer_events() + + # Transmit any pending data + await self._transmit() + + # Short sleep to prevent busy waiting + await trio.sleep(0.01) + + except Exception as e: + logger.error(f"Error in event processing loop: {e}") + await self._handle_connection_error(e) + finally: + logger.debug("QUIC event processing loop finished") + + async def _periodic_maintenance(self) -> None: + """Perform periodic connection maintenance.""" + try: + while not self._closed: + # Update connection statistics + self._update_stats() + + # Check for idle streams that can be cleaned up + await self._cleanup_idle_streams() + + if logger.isEnabledFor(logging.DEBUG): + cid_stats = self.get_connection_id_stats() + logger.debug(f"Connection ID stats: {cid_stats}") + + # Clean cache periodically + await self._cleanup_cache() + + # Sleep for maintenance interval + await trio.sleep(30.0) # 30 seconds + + except Exception as e: + logger.error(f"Error in periodic maintenance: {e}") + + async def _cleanup_cache(self) -> None: + """Clean up stream cache periodically to prevent memory leaks.""" + if len(self._stream_cache) > 100: # Arbitrary threshold + # Remove closed streams from cache + closed_stream_ids = [ + sid for sid, stream in self._stream_cache.items() if stream.is_closed() + ] + for sid in closed_stream_ids: + self._stream_cache.pop(sid, None) + + async def _client_packet_receiver(self) -> None: + """Receive packets for client connections.""" + logger.debug("Starting client packet receiver") + logger.debug("Started QUIC client packet receiver") + + try: + while not self._closed and self._socket: + try: + # Receive UDP packets + data, addr = await self._socket.recvfrom(65536) + logger.debug(f"Client received {len(data)} bytes from {addr}") + + # Feed packet to QUIC connection + self._quic.receive_datagram(data, addr, now=time.time()) + + # Batch process events + await self._process_quic_events_batched() + + # Send any response packets + await self._transmit() + + except trio.ClosedResourceError: + logger.debug("Client socket closed") + break + except Exception as e: + logger.error(f"Error receiving client packet: {e}") + await trio.sleep(0.01) + + except trio.Cancelled: + logger.debug("Client packet receiver cancelled") + raise + finally: + logger.debug("Client packet receiver terminated") + + # Security and identity methods + + async def _verify_peer_identity_with_security(self) -> ID | None: + """ + Verify peer identity using integrated security manager. + + Raises: + QUICPeerVerificationError: If peer verification fails + + """ + logger.debug("VERIFYING PEER IDENTITY") + if not self._security_manager: + logger.debug("No security manager available for peer verification") + return None + + try: + # Extract peer certificate from TLS handshake + await self._extract_peer_certificate() + + if not self._peer_certificate: + logger.debug("No peer certificate available for verification") + return None + + # Validate certificate format and accessibility + if not self._validate_peer_certificate(): + logger.debug("Validation Failed for peer cerificate") + raise QUICPeerVerificationError("Peer certificate validation failed") + + # Verify peer identity using security manager + verified_peer_id = self._security_manager.verify_peer_identity( + self._peer_certificate, + self._remote_peer_id, # Expected peer ID for outbound connections + ) + + # Update peer ID if it wasn't known (inbound connections) + if not self._remote_peer_id: + self._remote_peer_id = verified_peer_id + logger.debug(f"Discovered peer ID from certificate: {verified_peer_id}") + elif self._remote_peer_id != verified_peer_id: + raise QUICPeerVerificationError( + f"Peer ID mismatch: expected {self._remote_peer_id}, " + "got {verified_peer_id}" + ) + + self._peer_verified = True + logger.debug(f"Peer identity verified successfully: {verified_peer_id}") + + return verified_peer_id + + except QUICPeerVerificationError: + # Re-raise verification errors as-is + raise + except Exception as e: + # Wrap other errors in verification error + raise QUICPeerVerificationError(f"Peer verification failed: {e}") from e + + async def _extract_peer_certificate(self) -> None: + """Extract peer certificate from completed TLS handshake.""" + try: + # Get peer certificate from aioquic TLS context + if self._quic.tls: + tls_context = self._quic.tls + + if tls_context._peer_certificate: + # aioquic stores the peer certificate as cryptography + # x509.Certificate + self._peer_certificate = tls_context._peer_certificate + logger.debug( + f"Extracted peer certificate: {self._peer_certificate.subject}" + ) + else: + logger.debug("No peer certificate found in TLS context") + + else: + logger.debug("No TLS context available for certificate extraction") + + except Exception as e: + logger.warning(f"Failed to extract peer certificate: {e}") + + # Try alternative approach - check if certificate is in handshake events + try: + # Some versions of aioquic might expose certificate differently + config = self._quic.configuration + if hasattr(config, "certificate") and config.certificate: + # This would be the local certificate, not peer certificate + # but we can use it for debugging + logger.debug("Found local certificate in configuration") + + except Exception as inner_e: + logger.error( + f"Alternative certificate extraction also failed: {inner_e}" + ) + + async def get_peer_certificate(self) -> x509.Certificate | None: + """ + Get the peer's TLS certificate. + + Returns: + The peer's X.509 certificate, or None if not available + + """ + # If we don't have a certificate yet, try to extract it + if not self._peer_certificate and self._handshake_completed: + await self._extract_peer_certificate() + + return self._peer_certificate + + def _validate_peer_certificate(self) -> bool: + """ + Validate that the peer certificate is properly formatted and accessible. + + Returns: + True if certificate is valid and accessible, False otherwise + + """ + if not self._peer_certificate: + return False + + try: + # Basic validation - try to access certificate properties + subject = self._peer_certificate.subject + serial_number = self._peer_certificate.serial_number + + logger.debug( + f"Certificate validation - Subject: {subject}, Serial: {serial_number}" + ) + return True + + except Exception as e: + logger.error(f"Certificate validation failed: {e}") + return False + + def get_security_manager(self) -> Optional["QUICTLSConfigManager"]: + """Get the security manager for this connection.""" + return self._security_manager + + def get_security_info(self) -> dict[str, Any]: + """Get security-related information about the connection.""" + info: dict[str, bool | Any | None] = { + "peer_verified": self._peer_verified, + "handshake_complete": self._handshake_completed, + "peer_id": str(self._remote_peer_id) if self._remote_peer_id else None, + "local_peer_id": str(self._local_peer_id), + "is_initiator": self._is_initiator, + "has_certificate": self._peer_certificate is not None, + "security_manager_available": self._security_manager is not None, + } + + # Add certificate details if available + if self._peer_certificate: + try: + info.update( + { + "certificate_subject": str(self._peer_certificate.subject), + "certificate_issuer": str(self._peer_certificate.issuer), + "certificate_serial": str(self._peer_certificate.serial_number), + "certificate_not_before": ( + self._peer_certificate.not_valid_before.isoformat() + ), + "certificate_not_after": ( + self._peer_certificate.not_valid_after.isoformat() + ), + } + ) + except Exception as e: + info["certificate_error"] = str(e) + + # Add TLS context debug info + try: + if hasattr(self._quic, "tls") and self._quic.tls: + tls_info = { + "tls_context_available": True, + "tls_state": getattr(self._quic.tls, "state", None), + } + + # Check for peer certificate in TLS context + if hasattr(self._quic.tls, "_peer_certificate"): + tls_info["tls_peer_certificate_available"] = ( + self._quic.tls._peer_certificate is not None + ) + + info["tls_debug"] = tls_info + else: + info["tls_debug"] = {"tls_context_available": False} + + except Exception as e: + info["tls_debug"] = {"error": str(e)} + + return info + + # Stream management methods (IMuxedConn interface) + + async def open_stream(self, timeout: float = 5.0) -> QUICStream: + """ + Open a new outbound stream + + Args: + timeout: Timeout for stream creation + + Returns: + New QUIC stream + + Raises: + QUICStreamLimitError: Too many concurrent streams + QUICConnectionClosedError: Connection is closed + QUICStreamTimeoutError: Stream creation timed out + + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + if not self._started: + raise QUICConnectionError("Connection not started") + + # Use single lock for all stream operations + with trio.move_on_after(timeout): + async with self._stream_lock: + # Check stream limits inside lock + if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS: + raise QUICStreamLimitError( + "Maximum outbound streams " + f"({self.MAX_OUTGOING_STREAMS}) reached" + ) + + # Generate next stream ID + stream_id = self._next_stream_id + self._next_stream_id += 4 # Increment by 4 for bidirectional streams + + stream = QUICStream( + connection=self, + stream_id=stream_id, + direction=StreamDirection.OUTBOUND, + resource_scope=self._resource_scope, + remote_addr=self._remote_addr, + ) + + self._streams[stream_id] = stream + self._stream_cache[stream_id] = stream # Add to cache + + self._outbound_stream_count += 1 + self._stats["streams_opened"] += 1 + + logger.debug(f"Opened outbound QUIC stream {stream_id}") + return stream + + raise QUICStreamTimeoutError(f"Stream creation timed out after {timeout}s") + + async def accept_stream(self, timeout: float | None = None) -> QUICStream: + """ + Accept incoming stream. + + Args: + timeout: Optional timeout. If None, waits indefinitely. + + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + if timeout is not None: + with trio.move_on_after(timeout): + return await self._accept_stream_impl() + # Timeout occurred + if self._closed_event.is_set() or self._closed: + raise MuxedConnUnavailable("QUIC connection closed during timeout") + else: + raise QUICStreamTimeoutError( + f"Stream accept timed out after {timeout}s" + ) + else: + # No timeout - wait indefinitely + return await self._accept_stream_impl() + + async def _accept_stream_impl(self) -> QUICStream: + while True: + if self._closed: + raise MuxedConnUnavailable("QUIC connection is closed") + + # Use single lock for stream acceptance + async with self._stream_lock: + if self._stream_accept_queue: + stream = self._stream_accept_queue.pop(0) + logger.debug(f"Accepted inbound stream {stream.stream_id}") + return stream + + if self._closed: + raise MuxedConnUnavailable("Connection closed while accepting stream") + + # Wait for new streams indefinitely + await self._stream_accept_event.wait() + + raise QUICConnectionError("Error occurred while waiting to accept stream") + + def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None: + """ + Set handler for incoming streams. + + Args: + handler_function: Function to handle new incoming streams + + """ + self._stream_handler = handler_function + logger.debug("Set stream handler for incoming streams") + + def _remove_stream(self, stream_id: int) -> None: + """ + Remove stream from connection registry. + Called by stream cleanup process. + """ + if stream_id in self._streams: + stream = self._streams.pop(stream_id) + # Remove from cache too + self._stream_cache.pop(stream_id, None) + + # Update stream counts asynchronously + async def update_counts() -> None: + async with self._stream_lock: + if stream.direction == StreamDirection.OUTBOUND: + self._outbound_stream_count = max( + 0, self._outbound_stream_count - 1 + ) + else: + self._inbound_stream_count = max( + 0, self._inbound_stream_count - 1 + ) + self._stats["streams_closed"] += 1 + + # Schedule count update if we're in a trio context + if self._nursery: + self._nursery.start_soon(update_counts) + + logger.debug(f"Removed stream {stream_id} from connection") + + # Batched event processing to reduce overhead + async def _process_quic_events_batched(self) -> None: + """Process QUIC events in batches for better performance.""" + if self._event_processing_active: + return # Prevent recursion + + self._event_processing_active = True + + try: + current_time = time.time() + events_processed = 0 + + # Collect events into batch + while events_processed < self._event_batch_size: + event = self._quic.next_event() + if event is None: + break + + self._event_batch.append(event) + events_processed += 1 + + # Process batch if we have events or timeout + if self._event_batch and ( + len(self._event_batch) >= self._event_batch_size + or current_time - self._last_event_time > 0.01 # 10ms timeout + ): + await self._process_event_batch() + self._event_batch.clear() + self._last_event_time = current_time + + finally: + self._event_processing_active = False + + async def _process_event_batch(self) -> None: + """Process a batch of events efficiently.""" + if not self._event_batch: + return + + # Group events by type for batch processing where possible + events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list) + for event in self._event_batch: + events_by_type[type(event).__name__].append(event) + + # Process events by type + for event_type, event_list in events_by_type.items(): + if event_type == type(events.StreamDataReceived).__name__: + # Filter to only StreamDataReceived events + stream_data_events = [ + e for e in event_list if isinstance(e, events.StreamDataReceived) + ] + await self._handle_stream_data_batch(stream_data_events) + else: + # Process other events individually + for event in event_list: + await self._handle_quic_event(event) + + logger.debug(f"Processed batch of {len(self._event_batch)} events") + + async def _handle_stream_data_batch( + self, events_list: list[events.StreamDataReceived] + ) -> None: + """Handle stream data events in batch for better performance.""" + # Group by stream ID + events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list) + for event in events_list: + events_by_stream[event.stream_id].append(event) + + # Process each stream's events + for stream_id, stream_events in events_by_stream.items(): + stream = self._get_stream_fast(stream_id) # Use fast lookup + + if not stream: + if self._is_incoming_stream(stream_id): + try: + stream = await self._create_inbound_stream(stream_id) + except QUICStreamLimitError: + # Reset stream if we can't handle it + self._quic.reset_stream(stream_id, error_code=0x04) + await self._transmit() + continue + else: + logger.error( + f"Unexpected outbound stream {stream_id} in data event" + ) + continue + + # Process all events for this stream + for received_event in stream_events: + if hasattr(received_event, "data"): + self._stats["bytes_received"] += len(received_event.data) # type: ignore + + if hasattr(received_event, "end_stream"): + await stream.handle_data_received( + received_event.data, # type: ignore + received_event.end_stream, # type: ignore + ) + + async def _create_inbound_stream(self, stream_id: int) -> QUICStream: + """Create inbound stream with proper limit checking.""" + async with self._stream_lock: + # Double-check stream doesn't exist + existing_stream = self._streams.get(stream_id) + if existing_stream: + return existing_stream + + # Check limits + if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS: + logger.warning(f"Rejecting inbound stream {stream_id}: limit reached") + raise QUICStreamLimitError("Too many inbound streams") + + # Create stream + stream = QUICStream( + connection=self, + stream_id=stream_id, + direction=StreamDirection.INBOUND, + resource_scope=self._resource_scope, + remote_addr=self._remote_addr, + ) + + self._streams[stream_id] = stream + self._stream_cache[stream_id] = stream # Add to cache + self._inbound_stream_count += 1 + self._stats["streams_accepted"] += 1 + + # Add to accept queue + self._stream_accept_queue.append(stream) + self._stream_accept_event.set() + + logger.debug(f"Created inbound stream {stream_id}") + return stream + + async def _process_quic_events(self) -> None: + """Process all pending QUIC events.""" + # Delegate to batched processing for better performance + await self._process_quic_events_batched() + + async def _handle_quic_event(self, event: events.QuicEvent) -> None: + """Handle a single QUIC event with COMPLETE event type coverage.""" + logger.debug(f"Handling QUIC event: {type(event).__name__}") + logger.debug(f"QUIC event: {type(event).__name__}") + + try: + if isinstance(event, events.ConnectionTerminated): + await self._handle_connection_terminated(event) + elif isinstance(event, events.HandshakeCompleted): + await self._handle_handshake_completed(event) + elif isinstance(event, events.StreamDataReceived): + await self._handle_stream_data(event) + elif isinstance(event, events.StreamReset): + await self._handle_stream_reset(event) + elif isinstance(event, events.DatagramFrameReceived): + await self._handle_datagram_received(event) + # *** NEW: Connection ID event handlers - CRITICAL FIX *** + elif isinstance(event, events.ConnectionIdIssued): + await self._handle_connection_id_issued(event) + elif isinstance(event, events.ConnectionIdRetired): + await self._handle_connection_id_retired(event) + # *** NEW: Additional event handlers for completeness *** + elif isinstance(event, events.PingAcknowledged): + await self._handle_ping_acknowledged(event) + elif isinstance(event, events.ProtocolNegotiated): + await self._handle_protocol_negotiated(event) + elif isinstance(event, events.StopSendingReceived): + await self._handle_stop_sending_received(event) + else: + logger.debug(f"Unhandled QUIC event type: {type(event).__name__}") + logger.debug(f"Unhandled QUIC event: {type(event).__name__}") + + except Exception as e: + logger.error(f"Error handling QUIC event {type(event).__name__}: {e}") + + async def _handle_connection_id_issued( + self, event: events.ConnectionIdIssued + ) -> None: + """ + Handle new connection ID issued by peer. + + This is the CRITICAL missing functionality that was causing your issue! + """ + logger.debug(f"šŸ†” NEW CONNECTION ID ISSUED: {event.connection_id.hex()}") + logger.debug(f"šŸ†” NEW CONNECTION ID ISSUED: {event.connection_id.hex()}") + + # Add to available connection IDs + self._available_connection_ids.add(event.connection_id) + + # If we don't have a current connection ID, use this one + if self._current_connection_id is None: + self._current_connection_id = event.connection_id + logger.debug( + f"šŸ†” Set current connection ID to: {event.connection_id.hex()}" + ) + logger.debug( + f"šŸ†” Set current connection ID to: {event.connection_id.hex()}" + ) + + # Update statistics + self._stats["connection_ids_issued"] += 1 + + logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}") + logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}") + + async def _handle_connection_id_retired( + self, event: events.ConnectionIdRetired + ) -> None: + """ + Handle connection ID retirement. + + This handles when the peer tells us to stop using a connection ID. + """ + logger.debug(f"šŸ—‘ļø CONNECTION ID RETIRED: {event.connection_id.hex()}") + + # Remove from available IDs and add to retired set + self._available_connection_ids.discard(event.connection_id) + self._retired_connection_ids.add(event.connection_id) + + # If this was our current connection ID, switch to another + if self._current_connection_id == event.connection_id: + if self._available_connection_ids: + self._current_connection_id = next(iter(self._available_connection_ids)) + if self._current_connection_id: + logger.debug( + "Switching to new connection ID: " + f"{self._current_connection_id.hex()}" + ) + self._stats["connection_id_changes"] += 1 + else: + logger.warning("āš ļø No available connection IDs after retirement!") + else: + self._current_connection_id = None + logger.warning("āš ļø No available connection IDs after retirement!") + + # Update statistics + self._stats["connection_ids_retired"] += 1 + + async def _handle_ping_acknowledged(self, event: events.PingAcknowledged) -> None: + """Handle ping acknowledgment.""" + logger.debug(f"Ping acknowledged: uid={event.uid}") + + async def _handle_protocol_negotiated( + self, event: events.ProtocolNegotiated + ) -> None: + """Handle protocol negotiation completion.""" + logger.debug(f"Protocol negotiated: {event.alpn_protocol}") + + async def _handle_stop_sending_received( + self, event: events.StopSendingReceived + ) -> None: + """Handle stop sending request from peer.""" + logger.debug( + "Stop sending received: " + f"stream_id={event.stream_id}, error_code={event.error_code}" + ) + + # Use fast lookup + stream = self._get_stream_fast(event.stream_id) + if stream: + # Handle stop sending on the stream if method exists + await stream.handle_stop_sending(event.error_code) + + async def _handle_handshake_completed( + self, event: events.HandshakeCompleted + ) -> None: + """Handle handshake completion with security integration.""" + logger.debug("QUIC handshake completed") + self._handshake_completed = True + + # Store handshake event for security verification + self._handshake_events.append(event) + + # Try to extract certificate information after handshake + await self._extract_peer_certificate() + + logger.debug("āœ… Setting connected event") + self._connected_event.set() + + async def _handle_connection_terminated( + self, event: events.ConnectionTerminated + ) -> None: + """Handle connection termination.""" + logger.debug(f"QUIC connection terminated: {event.reason_phrase}") + + # Close all streams + for stream in list(self._streams.values()): + if event.error_code: + await stream.handle_reset(event.error_code) + else: + await stream.close() + + self._streams.clear() + self._stream_cache.clear() # Clear cache too + self._closed = True + self._closed_event.set() + + self._stream_accept_event.set() + logger.debug(f"Woke up pending accept_stream() calls, {id(self)}") + + await self._notify_parent_of_termination() + + async def _handle_stream_data(self, event: events.StreamDataReceived) -> None: + """Handle stream data events - create streams and add to accept queue.""" + stream_id = event.stream_id + self._stats["bytes_received"] += len(event.data) + + try: + # Use fast lookup + stream = self._get_stream_fast(stream_id) + + if not stream: + if self._is_incoming_stream(stream_id): + logger.debug(f"Creating new incoming stream {stream_id}") + stream = await self._create_inbound_stream(stream_id) + else: + logger.error( + f"Unexpected outbound stream {stream_id} in data event" + ) + return + + await stream.handle_data_received(event.data, event.end_stream) + + except Exception as e: + logger.error(f"Error handling stream data for stream {stream_id}: {e}") + logger.debug(f"āŒ STREAM_DATA: Error: {e}") + + async def _get_or_create_stream(self, stream_id: int) -> QUICStream: + """Get existing stream or create new inbound stream.""" + # Use fast lookup + stream = self._get_stream_fast(stream_id) + if stream: + return stream + + # Check if this is an incoming stream + is_incoming = self._is_incoming_stream(stream_id) + + if not is_incoming: + # This shouldn't happen - outbound streams should be created by open_stream + raise QUICStreamError( + f"Received data for unknown outbound stream {stream_id}" + ) + + # Create new inbound stream + return await self._create_inbound_stream(stream_id) + + def _is_incoming_stream(self, stream_id: int) -> bool: + """ + Determine if a stream ID represents an incoming stream. + + For bidirectional streams: + - Even IDs are client-initiated + - Odd IDs are server-initiated + """ + if self._is_initiator: + # We're the client, so odd stream IDs are incoming + return stream_id % 2 == 1 + else: + # We're the server, so even stream IDs are incoming + return stream_id % 2 == 0 + + async def _handle_stream_reset(self, event: events.StreamReset) -> None: + """Stream reset handling.""" + stream_id = event.stream_id + self._stats["streams_reset"] += 1 + + # Use fast lookup + stream = self._get_stream_fast(stream_id) + if stream: + try: + await stream.handle_reset(event.error_code) + logger.debug( + f"Handled reset for stream {stream_id}" + f"with error code {event.error_code}" + ) + except Exception as e: + logger.error(f"Error handling stream reset for {stream_id}: {e}") + # Force remove the stream + self._remove_stream(stream_id) + else: + logger.debug(f"Received reset for unknown stream {stream_id}") + + async def _handle_datagram_received( + self, event: events.DatagramFrameReceived + ) -> None: + """Handle datagram frame (if using QUIC datagrams).""" + logger.debug(f"Datagram frame received: size={len(event.data)}") + # For now, just log. Could be extended for custom datagram handling + + async def _handle_timer_events(self) -> None: + """Handle QUIC timer events.""" + timer = self._quic.get_timer() + if timer is not None: + now = time.time() + if timer <= now: + self._quic.handle_timer(now=now) + + # Network transmission + + async def _transmit(self) -> None: + """Transmit pending QUIC packets using available socket.""" + sock = self._socket + if not sock: + logger.debug("No socket to transmit") + return + + try: + current_time = time.time() + datagrams = self._quic.datagrams_to_send(now=current_time) + + # Batch stats updates + packet_count = 0 + total_bytes = 0 + + for data, addr in datagrams: + await sock.sendto(data, addr) + packet_count += 1 + total_bytes += len(data) + + # Update stats in batch + if packet_count > 0: + self._stats["packets_sent"] += packet_count + self._stats["bytes_sent"] += total_bytes + + except Exception as e: + logger.error(f"Transmission error: {e}") + await self._handle_connection_error(e) + + # Additional methods for stream data processing + async def _process_quic_event(self, event: events.QuicEvent) -> None: + """Process a single QUIC event.""" + await self._handle_quic_event(event) + + async def _transmit_pending_data(self) -> None: + """Transmit any pending data.""" + await self._transmit() + + # Error handling + + async def _handle_connection_error(self, error: Exception) -> None: + """Handle connection-level errors.""" + logger.error(f"Connection error: {error}") + + if not self._closed: + try: + await self.close() + except Exception as close_error: + logger.error(f"Error during connection close: {close_error}") + + # Connection close + + async def close(self) -> None: + """Connection close with proper stream cleanup.""" + if self._closed: + return + + self._closed = True + logger.debug(f"Closing QUIC connection to {self._remote_peer_id}") + + try: + # Close all streams gracefully + stream_close_tasks = [] + for stream in list(self._streams.values()): + if stream.can_write() or stream.can_read(): + stream_close_tasks.append(stream.close) + + if stream_close_tasks and self._nursery: + try: + # Close streams concurrently with timeout + with trio.move_on_after(self.CONNECTION_CLOSE_TIMEOUT): + async with trio.open_nursery() as close_nursery: + for task in stream_close_tasks: + close_nursery.start_soon(task) + except Exception as e: + logger.warning(f"Error during graceful stream close: {e}") + # Force reset remaining streams + for stream in self._streams.values(): + try: + await stream.reset(error_code=0) + except Exception: + pass + + if self.on_close: + await self.on_close() + + # Close QUIC connection + self._quic.close() + + if self._socket: + await self._transmit() # Send close frames + + # Close socket + if self._socket and self._owns_socket: + self._socket.close() + self._socket = None + + self._streams.clear() + self._stream_cache.clear() # Clear cache + self._closed_event.set() + + logger.debug(f"QUIC connection to {self._remote_peer_id} closed") + + except Exception as e: + logger.error(f"Error during connection close: {e}") + + async def _notify_parent_of_termination(self) -> None: + """ + Notify the parent listener/transport to remove this connection from tracking. + + This ensures that terminated connections are cleaned up from the + 'established connections' list. + """ + try: + if self._transport: + await self._transport._cleanup_terminated_connection(self) + logger.debug("Notified transport of connection termination") + return + + for listener in self._transport._listeners: + try: + await listener._remove_connection_by_object(self) + logger.debug( + "Found and notified listener of connection termination" + ) + return + except Exception: + continue + + # Method 4: Use connection ID if we have one (most reliable) + if self._current_connection_id: + await self._cleanup_by_connection_id(self._current_connection_id) + return + + logger.warning( + "Could not notify parent of connection termination - no" + f" parent reference found for conn host {self._quic.host_cid.hex()}" + ) + + except Exception as e: + logger.error(f"Error notifying parent of connection termination: {e}") + + async def _cleanup_by_connection_id(self, connection_id: bytes) -> None: + """Cleanup using connection ID as a fallback method.""" + try: + for listener in self._transport._listeners: + for tracked_cid, tracked_conn in list(listener._connections.items()): + if tracked_conn is self: + await listener._remove_connection(tracked_cid) + logger.debug(f"Removed connection {tracked_cid.hex()}") + return + + logger.debug("Fallback cleanup by connection ID completed") + except Exception as e: + logger.error(f"Error in fallback cleanup: {e}") + + # IRawConnection interface (for compatibility) + + def get_remote_address(self) -> tuple[str, int]: + return self._remote_addr + + async def write(self, data: bytes) -> None: + """ + Write data to the connection. + For QUIC, this creates a new stream for each write operation. + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + stream = await self.open_stream() + try: + await stream.write(data) + await stream.close_write() + except Exception: + await stream.reset() + raise + + async def read(self, n: int | None = -1) -> bytes: + """ + Read data from the stream. + + Args: + n: Maximum number of bytes to read. -1 means read all available. + + Returns: + Data bytes read from the stream. + + Raises: + QUICStreamClosedError: If stream is closed for reading. + QUICStreamResetError: If stream was reset. + QUICStreamTimeoutError: If read timeout occurs. + + """ + # It's here for interface compatibility but should not be used + raise NotImplementedError( + "Use streams for reading data from QUIC connections. " + "Call accept_stream() or open_stream() instead." + ) + + # Utility and monitoring methods + + def get_stream_stats(self) -> dict[str, Any]: + """Get stream statistics for monitoring.""" + return { + "total_streams": len(self._streams), + "outbound_streams": self._outbound_stream_count, + "inbound_streams": self._inbound_stream_count, + "max_streams": self.MAX_CONCURRENT_STREAMS, + "stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS, + "stats": self._stats.copy(), + "cache_size": len( + self._stream_cache + ), # Include cache metrics for monitoring + } + + def get_active_streams(self) -> list[QUICStream]: + """Get list of active streams.""" + return [stream for stream in self._streams.values() if not stream.is_closed()] + + def get_streams_by_protocol(self, protocol: str) -> list[QUICStream]: + """Get streams filtered by protocol.""" + return [ + stream + for stream in self._streams.values() + if hasattr(stream, "protocol") + and stream.protocol == protocol + and not stream.is_closed() + ] + + def _update_stats(self) -> None: + """Update connection statistics.""" + # Add any periodic stats updates here + pass + + async def _cleanup_idle_streams(self) -> None: + """Clean up idle streams that are no longer needed.""" + current_time = time.time() + streams_to_cleanup = [] + + for stream in self._streams.values(): + if stream.is_closed(): + # Check if stream has been closed for a while + if hasattr(stream, "_timeline") and stream._timeline.closed_at: + if current_time - stream._timeline.closed_at > 60: # 1 minute + streams_to_cleanup.append(stream.stream_id) + + for stream_id in streams_to_cleanup: + self._remove_stream(int(stream_id)) + + # String representation + + def __repr__(self) -> str: + current_cid: str | None = ( + self._current_connection_id.hex() if self._current_connection_id else None + ) + return ( + f"QUICConnection(peer={self._remote_peer_id}, " + f"addr={self._remote_addr}, " + f"initiator={self._is_initiator}, " + f"verified={self._peer_verified}, " + f"established={self._established}, " + f"streams={len(self._streams)}, " + f"current_cid={current_cid})" + ) + + def __str__(self) -> str: + return f"QUICConnection({self._remote_peer_id})" diff --git a/libp2p/transport/quic/exceptions.py b/libp2p/transport/quic/exceptions.py new file mode 100644 index 00000000..2df3dda5 --- /dev/null +++ b/libp2p/transport/quic/exceptions.py @@ -0,0 +1,391 @@ +""" +QUIC Transport exceptions +""" + +from typing import Any, Literal + + +class QUICError(Exception): + """Base exception for all QUIC transport errors.""" + + def __init__(self, message: str, error_code: int | None = None): + super().__init__(message) + self.error_code = error_code + + +# Transport-level exceptions + + +class QUICTransportError(QUICError): + """Base exception for QUIC transport operations.""" + + pass + + +class QUICDialError(QUICTransportError): + """Error occurred during QUIC connection establishment.""" + + pass + + +class QUICListenError(QUICTransportError): + """Error occurred during QUIC listener operations.""" + + pass + + +class QUICSecurityError(QUICTransportError): + """Error related to QUIC security/TLS operations.""" + + pass + + +# Connection-level exceptions + + +class QUICConnectionError(QUICError): + """Base exception for QUIC connection operations.""" + + pass + + +class QUICConnectionClosedError(QUICConnectionError): + """QUIC connection has been closed.""" + + pass + + +class QUICConnectionTimeoutError(QUICConnectionError): + """QUIC connection operation timed out.""" + + pass + + +class QUICHandshakeError(QUICConnectionError): + """Error during QUIC handshake process.""" + + pass + + +class QUICPeerVerificationError(QUICConnectionError): + """Error verifying peer identity during handshake.""" + + pass + + +# Stream-level exceptions + + +class QUICStreamError(QUICError): + """Base exception for QUIC stream operations.""" + + def __init__( + self, + message: str, + stream_id: str | None = None, + error_code: int | None = None, + ): + super().__init__(message, error_code) + self.stream_id = stream_id + + +class QUICStreamClosedError(QUICStreamError): + """Stream is closed and cannot be used for I/O operations.""" + + pass + + +class QUICStreamResetError(QUICStreamError): + """Stream was reset by local or remote peer.""" + + def __init__( + self, + message: str, + stream_id: str | None = None, + error_code: int | None = None, + reset_by_peer: bool = False, + ): + super().__init__(message, stream_id, error_code) + self.reset_by_peer = reset_by_peer + + +class QUICStreamTimeoutError(QUICStreamError): + """Stream operation timed out.""" + + pass + + +class QUICStreamBackpressureError(QUICStreamError): + """Stream write blocked due to flow control.""" + + pass + + +class QUICStreamLimitError(QUICStreamError): + """Stream limit reached (too many concurrent streams).""" + + pass + + +class QUICStreamStateError(QUICStreamError): + """Invalid operation for current stream state.""" + + def __init__( + self, + message: str, + stream_id: str | None = None, + current_state: str | None = None, + attempted_operation: str | None = None, + ): + super().__init__(message, stream_id) + self.current_state = current_state + self.attempted_operation = attempted_operation + + +# Flow control exceptions + + +class QUICFlowControlError(QUICError): + """Base exception for flow control related errors.""" + + pass + + +class QUICFlowControlViolationError(QUICFlowControlError): + """Flow control limits were violated.""" + + pass + + +class QUICFlowControlDeadlockError(QUICFlowControlError): + """Flow control deadlock detected.""" + + pass + + +# Resource management exceptions + + +class QUICResourceError(QUICError): + """Base exception for resource management errors.""" + + pass + + +class QUICMemoryLimitError(QUICResourceError): + """Memory limit exceeded.""" + + pass + + +class QUICConnectionLimitError(QUICResourceError): + """Connection limit exceeded.""" + + pass + + +# Multiaddr and addressing exceptions + + +class QUICAddressError(QUICError): + """Base exception for QUIC addressing errors.""" + + pass + + +class QUICInvalidMultiaddrError(QUICAddressError): + """Invalid multiaddr format for QUIC transport.""" + + pass + + +class QUICAddressResolutionError(QUICAddressError): + """Failed to resolve QUIC address.""" + + pass + + +class QUICProtocolError(QUICError): + """Base exception for QUIC protocol errors.""" + + pass + + +class QUICVersionNegotiationError(QUICProtocolError): + """QUIC version negotiation failed.""" + + pass + + +class QUICUnsupportedVersionError(QUICProtocolError): + """Unsupported QUIC version.""" + + pass + + +# Configuration exceptions + + +class QUICConfigurationError(QUICError): + """Base exception for QUIC configuration errors.""" + + pass + + +class QUICInvalidConfigError(QUICConfigurationError): + """Invalid QUIC configuration parameters.""" + + pass + + +class QUICCertificateError(QUICConfigurationError): + """Error with TLS certificate configuration.""" + + pass + + +def map_quic_error_code(error_code: int) -> str: + """ + Map QUIC error codes to human-readable descriptions. + Based on RFC 9000 Transport Error Codes. + """ + error_codes = { + 0x00: "NO_ERROR", + 0x01: "INTERNAL_ERROR", + 0x02: "CONNECTION_REFUSED", + 0x03: "FLOW_CONTROL_ERROR", + 0x04: "STREAM_LIMIT_ERROR", + 0x05: "STREAM_STATE_ERROR", + 0x06: "FINAL_SIZE_ERROR", + 0x07: "FRAME_ENCODING_ERROR", + 0x08: "TRANSPORT_PARAMETER_ERROR", + 0x09: "CONNECTION_ID_LIMIT_ERROR", + 0x0A: "PROTOCOL_VIOLATION", + 0x0B: "INVALID_TOKEN", + 0x0C: "APPLICATION_ERROR", + 0x0D: "CRYPTO_BUFFER_EXCEEDED", + 0x0E: "KEY_UPDATE_ERROR", + 0x0F: "AEAD_LIMIT_REACHED", + 0x10: "NO_VIABLE_PATH", + } + + return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}") + + +def create_stream_error( + error_type: str, + message: str, + stream_id: str | None = None, + error_code: int | None = None, +) -> QUICStreamError: + """ + Factory function to create appropriate stream error based on type. + + Args: + error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.) + message: Error message + stream_id: Stream identifier + error_code: QUIC error code + + Returns: + Appropriate QUICStreamError subclass + + """ + error_type = error_type.lower() + + if error_type in ("closed", "close"): + return QUICStreamClosedError(message, stream_id, error_code) + elif error_type == "reset": + return QUICStreamResetError(message, stream_id, error_code) + elif error_type == "timeout": + return QUICStreamTimeoutError(message, stream_id, error_code) + elif error_type in ("backpressure", "flow_control"): + return QUICStreamBackpressureError(message, stream_id, error_code) + elif error_type in ("limit", "stream_limit"): + return QUICStreamLimitError(message, stream_id, error_code) + elif error_type == "state": + return QUICStreamStateError(message, stream_id) + else: + return QUICStreamError(message, stream_id, error_code) + + +def create_connection_error( + error_type: str, message: str, error_code: int | None = None +) -> QUICConnectionError: + """ + Factory function to create appropriate connection error based on type. + + Args: + error_type: Type of error ("closed", "timeout", "handshake", etc.) + message: Error message + error_code: QUIC error code + + Returns: + Appropriate QUICConnectionError subclass + + """ + error_type = error_type.lower() + + if error_type in ("closed", "close"): + return QUICConnectionClosedError(message, error_code) + elif error_type == "timeout": + return QUICConnectionTimeoutError(message, error_code) + elif error_type == "handshake": + return QUICHandshakeError(message, error_code) + elif error_type in ("peer_verification", "verification"): + return QUICPeerVerificationError(message, error_code) + else: + return QUICConnectionError(message, error_code) + + +class QUICErrorContext: + """ + Context manager for handling QUIC errors with automatic error mapping. + Useful for converting low-level aioquic errors to py-libp2p QUIC errors. + """ + + def __init__(self, operation: str, component: str = "quic") -> None: + self.operation = operation + self.component = component + + def __enter__(self) -> "QUICErrorContext": + return self + + # TODO: Fix types for exc_type + def __exit__( + self, + exc_type: type[BaseException] | None | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> Literal[False]: + if exc_type is None: + return False + + if exc_val is None: + return False + + # Map common aioquic exceptions to our exceptions + if "ConnectionClosed" in str(exc_type): + raise QUICConnectionClosedError( + f"Connection closed during {self.operation}: {exc_val}" + ) from exc_val + elif "StreamReset" in str(exc_type): + raise QUICStreamResetError( + f"Stream reset during {self.operation}: {exc_val}" + ) from exc_val + elif "timeout" in str(exc_val).lower(): + if "stream" in self.component.lower(): + raise QUICStreamTimeoutError( + f"Timeout during {self.operation}: {exc_val}" + ) from exc_val + else: + raise QUICConnectionTimeoutError( + f"Timeout during {self.operation}: {exc_val}" + ) from exc_val + elif "flow control" in str(exc_val).lower(): + raise QUICStreamBackpressureError( + f"Flow control error during {self.operation}: {exc_val}" + ) from exc_val + + # Let other exceptions propagate + return False diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py new file mode 100644 index 00000000..42c8c662 --- /dev/null +++ b/libp2p/transport/quic/listener.py @@ -0,0 +1,1041 @@ +""" +QUIC Listener +""" + +import logging +import socket +import struct +import sys +import time +from typing import TYPE_CHECKING + +from aioquic.quic import events +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import QuicConnection +from aioquic.quic.packet import QuicPacketType +from multiaddr import Multiaddr +import trio + +from libp2p.abc import IListener +from libp2p.custom_types import ( + TProtocol, + TQUICConnHandlerFn, +) +from libp2p.transport.quic.security import ( + LIBP2P_TLS_EXTENSION_OID, + QUICTLSConfigManager, +) + +from .config import QUICTransportConfig +from .connection import QUICConnection +from .exceptions import QUICListenError +from .utils import ( + create_quic_multiaddr, + create_server_config_from_base, + custom_quic_version_to_wire_format, + is_quic_multiaddr, + multiaddr_to_quic_version, + quic_multiaddr_to_endpoint, +) + +if TYPE_CHECKING: + from .transport import QUICTransport + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class QUICPacketInfo: + """Information extracted from a QUIC packet header.""" + + def __init__( + self, + version: int, + destination_cid: bytes, + source_cid: bytes, + packet_type: QuicPacketType, + token: bytes | None = None, + ): + self.version = version + self.destination_cid = destination_cid + self.source_cid = source_cid + self.packet_type = packet_type + self.token = token + + +class QUICListener(IListener): + """ + QUIC Listener with connection ID handling and protocol negotiation. + """ + + def __init__( + self, + transport: "QUICTransport", + handler_function: TQUICConnHandlerFn, + quic_configs: dict[TProtocol, QuicConfiguration], + config: QUICTransportConfig, + security_manager: QUICTLSConfigManager | None = None, + ): + """Initialize enhanced QUIC listener.""" + self._transport = transport + self._handler = handler_function + self._quic_configs = quic_configs + self._config = config + self._security_manager = security_manager + + # Network components + self._socket: trio.socket.SocketType | None = None + self._bound_addresses: list[Multiaddr] = [] + + # Enhanced connection management with connection ID routing + self._connections: dict[ + bytes, QUICConnection + ] = {} # destination_cid -> connection + self._pending_connections: dict[ + bytes, QuicConnection + ] = {} # destination_cid -> quic_conn + self._addr_to_cid: dict[ + tuple[str, int], bytes + ] = {} # (host, port) -> destination_cid + self._cid_to_addr: dict[ + bytes, tuple[str, int] + ] = {} # destination_cid -> (host, port) + self._connection_lock = trio.Lock() + + # Version negotiation support + self._supported_versions = self._get_supported_versions() + + # Listener state + self._closed = False + self._listening = False + self._nursery: trio.Nursery | None = None + + # Performance tracking + self._stats = { + "connections_accepted": 0, + "connections_rejected": 0, + "version_negotiations": 0, + "bytes_received": 0, + "packets_processed": 0, + "invalid_packets": 0, + } + + def _get_supported_versions(self) -> set[int]: + """Get wire format versions for all supported QUIC configurations.""" + versions: set[int] = set() + for protocol in self._quic_configs: + try: + config = self._quic_configs[protocol] + wire_versions = config.supported_versions + for version in wire_versions: + versions.add(version) + except Exception as e: + logger.warning(f"Failed to get wire version for {protocol}: {e}") + return versions + + def parse_quic_packet(self, data: bytes) -> QUICPacketInfo | None: + """ + Parse QUIC packet header to extract connection IDs and version. + Based on RFC 9000 packet format. + """ + try: + if len(data) < 1: + return None + + # Read first byte to get packet type and flags + first_byte = data[0] + + # Check if this is a long header packet (version negotiation, initial, etc.) + is_long_header = (first_byte & 0x80) != 0 + + if not is_long_header: + cid_length = 8 # We are using standard CID length everywhere + + if len(data) < 1 + cid_length: + return None + + dest_cid = data[1 : 1 + cid_length] + + return QUICPacketInfo( + version=1, # Assume QUIC v1 for established connections + destination_cid=dest_cid, + source_cid=b"", # Not available in short header + packet_type=QuicPacketType.ONE_RTT, + token=b"", + ) + + # Long header packet parsing + offset = 1 + + # Extract version (4 bytes) + if len(data) < offset + 4: + return None + version = struct.unpack("!I", data[offset : offset + 4])[0] + offset += 4 + + # Extract destination connection ID length and value + if len(data) < offset + 1: + return None + dest_cid_len = data[offset] + offset += 1 + + if len(data) < offset + dest_cid_len: + return None + dest_cid = data[offset : offset + dest_cid_len] + offset += dest_cid_len + + # Extract source connection ID length and value + if len(data) < offset + 1: + return None + src_cid_len = data[offset] + offset += 1 + + if len(data) < offset + src_cid_len: + return None + src_cid = data[offset : offset + src_cid_len] + offset += src_cid_len + + # Determine packet type from first byte + packet_type_value = (first_byte & 0x30) >> 4 + + packet_value_to_type_mapping = { + 0: QuicPacketType.INITIAL, + 1: QuicPacketType.ZERO_RTT, + 2: QuicPacketType.HANDSHAKE, + 3: QuicPacketType.RETRY, + 4: QuicPacketType.VERSION_NEGOTIATION, + 5: QuicPacketType.ONE_RTT, + } + + # For Initial packets, extract token + token = b"" + if packet_type_value == 0: # Initial packet + if len(data) < offset + 1: + return None + # Token length is variable-length integer + token_len, token_len_bytes = self._decode_varint(data[offset:]) + offset += token_len_bytes + + if len(data) < offset + token_len: + return None + token = data[offset : offset + token_len] + + return QUICPacketInfo( + version=version, + destination_cid=dest_cid, + source_cid=src_cid, + packet_type=packet_value_to_type_mapping.get(packet_type_value) + or QuicPacketType.INITIAL, + token=token, + ) + + except Exception as e: + logger.debug(f"Failed to parse QUIC packet: {e}") + return None + + def _decode_varint(self, data: bytes) -> tuple[int, int]: + """Decode QUIC variable-length integer.""" + if len(data) < 1: + return 0, 0 + + first_byte = data[0] + length_bits = (first_byte & 0xC0) >> 6 + + if length_bits == 0: + return first_byte & 0x3F, 1 + elif length_bits == 1: + if len(data) < 2: + return 0, 0 + return ((first_byte & 0x3F) << 8) | data[1], 2 + elif length_bits == 2: + if len(data) < 4: + return 0, 0 + return ((first_byte & 0x3F) << 24) | (data[1] << 16) | ( + data[2] << 8 + ) | data[3], 4 + else: # length_bits == 3 + if len(data) < 8: + return 0, 0 + value = (first_byte & 0x3F) << 56 + for i in range(1, 8): + value |= data[i] << (8 * (7 - i)) + return value, 8 + + async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: + """Process incoming QUIC packet with optimized routing.""" + try: + self._stats["packets_processed"] += 1 + self._stats["bytes_received"] += len(data) + + packet_info = self.parse_quic_packet(data) + if packet_info is None: + self._stats["invalid_packets"] += 1 + return + + dest_cid = packet_info.destination_cid + + # Single lock acquisition with all lookups + async with self._connection_lock: + connection_obj = self._connections.get(dest_cid) + pending_quic_conn = self._pending_connections.get(dest_cid) + + if not connection_obj and not pending_quic_conn: + if packet_info.packet_type == QuicPacketType.INITIAL: + pending_quic_conn = await self._handle_new_connection( + data, addr, packet_info + ) + else: + return + + # Process outside the lock + if connection_obj: + await self._handle_established_connection_packet( + connection_obj, data, addr, dest_cid + ) + elif pending_quic_conn: + await self._handle_pending_connection_packet( + pending_quic_conn, data, addr, dest_cid + ) + + except Exception as e: + logger.error(f"Error processing packet from {addr}: {e}") + + async def _handle_established_connection_packet( + self, + connection_obj: QUICConnection, + data: bytes, + addr: tuple[str, int], + dest_cid: bytes, + ) -> None: + """Handle packet for established connection WITHOUT holding connection lock.""" + try: + await self._route_to_connection(connection_obj, data, addr) + + except Exception as e: + logger.error(f"Error handling established connection packet: {e}") + + async def _handle_pending_connection_packet( + self, + quic_conn: QuicConnection, + data: bytes, + addr: tuple[str, int], + dest_cid: bytes, + ) -> None: + """Handle packet for pending connection WITHOUT holding connection lock.""" + try: + logger.debug(f"Handling packet for pending connection {dest_cid.hex()}") + logger.debug(f"Packet size: {len(data)} bytes from {addr}") + + # Feed data to QUIC connection + quic_conn.receive_datagram(data, addr, now=time.time()) + logger.debug("PENDING: Datagram received by QUIC connection") + + # Process events - this is crucial for handshake progression + logger.debug("Processing QUIC events...") + await self._process_quic_events(quic_conn, addr, dest_cid) + + # Send any outgoing packets + logger.debug("Transmitting response...") + await self._transmit_for_connection(quic_conn, addr) + + # Check if handshake completed (with minimal locking) + if quic_conn._handshake_complete: + logger.debug("PENDING: Handshake completed, promoting connection") + await self._promote_pending_connection(quic_conn, addr, dest_cid) + else: + logger.debug("Handshake still in progress") + + except Exception as e: + logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}") + + async def _send_version_negotiation( + self, addr: tuple[str, int], source_cid: bytes + ) -> None: + """Send version negotiation packet to client.""" + try: + self._stats["version_negotiations"] += 1 + + # Construct version negotiation packet + packet = bytearray() + + # First byte: long header (1) + unused bits (0111) + packet.append(0x80 | 0x70) + + # Version: 0 for version negotiation + packet.extend(struct.pack("!I", 0)) + + # Destination connection ID (echo source CID from client) + packet.append(len(source_cid)) + packet.extend(source_cid) + + # Source connection ID (empty for version negotiation) + packet.append(0) + + # Supported versions + for version in sorted(self._supported_versions): + packet.extend(struct.pack("!I", version)) + + # Send the packet + if self._socket: + await self._socket.sendto(bytes(packet), addr) + logger.debug( + f"Sent version negotiation to {addr} " + f"with versions {sorted(self._supported_versions)}" + ) + + except Exception as e: + logger.error(f"Failed to send version negotiation to {addr}: {e}") + + async def _handle_new_connection( + self, data: bytes, addr: tuple[str, int], packet_info: QUICPacketInfo + ) -> QuicConnection | None: + """Handle new connection with proper connection ID handling.""" + try: + logger.debug(f"Starting handshake for {addr}") + + # Find appropriate QUIC configuration + quic_config = None + + for protocol, config in self._quic_configs.items(): + wire_versions = custom_quic_version_to_wire_format(protocol) + if wire_versions == packet_info.version: + quic_config = config + break + + if not quic_config: + logger.error( + f"No configuration found for version 0x{packet_info.version:08x}" + ) + await self._send_version_negotiation(addr, packet_info.source_cid) + return None + + if not quic_config: + raise QUICListenError("Cannot determine QUIC configuration") + + # Create server-side QUIC configuration + server_config = create_server_config_from_base( + base_config=quic_config, + security_manager=self._security_manager, + transport_config=self._config, + ) + + # Validate certificate has libp2p extension + if server_config.certificate: + cert = server_config.certificate + has_libp2p_ext = False + for ext in cert.extensions: + if ext.oid == LIBP2P_TLS_EXTENSION_OID: + has_libp2p_ext = True + break + logger.debug(f"Certificate has libp2p extension: {has_libp2p_ext}") + + if not has_libp2p_ext: + logger.error("Certificate missing libp2p extension!") + + logger.debug( + f"Original destination CID: {packet_info.destination_cid.hex()}" + ) + + quic_conn = QuicConnection( + configuration=server_config, + original_destination_connection_id=packet_info.destination_cid, + ) + + quic_conn._replenish_connection_ids() + # Use the first host CID as our routing CID + if quic_conn._host_cids: + destination_cid = quic_conn._host_cids[0].cid + logger.debug(f"Using host CID as routing CID: {destination_cid.hex()}") + else: + # Fallback to random if no host CIDs generated + import secrets + + destination_cid = secrets.token_bytes(8) + logger.debug(f"Fallback to random CID: {destination_cid.hex()}") + + logger.debug(f"Generated {len(quic_conn._host_cids)} host CIDs for client") + + logger.debug( + f"QUIC connection created for destination CID {destination_cid.hex()}" + ) + + # Store connection mapping using our generated CID + self._pending_connections[destination_cid] = quic_conn + self._addr_to_cid[addr] = destination_cid + self._cid_to_addr[destination_cid] = addr + + # Process initial packet + quic_conn.receive_datagram(data, addr, now=time.time()) + if quic_conn.tls: + if self._security_manager: + try: + quic_conn.tls._request_client_certificate = True + logger.debug( + "request_client_certificate set to True in server TLS" + ) + except Exception as e: + logger.error(f"FAILED to apply request_client_certificate: {e}") + + # Process events and send response + await self._process_quic_events(quic_conn, addr, destination_cid) + await self._transmit_for_connection(quic_conn, addr) + + logger.debug( + f"Started handshake for new connection from {addr} " + f"(version: 0x{packet_info.version:08x}, cid: {destination_cid.hex()})" + ) + + return quic_conn + + except Exception as e: + logger.error(f"Error handling new connection from {addr}: {e}") + self._stats["connections_rejected"] += 1 + return None + + async def _handle_short_header_packet( + self, data: bytes, addr: tuple[str, int] + ) -> None: + """Handle short header packets for established connections.""" + try: + logger.debug(f" SHORT_HDR: Handling short header packet from {addr}") + + # First, try address-based lookup + dest_cid = self._addr_to_cid.get(addr) + if dest_cid and dest_cid in self._connections: + connection = self._connections[dest_cid] + await self._route_to_connection(connection, data, addr) + return + + # Fallback: try to extract CID from packet + if len(data) >= 9: # 1 byte header + 8 byte CID + potential_cid = data[1:9] + + if potential_cid in self._connections: + connection = self._connections[potential_cid] + + # Update mappings for future packets + self._addr_to_cid[addr] = potential_cid + self._cid_to_addr[potential_cid] = addr + + await self._route_to_connection(connection, data, addr) + return + + logger.debug(f"āŒ SHORT_HDR: No matching connection found for {addr}") + + except Exception as e: + logger.error(f"Error handling short header packet from {addr}: {e}") + + async def _route_to_connection( + self, connection: QUICConnection, data: bytes, addr: tuple[str, int] + ) -> None: + """Route packet to existing connection.""" + try: + # Feed data to the connection's QUIC instance + connection._quic.receive_datagram(data, addr, now=time.time()) + + # Process events and handle responses + await connection._process_quic_events() + await connection._transmit() + + except Exception as e: + logger.error(f"Error routing packet to connection {addr}: {e}") + # Remove problematic connection + await self._remove_connection_by_addr(addr) + + async def _handle_pending_connection( + self, + quic_conn: QuicConnection, + data: bytes, + addr: tuple[str, int], + dest_cid: bytes, + ) -> None: + """Handle packet for a pending (handshaking) connection.""" + try: + logger.debug(f"Handling packet for pending connection {dest_cid.hex()}") + + # Feed data to QUIC connection + quic_conn.receive_datagram(data, addr, now=time.time()) + + if quic_conn.tls: + logger.debug(f"TLS state after: {quic_conn.tls.state}") + + # Process events - this is crucial for handshake progression + await self._process_quic_events(quic_conn, addr, dest_cid) + + # Send any outgoing packets - this is where the response should be sent + await self._transmit_for_connection(quic_conn, addr) + + # Check if handshake completed + if quic_conn._handshake_complete: + logger.debug("PENDING: Handshake completed, promoting connection") + await self._promote_pending_connection(quic_conn, addr, dest_cid) + + except Exception as e: + logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}") + + # Remove problematic pending connection + logger.error(f"Removing problematic connection {dest_cid.hex()}") + await self._remove_pending_connection(dest_cid) + + async def _process_quic_events( + self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes + ) -> None: + """Process QUIC events with enhanced debugging.""" + try: + events_processed = 0 + while True: + event = quic_conn.next_event() + if event is None: + break + + events_processed += 1 + logger.debug( + "QUIC EVENT: Processing event " + f"{events_processed}: {type(event).__name__}" + ) + + if isinstance(event, events.ConnectionTerminated): + logger.debug( + "QUIC EVENT: Connection terminated " + f"- code: {event.error_code}, reason: {event.reason_phrase}" + f"Connection {dest_cid.hex()} from {addr} " + f"terminated: {event.reason_phrase}" + ) + await self._remove_connection(dest_cid) + break + + elif isinstance(event, events.HandshakeCompleted): + logger.debug( + "QUIC EVENT: Handshake completed for connection " + f"{dest_cid.hex()}" + ) + logger.debug(f"Handshake completed for connection {dest_cid.hex()}") + await self._promote_pending_connection(quic_conn, addr, dest_cid) + + elif isinstance(event, events.StreamDataReceived): + logger.debug( + f"QUIC EVENT: Stream data received on stream {event.stream_id}" + ) + if dest_cid in self._connections: + connection = self._connections[dest_cid] + await connection._handle_stream_data(event) + + elif isinstance(event, events.StreamReset): + logger.debug( + f"QUIC EVENT: Stream reset on stream {event.stream_id}" + ) + if dest_cid in self._connections: + connection = self._connections[dest_cid] + await connection._handle_stream_reset(event) + + elif isinstance(event, events.ConnectionIdIssued): + logger.debug( + f"QUIC EVENT: Connection ID issued: {event.connection_id.hex()}" + ) + # Add new CID to the same address mapping + taddr = self._cid_to_addr.get(dest_cid) + if taddr: + # Don't overwrite, but this CID is also valid for this address + logger.debug( + f"QUIC EVENT: New CID {event.connection_id.hex()} " + f"available for {taddr}" + ) + + elif isinstance(event, events.ConnectionIdRetired): + logger.info(f"Connection ID retired: {event.connection_id.hex()}") + retired_cid = event.connection_id + if retired_cid in self._cid_to_addr: + addr = self._cid_to_addr[retired_cid] + del self._cid_to_addr[retired_cid] + # Only remove addr mapping if this was the active CID + if self._addr_to_cid.get(addr) == retired_cid: + del self._addr_to_cid[addr] + else: + logger.warning(f"Unhandled event type: {type(event).__name__}") + + except Exception as e: + logger.debug(f"āŒ EVENT: Error processing events: {e}") + + async def _promote_pending_connection( + self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes + ) -> None: + """Promote pending connection - avoid duplicate creation.""" + try: + self._pending_connections.pop(dest_cid, None) + + if dest_cid in self._connections: + logger.debug( + f"āš ļø Connection {dest_cid.hex()} already exists in _connections!" + ) + connection = self._connections[dest_cid] + else: + from .connection import QUICConnection + + host, port = addr + quic_version = "quic" + remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}") + + connection = QUICConnection( + quic_connection=quic_conn, + remote_addr=addr, + remote_peer_id=None, + local_peer_id=self._transport._peer_id, + is_initiator=False, + maddr=remote_maddr, + transport=self._transport, + security_manager=self._security_manager, + listener_socket=self._socket, + ) + + logger.debug(f"šŸ”„ Created NEW QUICConnection for {dest_cid.hex()}") + + self._connections[dest_cid] = connection + + self._addr_to_cid[addr] = dest_cid + self._cid_to_addr[dest_cid] = addr + + if self._nursery: + connection._nursery = self._nursery + await connection.connect(self._nursery) + logger.debug(f"Connection connected succesfully for {dest_cid.hex()}") + + if self._security_manager: + try: + peer_id = await connection._verify_peer_identity_with_security() + if peer_id: + connection.peer_id = peer_id + logger.info( + f"Security verification successful for {dest_cid.hex()}" + ) + except Exception as e: + logger.error( + f"Security verification failed for {dest_cid.hex()}: {e}" + ) + await connection.close() + return + + if self._nursery: + connection._nursery = self._nursery + await connection._start_background_tasks() + logger.debug( + f"Started background tasks for connection {dest_cid.hex()}" + ) + + try: + logger.debug(f"Invoking user callback {dest_cid.hex()}") + await self._handler(connection) + + except Exception as e: + logger.error(f"Error in user callback: {e}") + + self._stats["connections_accepted"] += 1 + logger.info(f"Enhanced connection {dest_cid.hex()} established from {addr}") + + except Exception as e: + logger.error(f"āŒ Error promoting connection {dest_cid.hex()}: {e}") + await self._remove_connection(dest_cid) + + async def _remove_connection(self, dest_cid: bytes) -> None: + """Remove connection by connection ID.""" + try: + # Remove connection + connection = self._connections.pop(dest_cid, None) + if connection: + await connection.close() + + # Clean up mappings + addr = self._cid_to_addr.pop(dest_cid, None) + if addr: + self._addr_to_cid.pop(addr, None) + + logger.debug(f"Removed connection {dest_cid.hex()}") + + except Exception as e: + logger.error(f"Error removing connection {dest_cid.hex()}: {e}") + + async def _remove_pending_connection(self, dest_cid: bytes) -> None: + """Remove pending connection by connection ID.""" + try: + self._pending_connections.pop(dest_cid, None) + addr = self._cid_to_addr.pop(dest_cid, None) + if addr: + self._addr_to_cid.pop(addr, None) + logger.debug(f"Removed pending connection {dest_cid.hex()}") + except Exception as e: + logger.error(f"Error removing pending connection {dest_cid.hex()}: {e}") + + async def _remove_connection_by_addr(self, addr: tuple[str, int]) -> None: + """Remove connection by address (fallback method).""" + dest_cid = self._addr_to_cid.get(addr) + if dest_cid: + await self._remove_connection(dest_cid) + + async def _transmit_for_connection( + self, quic_conn: QuicConnection, addr: tuple[str, int] + ) -> None: + """Enhanced transmission diagnostics to analyze datagram content.""" + try: + logger.debug(f" TRANSMIT: Starting transmission to {addr}") + + # Get current timestamp for timing + import time + + now = time.time() + + datagrams = quic_conn.datagrams_to_send(now=now) + logger.debug(f" TRANSMIT: Got {len(datagrams)} datagrams to send") + + if not datagrams: + logger.debug("āš ļø TRANSMIT: No datagrams to send") + return + + for i, (datagram, dest_addr) in enumerate(datagrams): + logger.debug(f" TRANSMIT: Analyzing datagram {i}") + logger.debug(f" TRANSMIT: Datagram size: {len(datagram)} bytes") + logger.debug(f" TRANSMIT: Destination: {dest_addr}") + logger.debug(f" TRANSMIT: Expected destination: {addr}") + + # Analyze datagram content + if len(datagram) > 0: + # QUIC packet format analysis + first_byte = datagram[0] + header_form = (first_byte & 0x80) >> 7 # Bit 7 + + # For long header packets (handshake), analyze further + if header_form == 1: # Long header + # CRYPTO frame type is 0x06 + crypto_frame_found = False + for offset in range(len(datagram)): + if datagram[offset] == 0x06: + crypto_frame_found = True + break + + if not crypto_frame_found: + logger.error("No CRYPTO frame found in datagram!") + # Look for other frame types + frame_types_found = set() + for offset in range(len(datagram)): + frame_type = datagram[offset] + if frame_type in [0x00, 0x01]: # PADDING/PING + frame_types_found.add("PADDING/PING") + elif frame_type == 0x02: # ACK + frame_types_found.add("ACK") + elif frame_type == 0x06: # CRYPTO + frame_types_found.add("CRYPTO") + + if self._socket: + try: + await self._socket.sendto(datagram, addr) + except Exception as send_error: + logger.error(f"Socket send failed: {send_error}") + else: + logger.error("No socket available!") + except Exception as e: + logger.debug(f"Transmission error: {e}") + + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + """Start listening on the given multiaddr with enhanced connection handling.""" + if self._listening: + raise QUICListenError("Already listening") + + if not is_quic_multiaddr(maddr): + raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}") + + if self._transport._background_nursery: + active_nursery = self._transport._background_nursery + logger.debug("Using transport background nursery for listener") + elif nursery: + active_nursery = nursery + self._transport._background_nursery = nursery + logger.debug("Using provided nursery for listener") + else: + raise QUICListenError("No nursery available") + + try: + host, port = quic_multiaddr_to_endpoint(maddr) + + # Create and configure socket + self._socket = await self._create_socket(host, port) + self._nursery = active_nursery + + # Get the actual bound address + bound_host, bound_port = self._socket.getsockname() + quic_version = multiaddr_to_quic_version(maddr) + bound_maddr = create_quic_multiaddr(bound_host, bound_port, quic_version) + self._bound_addresses = [bound_maddr] + + self._listening = True + + # Start packet handling loop + active_nursery.start_soon(self._handle_incoming_packets) + + logger.info( + f"QUIC listener started on {bound_maddr} with connection ID support" + ) + return True + + except Exception as e: + await self.close() + raise QUICListenError(f"Failed to start listening: {e}") from e + + async def _create_socket(self, host: str, port: int) -> trio.socket.SocketType: + """Create and configure UDP socket.""" + try: + # Determine address family + try: + import ipaddress + + ip = ipaddress.ip_address(host) + family = socket.AF_INET if ip.version == 4 else socket.AF_INET6 + except ValueError: + family = socket.AF_INET + + # Create UDP socket + sock = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + + # Set socket options + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, "SO_REUSEPORT"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # type: ignore[attr-defined] + + # Bind to address + await sock.bind((host, port)) + + logger.debug(f"Created and bound UDP socket to {host}:{port}") + return sock + + except Exception as e: + raise QUICListenError(f"Failed to create socket: {e}") from e + + async def _handle_incoming_packets(self) -> None: + """Handle incoming UDP packets with enhanced routing.""" + logger.debug("Started enhanced packet handling loop") + + try: + while self._listening and self._socket: + try: + # Receive UDP packet + data, addr = await self._socket.recvfrom(65536) + + # Process packet asynchronously + if self._nursery: + self._nursery.start_soon(self._process_packet, data, addr) + + except trio.ClosedResourceError: + logger.debug("Socket closed, exiting packet handler") + break + except Exception as e: + logger.error(f"Error receiving packet: {e}") + await trio.sleep(0.01) + except trio.Cancelled: + logger.info("Packet handling cancelled") + raise + finally: + logger.debug("Enhanced packet handling loop terminated") + + async def close(self) -> None: + """Close the listener and clean up resources.""" + if self._closed: + return + + self._closed = True + self._listening = False + + try: + # Close all connections + async with self._connection_lock: + for dest_cid in list(self._connections.keys()): + await self._remove_connection(dest_cid) + + for dest_cid in list(self._pending_connections.keys()): + await self._remove_pending_connection(dest_cid) + + # Close socket + if self._socket: + self._socket.close() + self._socket = None + + self._bound_addresses.clear() + + logger.info("QUIC listener closed") + + except Exception as e: + logger.error(f"Error closing listener: {e}") + + async def _remove_connection_by_object( + self, connection_obj: QUICConnection + ) -> None: + """Remove a connection by object reference.""" + try: + # Find the connection ID for this object + connection_cid = None + for cid, tracked_connection in self._connections.items(): + if tracked_connection is connection_obj: + connection_cid = cid + break + + if connection_cid: + await self._remove_connection(connection_cid) + logger.debug(f"Removed connection {connection_cid.hex()}") + else: + logger.warning("Connection object not found in tracking") + + except Exception as e: + logger.error(f"Error removing connection by object: {e}") + + def get_addresses(self) -> list[Multiaddr]: + """Get the bound addresses.""" + return self._bound_addresses.copy() + + async def _handle_new_established_connection( + self, connection: QUICConnection + ) -> None: + """Handle newly established connection by adding to swarm.""" + try: + logger.debug( + f"New QUIC connection established from {connection._remote_addr}" + ) + + if self._transport._swarm: + logger.debug("Adding QUIC connection directly to swarm") + await self._transport._swarm.add_conn(connection) + logger.debug("Successfully added QUIC connection to swarm") + else: + logger.error("No swarm available for QUIC connection") + await connection.close() + + except Exception as e: + logger.error(f"Error adding QUIC connection to swarm: {e}") + await connection.close() + + def get_addrs(self) -> tuple[Multiaddr]: + return tuple(self.get_addresses()) + + def is_listening(self) -> bool: + """ + Check if the listener is currently listening for connections. + + Returns: + bool: True if the listener is actively listening, False otherwise + + """ + return self._listening and not self._closed + + def get_stats(self) -> dict[str, int | bool]: + """ + Get listener statistics including the listening state. + + Returns: + dict: Statistics dictionary with current state information + + """ + stats = self._stats.copy() + stats["is_listening"] = self.is_listening() + stats["active_connections"] = len(self._connections) + stats["pending_connections"] = len(self._pending_connections) + return stats diff --git a/libp2p/transport/quic/security.py b/libp2p/transport/quic/security.py new file mode 100644 index 00000000..43ebfa37 --- /dev/null +++ b/libp2p/transport/quic/security.py @@ -0,0 +1,1165 @@ +""" +QUIC Security helpers implementation +""" + +from dataclasses import dataclass, field +import logging +import ssl +from typing import Any + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec, rsa +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509.base import Certificate +from cryptography.x509.extensions import Extension, UnrecognizedExtension +from cryptography.x509.oid import NameOID + +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.crypto.serialization import deserialize_public_key +from libp2p.peer.id import ID + +from .exceptions import ( + QUICCertificateError, + QUICPeerVerificationError, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# libp2p TLS Extension OID - Official libp2p specification +LIBP2P_TLS_EXTENSION_OID = x509.ObjectIdentifier("1.3.6.1.4.1.53594.1.1") + +# Certificate validity period +CERTIFICATE_VALIDITY_DAYS = 365 +CERTIFICATE_NOT_BEFORE_BUFFER = 3600 # 1 hour before now + + +@dataclass +@dataclass +class TLSConfig: + """TLS configuration for QUIC transport with libp2p extensions.""" + + certificate: x509.Certificate + private_key: ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey + peer_id: ID + + def get_certificate_der(self) -> bytes: + """Get certificate in DER format for external use.""" + return self.certificate.public_bytes(serialization.Encoding.DER) + + def get_private_key_der(self) -> bytes: + """Get private key in DER format for external use.""" + return self.private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + def get_certificate_pem(self) -> bytes: + """Get certificate in PEM format.""" + return self.certificate.public_bytes(serialization.Encoding.PEM) + + def get_private_key_pem(self) -> bytes: + """Get private key in PEM format.""" + return self.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +class LibP2PExtensionHandler: + """ + Handles libp2p-specific TLS extensions for peer identity verification. + + Based on libp2p TLS specification: + https://github.com/libp2p/specs/blob/master/tls/tls.md + """ + + @staticmethod + def create_signed_key_extension( + libp2p_private_key: PrivateKey, + cert_public_key: bytes, + ) -> bytes: + """ + Create the libp2p Public Key Extension with signed key proof. + + The extension contains: + 1. The libp2p public key + 2. A signature proving ownership of the private key + + Args: + libp2p_private_key: The libp2p identity private key + cert_public_key: The certificate's public key bytes + + Returns: + Encoded extension value + + """ + try: + # Get the libp2p public key + libp2p_public_key = libp2p_private_key.get_public_key() + + # Create the signature payload: "libp2p-tls-handshake:" + cert_public_key + signature_payload = b"libp2p-tls-handshake:" + cert_public_key + + # Sign the payload with the libp2p private key + signature = libp2p_private_key.sign(signature_payload) + + # Get the public key bytes + public_key_bytes = libp2p_public_key.serialize() + + # Create ASN.1 DER encoded structure (go-libp2p compatible) + return LibP2PExtensionHandler._create_asn1_der_extension( + public_key_bytes, signature + ) + + except Exception as e: + raise QUICCertificateError( + f"Failed to create signed key extension: {e}" + ) from e + + @staticmethod + def _create_asn1_der_extension(public_key_bytes: bytes, signature: bytes) -> bytes: + """ + Create ASN.1 DER encoded extension (go-libp2p compatible). + + Structure: + SEQUENCE { + publicKey OCTET STRING, + signature OCTET STRING + } + """ + # Encode public key as OCTET STRING + pubkey_octets = LibP2PExtensionHandler._encode_der_octet_string( + public_key_bytes + ) + + # Encode signature as OCTET STRING + sig_octets = LibP2PExtensionHandler._encode_der_octet_string(signature) + + # Combine into SEQUENCE + sequence_content = pubkey_octets + sig_octets + + # Encode as SEQUENCE + return LibP2PExtensionHandler._encode_der_sequence(sequence_content) + + @staticmethod + def _encode_der_length(length: int) -> bytes: + """Encode length in DER format.""" + if length < 128: + # Short form + return bytes([length]) + else: + # Long form + length_bytes = length.to_bytes( + (length.bit_length() + 7) // 8, byteorder="big" + ) + return bytes([0x80 | len(length_bytes)]) + length_bytes + + @staticmethod + def _encode_der_octet_string(data: bytes) -> bytes: + """Encode data as DER OCTET STRING.""" + return ( + bytes([0x04]) + LibP2PExtensionHandler._encode_der_length(len(data)) + data + ) + + @staticmethod + def _encode_der_sequence(data: bytes) -> bytes: + """Encode data as DER SEQUENCE.""" + return ( + bytes([0x30]) + LibP2PExtensionHandler._encode_der_length(len(data)) + data + ) + + @staticmethod + def parse_signed_key_extension( + extension: Extension[Any], + ) -> tuple[PublicKey, bytes]: + """ + Parse the libp2p Public Key Extension with support for all crypto types. + Handles both ASN.1 DER format (from go-libp2p) and simple binary format. + """ + try: + logger.debug(f"šŸ” Extension type: {type(extension)}") + logger.debug(f"šŸ” Extension.value type: {type(extension.value)}") + + # Extract the raw bytes from the extension + if isinstance(extension.value, UnrecognizedExtension): + raw_bytes = extension.value.value + logger.debug( + "šŸ” Extension is UnrecognizedExtension, using .value property" + ) + else: + raw_bytes = extension.value + logger.debug("šŸ” Extension.value is already bytes") + + logger.debug(f"šŸ” Total extension length: {len(raw_bytes)} bytes") + logger.debug(f"šŸ” Extension hex (first 50 bytes): {raw_bytes[:50].hex()}") + + if not isinstance(raw_bytes, bytes): + raise QUICCertificateError(f"Expected bytes, got {type(raw_bytes)}") + + # Check if this is ASN.1 DER encoded (from go-libp2p) + if len(raw_bytes) >= 4 and raw_bytes[0] == 0x30: + logger.debug("šŸ” Detected ASN.1 DER encoding") + return LibP2PExtensionHandler._parse_asn1_der_extension(raw_bytes) + else: + logger.debug("šŸ” Using simple binary format parsing") + return LibP2PExtensionHandler._parse_simple_binary_extension(raw_bytes) + + except Exception as e: + logger.debug(f"āŒ Extension parsing failed: {e}") + import traceback + + logger.debug(f"āŒ Traceback: {traceback.format_exc()}") + raise QUICCertificateError( + f"Failed to parse signed key extension: {e}" + ) from e + + @staticmethod + def _parse_asn1_der_extension(raw_bytes: bytes) -> tuple[PublicKey, bytes]: + """ + Parse ASN.1 DER encoded extension (go-libp2p format). + + The structure is typically: + SEQUENCE { + publicKey OCTET STRING, + signature OCTET STRING + } + """ + try: + offset = 0 + + # Parse SEQUENCE tag + if raw_bytes[offset] != 0x30: + raise QUICCertificateError( + f"Expected SEQUENCE tag (0x30), got {raw_bytes[offset]:02x}" + ) + offset += 1 + + # Parse SEQUENCE length + seq_length, length_bytes = LibP2PExtensionHandler._parse_der_length( + raw_bytes[offset:] + ) + offset += length_bytes + logger.debug(f"šŸ” SEQUENCE length: {seq_length} bytes") + + # Parse first OCTET STRING (public key) + if raw_bytes[offset] != 0x04: + raise QUICCertificateError( + f"Expected OCTET STRING tag (0x04), got {raw_bytes[offset]:02x}" + ) + offset += 1 + + pubkey_length, length_bytes = LibP2PExtensionHandler._parse_der_length( + raw_bytes[offset:] + ) + offset += length_bytes + logger.debug(f"šŸ” Public key length: {pubkey_length} bytes") + + if len(raw_bytes) < offset + pubkey_length: + raise QUICCertificateError("Extension too short for public key data") + + public_key_bytes = raw_bytes[offset : offset + pubkey_length] + offset += pubkey_length + + # Parse second OCTET STRING (signature) + if offset < len(raw_bytes) and raw_bytes[offset] == 0x04: + offset += 1 + sig_length, length_bytes = LibP2PExtensionHandler._parse_der_length( + raw_bytes[offset:] + ) + offset += length_bytes + logger.debug(f"šŸ” Signature length: {sig_length} bytes") + + if len(raw_bytes) < offset + sig_length: + raise QUICCertificateError("Extension too short for signature data") + + signature_data = raw_bytes[offset : offset + sig_length] + else: + # Signature might be the remaining bytes + signature_data = raw_bytes[offset:] + + logger.debug(f"šŸ” Public key data length: {len(public_key_bytes)} bytes") + logger.debug(f"šŸ” Signature data length: {len(signature_data)} bytes") + + # Deserialize the public key + public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes) + logger.debug(f"šŸ” Successfully deserialized public key: {type(public_key)}") + + # Extract signature based on key type + signature = LibP2PExtensionHandler._extract_signature_by_key_type( + public_key, signature_data + ) + + return public_key, signature + + except Exception as e: + raise QUICCertificateError( + f"Failed to parse ASN.1 DER extension: {e}" + ) from e + + @staticmethod + def _parse_der_length(data: bytes) -> tuple[int, int]: + """ + Parse DER length encoding. + Returns (length_value, bytes_consumed). + """ + if not data: + raise QUICCertificateError("No data for DER length") + + first_byte = data[0] + + # Short form (length < 128) + if first_byte < 0x80: + return first_byte, 1 + + # Long form + num_bytes = first_byte & 0x7F + if len(data) < 1 + num_bytes: + raise QUICCertificateError("Insufficient data for DER long form length") + + length = 0 + for i in range(1, num_bytes + 1): + length = (length << 8) | data[i] + + return length, 1 + num_bytes + + @staticmethod + def _parse_simple_binary_extension(raw_bytes: bytes) -> tuple[PublicKey, bytes]: + """ + Parse simple binary format extension (original py-libp2p format). + Format: [4-byte pubkey length][pubkey][4-byte sig length][signature] + """ + offset = 0 + + # Parse public key length and data + if len(raw_bytes) < 4: + raise QUICCertificateError("Extension too short for public key length") + + public_key_length = int.from_bytes( + raw_bytes[offset : offset + 4], byteorder="big" + ) + logger.debug(f"šŸ” Public key length: {public_key_length} bytes") + offset += 4 + + if len(raw_bytes) < offset + public_key_length: + raise QUICCertificateError("Extension too short for public key data") + + public_key_bytes = raw_bytes[offset : offset + public_key_length] + offset += public_key_length + + # Parse signature length and data + if len(raw_bytes) < offset + 4: + raise QUICCertificateError("Extension too short for signature length") + + signature_length = int.from_bytes( + raw_bytes[offset : offset + 4], byteorder="big" + ) + logger.debug(f"šŸ” Signature length: {signature_length} bytes") + offset += 4 + + if len(raw_bytes) < offset + signature_length: + raise QUICCertificateError("Extension too short for signature data") + + signature_data = raw_bytes[offset : offset + signature_length] + + # Deserialize the public key + public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes) + logger.debug(f"šŸ” Successfully deserialized public key: {type(public_key)}") + + # Extract signature based on key type + signature = LibP2PExtensionHandler._extract_signature_by_key_type( + public_key, signature_data + ) + + return public_key, signature + + @staticmethod + def _extract_signature_by_key_type( + public_key: PublicKey, signature_data: bytes + ) -> bytes: + """ + Extract the actual signature from signature_data based on the key type. + Different crypto libraries have different signature formats. + """ + if not hasattr(public_key, "get_type"): + logger.debug("āš ļø Public key has no get_type method, using signature as-is") + return signature_data + + key_type = public_key.get_type() + key_type_name = key_type.name if hasattr(key_type, "name") else str(key_type) + logger.debug(f"šŸ” Processing signature for key type: {key_type_name}") + + # Handle different key types + if key_type_name == "Ed25519": + return LibP2PExtensionHandler._extract_ed25519_signature(signature_data) + + elif key_type_name == "Secp256k1": + return LibP2PExtensionHandler._extract_secp256k1_signature(signature_data) + + elif key_type_name == "RSA": + return LibP2PExtensionHandler._extract_rsa_signature(signature_data) + + elif key_type_name in ["ECDSA", "ECC_P256"]: + return LibP2PExtensionHandler._extract_ecdsa_signature(signature_data) + + else: + logger.debug( + f"āš ļø Unknown key type {key_type_name}, using generic extraction" + ) + return LibP2PExtensionHandler._extract_generic_signature(signature_data) + + @staticmethod + def _extract_ed25519_signature(signature_data: bytes) -> bytes: + """Extract Ed25519 signature (must be exactly 64 bytes).""" + logger.debug("šŸ”§ Extracting Ed25519 signature") + + if len(signature_data) == 64: + logger.debug("āœ… Ed25519 signature is already 64 bytes") + return signature_data + + logger.debug( + f"āš ļø Ed25519 signature is {len(signature_data)} bytes, extracting 64 bytes" + ) + + # Look for the payload marker and extract signature before it + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index >= 64: + # The signature is likely the first 64 bytes before the payload + signature = signature_data[:64] + logger.debug("šŸ”§ Using first 64 bytes as Ed25519 signature") + return signature + + elif marker_index > 0 and marker_index == 64: + # Perfect case: signature is exactly before the marker + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + return signature + + else: + # Fallback: try to extract first 64 bytes + if len(signature_data) >= 64: + signature = signature_data[:64] + logger.debug("šŸ”§ Fallback: using first 64 bytes") + return signature + else: + logger.debug( + f"Cannot extract 64 bytes from {len(signature_data)} byte signature" + ) + return signature_data + + @staticmethod + def _extract_secp256k1_signature(signature_data: bytes) -> bytes: + """ + Extract Secp256k1 signature. Secp256k1 can use either DER-encoded + or raw format depending on the implementation. + """ + logger.debug("šŸ”§ Extracting Secp256k1 signature") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + + # Check if it's DER-encoded (starts with 0x30) + if len(signature) >= 2 and signature[0] == 0x30: + logger.debug("šŸ” Secp256k1 signature appears to be DER-encoded") + return LibP2PExtensionHandler._validate_der_signature(signature) + else: + logger.debug("šŸ” Secp256k1 signature appears to be raw format") + return signature + else: + # No marker found, check if the whole data is DER-encoded + if len(signature_data) >= 2 and signature_data[0] == 0x30: + logger.debug( + "šŸ” Secp256k1 signature appears to be DER-encoded (no marker)" + ) + return LibP2PExtensionHandler._validate_der_signature(signature_data) + else: + logger.debug("šŸ” Using Secp256k1 signature data as-is") + return signature_data + + @staticmethod + def _extract_rsa_signature(signature_data: bytes) -> bytes: + """ + Extract RSA signature. + RSA signatures are typically raw bytes with length matching the key size. + """ + logger.debug("šŸ”§ Extracting RSA signature") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug( + f"šŸ”§ Using {len(signature)} bytes before payload marker for RSA" + ) + return signature + else: + logger.debug("šŸ” Using RSA signature data as-is") + return signature_data + + @staticmethod + def _extract_ecdsa_signature(signature_data: bytes) -> bytes: + """ + Extract ECDSA signature (typically DER-encoded ASN.1). + ECDSA signatures start with 0x30 (ASN.1 SEQUENCE). + """ + logger.debug("šŸ”§ Extracting ECDSA signature") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + + # Validate DER encoding for ECDSA + if len(signature) >= 2 and signature[0] == 0x30: + return LibP2PExtensionHandler._validate_der_signature(signature) + else: + logger.debug( + "āš ļø ECDSA signature doesn't start with DER header, using as-is" + ) + return signature + else: + # Check if the whole data is DER-encoded + if len(signature_data) >= 2 and signature_data[0] == 0x30: + logger.debug("šŸ” ECDSA signature appears to be DER-encoded (no marker)") + return LibP2PExtensionHandler._validate_der_signature(signature_data) + else: + logger.debug("šŸ” Using ECDSA signature data as-is") + return signature_data + + @staticmethod + def _extract_generic_signature(signature_data: bytes) -> bytes: + """ + Generic signature extraction for unknown key types. + Tries to detect DER encoding or extract based on payload marker. + """ + logger.debug("šŸ”§ Extracting signature using generic method") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + + # Check if it's DER-encoded + if len(signature) >= 2 and signature[0] == 0x30: + return LibP2PExtensionHandler._validate_der_signature(signature) + else: + return signature + else: + # Check if the whole data is DER-encoded + if len(signature_data) >= 2 and signature_data[0] == 0x30: + logger.debug( + "šŸ” Generic signature appears to be DER-encoded (no marker)" + ) + return LibP2PExtensionHandler._validate_der_signature(signature_data) + else: + logger.debug("šŸ” Using signature data as-is") + return signature_data + + @staticmethod + def _validate_der_signature(signature: bytes) -> bytes: + """ + Validate and potentially fix DER-encoded signatures. + DER signatures have the format: 30 [length] ... + """ + if len(signature) < 2: + return signature + + if signature[0] != 0x30: + logger.debug("āš ļø Signature doesn't start with DER SEQUENCE tag") + return signature + + # Get the DER length + der_length = signature[1] + expected_total_length = der_length + 2 + + logger.debug( + f"šŸ” DER signature: length byte = {der_length}, " + f"expected total = {expected_total_length}, " + f"actual length = {len(signature)}" + ) + + if len(signature) == expected_total_length: + logger.debug("āœ… DER signature length is correct") + return signature + elif len(signature) > expected_total_length: + logger.debug( + "Truncating DER signature from " + f"{len(signature)} to {expected_total_length} bytes" + ) + return signature[:expected_total_length] + else: + logger.debug("DER signature is shorter than expected, using as-is") + return signature + + +class LibP2PKeyConverter: + """ + Converts between libp2p key formats and cryptography library formats. + Handles different key types: Ed25519, Secp256k1, RSA, ECDSA. + """ + + @staticmethod + def libp2p_to_tls_private_key( + libp2p_key: PrivateKey, + ) -> ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey: + """ + Convert libp2p private key to TLS-compatible private key. + + For certificate generation, we create a separate ephemeral key + rather than using the libp2p identity key directly. + """ + # For QUIC, we prefer ECDSA keys for smaller certificates + # Generate ephemeral P-256 key for certificate signing + private_key = ec.generate_private_key(ec.SECP256R1()) + return private_key + + @staticmethod + def serialize_public_key(public_key: PublicKey) -> bytes: + """Serialize libp2p public key to bytes.""" + return public_key.serialize() + + @staticmethod + def deserialize_public_key(key_bytes: bytes) -> PublicKey: + """ + Deserialize libp2p public key from protobuf bytes. + + Args: + key_bytes: Protobuf-serialized public key bytes + + Returns: + Deserialized PublicKey instance + + """ + try: + # Use the official libp2p deserialization function + return deserialize_public_key(key_bytes) + except Exception as e: + raise QUICCertificateError(f"Failed to deserialize public key: {e}") from e + + +class CertificateGenerator: + """ + Generates X.509 certificates with libp2p peer identity extensions. + Follows libp2p TLS specification for QUIC transport. + """ + + def __init__(self) -> None: + self.extension_handler = LibP2PExtensionHandler() + self.key_converter = LibP2PKeyConverter() + + def generate_certificate( + self, + libp2p_private_key: PrivateKey, + peer_id: ID, + validity_days: int = CERTIFICATE_VALIDITY_DAYS, + ) -> TLSConfig: + """ + Generate a TLS certificate with embedded libp2p peer identity. + Fixed to use datetime objects for validity periods. + + Args: + libp2p_private_key: The libp2p identity private key + peer_id: The libp2p peer ID + validity_days: Certificate validity period in days + + Returns: + TLSConfig with certificate and private key + + Raises: + QUICCertificateError: If certificate generation fails + + """ + try: + # Generate ephemeral private key for certificate + cert_private_key = self.key_converter.libp2p_to_tls_private_key( + libp2p_private_key + ) + cert_public_key = cert_private_key.public_key() + + # Get certificate public key bytes for extension + cert_public_key_bytes = cert_public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + # Create libp2p extension with signed key proof + extension_data = self.extension_handler.create_signed_key_extension( + libp2p_private_key, cert_public_key_bytes + ) + + from datetime import datetime, timedelta, timezone + + now = datetime.now(timezone.utc) + not_before = now - timedelta(minutes=1) + not_after = now + timedelta(days=validity_days) + + # Generate serial number + serial_number = int(now.timestamp()) + + certificate = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, peer_id.to_base58())] # type: ignore + ) + ) + .issuer_name( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, peer_id.to_base58())] # type: ignore + ) + ) + .public_key(cert_public_key) + .serial_number(serial_number) + .not_valid_before(not_before) + .not_valid_after(not_after) + .add_extension( + x509.UnrecognizedExtension( + oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data + ), + critical=False, + ) + .sign(cert_private_key, hashes.SHA256()) + ) + + logger.info(f"Generated libp2p TLS certificate for peer {peer_id}") + logger.debug(f"Certificate valid from {not_before} to {not_after}") + + return TLSConfig( + certificate=certificate, private_key=cert_private_key, peer_id=peer_id + ) + + except Exception as e: + raise QUICCertificateError(f"Failed to generate certificate: {e}") from e + + +class PeerAuthenticator: + """ + Authenticates remote peers using libp2p TLS certificates. + Validates both TLS certificate integrity and libp2p peer identity. + """ + + def __init__(self) -> None: + self.extension_handler = LibP2PExtensionHandler() + + def verify_peer_certificate( + self, certificate: x509.Certificate, expected_peer_id: ID | None = None + ) -> ID: + """ + Verify a peer's TLS certificate and extract/validate peer identity. + + Args: + certificate: The peer's TLS certificate + expected_peer_id: Expected peer ID (for outbound connections) + + Returns: + The verified peer ID + + Raises: + QUICPeerVerificationError: If verification fails + + """ + try: + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + + if certificate.not_valid_after_utc < now: + raise QUICPeerVerificationError("Certificate has expired") + + if certificate.not_valid_before_utc > now: + raise QUICPeerVerificationError("Certificate not yet valid") + + # Extract libp2p extension + libp2p_extension = None + for extension in certificate.extensions: + if extension.oid == LIBP2P_TLS_EXTENSION_OID: + libp2p_extension = extension + break + + if not libp2p_extension: + raise QUICPeerVerificationError("Certificate missing libp2p extension") + + assert libp2p_extension.value is not None + logger.debug(f"Extension type: {type(libp2p_extension)}") + logger.debug(f"Extension value type: {type(libp2p_extension.value)}") + if hasattr(libp2p_extension.value, "__len__"): + logger.debug(f"Extension value length: {len(libp2p_extension.value)}") + logger.debug(f"Extension value: {libp2p_extension.value}") + # Parse the extension to get public key and signature + public_key, signature = self.extension_handler.parse_signed_key_extension( + libp2p_extension + ) + + # Get certificate public key for signature verification + cert_public_key_bytes = certificate.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + # Verify the signature proves ownership of the libp2p private key + signature_payload = b"libp2p-tls-handshake:" + cert_public_key_bytes + + try: + public_key.verify(signature_payload, signature) + except Exception as e: + raise QUICPeerVerificationError( + f"Invalid signature in libp2p extension: {e}" + ) + + # Derive peer ID from public key + derived_peer_id = ID.from_pubkey(public_key) + + # Verify against expected peer ID if provided + if expected_peer_id and derived_peer_id != expected_peer_id: + logger.debug(f"Expected Peer id: {expected_peer_id}") + logger.debug(f"Derived Peer ID: {derived_peer_id}") + raise QUICPeerVerificationError( + f"Peer ID mismatch: expected {expected_peer_id}, " + f"got {derived_peer_id}" + ) + + logger.debug( + f"Successfully verified peer certificate for {derived_peer_id}" + ) + return derived_peer_id + + except QUICPeerVerificationError: + raise + except Exception as e: + raise QUICPeerVerificationError( + f"Certificate verification failed: {e}" + ) from e + + +@dataclass +class QUICTLSSecurityConfig: + """ + Type-safe TLS security configuration for QUIC transport. + """ + + # Core TLS components (required) + certificate: Certificate + private_key: EllipticCurvePrivateKey | RSAPrivateKey + + # Certificate chain (optional) + certificate_chain: list[Certificate] = field(default_factory=list) + + # ALPN protocols + alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"]) + + # TLS verification settings + verify_mode: ssl.VerifyMode = ssl.CERT_NONE + check_hostname: bool = False + request_client_certificate: bool = False + + # Optional peer ID for validation + peer_id: ID | None = None + + # Configuration metadata + is_client_config: bool = False + config_name: str | None = None + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + self._validate() + + def _validate(self) -> None: + """Validate the TLS configuration.""" + if self.certificate is None: + raise ValueError("Certificate is required") + + if self.private_key is None: + raise ValueError("Private key is required") + + if not isinstance(self.certificate, x509.Certificate): + raise TypeError( + f"Certificate must be x509.Certificate, got {type(self.certificate)}" + ) + + if not isinstance( + self.private_key, (ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey) + ): + raise TypeError( + f"Private key must be EC or RSA key, got {type(self.private_key)}" + ) + + if not self.alpn_protocols: + raise ValueError("At least one ALPN protocol is required") + + def validate_certificate_key_match(self) -> bool: + """ + Validate that the certificate and private key match. + + Returns: + True if certificate and private key match + + """ + try: + from cryptography.hazmat.primitives import serialization + + # Get public keys from both certificate and private key + cert_public_key = self.certificate.public_key() + private_public_key = self.private_key.public_key() + + # Compare their PEM representations + cert_pub_pem = cert_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + private_pub_pem = private_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return cert_pub_pem == private_pub_pem + + except Exception: + return False + + def has_libp2p_extension(self) -> bool: + """ + Check if the certificate has the required libp2p extension. + + Returns: + True if libp2p extension is present + + """ + try: + for ext in self.certificate.extensions: + if ext.oid == LIBP2P_TLS_EXTENSION_OID: + return True + return False + except Exception: + return False + + def is_certificate_valid(self) -> bool: + """ + Check if the certificate is currently valid (not expired). + + Returns: + True if certificate is valid + + """ + try: + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + not_before = self.certificate.not_valid_before_utc + not_after = self.certificate.not_valid_after_utc + + return not_before <= now <= not_after + except Exception: + return False + + def get_certificate_info(self) -> dict[Any, Any]: + """ + Get certificate information for debugging. + + Returns: + Dictionary with certificate details + + """ + try: + return { + "subject": str(self.certificate.subject), + "issuer": str(self.certificate.issuer), + "serial_number": self.certificate.serial_number, + "not_valid_before_utc": self.certificate.not_valid_before_utc, + "not_valid_after_utc": self.certificate.not_valid_after_utc, + "has_libp2p_extension": self.has_libp2p_extension(), + "is_valid": self.is_certificate_valid(), + "certificate_key_match": self.validate_certificate_key_match(), + } + except Exception as e: + return {"error": str(e)} + + def debug_config(self) -> None: + """logger.debug debugging information about this configuration.""" + logger.debug( + f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===" + ) + logger.debug(f"Is client config: {self.is_client_config}") + logger.debug(f"ALPN protocols: {self.alpn_protocols}") + logger.debug(f"Verify mode: {self.verify_mode}") + logger.debug(f"Check hostname: {self.check_hostname}") + logger.debug(f"Certificate chain length: {len(self.certificate_chain)}") + + cert_info: dict[Any, Any] = self.get_certificate_info() + for key, value in cert_info.items(): + logger.debug(f"Certificate {key}: {value}") + + logger.debug(f"Private key type: {type(self.private_key).__name__}") + if hasattr(self.private_key, "key_size"): + logger.debug(f"Private key size: {self.private_key.key_size}") + + +def create_server_tls_config( + certificate: Certificate, + private_key: EllipticCurvePrivateKey | RSAPrivateKey, + peer_id: ID | None = None, + **kwargs: Any, +) -> QUICTLSSecurityConfig: + """ + Create a server TLS configuration. + + Args: + certificate: X.509 certificate + private_key: Private key corresponding to certificate + peer_id: Optional peer ID for validation + kwargs: Additional configuration parameters + + Returns: + Server TLS configuration + + """ + return QUICTLSSecurityConfig( + certificate=certificate, + private_key=private_key, + peer_id=peer_id, + is_client_config=False, + config_name="server", + verify_mode=ssl.CERT_NONE, + check_hostname=False, + request_client_certificate=True, + **kwargs, + ) + + +def create_client_tls_config( + certificate: Certificate, + private_key: EllipticCurvePrivateKey | RSAPrivateKey, + peer_id: ID | None = None, + **kwargs: Any, +) -> QUICTLSSecurityConfig: + """ + Create a client TLS configuration. + + Args: + certificate: X.509 certificate + private_key: Private key corresponding to certificate + peer_id: Optional peer ID for validation + kwargs: Additional configuration parameters + + Returns: + Client TLS configuration + + """ + return QUICTLSSecurityConfig( + certificate=certificate, + private_key=private_key, + peer_id=peer_id, + is_client_config=True, + config_name="client", + verify_mode=ssl.CERT_NONE, + check_hostname=False, + **kwargs, + ) + + +class QUICTLSConfigManager: + """ + Manages TLS configuration for QUIC transport with libp2p security. + Integrates with aioquic's TLS configuration system. + """ + + def __init__(self, libp2p_private_key: PrivateKey, peer_id: ID) -> None: + self.libp2p_private_key = libp2p_private_key + self.peer_id = peer_id + self.certificate_generator = CertificateGenerator() + self.peer_authenticator = PeerAuthenticator() + + # Generate certificate for this peer + self.tls_config = self.certificate_generator.generate_certificate( + libp2p_private_key, peer_id + ) + + def create_server_config(self) -> QUICTLSSecurityConfig: + """ + Create server configuration using the new class-based approach. + + Returns: + QUICTLSSecurityConfig instance for server + + """ + config = create_server_tls_config( + certificate=self.tls_config.certificate, + private_key=self.tls_config.private_key, + peer_id=self.peer_id, + ) + + return config + + def create_client_config(self) -> QUICTLSSecurityConfig: + """ + Create client configuration using the new class-based approach. + + Returns: + QUICTLSSecurityConfig instance for client + + """ + config = create_client_tls_config( + certificate=self.tls_config.certificate, + private_key=self.tls_config.private_key, + peer_id=self.peer_id, + ) + + return config + + def verify_peer_identity( + self, peer_certificate: x509.Certificate, expected_peer_id: ID | None = None + ) -> ID: + """ + Verify remote peer's identity from their TLS certificate. + + Args: + peer_certificate: Remote peer's TLS certificate + expected_peer_id: Expected peer ID (for outbound connections) + + Returns: + Verified peer ID + + """ + return self.peer_authenticator.verify_peer_certificate( + peer_certificate, expected_peer_id + ) + + def get_local_peer_id(self) -> ID: + """Get the local peer ID.""" + return self.peer_id + + +# Factory function for creating QUIC security transport +def create_quic_security_transport( + libp2p_private_key: PrivateKey, peer_id: ID +) -> QUICTLSConfigManager: + """ + Factory function to create QUIC security transport. + + Args: + libp2p_private_key: The libp2p identity private key + peer_id: The libp2p peer ID + + Returns: + Configured QUIC TLS manager + + """ + return QUICTLSConfigManager(libp2p_private_key, peer_id) diff --git a/libp2p/transport/quic/stream.py b/libp2p/transport/quic/stream.py new file mode 100644 index 00000000..dac8925e --- /dev/null +++ b/libp2p/transport/quic/stream.py @@ -0,0 +1,656 @@ +""" +QUIC Stream implementation +Provides stream interface over QUIC's native multiplexing. +""" + +from enum import Enum +import logging +import time +from types import TracebackType +from typing import TYPE_CHECKING, Any, cast + +import trio + +from .exceptions import ( + QUICStreamBackpressureError, + QUICStreamClosedError, + QUICStreamResetError, + QUICStreamTimeoutError, +) + +if TYPE_CHECKING: + from libp2p.abc import IMuxedStream + from libp2p.custom_types import TProtocol + + from .connection import QUICConnection +else: + IMuxedStream = cast(type, object) + TProtocol = cast(type, object) + +logger = logging.getLogger(__name__) + + +class StreamState(Enum): + """Stream lifecycle states following libp2p patterns.""" + + OPEN = "open" + WRITE_CLOSED = "write_closed" + READ_CLOSED = "read_closed" + CLOSED = "closed" + RESET = "reset" + + +class StreamDirection(Enum): + """Stream direction for tracking initiator.""" + + INBOUND = "inbound" + OUTBOUND = "outbound" + + +class StreamTimeline: + """Track stream lifecycle events for debugging and monitoring.""" + + def __init__(self) -> None: + self.created_at = time.time() + self.opened_at: float | None = None + self.first_data_at: float | None = None + self.closed_at: float | None = None + self.reset_at: float | None = None + self.error_code: int | None = None + + def record_open(self) -> None: + self.opened_at = time.time() + + def record_first_data(self) -> None: + if self.first_data_at is None: + self.first_data_at = time.time() + + def record_close(self) -> None: + self.closed_at = time.time() + + def record_reset(self, error_code: int) -> None: + self.reset_at = time.time() + self.error_code = error_code + + +class QUICStream(IMuxedStream): + """ + QUIC Stream implementation following libp2p IMuxedStream interface. + + Based on patterns from go-libp2p and js-libp2p, this implementation: + - Leverages QUIC's native multiplexing and flow control + - Integrates with libp2p resource management + - Provides comprehensive error handling with QUIC-specific codes + - Supports bidirectional communication with independent close semantics + - Implements proper stream lifecycle management + """ + + def __init__( + self, + connection: "QUICConnection", + stream_id: int, + direction: StreamDirection, + remote_addr: tuple[str, int], + resource_scope: Any | None = None, + ): + """ + Initialize QUIC stream. + + Args: + connection: Parent QUIC connection + stream_id: QUIC stream identifier + direction: Stream direction (inbound/outbound) + resource_scope: Resource manager scope for memory accounting + remote_addr: Remote addr stream is connected to + + """ + self._connection = connection + self._stream_id = stream_id + self._direction = direction + self._resource_scope = resource_scope + + # libp2p interface compliance + self._protocol: TProtocol | None = None + self._metadata: dict[str, Any] = {} + self._remote_addr = remote_addr + + # Stream state management + self._state = StreamState.OPEN + self._state_lock = trio.Lock() + + # Flow control and buffering + self._receive_buffer = bytearray() + self._receive_buffer_lock = trio.Lock() + self._receive_event = trio.Event() + self._backpressure_event = trio.Event() + self._backpressure_event.set() # Initially no backpressure + + # Close/reset state + self._write_closed = False + self._read_closed = False + self._close_event = trio.Event() + self._reset_error_code: int | None = None + + # Lifecycle tracking + self._timeline = StreamTimeline() + self._timeline.record_open() + + # Resource accounting + self._memory_reserved = 0 + + # Stream constant configurations + self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT + self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT + self.FLOW_CONTROL_WINDOW_SIZE = ( + connection._transport._config.STREAM_FLOW_CONTROL_WINDOW + ) + self.MAX_RECEIVE_BUFFER_SIZE = ( + connection._transport._config.MAX_STREAM_RECEIVE_BUFFER + ) + + if self._resource_scope: + self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE) + + logger.debug( + f"Created QUIC stream {stream_id} " + f"({direction.value}, connection: {connection.remote_peer_id()})" + ) + + # Properties for libp2p interface compliance + + @property + def protocol(self) -> TProtocol | None: + """Get the protocol identifier for this stream.""" + return self._protocol + + @protocol.setter + def protocol(self, protocol_id: TProtocol) -> None: + """Set the protocol identifier for this stream.""" + self._protocol = protocol_id + self._metadata["protocol"] = protocol_id + logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}") + + @property + def stream_id(self) -> str: + """Get stream ID as string for libp2p compatibility.""" + return str(self._stream_id) + + @property + def muxed_conn(self) -> "QUICConnection": # type: ignore + """Get the parent muxed connection.""" + return self._connection + + @property + def state(self) -> StreamState: + """Get current stream state.""" + return self._state + + @property + def direction(self) -> StreamDirection: + """Get stream direction.""" + return self._direction + + @property + def is_initiator(self) -> bool: + """Check if this stream was locally initiated.""" + return self._direction == StreamDirection.OUTBOUND + + # Core stream operations + + async def read(self, n: int | None = None) -> bytes: + """ + Read data from the stream with QUIC flow control. + + Args: + n: Maximum number of bytes to read. If None or -1, read all available. + + Returns: + Data read from stream + + Raises: + QUICStreamClosedError: Stream is closed + QUICStreamResetError: Stream was reset + QUICStreamTimeoutError: Read timeout exceeded + + """ + if n is None: + n = -1 + + async with self._state_lock: + if self._state in (StreamState.CLOSED, StreamState.RESET): + raise QUICStreamClosedError(f"Stream {self.stream_id} is closed") + + if self._read_closed: + # Return any remaining buffered data, then EOF + async with self._receive_buffer_lock: + if self._receive_buffer: + data = self._extract_data_from_buffer(n) + self._timeline.record_first_data() + return data + return b"" + + # Wait for data with timeout + timeout = self.READ_TIMEOUT + try: + with trio.move_on_after(timeout) as cancel_scope: + while True: + async with self._receive_buffer_lock: + if self._receive_buffer: + data = self._extract_data_from_buffer(n) + self._timeline.record_first_data() + return data + + # Check if stream was closed while waiting + if self._read_closed: + return b"" + + # Wait for more data + await self._receive_event.wait() + self._receive_event = trio.Event() # Reset for next wait + + if cancel_scope.cancelled_caught: + raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}") + + return b"" + except QUICStreamResetError: + # Stream was reset while reading + raise + except Exception as e: + logger.error(f"Error reading from stream {self.stream_id}: {e}") + await self._handle_stream_error(e) + raise + + async def write(self, data: bytes) -> None: + """ + Write data to the stream with QUIC flow control. + + Args: + data: Data to write + + Raises: + QUICStreamClosedError: Stream is closed for writing + QUICStreamBackpressureError: Flow control window exhausted + QUICStreamResetError: Stream was reset + + """ + if not data: + return + + async with self._state_lock: + if self._state in (StreamState.CLOSED, StreamState.RESET): + raise QUICStreamClosedError(f"Stream {self.stream_id} is closed") + + if self._write_closed: + raise QUICStreamClosedError( + f"Stream {self.stream_id} write side is closed" + ) + + try: + # Handle flow control backpressure + await self._backpressure_event.wait() + + # Send data through QUIC connection + self._connection._quic.send_stream_data(self._stream_id, data) + await self._connection._transmit() + + self._timeline.record_first_data() + logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}") + + except Exception as e: + logger.error(f"Error writing to stream {self.stream_id}: {e}") + # Convert QUIC-specific errors + if "flow control" in str(e).lower(): + raise QUICStreamBackpressureError(f"Flow control limit reached: {e}") + await self._handle_stream_error(e) + raise + + async def close(self) -> None: + """ + Close the stream gracefully (both read and write sides). + + This implements proper close semantics where both sides + are closed and resources are cleaned up. + """ + async with self._state_lock: + if self._state in (StreamState.CLOSED, StreamState.RESET): + return + + logger.debug(f"Closing stream {self.stream_id}") + + # Close both sides + if not self._write_closed: + await self.close_write() + if not self._read_closed: + await self.close_read() + + # Update state and cleanup + async with self._state_lock: + self._state = StreamState.CLOSED + + await self._cleanup_resources() + self._timeline.record_close() + self._close_event.set() + + logger.debug(f"Stream {self.stream_id} closed") + + async def close_write(self) -> None: + """Close the write side of the stream.""" + if self._write_closed: + return + + try: + # Send FIN to close write side + self._connection._quic.send_stream_data( + self._stream_id, b"", end_stream=True + ) + await self._connection._transmit() + + self._write_closed = True + + async with self._state_lock: + if self._read_closed: + self._state = StreamState.CLOSED + else: + self._state = StreamState.WRITE_CLOSED + + logger.debug(f"Stream {self.stream_id} write side closed") + + except Exception as e: + logger.error(f"Error closing write side of stream {self.stream_id}: {e}") + + async def close_read(self) -> None: + """Close the read side of the stream.""" + if self._read_closed: + return + + try: + self._read_closed = True + + async with self._state_lock: + if self._write_closed: + self._state = StreamState.CLOSED + else: + self._state = StreamState.READ_CLOSED + + # Wake up any pending reads + self._receive_event.set() + + logger.debug(f"Stream {self.stream_id} read side closed") + + except Exception as e: + logger.error(f"Error closing read side of stream {self.stream_id}: {e}") + + async def reset(self, error_code: int = 0) -> None: + """ + Reset the stream with the given error code. + + Args: + error_code: QUIC error code for the reset + + """ + async with self._state_lock: + if self._state == StreamState.RESET: + return + + logger.debug( + f"Resetting stream {self.stream_id} with error code {error_code}" + ) + + self._state = StreamState.RESET + self._reset_error_code = error_code + + try: + # Send QUIC reset frame + self._connection._quic.reset_stream(self._stream_id, error_code) + await self._connection._transmit() + + except Exception as e: + logger.error(f"Error sending reset for stream {self.stream_id}: {e}") + finally: + # Always cleanup resources + await self._cleanup_resources() + self._timeline.record_reset(error_code) + self._close_event.set() + + def is_closed(self) -> bool: + """Check if stream is completely closed.""" + return self._state in (StreamState.CLOSED, StreamState.RESET) + + def is_reset(self) -> bool: + """Check if stream was reset.""" + return self._state == StreamState.RESET + + def can_read(self) -> bool: + """Check if stream can be read from.""" + return not self._read_closed and self._state not in ( + StreamState.CLOSED, + StreamState.RESET, + ) + + def can_write(self) -> bool: + """Check if stream can be written to.""" + return not self._write_closed and self._state not in ( + StreamState.CLOSED, + StreamState.RESET, + ) + + async def handle_data_received(self, data: bytes, end_stream: bool) -> None: + """ + Handle data received from the QUIC connection. + + Args: + data: Received data + end_stream: Whether this is the last data (FIN received) + + """ + if self._state == StreamState.RESET: + return + + if data: + async with self._receive_buffer_lock: + if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE: + logger.warning( + f"Stream {self.stream_id} receive buffer overflow, " + f"dropping {len(data)} bytes" + ) + return + + self._receive_buffer.extend(data) + self._timeline.record_first_data() + + # Notify waiting readers + self._receive_event.set() + + logger.debug(f"Stream {self.stream_id} received {len(data)} bytes") + + if end_stream: + self._read_closed = True + async with self._state_lock: + if self._write_closed: + self._state = StreamState.CLOSED + else: + self._state = StreamState.READ_CLOSED + + # Wake up readers to process remaining data and EOF + self._receive_event.set() + + logger.debug(f"Stream {self.stream_id} received FIN") + + async def handle_stop_sending(self, error_code: int) -> None: + """ + Handle STOP_SENDING frame from remote peer. + + When a STOP_SENDING frame is received, the peer is requesting that we + stop sending data on this stream. We respond by resetting the stream. + + Args: + error_code: Error code from the STOP_SENDING frame + + """ + logger.debug( + f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})" + ) + + self._write_closed = True + + # Wake up any pending write operations + self._backpressure_event.set() + + async with self._state_lock: + if self.direction == StreamDirection.OUTBOUND: + self._state = StreamState.CLOSED + elif self._read_closed: + self._state = StreamState.CLOSED + else: + # Only write side closed - add WRITE_CLOSED state if needed + self._state = StreamState.WRITE_CLOSED + + # Send RESET_STREAM in response (QUIC protocol requirement) + try: + self._connection._quic.reset_stream(int(self.stream_id), error_code) + await self._connection._transmit() + logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}") + except Exception as e: + logger.warning( + f"Could not send RESET_STREAM for stream {self.stream_id}: {e}" + ) + + async def handle_reset(self, error_code: int) -> None: + """ + Handle stream reset from remote peer. + + Args: + error_code: QUIC error code from reset frame + + """ + logger.debug( + f"Stream {self.stream_id} reset by peer with error code {error_code}" + ) + + async with self._state_lock: + self._state = StreamState.RESET + self._reset_error_code = error_code + + await self._cleanup_resources() + self._timeline.record_reset(error_code) + self._close_event.set() + + # Wake up any pending operations + self._receive_event.set() + self._backpressure_event.set() + + async def handle_flow_control_update(self, available_window: int) -> None: + """ + Handle flow control window updates. + + Args: + available_window: Available flow control window size + + """ + if available_window > 0: + self._backpressure_event.set() + logger.debug( + f"Stream {self.stream_id} flow control".__add__( + f"window updated: {available_window}" + ) + ) + else: + self._backpressure_event = trio.Event() # Reset to blocking state + logger.debug(f"Stream {self.stream_id} flow control window exhausted") + + def _extract_data_from_buffer(self, n: int) -> bytes: + """Extract data from receive buffer with specified limit.""" + if n == -1: + # Read all available data + data = bytes(self._receive_buffer) + self._receive_buffer.clear() + else: + # Read up to n bytes + data = bytes(self._receive_buffer[:n]) + self._receive_buffer = self._receive_buffer[n:] + + return data + + async def _handle_stream_error(self, error: Exception) -> None: + """Handle errors by resetting the stream.""" + logger.error(f"Stream {self.stream_id} error: {error}") + await self.reset(error_code=1) # Generic error code + + def _reserve_memory(self, size: int) -> None: + """Reserve memory with resource manager.""" + if self._resource_scope: + try: + self._resource_scope.reserve_memory(size) + self._memory_reserved += size + except Exception as e: + logger.warning( + f"Failed to reserve memory for stream {self.stream_id}: {e}" + ) + + def _release_memory(self, size: int) -> None: + """Release memory with resource manager.""" + if self._resource_scope and size > 0: + try: + self._resource_scope.release_memory(size) + self._memory_reserved = max(0, self._memory_reserved - size) + except Exception as e: + logger.warning( + f"Failed to release memory for stream {self.stream_id}: {e}" + ) + + async def _cleanup_resources(self) -> None: + """Clean up stream resources.""" + # Release all reserved memory + if self._memory_reserved > 0: + self._release_memory(self._memory_reserved) + + # Clear receive buffer + async with self._receive_buffer_lock: + self._receive_buffer.clear() + + # Remove from connection's stream registry + self._connection._remove_stream(self._stream_id) + + logger.debug(f"Stream {self.stream_id} resources cleaned up") + + # Abstact implementations + + def get_remote_address(self) -> tuple[str, int]: + return self._remote_addr + + async def __aenter__(self) -> "QUICStream": + """Enter the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the async context manager and close the stream.""" + logger.debug("Exiting the context and closing the stream") + await self.close() + + def set_deadline(self, ttl: int) -> bool: + """ + Set a deadline for the stream. QUIC does not support deadlines natively, + so this method always returns False to indicate the operation is unsupported. + + :param ttl: Time-to-live in seconds (ignored). + :return: False, as deadlines are not supported. + """ + raise NotImplementedError("QUIC does not support setting read deadlines") + + # String representation for debugging + + def __repr__(self) -> str: + return ( + f"QUICStream(id={self.stream_id}, " + f"state={self._state.value}, " + f"direction={self._direction.value}, " + f"protocol={self._protocol})" + ) + + def __str__(self) -> str: + return f"QUICStream({self.stream_id})" diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py new file mode 100644 index 00000000..ef0df368 --- /dev/null +++ b/libp2p/transport/quic/transport.py @@ -0,0 +1,491 @@ +""" +QUIC Transport implementation +""" + +import copy +import logging +import ssl +from typing import TYPE_CHECKING, cast + +from aioquic.quic.configuration import ( + QuicConfiguration, +) +from aioquic.quic.connection import ( + QuicConnection as NativeQUICConnection, +) +from aioquic.quic.logger import QuicLogger +import multiaddr +import trio + +from libp2p.abc import ( + ITransport, +) +from libp2p.crypto.keys import ( + PrivateKey, +) +from libp2p.custom_types import TProtocol, TQUICConnHandlerFn +from libp2p.peer.id import ( + ID, +) +from libp2p.transport.quic.security import QUICTLSSecurityConfig +from libp2p.transport.quic.utils import ( + create_client_config_from_base, + create_server_config_from_base, + get_alpn_protocols, + is_quic_multiaddr, + multiaddr_to_quic_version, + quic_multiaddr_to_endpoint, + quic_version_to_wire_format, +) + +if TYPE_CHECKING: + from libp2p.network.swarm import Swarm +else: + Swarm = cast(type, object) + +from .config import ( + QUICTransportConfig, +) +from .connection import ( + QUICConnection, +) +from .exceptions import ( + QUICDialError, + QUICListenError, + QUICSecurityError, +) +from .listener import ( + QUICListener, +) +from .security import ( + QUICTLSConfigManager, + create_quic_security_transport, +) + +QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 +QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 + +logger = logging.getLogger(__name__) + + +class QUICTransport(ITransport): + """ + QUIC Stream implementation following libp2p IMuxedStream interface. + """ + + def __init__( + self, private_key: PrivateKey, config: QUICTransportConfig | None = None + ) -> None: + """ + Initialize QUIC transport with security integration. + + Args: + private_key: libp2p private key for identity and TLS cert generation + config: QUIC transport configuration options + + """ + self._private_key = private_key + self._peer_id = ID.from_pubkey(private_key.get_public_key()) + self._config = config or QUICTransportConfig() + + # Connection management + self._connections: dict[str, QUICConnection] = {} + self._listeners: list[QUICListener] = [] + + # Security manager for TLS integration + self._security_manager = create_quic_security_transport( + self._private_key, self._peer_id + ) + + # QUIC configurations for different versions + self._quic_configs: dict[TProtocol, QuicConfiguration] = {} + self._setup_quic_configurations() + + # Resource management + self._closed = False + self._nursery_manager = trio.CapacityLimiter(1) + self._background_nursery: trio.Nursery | None = None + + self._swarm: Swarm | None = None + + logger.debug( + f"Initialized QUIC transport with security for peer {self._peer_id}" + ) + + def set_background_nursery(self, nursery: trio.Nursery) -> None: + """Set the nursery to use for background tasks (called by swarm).""" + self._background_nursery = nursery + logger.debug("Transport background nursery set") + + def set_swarm(self, swarm: Swarm) -> None: + """Set the swarm for adding incoming connections.""" + self._swarm = swarm + + def _setup_quic_configurations(self) -> None: + """Setup QUIC configurations.""" + try: + # Get TLS configuration from security manager + server_tls_config = self._security_manager.create_server_config() + client_tls_config = self._security_manager.create_client_config() + + # Base server configuration + base_server_config = QuicConfiguration( + is_client=False, + alpn_protocols=get_alpn_protocols(), + verify_mode=self._config.verify_mode, + max_datagram_frame_size=self._config.max_datagram_size, + idle_timeout=self._config.idle_timeout, + ) + + # Base client configuration + base_client_config = QuicConfiguration( + is_client=True, + alpn_protocols=get_alpn_protocols(), + verify_mode=self._config.verify_mode, + max_datagram_frame_size=self._config.max_datagram_size, + idle_timeout=self._config.idle_timeout, + ) + + # Apply TLS configuration + self._apply_tls_configuration(base_server_config, server_tls_config) + self._apply_tls_configuration(base_client_config, client_tls_config) + + # QUIC v1 (RFC 9000) configurations + if self._config.enable_v1: + quic_v1_server_config = create_server_config_from_base( + base_server_config, self._security_manager, self._config + ) + quic_v1_server_config.supported_versions = [ + quic_version_to_wire_format(QUIC_V1_PROTOCOL) + ] + + quic_v1_client_config = create_client_config_from_base( + base_client_config, self._security_manager, self._config + ) + quic_v1_client_config.supported_versions = [ + quic_version_to_wire_format(QUIC_V1_PROTOCOL) + ] + + # Store both server and client configs for v1 + self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = ( + quic_v1_server_config + ) + self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = ( + quic_v1_client_config + ) + + # QUIC draft-29 configurations for compatibility + if self._config.enable_draft29: + draft29_server_config: QuicConfiguration = copy.copy(base_server_config) + draft29_server_config.supported_versions = [ + quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL) + ] + + draft29_client_config = copy.copy(base_client_config) + draft29_client_config.supported_versions = [ + quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL) + ] + + self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = ( + draft29_server_config + ) + self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = ( + draft29_client_config + ) + + logger.debug("QUIC configurations initialized with libp2p TLS security") + + except Exception as e: + raise QUICSecurityError( + f"Failed to setup QUIC TLS configurations: {e}" + ) from e + + def _apply_tls_configuration( + self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig + ) -> None: + """ + Apply TLS configuration to a QUIC configuration using aioquic's actual API. + + Args: + config: QuicConfiguration to update + tls_config: TLS configuration dictionary from security manager + + """ + try: + config.certificate = tls_config.certificate + config.private_key = tls_config.private_key + config.certificate_chain = tls_config.certificate_chain + config.alpn_protocols = tls_config.alpn_protocols + config.verify_mode = ssl.CERT_NONE + + logger.debug("Successfully applied TLS configuration to QUIC config") + + except Exception as e: + raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e + + async def dial( + self, + maddr: multiaddr.Multiaddr, + ) -> QUICConnection: + """ + Dial a remote peer using QUIC transport with security verification. + + Args: + maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1) + peer_id: Expected peer ID for verification + nursery: Nursery to execute the background tasks + + Returns: + Raw connection interface to the remote peer + + Raises: + QUICDialError: If dialing fails + QUICSecurityError: If security verification fails + + """ + if self._closed: + raise QUICDialError("Transport is closed") + + if not is_quic_multiaddr(maddr): + raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}") + + try: + # Extract connection details from multiaddr + host, port = quic_multiaddr_to_endpoint(maddr) + remote_peer_id = maddr.get_peer_id() + if remote_peer_id is not None: + remote_peer_id = ID.from_base58(remote_peer_id) + + if remote_peer_id is None: + logger.error("Unable to derive peer id from multiaddr") + raise QUICDialError("Unable to derive peer id from multiaddr") + quic_version = multiaddr_to_quic_version(maddr) + + # Get appropriate QUIC client configuration + config_key = TProtocol(f"{quic_version}_client") + logger.debug("config_key", config_key, self._quic_configs.keys()) + config = self._quic_configs.get(config_key) + if not config: + raise QUICDialError(f"Unsupported QUIC version: {quic_version}") + + config.is_client = True + config.quic_logger = QuicLogger() + + # Ensure client certificate is properly set for mutual authentication + if not config.certificate or not config.private_key: + logger.warning( + "Client config missing certificate - applying TLS config" + ) + client_tls_config = self._security_manager.create_client_config() + self._apply_tls_configuration(config, client_tls_config) + + # Debug log to verify certificate is present + logger.info( + f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})" + ) + + logger.debug("Starting QUIC Connection") + # Create QUIC connection using aioquic's sans-IO core + native_quic_connection = NativeQUICConnection(configuration=config) + + # Create trio-based QUIC connection wrapper with security + connection = QUICConnection( + quic_connection=native_quic_connection, + remote_addr=(host, port), + remote_peer_id=remote_peer_id, + local_peer_id=self._peer_id, + is_initiator=True, + maddr=maddr, + transport=self, + security_manager=self._security_manager, + ) + logger.debug("QUIC Connection Created") + + if self._background_nursery is None: + logger.error("No nursery set to execute background tasks") + raise QUICDialError("No nursery found to execute tasks") + + await connection.connect(self._background_nursery) + + # Store connection for management + conn_id = f"{host}:{port}" + self._connections[conn_id] = connection + + return connection + + except Exception as e: + logger.error(f"Failed to dial QUIC connection to {maddr}: {e}") + raise QUICDialError(f"Dial failed: {e}") from e + + async def _verify_peer_identity( + self, connection: QUICConnection, expected_peer_id: ID + ) -> None: + """ + Verify remote peer identity after TLS handshake. + + Args: + connection: The established QUIC connection + expected_peer_id: Expected peer ID + + Raises: + QUICSecurityError: If peer verification fails + + """ + try: + # Get peer certificate from the connection + peer_certificate = await connection.get_peer_certificate() + + if not peer_certificate: + raise QUICSecurityError("No peer certificate available") + + # Verify peer identity using security manager + verified_peer_id = self._security_manager.verify_peer_identity( + peer_certificate, expected_peer_id + ) + + if verified_peer_id != expected_peer_id: + raise QUICSecurityError( + "Peer ID verification failed: expected " + f"{expected_peer_id}, got {verified_peer_id}" + ) + + logger.debug(f"Peer identity verified: {verified_peer_id}") + logger.debug(f"Peer identity verified: {verified_peer_id}") + + except Exception as e: + raise QUICSecurityError(f"Peer identity verification failed: {e}") from e + + def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener: + """ + Create a QUIC listener with integrated security. + + Args: + handler_function: Function to handle new connections + + Returns: + QUIC listener instance + + Raises: + QUICListenError: If transport is closed + + """ + if self._closed: + raise QUICListenError("Transport is closed") + + # Get server configurations for the listener + server_configs = { + version: config + for version, config in self._quic_configs.items() + if version.endswith("_server") + } + + listener = QUICListener( + transport=self, + handler_function=handler_function, + quic_configs=server_configs, + config=self._config, + security_manager=self._security_manager, + ) + + self._listeners.append(listener) + logger.debug("Created QUIC listener with security") + return listener + + def can_dial(self, maddr: multiaddr.Multiaddr) -> bool: + """ + Check if this transport can dial the given multiaddr. + + Args: + maddr: Multiaddr to check + + Returns: + True if this transport can dial the address + + """ + return is_quic_multiaddr(maddr) + + def protocols(self) -> list[TProtocol]: + """ + Get supported protocol identifiers. + + Returns: + List of supported protocol strings + + """ + protocols = [QUIC_V1_PROTOCOL] + if self._config.enable_draft29: + protocols.append(QUIC_DRAFT29_PROTOCOL) + return protocols + + def listen_order(self) -> int: + """ + Get the listen order priority for this transport. + Matches go-libp2p's ListenOrder = 1 for QUIC. + + Returns: + Priority order for listening (lower = higher priority) + + """ + return 1 + + async def close(self) -> None: + """Close the transport and cleanup resources.""" + if self._closed: + return + + self._closed = True + logger.debug("Closing QUIC transport") + + # Close all active connections and listeners concurrently using trio nursery + async with trio.open_nursery() as nursery: + # Close all connections + for connection in self._connections.values(): + nursery.start_soon(connection.close) + + # Close all listeners + for listener in self._listeners: + nursery.start_soon(listener.close) + + self._connections.clear() + self._listeners.clear() + + logger.debug("QUIC transport closed") + + async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None: + """Clean up a terminated connection from all listeners.""" + try: + for listener in self._listeners: + await listener._remove_connection_by_object(connection) + logger.debug( + "āœ… TRANSPORT: Cleaned up terminated connection from all listeners" + ) + except Exception as e: + logger.error(f"āŒ TRANSPORT: Error cleaning up terminated connection: {e}") + + def get_stats(self) -> dict[str, int | list[str] | object]: + """Get transport statistics including security info.""" + return { + "active_connections": len(self._connections), + "active_listeners": len(self._listeners), + "supported_protocols": self.protocols(), + "local_peer_id": str(self._peer_id), + "security_enabled": True, + "tls_configured": True, + } + + def get_security_manager(self) -> QUICTLSConfigManager: + """ + Get the security manager for this transport. + + Returns: + The QUIC TLS configuration manager + + """ + return self._security_manager + + def get_listener_socket(self) -> trio.socket.SocketType | None: + """Get the socket from the first active listener.""" + for listener in self._listeners: + if listener.is_listening() and listener._socket: + return listener._socket + return None diff --git a/libp2p/transport/quic/utils.py b/libp2p/transport/quic/utils.py new file mode 100644 index 00000000..37b7880b --- /dev/null +++ b/libp2p/transport/quic/utils.py @@ -0,0 +1,466 @@ +""" +Multiaddr utilities for QUIC transport - Module 4. +Essential utilities required for QUIC transport implementation. +Based on go-libp2p and js-libp2p QUIC implementations. +""" + +import ipaddress +import logging +import ssl + +from aioquic.quic.configuration import QuicConfiguration +import multiaddr + +from libp2p.custom_types import TProtocol +from libp2p.transport.quic.security import QUICTLSConfigManager + +from .config import QUICTransportConfig +from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError + +logger = logging.getLogger(__name__) + +# Protocol constants +QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 +QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 +UDP_PROTOCOL = "udp" +IP4_PROTOCOL = "ip4" +IP6_PROTOCOL = "ip6" + +SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server" +CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client" + +SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server" +CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client" + +CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = { + SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000 + CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000 + SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29 + CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29 +} + +# QUIC version to wire format mappings (required for aioquic) +QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = { + QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000 + QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29 +} + +# ALPN protocols for libp2p over QUIC +LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"] + + +def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool: + """ + Check if a multiaddr represents a QUIC address. + + Valid QUIC multiaddrs: + - /ip4/127.0.0.1/udp/4001/quic-v1 + - /ip4/127.0.0.1/udp/4001/quic + - /ip6/::1/udp/4001/quic-v1 + - /ip6/::1/udp/4001/quic + + Args: + maddr: Multiaddr to check + + Returns: + True if the multiaddr represents a QUIC address + + """ + try: + addr_str = str(maddr) + + # Check for required components + has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str + has_udp = f"/{UDP_PROTOCOL}/" in addr_str + has_quic = ( + f"/{QUIC_V1_PROTOCOL}" in addr_str + or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str + or "/quic" in addr_str + ) + + return has_ip and has_udp and has_quic + + except Exception: + return False + + +def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]: + """ + Extract host and port from a QUIC multiaddr. + + Args: + maddr: QUIC multiaddr + + Returns: + Tuple of (host, port) + + Raises: + QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address + + """ + if not is_quic_multiaddr(maddr): + raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}") + + try: + host = None + port = None + + # Try to get IPv4 address + try: + host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore + except Exception: + pass + + # Try to get IPv6 address if IPv4 not found + if host is None: + try: + host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore + except Exception: + pass + + # Get UDP port + try: + port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore + port = int(port_str) + except Exception: + pass + + if host is None or port is None: + raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}") + + return host, port + + except Exception as e: + raise QUICInvalidMultiaddrError( + f"Failed to parse QUIC multiaddr {maddr}: {e}" + ) from e + + +def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol: + """ + Determine QUIC version from multiaddr. + + Args: + maddr: QUIC multiaddr + + Returns: + QUIC version identifier ("quic-v1" or "quic") + + Raises: + QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol + + """ + try: + addr_str = str(maddr) + + if f"/{QUIC_V1_PROTOCOL}" in addr_str: + return QUIC_V1_PROTOCOL # RFC 9000 + elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str: + return QUIC_DRAFT29_PROTOCOL # draft-29 + else: + raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}") + + except Exception as e: + raise QUICInvalidMultiaddrError( + f"Failed to determine QUIC version from {maddr}: {e}" + ) from e + + +def create_quic_multiaddr( + host: str, port: int, version: str = "quic-v1" +) -> multiaddr.Multiaddr: + """ + Create a QUIC multiaddr from host, port, and version. + + Args: + host: IP address (IPv4 or IPv6) + port: UDP port number + version: QUIC version ("quic-v1" or "quic") + + Returns: + QUIC multiaddr + + Raises: + QUICInvalidMultiaddrError: If invalid parameters provided + + """ + try: + # Determine IP version + try: + ip = ipaddress.ip_address(host) + if isinstance(ip, ipaddress.IPv4Address): + ip_proto = IP4_PROTOCOL + else: + ip_proto = IP6_PROTOCOL + except ValueError: + raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}") + + # Validate port + if not (0 <= port <= 65535): + raise QUICInvalidMultiaddrError(f"Invalid port: {port}") + + # Validate and normalize QUIC version + if version == "quic-v1" or version == "/quic-v1": + quic_proto = QUIC_V1_PROTOCOL + elif version == "quic" or version == "/quic": + quic_proto = QUIC_DRAFT29_PROTOCOL + else: + raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}") + + # Construct multiaddr + addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}" + return multiaddr.Multiaddr(addr_str) + + except Exception as e: + raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e + + +def quic_version_to_wire_format(version: TProtocol) -> int: + """ + Convert QUIC version string to wire format integer for aioquic. + + Args: + version: QUIC version string ("quic-v1" or "quic") + + Returns: + Wire format version number + + Raises: + QUICUnsupportedVersionError: If version is not supported + + """ + wire_version = QUIC_VERSION_MAPPINGS.get(version) + if wire_version is None: + raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}") + + return wire_version + + +def custom_quic_version_to_wire_format(version: TProtocol) -> int: + """ + Convert QUIC version string to wire format integer for aioquic. + + Args: + version: QUIC version string ("quic-v1" or "quic") + + Returns: + Wire format version number + + Raises: + QUICUnsupportedVersionError: If version is not supported + + """ + wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version) + if wire_version is None: + raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}") + + return wire_version + + +def get_alpn_protocols() -> list[str]: + """ + Get ALPN protocols for libp2p over QUIC. + + Returns: + List of ALPN protocol identifiers + + """ + return LIBP2P_ALPN_PROTOCOLS.copy() + + +def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr: + """ + Normalize a QUIC multiaddr to canonical form. + + Args: + maddr: Input QUIC multiaddr + + Returns: + Normalized multiaddr + + Raises: + QUICInvalidMultiaddrError: If not a valid QUIC multiaddr + + """ + if not is_quic_multiaddr(maddr): + raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}") + + host, port = quic_multiaddr_to_endpoint(maddr) + version = multiaddr_to_quic_version(maddr) + + return create_quic_multiaddr(host, port, version) + + +def create_server_config_from_base( + base_config: QuicConfiguration, + security_manager: QUICTLSConfigManager | None = None, + transport_config: QUICTransportConfig | None = None, +) -> QuicConfiguration: + """ + Create a server configuration without using deepcopy. + Manually copies attributes while handling cryptography objects properly. + """ + try: + # Create new server configuration from scratch + server_config = QuicConfiguration(is_client=False) + server_config.verify_mode = ssl.CERT_NONE + + # Copy basic configuration attributes (these are safe to copy) + copyable_attrs = [ + "alpn_protocols", + "verify_mode", + "max_datagram_frame_size", + "idle_timeout", + "max_concurrent_streams", + "supported_versions", + "max_data", + "max_stream_data", + "stateless_retry", + "quantum_readiness_test", + ] + + for attr in copyable_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(server_config, attr, value) + + # Handle cryptography objects - these need direct reference, not copying + crypto_attrs = [ + "certificate", + "private_key", + "certificate_chain", + "ca_certs", + ] + + for attr in crypto_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(server_config, attr, value) + + # Apply security manager configuration if available + if security_manager: + try: + server_tls_config = security_manager.create_server_config() + + # Override with security manager's TLS configuration + if server_tls_config.certificate: + server_config.certificate = server_tls_config.certificate + if server_tls_config.private_key: + server_config.private_key = server_tls_config.private_key + if server_tls_config.certificate_chain: + server_config.certificate_chain = ( + server_tls_config.certificate_chain + ) + if server_tls_config.alpn_protocols: + server_config.alpn_protocols = server_tls_config.alpn_protocols + server_tls_config.request_client_certificate = True + if getattr(server_tls_config, "request_client_certificate", False): + server_config._libp2p_request_client_cert = True # type: ignore + else: + logger.error( + "šŸ”§ Failed to set request_client_certificate in server config" + ) + + except Exception as e: + logger.warning(f"Failed to apply security manager config: {e}") + + # Set transport-specific defaults if provided + if transport_config: + if server_config.idle_timeout == 0: + server_config.idle_timeout = getattr( + transport_config, "idle_timeout", 30.0 + ) + if server_config.max_datagram_frame_size is None: + server_config.max_datagram_frame_size = getattr( + transport_config, "max_datagram_size", 1200 + ) + # Ensure we have ALPN protocols + if not server_config.alpn_protocols: + server_config.alpn_protocols = ["libp2p"] + + logger.debug("Successfully created server config without deepcopy") + return server_config + + except Exception as e: + logger.error(f"Failed to create server config: {e}") + raise + + +def create_client_config_from_base( + base_config: QuicConfiguration, + security_manager: QUICTLSConfigManager | None = None, + transport_config: QUICTransportConfig | None = None, +) -> QuicConfiguration: + """ + Create a client configuration without using deepcopy. + """ + try: + # Create new client configuration from scratch + client_config = QuicConfiguration(is_client=True) + client_config.verify_mode = ssl.CERT_NONE + + # Copy basic configuration attributes + copyable_attrs = [ + "alpn_protocols", + "verify_mode", + "max_datagram_frame_size", + "idle_timeout", + "max_concurrent_streams", + "supported_versions", + "max_data", + "max_stream_data", + "quantum_readiness_test", + ] + + for attr in copyable_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(client_config, attr, value) + + # Handle cryptography objects - these need direct reference, not copying + crypto_attrs = [ + "certificate", + "private_key", + "certificate_chain", + "ca_certs", + ] + + for attr in crypto_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(client_config, attr, value) + + # Apply security manager configuration if available + if security_manager: + try: + client_tls_config = security_manager.create_client_config() + + # Override with security manager's TLS configuration + if client_tls_config.certificate: + client_config.certificate = client_tls_config.certificate + if client_tls_config.private_key: + client_config.private_key = client_tls_config.private_key + if client_tls_config.certificate_chain: + client_config.certificate_chain = ( + client_tls_config.certificate_chain + ) + if client_tls_config.alpn_protocols: + client_config.alpn_protocols = client_tls_config.alpn_protocols + + except Exception as e: + logger.warning(f"Failed to apply security manager config: {e}") + + # Ensure we have ALPN protocols + if not client_config.alpn_protocols: + client_config.alpn_protocols = ["libp2p"] + + logger.debug("Successfully created client config without deepcopy") + return client_config + + except Exception as e: + logger.error(f"Failed to create client config: {e}") + raise diff --git a/libp2p/transport/transport_registry.py b/libp2p/transport/transport_registry.py new file mode 100644 index 00000000..2f6a4c8b --- /dev/null +++ b/libp2p/transport/transport_registry.py @@ -0,0 +1,267 @@ +""" +Transport registry for dynamic transport selection based on multiaddr protocols. +""" + +from collections.abc import Callable +import logging +from typing import Any + +from multiaddr import Multiaddr +from multiaddr.protocols import Protocol + +from libp2p.abc import ITransport +from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.multiaddr_utils import ( + is_valid_websocket_multiaddr, +) + + +# Import QUIC utilities here to avoid circular imports +def _get_quic_transport() -> Any: + from libp2p.transport.quic.transport import QUICTransport + + return QUICTransport + + +def _get_quic_validation() -> Callable[[Multiaddr], bool]: + from libp2p.transport.quic.utils import is_quic_multiaddr + + return is_quic_multiaddr + + +# Import WebsocketTransport here to avoid circular imports +def _get_websocket_transport() -> Any: + from libp2p.transport.websocket.transport import WebsocketTransport + + return WebsocketTransport + + +logger = logging.getLogger("libp2p.transport.registry") + + +def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool: + """ + Validate that a multiaddr has a valid TCP structure. + + :param maddr: The multiaddr to validate + :return: True if valid TCP structure, False otherwise + """ + try: + # TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080 + # or /ip6/::1/tcp/8080 + protocols: list[Protocol] = list(maddr.protocols()) + + # Must have at least 2 protocols: network (ip4/ip6) + tcp + if len(protocols) < 2: + return False + + # First protocol should be a network protocol (ip4, ip6, dns4, dns6) + if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]: + return False + + # Second protocol should be tcp + if protocols[1].name != "tcp": + return False + + # Should not have any protocols after tcp (unless it's a valid + # continuation like p2p) + # For now, we'll be strict and only allow network + tcp + if len(protocols) > 2: + # Check if the additional protocols are valid continuations + valid_continuations = ["p2p"] # Add more as needed + for i in range(2, len(protocols)): + if protocols[i].name not in valid_continuations: + return False + + return True + + except Exception: + return False + + +class TransportRegistry: + """ + Registry for mapping multiaddr protocols to transport implementations. + """ + + def __init__(self) -> None: + self._transports: dict[str, type[ITransport]] = {} + self._register_default_transports() + + def _register_default_transports(self) -> None: + """Register the default transport implementations.""" + # Register TCP transport for /tcp protocol + self.register_transport("tcp", TCP) + + # Register WebSocket transport for /ws and /wss protocols + WebsocketTransport = _get_websocket_transport() + self.register_transport("ws", WebsocketTransport) + self.register_transport("wss", WebsocketTransport) + + # Register QUIC transport for /quic and /quic-v1 protocols + QUICTransport = _get_quic_transport() + self.register_transport("quic", QUICTransport) + self.register_transport("quic-v1", QUICTransport) + + def register_transport( + self, protocol: str, transport_class: type[ITransport] + ) -> None: + """ + Register a transport class for a specific protocol. + + :param protocol: The protocol identifier (e.g., "tcp", "ws") + :param transport_class: The transport class to register + """ + self._transports[protocol] = transport_class + logger.debug( + f"Registered transport {transport_class.__name__} for protocol {protocol}" + ) + + def get_transport(self, protocol: str) -> type[ITransport] | None: + """ + Get the transport class for a specific protocol. + + :param protocol: The protocol identifier + :return: The transport class or None if not found + """ + return self._transports.get(protocol) + + def get_supported_protocols(self) -> list[str]: + """Get list of supported transport protocols.""" + return list(self._transports.keys()) + + def create_transport( + self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any + ) -> ITransport | None: + """ + Create a transport instance for a specific protocol. + + :param protocol: The protocol identifier + :param upgrader: The transport upgrader instance (required for WebSocket) + :param kwargs: Additional arguments for transport construction + :return: Transport instance or None if protocol not supported or creation fails + """ + transport_class = self.get_transport(protocol) + if transport_class is None: + return None + + try: + if protocol in ["ws", "wss"]: + # WebSocket transport requires upgrader + if upgrader is None: + logger.warning( + f"WebSocket transport '{protocol}' requires upgrader" + ) + return None + # Use explicit WebsocketTransport to avoid type issues + WebsocketTransport = _get_websocket_transport() + return WebsocketTransport( + upgrader, + tls_client_config=kwargs.get("tls_client_config"), + tls_server_config=kwargs.get("tls_server_config"), + handshake_timeout=kwargs.get("handshake_timeout", 15.0), + ) + elif protocol in ["quic", "quic-v1"]: + # QUIC transport requires private_key + private_key = kwargs.get("private_key") + if private_key is None: + logger.warning(f"QUIC transport '{protocol}' requires private_key") + return None + # Use explicit QUICTransport to avoid type issues + QUICTransport = _get_quic_transport() + config = kwargs.get("config") + return QUICTransport(private_key, config) + else: + # TCP transport doesn't require upgrader + return transport_class() + except Exception as e: + logger.error(f"Failed to create transport for protocol {protocol}: {e}") + return None + + +# Global transport registry instance (lazy initialization) +_global_registry: TransportRegistry | None = None + + +def get_transport_registry() -> TransportRegistry: + """Get the global transport registry instance.""" + global _global_registry + if _global_registry is None: + _global_registry = TransportRegistry() + return _global_registry + + +def register_transport(protocol: str, transport_class: type[ITransport]) -> None: + """Register a transport class in the global registry.""" + registry = get_transport_registry() + registry.register_transport(protocol, transport_class) + + +def create_transport_for_multiaddr( + maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any +) -> ITransport | None: + """ + Create the appropriate transport for a given multiaddr. + + :param maddr: The multiaddr to create transport for + :param upgrader: The transport upgrader instance + :param kwargs: Additional arguments for transport construction + (e.g., private_key for QUIC) + :return: Transport instance or None if no suitable transport found + """ + try: + # Get all protocols in the multiaddr + protocols = [proto.name for proto in maddr.protocols()] + + # Check for supported transport protocols in order of preference + # We need to validate that the multiaddr structure is valid for our transports + if "quic" in protocols or "quic-v1" in protocols: + # For QUIC, we need a valid structure like: + # /ip4/127.0.0.1/udp/4001/quic + # /ip4/127.0.0.1/udp/4001/quic-v1 + is_quic_multiaddr = _get_quic_validation() + if is_quic_multiaddr(maddr): + # Determine QUIC version + registry = get_transport_registry() + if "quic-v1" in protocols: + return registry.create_transport("quic-v1", upgrader, **kwargs) + else: + return registry.create_transport("quic", upgrader, **kwargs) + elif "ws" in protocols or "wss" in protocols or "tls" in protocols: + # For WebSocket, we need a valid structure like: + # /ip4/127.0.0.1/tcp/8080/ws (insecure) + # /ip4/127.0.0.1/tcp/8080/wss (secure) + # /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS) + # /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI) + if is_valid_websocket_multiaddr(maddr): + # Determine if this is a secure WebSocket connection + registry = get_transport_registry() + if "wss" in protocols or "tls" in protocols: + return registry.create_transport("wss", upgrader, **kwargs) + else: + return registry.create_transport("ws", upgrader, **kwargs) + elif "tcp" in protocols: + # For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080 + # Check if the multiaddr has proper TCP structure + if _is_valid_tcp_multiaddr(maddr): + registry = get_transport_registry() + return registry.create_transport("tcp", upgrader) + + # If no supported transport protocol found or structure is invalid, return None + logger.warning( + f"No supported transport protocol found or invalid structure in " + f"multiaddr: {maddr}" + ) + return None + + except Exception as e: + # Handle any errors gracefully (e.g., invalid multiaddr) + logger.warning(f"Error processing multiaddr {maddr}: {e}") + return None + + +def get_supported_transport_protocols() -> list[str]: + """Get list of supported transport protocols from the global registry.""" + registry = get_transport_registry() + return registry.get_supported_protocols() diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 40ba5321..dad2ad72 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, MultiselectError, ) +from libp2p.protocol_muxer.multiselect import ( + DEFAULT_NEGOTIATE_TIMEOUT, +) from libp2p.security.exceptions import ( HandshakeFailure, ) @@ -37,9 +40,12 @@ class TransportUpgrader: self, secure_transports_by_protocol: TSecurityOptions, muxer_transports_by_protocol: TMuxerOptions, + negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, ): self.security_multistream = SecurityMultistream(secure_transports_by_protocol) - self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) + self.muxer_multistream = MuxerMultistream( + muxer_transports_by_protocol, negotiate_timeout + ) async def upgrade_security( self, diff --git a/libp2p/transport/websocket/connection.py b/libp2p/transport/websocket/connection.py new file mode 100644 index 00000000..f5be8812 --- /dev/null +++ b/libp2p/transport/websocket/connection.py @@ -0,0 +1,198 @@ +import logging +import time +from typing import Any + +import trio + +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException + +logger = logging.getLogger(__name__) + + +class P2PWebSocketConnection(ReadWriteCloser): + """ + Wraps a WebSocketConnection to provide the raw stream interface + that libp2p protocols expect. + + Implements production-ready buffer management and flow control + as recommended in the libp2p WebSocket specification. + """ + + def __init__( + self, + ws_connection: Any, + ws_context: Any = None, + is_secure: bool = False, + max_buffered_amount: int = 4 * 1024 * 1024, + ) -> None: + self._ws_connection = ws_connection + self._ws_context = ws_context + self._is_secure = is_secure + self._read_buffer = b"" + self._read_lock = trio.Lock() + self._connection_start_time = time.time() + self._bytes_read = 0 + self._bytes_written = 0 + self._closed = False + self._close_lock = trio.Lock() + self._max_buffered_amount = max_buffered_amount + self._write_lock = trio.Lock() + + async def write(self, data: bytes) -> None: + """Write data with flow control and buffer management""" + if self._closed: + raise IOException("Connection is closed") + + async with self._write_lock: + try: + logger.debug(f"WebSocket writing {len(data)} bytes") + + # Check buffer amount for flow control + if hasattr(self._ws_connection, "bufferedAmount"): + buffered = self._ws_connection.bufferedAmount + if buffered > self._max_buffered_amount: + logger.warning(f"WebSocket buffer full: {buffered} bytes") + # In production, you might want to + # wait or implement backpressure + # For now, we'll continue but log the warning + + # Send as a binary WebSocket message + await self._ws_connection.send_message(data) + self._bytes_written += len(data) + logger.debug(f"WebSocket wrote {len(data)} bytes successfully") + + except Exception as e: + logger.error(f"WebSocket write failed: {e}") + self._closed = True + raise IOException from e + + async def read(self, n: int | None = None) -> bytes: + """ + Read up to n bytes (if n is given), else read up to 64KiB. + This implementation provides byte-level access to WebSocket messages, + which is required for libp2p protocol compatibility. + + For WebSocket compatibility with libp2p protocols, this method: + 1. Buffers incoming WebSocket messages + 2. Returns exactly the requested number of bytes when n is specified + 3. Accumulates multiple WebSocket messages if needed to satisfy the request + 4. Returns empty bytes (not raises) when connection is closed and no data + available + """ + if self._closed: + raise IOException("Connection is closed") + + async with self._read_lock: + try: + # If n is None, read at least one message and return all buffered data + if n is None: + if not self._read_buffer: + try: + # Use a short timeout to avoid blocking indefinitely + with trio.fail_after(1.0): # 1 second timeout + message = await self._ws_connection.get_message() + if isinstance(message, str): + message = message.encode("utf-8") + self._read_buffer = message + except trio.TooSlowError: + # No message available within timeout + return b"" + except Exception: + # Return empty bytes if no data available + # (connection closed) + return b"" + + result = self._read_buffer + self._read_buffer = b"" + self._bytes_read += len(result) + return result + + # For specific byte count requests, return UP TO n bytes (not exactly n) + # This matches TCP semantics where read(1024) returns available data + # up to 1024 bytes + + # If we don't have any data buffered, try to get at least one message + if not self._read_buffer: + try: + # Use a short timeout to avoid blocking indefinitely + with trio.fail_after(1.0): # 1 second timeout + message = await self._ws_connection.get_message() + if isinstance(message, str): + message = message.encode("utf-8") + self._read_buffer = message + except trio.TooSlowError: + return b"" # No data available + except Exception: + return b"" + + # Now return up to n bytes from the buffer (TCP-like semantics) + if len(self._read_buffer) == 0: + return b"" + + # Return up to n bytes (like TCP read()) + result = self._read_buffer[:n] + self._read_buffer = self._read_buffer[len(result) :] + self._bytes_read += len(result) + return result + + except Exception as e: + logger.error(f"WebSocket read failed: {e}") + raise IOException from e + + async def close(self) -> None: + """Close the WebSocket connection. This method is idempotent.""" + async with self._close_lock: + if self._closed: + return # Already closed + + logger.debug("WebSocket connection closing") + self._closed = True + try: + # Always close the connection directly, avoid context manager issues + # The context manager may be causing cancel scope corruption + logger.debug("WebSocket closing connection directly") + await self._ws_connection.aclose() + # Exit the context manager if we have one + if self._ws_context is not None: + await self._ws_context.__aexit__(None, None, None) + except Exception as e: + logger.error(f"WebSocket close error: {e}") + # Don't raise here, as close() should be idempotent + finally: + logger.debug("WebSocket connection closed") + + def is_closed(self) -> bool: + """Check if the connection is closed""" + return self._closed + + def conn_state(self) -> dict[str, Any]: + """ + Return connection state information similar to Go's ConnState() method. + + :return: Dictionary containing connection state information + """ + current_time = time.time() + return { + "transport": "websocket", + "secure": self._is_secure, + "connection_duration": current_time - self._connection_start_time, + "bytes_read": self._bytes_read, + "bytes_written": self._bytes_written, + "total_bytes": self._bytes_read + self._bytes_written, + } + + def get_remote_address(self) -> tuple[str, int] | None: + # Try to get remote address from the WebSocket connection + try: + remote = self._ws_connection.remote + if hasattr(remote, "address") and hasattr(remote, "port"): + return str(remote.address), int(remote.port) + elif isinstance(remote, str): + # Parse address:port format + if ":" in remote: + host, port = remote.rsplit(":", 1) + return host, int(port) + except Exception: + pass + return None diff --git a/libp2p/transport/websocket/listener.py b/libp2p/transport/websocket/listener.py new file mode 100644 index 00000000..1ea3bc9b --- /dev/null +++ b/libp2p/transport/websocket/listener.py @@ -0,0 +1,225 @@ +from collections.abc import Awaitable, Callable +import logging +import ssl +from typing import Any + +from multiaddr import Multiaddr +import trio +from trio_typing import TaskStatus +from trio_websocket import serve_websocket + +from libp2p.abc import IListener +from libp2p.custom_types import THandler +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr + +from .connection import P2PWebSocketConnection + +logger = logging.getLogger("libp2p.transport.websocket.listener") + + +class WebsocketListener(IListener): + """ + Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection. + """ + + def __init__( + self, + handler: THandler, + upgrader: TransportUpgrader, + tls_config: ssl.SSLContext | None = None, + handshake_timeout: float = 15.0, + ) -> None: + self._handler = handler + self._upgrader = upgrader + self._tls_config = tls_config + self._handshake_timeout = handshake_timeout + self._server = None + self._shutdown_event = trio.Event() + self._nursery: trio.Nursery | None = None + self._listeners: Any = None + self._is_wss = False # Track whether this is a WSS listener + + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + logger.debug(f"WebsocketListener.listen called with {maddr}") + + # Parse the WebSocket multiaddr to determine if it's secure + try: + parsed = parse_websocket_multiaddr(maddr) + except ValueError as e: + raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e + + # Check if WSS is requested but no TLS config provided + if parsed.is_wss and self._tls_config is None: + raise ValueError( + f"Cannot listen on WSS address {maddr} without TLS configuration" + ) + + # Store whether this is a WSS listener + self._is_wss = parsed.is_wss + + # Extract host and port from the base multiaddr + host = ( + parsed.rest_multiaddr.value_for_protocol("ip4") + or parsed.rest_multiaddr.value_for_protocol("ip6") + or parsed.rest_multiaddr.value_for_protocol("dns") + or parsed.rest_multiaddr.value_for_protocol("dns4") + or parsed.rest_multiaddr.value_for_protocol("dns6") + or "0.0.0.0" + ) + port_str = parsed.rest_multiaddr.value_for_protocol("tcp") + if port_str is None: + raise ValueError(f"No TCP port found in multiaddr: {maddr}") + port = int(port_str) + + logger.debug( + f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}" + ) + + async def serve_websocket_tcp( + handler: Callable[[Any], Awaitable[None]], + port: int, + host: str, + task_status: TaskStatus[Any], + ) -> None: + """Start TCP server and handle WebSocket connections manually""" + logger.debug( + "serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss + ) + + async def websocket_handler(request: Any) -> None: + """Handle WebSocket requests""" + logger.debug("WebSocket request received") + try: + # Apply handshake timeout + with trio.fail_after(self._handshake_timeout): + # Accept the WebSocket connection + ws_connection = await request.accept() + logger.debug("WebSocket handshake successful") + + # Create the WebSocket connection wrapper + conn = P2PWebSocketConnection( + ws_connection, is_secure=parsed.is_wss + ) # type: ignore[no-untyped-call] + + # Call the handler function that was passed to create_listener + # This handler will handle the security and muxing upgrades + logger.debug("Calling connection handler") + await self._handler(conn) + + # Don't keep the connection alive indefinitely + # Let the handler manage the connection lifecycle + logger.debug( + "Handler completed, connection will be managed by handler" + ) + + except trio.TooSlowError: + logger.debug( + f"WebSocket handshake timeout after {self._handshake_timeout}s" + ) + try: + await request.reject(408) # Request Timeout + except Exception: + pass + except Exception as e: + logger.debug(f"WebSocket connection error: {e}") + logger.debug(f"Error type: {type(e)}") + import traceback + + logger.debug(f"Traceback: {traceback.format_exc()}") + # Reject the connection + try: + await request.reject(400) + except Exception: + pass + + # Use trio_websocket.serve_websocket for proper WebSocket handling + ssl_context = self._tls_config if parsed.is_wss else None + await serve_websocket( + websocket_handler, host, port, ssl_context, task_status=task_status + ) + + # Store the nursery for shutdown + self._nursery = nursery + + # Start the server using nursery.start() like TCP does + logger.debug("Calling nursery.start()...") + started_listeners = await nursery.start( + serve_websocket_tcp, + None, # No handler needed since it's defined inside serve_websocket_tcp + port, + host, + ) + logger.debug(f"nursery.start() returned: {started_listeners}") + + if started_listeners is None: + logger.error(f"Failed to start WebSocket listener for {maddr}") + return False + + # Store the listeners for get_addrs() and close() - these are real + # SocketListener objects + self._listeners = started_listeners + logger.debug( + "WebsocketListener.listen returning True with WebSocketServer object" + ) + return True + + def get_addrs(self) -> tuple[Multiaddr, ...]: + if not hasattr(self, "_listeners") or not self._listeners: + logger.debug("No listeners available for get_addrs()") + return () + + # Handle WebSocketServer objects + if hasattr(self._listeners, "port"): + # This is a WebSocketServer object + port = self._listeners.port + # Create a multiaddr from the port with correct WSS/WS protocol + protocol = "wss" if self._is_wss else "ws" + return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),) + else: + # This is a list of listeners (like TCP) + listeners = self._listeners + # Get addresses from listeners like TCP does + return tuple( + _multiaddr_from_socket(listener.socket, self._is_wss) + for listener in listeners + ) + + async def close(self) -> None: + """Close the WebSocket listener and stop accepting new connections""" + logger.debug("WebsocketListener.close called") + if hasattr(self, "_listeners") and self._listeners: + # Signal shutdown + self._shutdown_event.set() + + # Close the WebSocket server + if hasattr(self._listeners, "aclose"): + # This is a WebSocketServer object + logger.debug("Closing WebSocket server") + await self._listeners.aclose() + logger.debug("WebSocket server closed") + elif isinstance(self._listeners, (list, tuple)): + # This is a list of listeners (like TCP) + logger.debug("Closing TCP listeners") + for listener in self._listeners: + await listener.aclose() + logger.debug("TCP listeners closed") + else: + # Unknown type, try to close it directly + logger.debug("Closing unknown listener type") + if hasattr(self._listeners, "close"): + self._listeners.close() + logger.debug("Unknown listener closed") + + # Clear the listeners reference + self._listeners = None + logger.debug("WebsocketListener.close completed") + + +def _multiaddr_from_socket( + socket: trio.socket.SocketType, is_wss: bool = False +) -> Multiaddr: + """Convert socket to multiaddr""" + ip, port = socket.getsockname() + protocol = "wss" if is_wss else "ws" + return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}") diff --git a/libp2p/transport/websocket/multiaddr_utils.py b/libp2p/transport/websocket/multiaddr_utils.py new file mode 100644 index 00000000..16a38073 --- /dev/null +++ b/libp2p/transport/websocket/multiaddr_utils.py @@ -0,0 +1,202 @@ +""" +WebSocket multiaddr parsing utilities. +""" + +from typing import NamedTuple + +from multiaddr import Multiaddr +from multiaddr.protocols import Protocol + + +class ParsedWebSocketMultiaddr(NamedTuple): + """Parsed WebSocket multiaddr information.""" + + is_wss: bool + sni: str | None + rest_multiaddr: Multiaddr + + +def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr: + """ + Parse a WebSocket multiaddr and extract security information. + + :param maddr: The multiaddr to parse + :return: Parsed WebSocket multiaddr information + :raises ValueError: If the multiaddr is not a valid WebSocket multiaddr + """ + # First validate that this is a valid WebSocket multiaddr + if not is_valid_websocket_multiaddr(maddr): + raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}") + + protocols = list(maddr.protocols()) + + # Find the WebSocket protocol and check for security + is_wss = False + sni = None + ws_index = -1 + tls_index = -1 + sni_index = -1 + + # Find protocol indices + for i, protocol in enumerate(protocols): + if protocol.name == "ws": + ws_index = i + elif protocol.name == "wss": + ws_index = i + is_wss = True + elif protocol.name == "tls": + tls_index = i + elif protocol.name == "sni": + sni_index = i + sni = protocol.value + + if ws_index == -1: + raise ValueError("Not a WebSocket multiaddr") + + # Handle /wss protocol (convert to /tls/ws internally) + if is_wss and tls_index == -1: + # Convert /wss to /tls/ws format + # Remove /wss to get the base multiaddr + without_wss = maddr.decapsulate(Multiaddr("/wss")) + return ParsedWebSocketMultiaddr( + is_wss=True, sni=None, rest_multiaddr=without_wss + ) + + # Handle /tls/ws and /tls/sni/.../ws formats + if tls_index != -1: + is_wss = True + # Extract the base multiaddr (everything before /tls) + # For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080 + # Use multiaddr methods to properly extract the base + rest_multiaddr = maddr + # Remove /tls/ws or /tls/sni/.../ws from the end + if sni_index != -1: + # /tls/sni/example.com/ws format + rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws")) + rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}")) + rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls")) + else: + # /tls/ws format + rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws")) + rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls")) + return ParsedWebSocketMultiaddr( + is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr + ) + + # Regular /ws multiaddr - remove /ws and any additional protocols + rest_multiaddr = maddr.decapsulate(Multiaddr("/ws")) + return ParsedWebSocketMultiaddr( + is_wss=False, sni=None, rest_multiaddr=rest_multiaddr + ) + + +def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool: + """ + Validate that a multiaddr has a valid WebSocket structure. + + :param maddr: The multiaddr to validate + :return: True if valid WebSocket structure, False otherwise + """ + try: + # WebSocket multiaddr should have structure like: + # /ip4/127.0.0.1/tcp/8080/ws (insecure) + # /ip4/127.0.0.1/tcp/8080/wss (secure) + # /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS) + # /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI) + protocols: list[Protocol] = list(maddr.protocols()) + + # Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss + if len(protocols) < 3: + return False + + # First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6) + if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]: + return False + + # Second protocol should be tcp + if protocols[1].name != "tcp": + return False + + # Check for valid WebSocket protocols + ws_protocols = ["ws", "wss"] + tls_protocols = ["tls"] + sni_protocols = ["sni"] + + # Find the WebSocket protocol + ws_protocol_found = False + tls_found = False + # sni_found = False # Not used currently + + for i, protocol in enumerate(protocols[2:], start=2): + if protocol.name in ws_protocols: + ws_protocol_found = True + break + elif protocol.name in tls_protocols: + tls_found = True + elif protocol.name in sni_protocols: + pass # sni_found = True # Not used in current implementation + + if not ws_protocol_found: + return False + + # Validate protocol sequence + # For /ws: network + tcp + ws + # For /wss: network + tcp + wss + # For /tls/ws: network + tcp + tls + ws + # For /tls/sni/example.com/ws: network + tcp + tls + sni + ws + + # Check if it's a simple /ws or /wss + if len(protocols) == 3: + return protocols[2].name in ["ws", "wss"] + + # Check for /tls/ws or /tls/sni/.../ws patterns + if tls_found: + # Must end with /ws (not /wss when using /tls) + if protocols[-1].name != "ws": + return False + + # Check for valid TLS sequence + tls_index = None + for i, protocol in enumerate(protocols[2:], start=2): + if protocol.name == "tls": + tls_index = i + break + + if tls_index is None: + return False + + # After tls, we can have sni, then ws + remaining_protocols = protocols[tls_index + 1 :] + if len(remaining_protocols) == 1: + # /tls/ws + return remaining_protocols[0].name == "ws" + elif len(remaining_protocols) == 2: + # /tls/sni/example.com/ws + return ( + remaining_protocols[0].name == "sni" + and remaining_protocols[1].name == "ws" + ) + else: + return False + + # If we have more than 3 protocols but no TLS, check for valid continuations + # Allow additional protocols after the WebSocket protocol (like /p2p) + valid_continuations = ["p2p"] + + # Find the WebSocket protocol index + ws_index = None + for i, protocol in enumerate(protocols): + if protocol.name in ["ws", "wss"]: + ws_index = i + break + + if ws_index is not None: + # Check protocols after the WebSocket protocol + for i in range(ws_index + 1, len(protocols)): + if protocols[i].name not in valid_continuations: + return False + + return True + + except Exception: + return False diff --git a/libp2p/transport/websocket/transport.py b/libp2p/transport/websocket/transport.py new file mode 100644 index 00000000..30da5942 --- /dev/null +++ b/libp2p/transport/websocket/transport.py @@ -0,0 +1,229 @@ +import logging +import ssl + +from multiaddr import Multiaddr + +from libp2p.abc import IListener, ITransport +from libp2p.custom_types import THandler +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.transport.exceptions import OpenConnectionError +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr + +from .connection import P2PWebSocketConnection +from .listener import WebsocketListener + +logger = logging.getLogger(__name__) + + +class WebsocketTransport(ITransport): + """ + Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss + + Implements production-ready WebSocket transport with: + - Flow control and buffer management + - Connection limits and rate limiting + - Proper error handling and cleanup + - Support for both WS and WSS protocols + - TLS configuration and handshake timeout + """ + + def __init__( + self, + upgrader: TransportUpgrader, + tls_client_config: ssl.SSLContext | None = None, + tls_server_config: ssl.SSLContext | None = None, + handshake_timeout: float = 15.0, + max_buffered_amount: int = 4 * 1024 * 1024, + ): + self._upgrader = upgrader + self._tls_client_config = tls_client_config + self._tls_server_config = tls_server_config + self._handshake_timeout = handshake_timeout + self._max_buffered_amount = max_buffered_amount + self._connection_count = 0 + self._max_connections = 1000 # Production limit + + async def dial(self, maddr: Multiaddr) -> RawConnection: + """Dial a WebSocket connection to the given multiaddr.""" + logger.debug(f"WebsocketTransport.dial called with {maddr}") + + # Parse the WebSocket multiaddr to determine if it's secure + try: + parsed = parse_websocket_multiaddr(maddr) + except ValueError as e: + raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e + + # Extract host and port from the base multiaddr + host = ( + parsed.rest_multiaddr.value_for_protocol("ip4") + or parsed.rest_multiaddr.value_for_protocol("ip6") + or parsed.rest_multiaddr.value_for_protocol("dns") + or parsed.rest_multiaddr.value_for_protocol("dns4") + or parsed.rest_multiaddr.value_for_protocol("dns6") + ) + port_str = parsed.rest_multiaddr.value_for_protocol("tcp") + if port_str is None: + raise ValueError(f"No TCP port found in multiaddr: {maddr}") + port = int(port_str) + + # Build WebSocket URL based on security + if parsed.is_wss: + ws_url = f"wss://{host}:{port}/" + else: + ws_url = f"ws://{host}:{port}/" + + logger.debug( + f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})" + ) + + try: + # Check connection limits + if self._connection_count >= self._max_connections: + raise OpenConnectionError( + f"Maximum connections reached: {self._max_connections}" + ) + + # Prepare SSL context for WSS connections + ssl_context = None + if parsed.is_wss: + if self._tls_client_config: + ssl_context = self._tls_client_config + else: + # Create default SSL context for client + ssl_context = ssl.create_default_context() + # Set SNI if available + if parsed.sni: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}") + + # Use a different approach: start background nursery that will persist + logger.debug("WebsocketTransport.dial establishing connection") + + # Import trio-websocket functions + from trio_websocket import connect_websocket + from trio_websocket._impl import _url_to_host + + # Parse the WebSocket URL to get host, port, resource + # like trio-websocket does + ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host( + ws_url, ssl_context + ) + + logger.debug( + f"WebsocketTransport.dial parsed URL: host={ws_host}, " + f"port={ws_port}, resource={ws_resource}" + ) + + # Create a background task manager for this connection + import trio + + nursery_manager = trio.lowlevel.current_task().parent_nursery + if nursery_manager is None: + raise OpenConnectionError( + f"No parent nursery available for WebSocket connection to {maddr}" + ) + + # Apply timeout to the connection process + with trio.fail_after(self._handshake_timeout): + logger.debug("WebsocketTransport.dial connecting WebSocket") + ws = await connect_websocket( + nursery_manager, # Use the existing nursery from libp2p + ws_host, + ws_port, + ws_resource, + use_ssl=ws_ssl_context, + message_queue_size=1024, # Reasonable defaults + max_message_size=16 * 1024 * 1024, # 16MB max message + ) + logger.debug("WebsocketTransport.dial WebSocket connection established") + + # Create our connection wrapper with both WSS support and flow control + conn = P2PWebSocketConnection( + ws, + None, + is_secure=parsed.is_wss, + max_buffered_amount=self._max_buffered_amount, + ) + logger.debug("WebsocketTransport.dial created P2PWebSocketConnection") + + self._connection_count += 1 + logger.debug(f"Total connections: {self._connection_count}") + + return RawConnection(conn, initiator=True) + except trio.TooSlowError as e: + raise OpenConnectionError( + f"WebSocket handshake timeout after {self._handshake_timeout}s " + f"for {maddr}" + ) from e + except Exception as e: + logger.error(f"Failed to dial WebSocket {maddr}: {e}") + raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e + + def create_listener(self, handler: THandler) -> IListener: # type: ignore[override] + """ + The type checker is incorrectly reporting this as an inconsistent override. + """ + logger.debug("WebsocketTransport.create_listener called") + return WebsocketListener( + handler, self._upgrader, self._tls_server_config, self._handshake_timeout + ) + + def resolve(self, maddr: Multiaddr) -> list[Multiaddr]: + """ + Resolve a WebSocket multiaddr, automatically adding SNI for DNS names. + Similar to Go's Resolve() method. + + :param maddr: The multiaddr to resolve + :return: List of resolved multiaddrs + """ + try: + parsed = parse_websocket_multiaddr(maddr) + except ValueError as e: + logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}") + return [maddr] # Return original if not a valid WebSocket multiaddr + + logger.debug( + f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}" + ) + + if not parsed.is_wss: + # No /tls/ws component, this isn't a secure websocket multiaddr + return [maddr] + + if parsed.sni is not None: + # Already has SNI, return as-is + return [maddr] + + # Try to extract DNS name from the base multiaddr + dns_name = None + for protocol_name in ["dns", "dns4", "dns6"]: + try: + dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name) + break + except Exception: + continue + + if dns_name is None: + # No DNS name found, return original + return [maddr] + + # Create new multiaddr with SNI + # For /dns/example.com/tcp/8080/wss -> + # /dns/example.com/tcp/8080/tls/sni/example.com/ws + try: + # Remove /wss and add /tls/sni/example.com/ws + without_wss = maddr.decapsulate(Multiaddr("/wss")) + sni_component = Multiaddr(f"/sni/{dns_name}") + resolved = ( + without_wss.encapsulate(Multiaddr("/tls")) + .encapsulate(sni_component) + .encapsulate(Multiaddr("/ws")) + ) + logger.debug(f"Resolved {maddr} to {resolved}") + return [resolved] + except Exception as e: + logger.debug(f"Failed to resolve multiaddr {maddr}: {e}") + return [maddr] diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py index 77b797a1..5ce58671 100644 --- a/libp2p/utils/address_validation.py +++ b/libp2p/utils/address_validation.py @@ -3,38 +3,24 @@ from __future__ import annotations import socket from multiaddr import Multiaddr - -try: - from multiaddr.utils import ( # type: ignore - get_network_addrs, - get_thin_waist_addresses, - ) - - _HAS_THIN_WAIST = True -except ImportError: # pragma: no cover - only executed in older environments - _HAS_THIN_WAIST = False - get_thin_waist_addresses = None # type: ignore - get_network_addrs = None # type: ignore +from multiaddr.utils import get_network_addrs, get_thin_waist_addresses def _safe_get_network_addrs(ip_version: int) -> list[str]: """ Internal safe wrapper. Returns a list of IP addresses for the requested IP version. - Falls back to minimal defaults when Thin Waist helpers are missing. :param ip_version: 4 or 6 """ - if _HAS_THIN_WAIST and get_network_addrs: - try: - return get_network_addrs(ip_version) or [] - except Exception: # pragma: no cover - defensive - return [] - # Fallback behavior (very conservative) - if ip_version == 4: - return ["127.0.0.1"] - if ip_version == 6: - return ["::1"] - return [] + try: + return get_network_addrs(ip_version) or [] + except Exception: # pragma: no cover - defensive + # Fallback behavior (very conservative) + if ip_version == 4: + return ["127.0.0.1"] + if ip_version == 6: + return ["::1"] + return [] def find_free_port() -> int: @@ -47,16 +33,13 @@ def find_free_port() -> int: def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]: """ Internal safe expansion wrapper. Returns a list of Multiaddr objects. - If Thin Waist isn't available, returns [addr] (identity). """ - if _HAS_THIN_WAIST and get_thin_waist_addresses: - try: - if port is not None: - return get_thin_waist_addresses(addr, port=port) or [] - return get_thin_waist_addresses(addr) or [] - except Exception: # pragma: no cover - defensive - return [addr] - return [addr] + try: + if port is not None: + return get_thin_waist_addresses(addr, port=port) or [] + return get_thin_waist_addresses(addr) or [] + except Exception: # pragma: no cover - defensive + return [addr] def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]: @@ -73,8 +56,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr seen_v4: set[str] = set() for ip in _safe_get_network_addrs(4): - seen_v4.add(ip) - addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) + if ip not in seen_v4: # Avoid duplicates + seen_v4.add(ip) + addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) # Ensure IPv4 loopback is always included when IPv4 interfaces are discovered if seen_v4 and "127.0.0.1" not in seen_v4: @@ -89,8 +73,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr # # seen_v6: set[str] = set() # for ip in _safe_get_network_addrs(6): - # seen_v6.add(ip) - # addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) + # if ip not in seen_v6: # Avoid duplicates + # seen_v6.add(ip) + # addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) # # # Always include IPv6 loopback for testing purposes when IPv6 is available # # This ensures IPv6 functionality can be tested even without global IPv6 addresses @@ -99,7 +84,7 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr # Fallback if nothing discovered if not addrs: - addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")) + addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")) return addrs @@ -120,6 +105,20 @@ def expand_wildcard_address( return expanded +def get_wildcard_address(port: int, protocol: str = "tcp") -> Multiaddr: + """ + Get wildcard address (0.0.0.0) when explicitly needed. + + This function provides access to wildcard binding as a feature when + explicitly required, preserving the ability to bind to all interfaces. + + :param port: Port number. + :param protocol: Transport protocol. + :return: A Multiaddr with wildcard binding (0.0.0.0). + """ + return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + + def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr: """ Choose an optimal address for an example to bind to: @@ -148,13 +147,14 @@ def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr: if "/ip4/127." in str(c) or "/ip6/::1" in str(c): return c - # As a final fallback, produce a wildcard - return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + # As a final fallback, produce a loopback address + return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}") __all__ = [ "get_available_interfaces", "get_optimal_binding_address", + "get_wildcard_address", "expand_wildcard_address", "find_free_port", ] diff --git a/newsfragments/585.feature.rst b/newsfragments/585.feature.rst new file mode 100644 index 00000000..ca9ef3dc --- /dev/null +++ b/newsfragments/585.feature.rst @@ -0,0 +1,12 @@ +Added experimental WebSocket transport support with basic WS and WSS functionality. This includes: + +- WebSocket transport implementation with trio-websocket backend +- Support for both WS (WebSocket) and WSS (WebSocket Secure) protocols +- Basic connection management and stream handling +- TLS configuration support for WSS connections +- Multiaddr parsing for WebSocket addresses +- Integration with libp2p host and peer discovery + +**Note**: This is experimental functionality. Advanced features like proxy support, +interop testing, and production examples are still in development. See + https://github.com/libp2p/py-libp2p/discussions/937 for the complete roadmap of missing features. diff --git a/newsfragments/763.feature.rst b/newsfragments/763.feature.rst new file mode 100644 index 00000000..838b0cae --- /dev/null +++ b/newsfragments/763.feature.rst @@ -0,0 +1 @@ +Add QUIC transport support for faster, more efficient peer-to-peer connections with native stream multiplexing. diff --git a/newsfragments/843.bugfix.rst b/newsfragments/843.bugfix.rst new file mode 100644 index 00000000..6160bbc7 --- /dev/null +++ b/newsfragments/843.bugfix.rst @@ -0,0 +1 @@ +Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module. diff --git a/newsfragments/885.feature.rst b/newsfragments/885.feature.rst new file mode 100644 index 00000000..e255be4f --- /dev/null +++ b/newsfragments/885.feature.rst @@ -0,0 +1,2 @@ +Updated all example scripts and core modules to use secure loopback addresses instead of wildcard addresses for network binding. +The `get_wildcard_address` function and related logic now utilize all available interfaces safely, improving security and consistency across the codebase. diff --git a/newsfragments/896.bugfix.rst b/newsfragments/896.bugfix.rst new file mode 100644 index 00000000..aaf338d4 --- /dev/null +++ b/newsfragments/896.bugfix.rst @@ -0,0 +1 @@ +Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly diff --git a/newsfragments/897.bugfix.rst b/newsfragments/897.bugfix.rst new file mode 100644 index 00000000..575b5769 --- /dev/null +++ b/newsfragments/897.bugfix.rst @@ -0,0 +1,6 @@ +enhancement: Add write lock to `YamuxStream` to prevent concurrent write race conditions + +- Implements ReadWriteLock for `YamuxStream` write operations +- Prevents data corruption from concurrent write operations +- Read operations remain lock-free due to existing `Yamux` architecture +- Resolves race conditions identified in Issue #793 diff --git a/newsfragments/917.internal.rst b/newsfragments/917.internal.rst new file mode 100644 index 00000000..ed06f3ed --- /dev/null +++ b/newsfragments/917.internal.rst @@ -0,0 +1,11 @@ +Replace magic numbers with named constants and enums for clarity and maintainability + +**Key Changes:** +- **Introduced type-safe enums** for better code clarity: + - `RelayRole(Flag)` enum with HOP, STOP, CLIENT roles supporting bitwise combinations (e.g., `RelayRole.HOP | RelayRole.STOP`) + - `ReservationStatus(Enum)` for reservation lifecycle management (ACTIVE, EXPIRED, REJECTED) +- **Replaced magic numbers with named constants** throughout the codebase, improving code maintainability and eliminating hardcoded timeout values (15s, 30s, 10s) with descriptive constant names +- **Added comprehensive timeout configuration system** with new `TimeoutConfig` dataclass supporting component-specific timeouts (discovery, protocol, DCUtR) +- **Enhanced configurability** of `RelayDiscovery`, `CircuitV2Protocol`, and `DCUtRProtocol` constructors with optional timeout parameters +- **Improved architecture consistency** with clean configuration flow across all circuit relay components +**Backward Compatibility:** All changes maintain full backward compatibility. Existing code continues to work unchanged while new timeout configuration options are available for users who need them. diff --git a/newsfragments/927.bugfix.rst b/newsfragments/927.bugfix.rst new file mode 100644 index 00000000..99573ff9 --- /dev/null +++ b/newsfragments/927.bugfix.rst @@ -0,0 +1 @@ +Fix multiaddr dependency to use the last py-multiaddr commit hash to resolve installation issues diff --git a/newsfragments/934.misc.rst b/newsfragments/934.misc.rst new file mode 100644 index 00000000..0a6d9120 --- /dev/null +++ b/newsfragments/934.misc.rst @@ -0,0 +1 @@ +Updated multiaddr dependency from git repository to pip package version 0.0.11. diff --git a/newsfragments/952.bugfix.rst b/newsfragments/952.bugfix.rst new file mode 100644 index 00000000..3dcd6407 --- /dev/null +++ b/newsfragments/952.bugfix.rst @@ -0,0 +1 @@ +Fixed Windows CI/CD tests to use correct Python version instead of hardcoded Python 3.11. test 2 diff --git a/pyproject.toml b/pyproject.toml index 7f08697e..1e3c4a02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,14 @@ maintainers = [ { name = "Dave Grantham", email = "dwg@linuxprogrammer.org" }, ] dependencies = [ + "aioquic>=1.2.0", "base58>=1.0.3", - "coincurve>=10.0.0", + "coincurve==21.0.0", "exceptiongroup>=1.2.0; python_version < '3.11'", + "fastecdsa==2.3.2; sys_platform != 'win32'", "grpcio>=1.41.0", "lru-dict>=1.1.6", - # "multiaddr>=0.0.9", - "multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6", + "multiaddr>=0.0.11", "mypy-protobuf>=3.0.0", "noiseprotocol>=0.3.0", "protobuf>=4.25.0,<5.0.0", @@ -32,7 +33,7 @@ dependencies = [ "rpcudp>=3.0.0", "trio-typing>=0.0.4", "trio>=0.26.0", - "fastecdsa==2.3.2; sys_platform != 'win32'", + "trio-websocket>=0.11.0", "zeroconf (>=0.147.0,<0.148.0)", ] classifiers = [ @@ -52,6 +53,7 @@ Homepage = "https://github.com/libp2p/py-libp2p" [project.scripts] chat-demo = "examples.chat.chat:main" echo-demo = "examples.echo.echo:main" +echo-quic-demo="examples.echo.echo_quic:main" ping-demo = "examples.ping.ping:main" identify-demo = "examples.identify.identify:main" identify-push-demo = "examples.identify_push.identify_push_demo:run_main" @@ -77,6 +79,7 @@ dev = [ "pytest>=7.0.0", "pytest-xdist>=2.4.0", "pytest-trio>=0.5.2", + "pytest-timeout>=2.4.0", "factory-boy>=2.12.0,<3.0.0", "ruff>=0.11.10", "pyrefly (>=0.17.1,<0.18.0)", @@ -88,11 +91,12 @@ docs = [ "tomli; python_version < '3.11'", ] test = [ + "factory-boy>=2.12.0,<3.0.0", "p2pclient==0.2.0", "pytest>=7.0.0", - "pytest-xdist>=2.4.0", + "pytest-timeout>=2.4.0", "pytest-trio>=0.5.2", - "factory-boy>=2.12.0,<3.0.0", + "pytest-xdist>=2.4.0", ] [tool.setuptools] @@ -282,4 +286,5 @@ project_excludes = [ "**/*pb2.py", "**/*.pyi", ".venv/**", + "./tests/interop/nim_libp2p", ] diff --git a/tests/conftest.py b/tests/conftest.py index ba3b7da0..343a03d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ import pytest - @pytest.fixture def security_protocol(): - return None + return None \ No newline at end of file diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index df08ff98..47bc3ace 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -250,10 +250,13 @@ def test_new_swarm_tcp_multiaddr_supported(): assert isinstance(swarm.transport, TCP) -def test_new_swarm_quic_multiaddr_raises(): +def test_new_swarm_quic_multiaddr_supported(): + from libp2p.transport.quic.transport import QUICTransport + addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic") - with pytest.raises(ValueError, match="QUIC not yet supported"): - new_swarm(listen_addrs=[addr]) + swarm = new_swarm(listen_addrs=[addr]) + assert isinstance(swarm, Swarm) + assert isinstance(swarm.transport, QUICTransport) @pytest.mark.trio diff --git a/tests/core/pubsub/test_gossipsub.py b/tests/core/pubsub/test_gossipsub.py index 91205b29..5c341d0b 100644 --- a/tests/core/pubsub/test_gossipsub.py +++ b/tests/core/pubsub/test_gossipsub.py @@ -1,4 +1,8 @@ import random +from unittest.mock import ( + AsyncMock, + MagicMock, +) import pytest import trio @@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import ( PROTOCOL_ID, GossipSub, ) +from libp2p.pubsub.pb import ( + rpc_pb2, +) from libp2p.tools.utils import ( connect, ) @@ -754,3 +761,173 @@ async def test_single_host(): assert connected_peers == 0, ( f"Single host has {connected_peers} connections, expected 0" ) + + +@pytest.mark.trio +async def test_handle_ihave(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock emit_iwant to capture calls + mock_emit_iwant = AsyncMock() + monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant) + + # Create a test message ID as a string representation of a (seqno, from) tuple + test_seqno = b"1234" + test_from = id_bob.to_bytes() + test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')" + ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id]) + + # Mock seen_messages.cache to avoid false positives + monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {}) + + # Simulate Bob sending IHAVE to Alice + await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob) + + # Check if emit_iwant was called with the correct message ID + mock_emit_iwant.assert_called_once() + called_args = mock_emit_iwant.call_args[0] + assert called_args[0] == [test_msg_id] # Expected message IDs + assert called_args[1] == id_bob # Sender peer ID + + +@pytest.mark.trio +async def test_handle_iwant(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_alice = pubsubs_gsub[index_alice].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock mcache.get to return a message + test_message = rpc_pb2.Message(data=b"test_data") + test_seqno = b"1234" + test_from = id_alice.to_bytes() + + # āœ… Correct: use raw tuple and str() to serialize, no hex() + test_msg_id = str((test_seqno, test_from)) + + mock_mcache_get = MagicMock(return_value=test_message) + monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get) + + # Mock write_msg to capture the sent packet + mock_write_msg = AsyncMock() + monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg) + + # Simulate Alice sending IWANT to Bob + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id]) + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + + # Check if write_msg was called with the correct packet + mock_write_msg.assert_called_once() + packet = mock_write_msg.call_args[0][1] + assert isinstance(packet, rpc_pb2.RPC) + assert len(packet.publish) == 1 + assert packet.publish[0] == test_message + + # Verify that mcache.get was called with the correct parsed message ID + mock_mcache_get.assert_called_once() + called_msg_id = mock_mcache_get.call_args[0][0] + assert isinstance(called_msg_id, tuple) + assert called_msg_id == (test_seqno, test_from) + + +@pytest.mark.trio +async def test_handle_iwant_invalid_msg_id(monkeypatch): + """ + Test that handle_iwant raises ValueError for malformed message IDs. + """ + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_alice = pubsubs_gsub[index_alice].my_id + + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) + + # Malformed message ID (not a tuple string) + malformed_msg_id = "not_a_valid_msg_id" + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id]) + + # Mock mcache.get and write_msg to ensure they are not called + mock_mcache_get = MagicMock() + monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get) + mock_write_msg = AsyncMock() + monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg) + + with pytest.raises(ValueError): + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + mock_mcache_get.assert_not_called() + mock_write_msg.assert_not_called() + + # Message ID that's a tuple string but not (bytes, bytes) + invalid_tuple_msg_id = "('abc', 123)" + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id]) + with pytest.raises(ValueError): + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + mock_mcache_get.assert_not_called() + mock_write_msg.assert_not_called() + + +@pytest.mark.trio +async def test_handle_ihave_empty_message_ids(monkeypatch): + """ + Test that handle_ihave with an empty messageIDs list does not call emit_iwant. + """ + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock emit_iwant to capture calls + mock_emit_iwant = AsyncMock() + monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant) + + # Empty messageIDs list + ihave_msg = rpc_pb2.ControlIHave(messageIDs=[]) + + # Mock seen_messages.cache to avoid false positives + monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {}) + + # Simulate Bob sending IHAVE to Alice + await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob) + + # emit_iwant should not be called since there are no message IDs + mock_emit_iwant.assert_not_called() diff --git a/tests/core/pubsub/test_gossipsub_px_and_backoff.py b/tests/core/pubsub/test_gossipsub_px_and_backoff.py index 72ad5f9d..9701b6e5 100644 --- a/tests/core/pubsub/test_gossipsub_px_and_backoff.py +++ b/tests/core/pubsub/test_gossipsub_px_and_backoff.py @@ -65,7 +65,7 @@ async def test_prune_backoff(): @pytest.mark.trio async def test_unsubscribe_backoff(): async with PubsubFactory.create_batch_with_gossipsub( - 2, heartbeat_interval=1, prune_back_off=1, unsubscribe_back_off=2 + 2, heartbeat_interval=0.5, prune_back_off=2, unsubscribe_back_off=4 ) as pubsubs: gsub0 = pubsubs[0].router gsub1 = pubsubs[1].router @@ -107,7 +107,8 @@ async def test_unsubscribe_backoff(): ) # try to graft again (should succeed after backoff) - await trio.sleep(1) + # Wait longer than unsubscribe_back_off (4 seconds) + some buffer + await trio.sleep(4.5) await gsub0.emit_graft(topic, host_1.get_id()) await trio.sleep(1) assert host_0.get_id() in gsub1.mesh[topic], ( diff --git a/tests/core/stream_muxer/test_muxer_multistream.py b/tests/core/stream_muxer/test_muxer_multistream.py new file mode 100644 index 00000000..070d47ae --- /dev/null +++ b/tests/core/stream_muxer/test_muxer_multistream.py @@ -0,0 +1,108 @@ +from unittest.mock import ( + AsyncMock, + MagicMock, +) + +import pytest + +from libp2p.custom_types import ( + TMuxerClass, + TProtocol, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.protocol_muxer.exceptions import ( + MultiselectError, +) +from libp2p.stream_muxer.muxer_multistream import ( + MuxerMultistream, +) + + +@pytest.mark.trio +async def test_muxer_timeout_configuration(): + """Test that muxer respects timeout configuration.""" + muxer = MuxerMultistream({}, negotiate_timeout=1) + assert muxer.negotiate_timeout == 1 + + +@pytest.mark.trio +async def test_select_transport_passes_timeout_to_multiselect(): + """Test that timeout is passed to multiselect client in select_transport.""" + # Mock dependencies + mock_conn = MagicMock() + mock_conn.is_initiator = False + + # Mock MultiselectClient + muxer = MuxerMultistream({}, negotiate_timeout=10) + muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None)) + muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock()) + + # Call select_transport + await muxer.select_transport(mock_conn) + + # Verify that select_one_of was called with the correct timeout + args, _ = muxer.multiselect.negotiate.call_args + assert args[1] == 10 + + +@pytest.mark.trio +async def test_new_conn_passes_timeout_to_multistream_client(): + """Test that timeout is passed to multistream client in new_conn.""" + # Mock dependencies + mock_conn = MagicMock() + mock_conn.is_initiator = True + mock_peer_id = ID(b"test_peer") + mock_communicator = MagicMock() + + # Mock MultistreamClient and transports + muxer = MuxerMultistream({}, negotiate_timeout=30) + muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol") + muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock()) + + # Call new_conn + await muxer.new_conn(mock_conn, mock_peer_id) + + # Verify that select_one_of was called with the correct timeout + muxer.multistream_client.select_one_of( + tuple(muxer.transports.keys()), mock_communicator, 30 + ) + + +@pytest.mark.trio +async def test_select_transport_no_protocol_selected(): + """ + Test that select_transport raises MultiselectError when no protocol is selected. + """ + # Mock dependencies + mock_conn = MagicMock() + mock_conn.is_initiator = False + + # Mock Multiselect to return None + muxer = MuxerMultistream({}, negotiate_timeout=30) + muxer.multiselect.negotiate = AsyncMock(return_value=(None, None)) + + # Expect MultiselectError to be raised + with pytest.raises(MultiselectError, match="no protocol selected"): + await muxer.select_transport(mock_conn) + + +@pytest.mark.trio +async def test_add_transport_updates_precedence(): + """Test that adding a transport updates protocol precedence.""" + # Mock transport classes + mock_transport1 = MagicMock(spec=TMuxerClass) + mock_transport2 = MagicMock(spec=TMuxerClass) + + # Initialize muxer and add transports + muxer = MuxerMultistream({}, negotiate_timeout=30) + muxer.add_transport(TProtocol("proto1"), mock_transport1) + muxer.add_transport(TProtocol("proto2"), mock_transport2) + + # Verify transport order + assert list(muxer.transports.keys()) == ["proto1", "proto2"] + + # Re-add proto1 to check if it moves to the end + muxer.add_transport(TProtocol("proto1"), mock_transport1) + assert list(muxer.transports.keys()) == ["proto2", "proto1"] diff --git a/tests/core/transport/quic/test_concurrency.py b/tests/core/transport/quic/test_concurrency.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/transport/quic/test_connection.py b/tests/core/transport/quic/test_connection.py new file mode 100644 index 00000000..9b3ad3a9 --- /dev/null +++ b/tests/core/transport/quic/test_connection.py @@ -0,0 +1,553 @@ +""" +Enhanced tests for QUIC connection functionality - Module 3. +Tests all new features including advanced stream management, resource management, +error handling, and concurrent operations. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from multiaddr.multiaddr import Multiaddr +import trio + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.peer.id import ID +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.exceptions import ( + QUICConnectionClosedError, + QUICConnectionError, + QUICConnectionTimeoutError, + QUICPeerVerificationError, + QUICStreamLimitError, + QUICStreamTimeoutError, +) +from libp2p.transport.quic.security import QUICTLSConfigManager +from libp2p.transport.quic.stream import QUICStream, StreamDirection + + +class MockResourceScope: + """Mock resource scope for testing.""" + + def __init__(self): + self.memory_reserved = 0 + + def reserve_memory(self, size): + self.memory_reserved += size + + def release_memory(self, size): + self.memory_reserved = max(0, self.memory_reserved - size) + + +class TestQUICConnection: + """Test suite for QUIC connection functionality.""" + + @pytest.fixture + def mock_quic_connection(self): + """Create mock aioquic QuicConnection.""" + mock = Mock() + mock.next_event.return_value = None + mock.datagrams_to_send.return_value = [] + mock.get_timer.return_value = None + mock.connect = Mock() + mock.close = Mock() + mock.send_stream_data = Mock() + mock.reset_stream = Mock() + return mock + + @pytest.fixture + def mock_quic_transport(self): + mock = Mock() + mock._config = QUICTransportConfig() + return mock + + @pytest.fixture + def mock_resource_scope(self): + """Create mock resource scope.""" + return MockResourceScope() + + @pytest.fixture + def quic_connection( + self, + mock_quic_connection: Mock, + mock_quic_transport: Mock, + mock_resource_scope: MockResourceScope, + ): + """Create test QUIC connection with enhanced features.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + mock_security_manager = Mock() + + return QUICConnection( + quic_connection=mock_quic_connection, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=None, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=mock_quic_transport, + resource_scope=mock_resource_scope, + security_manager=mock_security_manager, + ) + + @pytest.fixture + def server_connection(self, mock_quic_connection, mock_resource_scope): + """Create server-side QUIC connection.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + return QUICConnection( + quic_connection=mock_quic_connection, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=False, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + resource_scope=mock_resource_scope, + ) + + # Basic functionality tests + + def test_connection_initialization_enhanced( + self, quic_connection, mock_resource_scope + ): + """Test enhanced connection initialization.""" + assert quic_connection._remote_addr == ("127.0.0.1", 4001) + assert quic_connection.is_initiator is True + assert not quic_connection.is_closed + assert not quic_connection.is_established + assert len(quic_connection._streams) == 0 + assert quic_connection._resource_scope == mock_resource_scope + assert quic_connection._outbound_stream_count == 0 + assert quic_connection._inbound_stream_count == 0 + assert len(quic_connection._stream_accept_queue) == 0 + + def test_stream_id_calculation_enhanced(self): + """Test enhanced stream ID calculation for client/server.""" + # Client connection (initiator) + client_conn = QUICConnection( + quic_connection=Mock(), + remote_addr=("127.0.0.1", 4001), + remote_peer_id=None, + local_peer_id=Mock(), + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + assert client_conn._next_stream_id == 0 # Client starts with 0 + + # Server connection (not initiator) + server_conn = QUICConnection( + quic_connection=Mock(), + remote_addr=("127.0.0.1", 4001), + remote_peer_id=None, + local_peer_id=Mock(), + is_initiator=False, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + assert server_conn._next_stream_id == 1 # Server starts with 1 + + def test_incoming_stream_detection_enhanced(self, quic_connection): + """Test enhanced incoming stream detection logic.""" + # For client (initiator), odd stream IDs are incoming + assert quic_connection._is_incoming_stream(1) is True # Server-initiated + assert quic_connection._is_incoming_stream(0) is False # Client-initiated + assert quic_connection._is_incoming_stream(5) is True # Server-initiated + assert quic_connection._is_incoming_stream(4) is False # Client-initiated + + # Stream management tests + + @pytest.mark.trio + async def test_open_stream_basic(self, quic_connection): + """Test basic stream opening.""" + quic_connection._started = True + + stream = await quic_connection.open_stream() + + assert isinstance(stream, QUICStream) + assert stream.stream_id == "0" + assert stream.direction == StreamDirection.OUTBOUND + assert 0 in quic_connection._streams + assert quic_connection._outbound_stream_count == 1 + + @pytest.mark.trio + async def test_open_stream_limit_reached(self, quic_connection): + """Test stream limit enforcement.""" + quic_connection._started = True + quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS + + with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"): + await quic_connection.open_stream() + + @pytest.mark.trio + async def test_open_stream_timeout(self, quic_connection: QUICConnection): + """Test stream opening timeout.""" + quic_connection._started = True + return + + # Mock the stream ID lock to simulate slow operation + async def slow_acquire(): + await trio.sleep(10) # Longer than timeout + + with patch.object( + quic_connection._stream_lock, "acquire", side_effect=slow_acquire + ): + with pytest.raises( + QUICStreamTimeoutError, match="Stream creation timed out" + ): + await quic_connection.open_stream(timeout=0.1) + + @pytest.mark.trio + async def test_accept_stream_basic(self, quic_connection): + """Test basic stream acceptance.""" + # Create a mock inbound stream + mock_stream = Mock(spec=QUICStream) + mock_stream.stream_id = "1" + + # Add to accept queue + quic_connection._stream_accept_queue.append(mock_stream) + quic_connection._stream_accept_event.set() + + accepted_stream = await quic_connection.accept_stream(timeout=0.1) + + assert accepted_stream == mock_stream + assert len(quic_connection._stream_accept_queue) == 0 + + @pytest.mark.trio + async def test_accept_stream_timeout(self, quic_connection): + """Test stream acceptance timeout.""" + with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"): + await quic_connection.accept_stream(timeout=0.1) + + @pytest.mark.trio + async def test_accept_stream_on_closed_connection(self, quic_connection): + """Test stream acceptance on closed connection.""" + await quic_connection.close() + + with pytest.raises(QUICConnectionClosedError, match="Connection is closed"): + await quic_connection.accept_stream() + + # Stream handler tests + + @pytest.mark.trio + async def test_stream_handler_setting(self, quic_connection): + """Test setting stream handler.""" + + async def mock_handler(stream): + pass + + quic_connection.set_stream_handler(mock_handler) + assert quic_connection._stream_handler == mock_handler + + # Connection lifecycle tests + + @pytest.mark.trio + async def test_connection_start_client(self, quic_connection): + """Test client connection start.""" + with patch.object( + quic_connection, "_initiate_connection", new_callable=AsyncMock + ) as mock_initiate: + await quic_connection.start() + + assert quic_connection._started + mock_initiate.assert_called_once() + + @pytest.mark.trio + async def test_connection_start_server(self, server_connection): + """Test server connection start.""" + await server_connection.start() + + assert server_connection._started + assert server_connection._established + assert server_connection._connected_event.is_set() + + @pytest.mark.trio + async def test_connection_start_already_started(self, quic_connection): + """Test starting already started connection.""" + quic_connection._started = True + + # Should not raise error, just log warning + await quic_connection.start() + assert quic_connection._started + + @pytest.mark.trio + async def test_connection_start_closed(self, quic_connection): + """Test starting closed connection.""" + quic_connection._closed = True + + with pytest.raises( + QUICConnectionError, match="Cannot start a closed connection" + ): + await quic_connection.start() + + @pytest.mark.trio + async def test_connection_connect_with_nursery( + self, quic_connection: QUICConnection + ): + """Test connection establishment with nursery.""" + quic_connection._started = True + quic_connection._established = True + quic_connection._connected_event.set() + + with patch.object( + quic_connection, "_start_background_tasks", new_callable=AsyncMock + ) as mock_start_tasks: + with patch.object( + quic_connection, + "_verify_peer_identity_with_security", + new_callable=AsyncMock, + ) as mock_verify: + async with trio.open_nursery() as nursery: + await quic_connection.connect(nursery) + + assert quic_connection._nursery == nursery + mock_start_tasks.assert_called_once() + mock_verify.assert_called_once() + + @pytest.mark.trio + @pytest.mark.slow + async def test_connection_connect_timeout( + self, quic_connection: QUICConnection + ) -> None: + """Test connection establishment timeout.""" + quic_connection._started = True + # Don't set connected event to simulate timeout + + with patch.object( + quic_connection, "_start_background_tasks", new_callable=AsyncMock + ): + async with trio.open_nursery() as nursery: + with pytest.raises( + QUICConnectionTimeoutError, match="Connection handshake timed out" + ): + await quic_connection.connect(nursery) + + # Resource management tests + + @pytest.mark.trio + async def test_stream_removal_resource_cleanup( + self, quic_connection: QUICConnection, mock_resource_scope + ): + """Test stream removal and resource cleanup.""" + quic_connection._started = True + + # Create a stream + stream = await quic_connection.open_stream() + + # Remove the stream + quic_connection._remove_stream(int(stream.stream_id)) + + assert int(stream.stream_id) not in quic_connection._streams + # Note: Count updates is async, so we can't test it directly here + + # Error handling tests + + @pytest.mark.trio + async def test_connection_error_handling(self, quic_connection) -> None: + """Test connection error handling.""" + error = Exception("Test error") + + with patch.object( + quic_connection, "close", new_callable=AsyncMock + ) as mock_close: + await quic_connection._handle_connection_error(error) + mock_close.assert_called_once() + + # Statistics and monitoring tests + + @pytest.mark.trio + async def test_connection_stats_enhanced(self, quic_connection) -> None: + """Test enhanced connection statistics.""" + quic_connection._started = True + + # Create some streams + _stream1 = await quic_connection.open_stream() + _stream2 = await quic_connection.open_stream() + + stats = quic_connection.get_stream_stats() + + expected_keys = [ + "total_streams", + "outbound_streams", + "inbound_streams", + "max_streams", + "stream_utilization", + "stats", + ] + + for key in expected_keys: + assert key in stats + + assert stats["total_streams"] == 2 + assert stats["outbound_streams"] == 2 + assert stats["inbound_streams"] == 0 + + @pytest.mark.trio + async def test_get_active_streams(self, quic_connection) -> None: + """Test getting active streams.""" + quic_connection._started = True + + # Create streams + stream1 = await quic_connection.open_stream() + stream2 = await quic_connection.open_stream() + + active_streams = quic_connection.get_active_streams() + + assert len(active_streams) == 2 + assert stream1 in active_streams + assert stream2 in active_streams + + @pytest.mark.trio + async def test_get_streams_by_protocol(self, quic_connection) -> None: + """Test getting streams by protocol.""" + quic_connection._started = True + + # Create streams with different protocols + stream1 = await quic_connection.open_stream() + stream1.protocol = "/test/1.0.0" + + stream2 = await quic_connection.open_stream() + stream2.protocol = "/other/1.0.0" + + test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0") + other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0") + + assert len(test_streams) == 1 + assert len(other_streams) == 1 + assert stream1 in test_streams + assert stream2 in other_streams + + # Enhanced close tests + + @pytest.mark.trio + async def test_connection_close_enhanced( + self, quic_connection: QUICConnection + ) -> None: + """Test enhanced connection close with stream cleanup.""" + quic_connection._started = True + + # Create some streams + _stream1 = await quic_connection.open_stream() + _stream2 = await quic_connection.open_stream() + + await quic_connection.close() + + assert quic_connection.is_closed + assert len(quic_connection._streams) == 0 + + # Concurrent operations tests + + @pytest.mark.trio + async def test_concurrent_stream_operations( + self, quic_connection: QUICConnection + ) -> None: + """Test concurrent stream operations.""" + quic_connection._started = True + + async def create_stream(): + return await quic_connection.open_stream() + + # Create multiple streams concurrently + async with trio.open_nursery() as nursery: + for i in range(10): + nursery.start_soon(create_stream) + + # Wait a bit for all to start + await trio.sleep(0.1) + + # Should have created streams without conflicts + assert quic_connection._outbound_stream_count == 10 + assert len(quic_connection._streams) == 10 + + # Connection properties tests + + def test_connection_properties(self, quic_connection: QUICConnection) -> None: + """Test connection property accessors.""" + assert quic_connection.multiaddr() == quic_connection._maddr + assert quic_connection.local_peer_id() == quic_connection._local_peer_id + assert quic_connection.remote_peer_id() == quic_connection._remote_peer_id + + # IRawConnection interface tests + + @pytest.mark.trio + async def test_raw_connection_write(self, quic_connection: QUICConnection) -> None: + """Test raw connection write interface.""" + quic_connection._started = True + + with patch.object(quic_connection, "open_stream") as mock_open: + mock_stream = AsyncMock() + mock_open.return_value = mock_stream + + await quic_connection.write(b"test data") + + mock_open.assert_called_once() + mock_stream.write.assert_called_once_with(b"test data") + mock_stream.close_write.assert_called_once() + + @pytest.mark.trio + async def test_raw_connection_read_not_implemented( + self, quic_connection: QUICConnection + ) -> None: + """Test raw connection read raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + await quic_connection.read() + + # Mock verification helpers + + def test_mock_resource_scope_functionality(self, mock_resource_scope) -> None: + """Test mock resource scope works correctly.""" + assert mock_resource_scope.memory_reserved == 0 + + mock_resource_scope.reserve_memory(1000) + assert mock_resource_scope.memory_reserved == 1000 + + mock_resource_scope.reserve_memory(500) + assert mock_resource_scope.memory_reserved == 1500 + + mock_resource_scope.release_memory(600) + assert mock_resource_scope.memory_reserved == 900 + + mock_resource_scope.release_memory(2000) # Should not go negative + assert mock_resource_scope.memory_reserved == 0 + + +@pytest.mark.trio +async def test_invalid_certificate_verification(): + key_pair1 = create_new_key_pair() + key_pair2 = create_new_key_pair() + + peer_id1 = ID.from_pubkey(key_pair1.public_key) + peer_id2 = ID.from_pubkey(key_pair2.public_key) + + manager = QUICTLSConfigManager( + libp2p_private_key=key_pair1.private_key, peer_id=peer_id1 + ) + + # Match the certificate against a different peer_id + with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"): + manager.verify_peer_identity(manager.tls_config.certificate, peer_id2) + + from cryptography.hazmat.primitives.serialization import Encoding + + # --- Corrupt the certificate by tampering the DER bytes --- + cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER) + corrupted_bytes = bytearray(cert_bytes) + + # Flip some random bytes in the middle of the certificate + corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF + + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + + # This will still parse (structurally valid), but the signature + # or fingerprint will break + corrupted_cert = x509.load_der_x509_certificate( + bytes(corrupted_bytes), backend=default_backend() + ) + + with pytest.raises( + QUICPeerVerificationError, match="Certificate verification failed" + ): + manager.verify_peer_identity(corrupted_cert, peer_id1) diff --git a/tests/core/transport/quic/test_connection_id.py b/tests/core/transport/quic/test_connection_id.py new file mode 100644 index 00000000..de371550 --- /dev/null +++ b/tests/core/transport/quic/test_connection_id.py @@ -0,0 +1,624 @@ +""" +QUIC Connection ID Management Tests + +This test module covers comprehensive testing of QUIC connection ID functionality +including generation, rotation, retirement, and validation according to RFC 9000. + +Tests are organized into: +1. Basic Connection ID Management +2. Connection ID Rotation and Updates +3. Connection ID Retirement +4. Error Conditions and Edge Cases +5. Integration Tests with Real Connections +""" + +import secrets +import time +from typing import Any +from unittest.mock import Mock + +import pytest +from aioquic.buffer import Buffer + +# Import aioquic components for low-level testing +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import QuicConnection, QuicConnectionId +from multiaddr import Multiaddr + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.peer.id import ID +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.transport import QUICTransport + + +class ConnectionIdTestHelper: + """Helper class for connection ID testing utilities.""" + + @staticmethod + def generate_connection_id(length: int = 8) -> bytes: + """Generate a random connection ID of specified length.""" + return secrets.token_bytes(length) + + @staticmethod + def create_quic_connection_id(cid: bytes, sequence: int = 0) -> QuicConnectionId: + """Create a QuicConnectionId object.""" + return QuicConnectionId( + cid=cid, + sequence_number=sequence, + stateless_reset_token=secrets.token_bytes(16), + ) + + @staticmethod + def extract_connection_ids_from_connection(conn: QUICConnection) -> dict[str, Any]: + """Extract connection ID information from a QUIC connection.""" + quic = conn._quic + return { + "host_cids": [cid.cid.hex() for cid in getattr(quic, "_host_cids", [])], + "peer_cid": getattr(quic, "_peer_cid", None), + "peer_cid_available": [ + cid.cid.hex() for cid in getattr(quic, "_peer_cid_available", []) + ], + "retire_connection_ids": getattr(quic, "_retire_connection_ids", []), + "host_cid_seq": getattr(quic, "_host_cid_seq", 0), + } + + +class TestBasicConnectionIdManagement: + """Test basic connection ID management functionality.""" + + @pytest.fixture + def mock_quic_connection(self): + """Create a mock QUIC connection with connection ID support.""" + mock_quic = Mock(spec=QuicConnection) + mock_quic._host_cids = [] + mock_quic._host_cid_seq = 0 + mock_quic._peer_cid = None + mock_quic._peer_cid_available = [] + mock_quic._retire_connection_ids = [] + mock_quic._configuration = Mock() + mock_quic._configuration.connection_id_length = 8 + mock_quic._remote_active_connection_id_limit = 8 + return mock_quic + + @pytest.fixture + def quic_connection(self, mock_quic_connection): + """Create a QUICConnection instance for testing.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + return QUICConnection( + quic_connection=mock_quic_connection, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + + def test_connection_id_initialization(self, quic_connection): + """Test that connection ID tracking is properly initialized.""" + # Check that connection ID tracking structures are initialized + assert hasattr(quic_connection, "_available_connection_ids") + assert hasattr(quic_connection, "_current_connection_id") + assert hasattr(quic_connection, "_retired_connection_ids") + assert hasattr(quic_connection, "_connection_id_sequence_numbers") + + # Initial state should be empty + assert len(quic_connection._available_connection_ids) == 0 + assert quic_connection._current_connection_id is None + assert len(quic_connection._retired_connection_ids) == 0 + assert len(quic_connection._connection_id_sequence_numbers) == 0 + + def test_connection_id_stats_tracking(self, quic_connection): + """Test connection ID statistics are properly tracked.""" + stats = quic_connection.get_connection_id_stats() + + # Check that all expected stats are present + expected_keys = [ + "available_connection_ids", + "current_connection_id", + "retired_connection_ids", + "connection_ids_issued", + "connection_ids_retired", + "connection_id_changes", + "available_cid_list", + ] + + for key in expected_keys: + assert key in stats + + # Initial values should be zero/empty + assert stats["available_connection_ids"] == 0 + assert stats["current_connection_id"] is None + assert stats["retired_connection_ids"] == 0 + assert stats["connection_ids_issued"] == 0 + assert stats["connection_ids_retired"] == 0 + assert stats["connection_id_changes"] == 0 + assert stats["available_cid_list"] == [] + + def test_current_connection_id_getter(self, quic_connection): + """Test getting current connection ID.""" + # Initially no connection ID + assert quic_connection.get_current_connection_id() is None + + # Set a connection ID + test_cid = ConnectionIdTestHelper.generate_connection_id() + quic_connection._current_connection_id = test_cid + + assert quic_connection.get_current_connection_id() == test_cid + + def test_connection_id_generation(self): + """Test connection ID generation utilities.""" + # Test default length + cid1 = ConnectionIdTestHelper.generate_connection_id() + assert len(cid1) == 8 + assert isinstance(cid1, bytes) + + # Test custom length + cid2 = ConnectionIdTestHelper.generate_connection_id(16) + assert len(cid2) == 16 + + # Test uniqueness + cid3 = ConnectionIdTestHelper.generate_connection_id() + assert cid1 != cid3 + + +class TestConnectionIdRotationAndUpdates: + """Test connection ID rotation and update mechanisms.""" + + @pytest.fixture + def transport_config(self): + """Create transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=100, + ) + + @pytest.fixture + def server_key(self): + """Generate server private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def client_key(self): + """Generate client private key.""" + return create_new_key_pair().private_key + + def test_connection_id_replenishment(self): + """Test connection ID replenishment mechanism.""" + # Create a real QuicConnection to test replenishment + config = QuicConfiguration(is_client=True) + config.connection_id_length = 8 + + quic_conn = QuicConnection(configuration=config) + + # Initial state - should have some host connection IDs + initial_count = len(quic_conn._host_cids) + assert initial_count > 0 + + # Remove some connection IDs to trigger replenishment + while len(quic_conn._host_cids) > 2: + quic_conn._host_cids.pop() + + # Trigger replenishment + quic_conn._replenish_connection_ids() + + # Should have replenished up to the limit + assert len(quic_conn._host_cids) >= initial_count + + # All connection IDs should have unique sequence numbers + sequences = [cid.sequence_number for cid in quic_conn._host_cids] + assert len(sequences) == len(set(sequences)) + + def test_connection_id_sequence_numbers(self): + """Test connection ID sequence number management.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Get initial sequence number + initial_seq = quic_conn._host_cid_seq + + # Trigger replenishment to generate new connection IDs + quic_conn._replenish_connection_ids() + + # Sequence numbers should increment + assert quic_conn._host_cid_seq > initial_seq + + # All host connection IDs should have sequential numbers + sequences = [cid.sequence_number for cid in quic_conn._host_cids] + sequences.sort() + + # Check for proper sequence + for i in range(len(sequences) - 1): + assert sequences[i + 1] > sequences[i] + + def test_connection_id_limits(self): + """Test connection ID limit enforcement.""" + config = QuicConfiguration(is_client=True) + config.connection_id_length = 8 + + quic_conn = QuicConnection(configuration=config) + + # Set a reasonable limit + quic_conn._remote_active_connection_id_limit = 4 + + # Replenish connection IDs + quic_conn._replenish_connection_ids() + + # Should not exceed the limit + assert len(quic_conn._host_cids) <= quic_conn._remote_active_connection_id_limit + + +class TestConnectionIdRetirement: + """Test connection ID retirement functionality.""" + + def test_connection_id_retirement_basic(self): + """Test basic connection ID retirement.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Create a test connection ID to retire + test_cid = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=1 + ) + + # Add it to peer connection IDs + quic_conn._peer_cid_available.append(test_cid) + quic_conn._peer_cid_sequence_numbers.add(1) + + # Retire the connection ID + quic_conn._retire_peer_cid(test_cid) + + # Should be added to retirement list + assert 1 in quic_conn._retire_connection_ids + + def test_connection_id_retirement_limits(self): + """Test connection ID retirement limits.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Fill up retirement list near the limit + max_retirements = 32 # Based on aioquic's default limit + + for i in range(max_retirements): + quic_conn._retire_connection_ids.append(i) + + # Should be at limit + assert len(quic_conn._retire_connection_ids) == max_retirements + + def test_connection_id_retirement_events(self): + """Test that retirement generates proper events.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Create and add a host connection ID + test_cid = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=5 + ) + quic_conn._host_cids.append(test_cid) + + # Create a retirement frame buffer + from aioquic.buffer import Buffer + + buf = Buffer(capacity=16) + buf.push_uint_var(5) # sequence number to retire + buf.seek(0) + + # Process retirement (this should generate an event) + try: + quic_conn._handle_retire_connection_id_frame( + Mock(), # context + 0x19, # RETIRE_CONNECTION_ID frame type + buf, + ) + + # Check that connection ID was removed + remaining_sequences = [cid.sequence_number for cid in quic_conn._host_cids] + assert 5 not in remaining_sequences + + except Exception: + # May fail due to missing context, but that's okay for this test + pass + + +class TestConnectionIdErrorConditions: + """Test error conditions and edge cases in connection ID handling.""" + + def test_invalid_connection_id_length(self): + """Test handling of invalid connection ID lengths.""" + # Connection IDs must be 1-20 bytes according to RFC 9000 + + # Test too short (0 bytes) - this should be handled gracefully + empty_cid = b"" + assert len(empty_cid) == 0 + + # Test too long (>20 bytes) + long_cid = secrets.token_bytes(21) + assert len(long_cid) == 21 + + # Test valid lengths + for length in range(1, 21): + valid_cid = secrets.token_bytes(length) + assert len(valid_cid) == length + + def test_duplicate_sequence_numbers(self): + """Test handling of duplicate sequence numbers.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Create two connection IDs with same sequence number + cid1 = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=10 + ) + cid2 = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=10 + ) + + # Add first connection ID + quic_conn._peer_cid_available.append(cid1) + quic_conn._peer_cid_sequence_numbers.add(10) + + # Adding second with same sequence should be handled appropriately + # (The implementation should prevent duplicates) + if 10 not in quic_conn._peer_cid_sequence_numbers: + quic_conn._peer_cid_available.append(cid2) + quic_conn._peer_cid_sequence_numbers.add(10) + + # Should only have one entry for sequence 10 + sequences = [cid.sequence_number for cid in quic_conn._peer_cid_available] + assert sequences.count(10) <= 1 + + def test_retire_unknown_connection_id(self): + """Test retiring an unknown connection ID.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Try to create a buffer to retire unknown sequence number + buf = Buffer(capacity=16) + buf.push_uint_var(999) # Unknown sequence number + buf.seek(0) + + # This should raise an error when processed + # (Testing the error condition, not the full processing) + unknown_sequence = 999 + known_sequences = [cid.sequence_number for cid in quic_conn._host_cids] + + assert unknown_sequence not in known_sequences + + def test_retire_current_connection_id(self): + """Test that retiring current connection ID is prevented.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Get current connection ID if available + if quic_conn._host_cids: + current_cid = quic_conn._host_cids[0] + current_sequence = current_cid.sequence_number + + # Trying to retire current connection ID should be prevented + # This is tested by checking the sequence number logic + assert current_sequence >= 0 + + +class TestConnectionIdIntegration: + """Integration tests for connection ID functionality with real connections.""" + + @pytest.fixture + def server_config(self): + """Server transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=100, + ) + + @pytest.fixture + def client_config(self): + """Client transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + ) + + @pytest.fixture + def server_key(self): + """Generate server private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def client_key(self): + """Generate client private key.""" + return create_new_key_pair().private_key + + @pytest.mark.trio + async def test_connection_id_exchange_during_handshake( + self, server_key, client_key, server_config, client_config + ): + """Test connection ID exchange during connection handshake.""" + # This test would require a full connection setup + # For now, we test the setup components + + server_transport = QUICTransport(server_key, server_config) + client_transport = QUICTransport(client_key, client_config) + + # Verify transports are created with proper configuration + assert server_transport._config == server_config + assert client_transport._config == client_config + + # Test that connection ID tracking is available + # (Integration with actual networking would require more setup) + + def test_connection_id_extraction_utilities(self): + """Test connection ID extraction utilities.""" + # Create a mock connection with some connection IDs + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + mock_quic = Mock() + mock_quic._host_cids = [ + ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), i + ) + for i in range(3) + ] + mock_quic._peer_cid = None + mock_quic._peer_cid_available = [] + mock_quic._retire_connection_ids = [] + mock_quic._host_cid_seq = 3 + + quic_conn = QUICConnection( + quic_connection=mock_quic, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + + # Extract connection ID information + cid_info = ConnectionIdTestHelper.extract_connection_ids_from_connection( + quic_conn + ) + + # Verify extraction works + assert "host_cids" in cid_info + assert "peer_cid" in cid_info + assert "peer_cid_available" in cid_info + assert "retire_connection_ids" in cid_info + assert "host_cid_seq" in cid_info + + # Check values + assert len(cid_info["host_cids"]) == 3 + assert cid_info["host_cid_seq"] == 3 + assert cid_info["peer_cid"] is None + assert len(cid_info["peer_cid_available"]) == 0 + assert len(cid_info["retire_connection_ids"]) == 0 + + +class TestConnectionIdStatistics: + """Test connection ID statistics and monitoring.""" + + @pytest.fixture + def connection_with_stats(self): + """Create a connection with connection ID statistics.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + mock_quic = Mock() + mock_quic._host_cids = [] + mock_quic._peer_cid = None + mock_quic._peer_cid_available = [] + mock_quic._retire_connection_ids = [] + + return QUICConnection( + quic_connection=mock_quic, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + + def test_connection_id_stats_initialization(self, connection_with_stats): + """Test that connection ID statistics are properly initialized.""" + stats = connection_with_stats._stats + + # Check that connection ID stats are present + assert "connection_ids_issued" in stats + assert "connection_ids_retired" in stats + assert "connection_id_changes" in stats + + # Initial values should be zero + assert stats["connection_ids_issued"] == 0 + assert stats["connection_ids_retired"] == 0 + assert stats["connection_id_changes"] == 0 + + def test_connection_id_stats_update(self, connection_with_stats): + """Test updating connection ID statistics.""" + conn = connection_with_stats + + # Add some connection IDs to tracking + test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(3)] + + for cid in test_cids: + conn._available_connection_ids.add(cid) + + # Update stats (this would normally be done by the implementation) + conn._stats["connection_ids_issued"] = len(test_cids) + + # Verify stats + stats = conn.get_connection_id_stats() + assert stats["connection_ids_issued"] == 3 + assert stats["available_connection_ids"] == 3 + + def test_connection_id_list_representation(self, connection_with_stats): + """Test connection ID list representation in stats.""" + conn = connection_with_stats + + # Add some connection IDs + test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(2)] + + for cid in test_cids: + conn._available_connection_ids.add(cid) + + # Get stats + stats = conn.get_connection_id_stats() + + # Check that CID list is properly formatted + assert "available_cid_list" in stats + assert len(stats["available_cid_list"]) == 2 + + # All entries should be hex strings + for cid_hex in stats["available_cid_list"]: + assert isinstance(cid_hex, str) + assert len(cid_hex) == 16 # 8 bytes = 16 hex chars + + +# Performance and stress tests +class TestConnectionIdPerformance: + """Test connection ID performance and stress scenarios.""" + + def test_connection_id_generation_performance(self): + """Test connection ID generation performance.""" + start_time = time.time() + + # Generate many connection IDs + cids = [] + for _ in range(1000): + cid = ConnectionIdTestHelper.generate_connection_id() + cids.append(cid) + + end_time = time.time() + generation_time = end_time - start_time + + # Should be reasonably fast (less than 1 second for 1000 IDs) + assert generation_time < 1.0 + + # All should be unique + assert len(set(cids)) == len(cids) + + def test_connection_id_tracking_memory(self): + """Test memory usage of connection ID tracking.""" + conn_ids = set() + + # Add many connection IDs + for _ in range(1000): + cid = ConnectionIdTestHelper.generate_connection_id() + conn_ids.add(cid) + + # Verify they're all stored + assert len(conn_ids) == 1000 + + # Clean up + conn_ids.clear() + assert len(conn_ids) == 0 + + +if __name__ == "__main__": + # Run tests if executed directly + pytest.main([__file__, "-v"]) diff --git a/tests/core/transport/quic/test_integration.py b/tests/core/transport/quic/test_integration.py new file mode 100644 index 00000000..5016c996 --- /dev/null +++ b/tests/core/transport/quic/test_integration.py @@ -0,0 +1,418 @@ +""" +Basic QUIC Echo Test + +Simple test to verify the basic QUIC flow: +1. Client connects to server +2. Client sends data +3. Server receives data and echoes back +4. Client receives the echo + +This test focuses on identifying where the accept_stream issue occurs. +""" + +import logging + +import pytest +import multiaddr +import trio + +from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID +from libp2p import new_host +from libp2p.abc import INetStream +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.transport import QUICTransport +from libp2p.transport.quic.utils import create_quic_multiaddr + +# Set up logging to see what's happening +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class TestBasicQUICFlow: + """Test basic QUIC client-server communication flow.""" + + @pytest.fixture + def server_key(self): + """Generate server key pair.""" + return create_new_key_pair() + + @pytest.fixture + def client_key(self): + """Generate client key pair.""" + return create_new_key_pair() + + @pytest.fixture + def server_config(self): + """Simple server configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=10, + max_connections=5, + ) + + @pytest.fixture + def client_config(self): + """Simple client configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=5, + ) + + @pytest.mark.trio + async def test_basic_echo_flow( + self, server_key, client_key, server_config, client_config + ): + """Test basic client-server echo flow with detailed logging.""" + print("\n=== BASIC QUIC ECHO TEST ===") + + # Create server components + server_transport = QUICTransport(server_key.private_key, server_config) + + # Track test state + server_received_data = None + server_connection_established = False + echo_sent = False + + async def echo_server_handler(connection: QUICConnection) -> None: + """Simple echo server handler with detailed logging.""" + nonlocal server_received_data, server_connection_established, echo_sent + + print("šŸ”— SERVER: Connection handler called") + server_connection_established = True + + try: + print("šŸ“” SERVER: Waiting for incoming stream...") + + # Accept stream with timeout and detailed logging + print("šŸ“” SERVER: Calling accept_stream...") + stream = await connection.accept_stream(timeout=5.0) + + if stream is None: + print("āŒ SERVER: accept_stream returned None") + return + + print(f"āœ… SERVER: Stream accepted! Stream ID: {stream.stream_id}") + + # Read data from the stream + print("šŸ“– SERVER: Reading data from stream...") + server_data = await stream.read(1024) + + if not server_data: + print("āŒ SERVER: No data received from stream") + return + + server_received_data = server_data.decode("utf-8", errors="ignore") + print(f"šŸ“Ø SERVER: Received data: '{server_received_data}'") + + # Echo the data back + echo_message = f"ECHO: {server_received_data}" + print(f"šŸ“¤ SERVER: Sending echo: '{echo_message}'") + + await stream.write(echo_message.encode()) + echo_sent = True + print("āœ… SERVER: Echo sent successfully") + + # Close the stream + await stream.close() + print("šŸ”’ SERVER: Stream closed") + + except Exception as e: + print(f"āŒ SERVER: Error in handler: {e}") + import traceback + + traceback.print_exc() + + # Create listener + listener = server_transport.create_listener(echo_server_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + # Variables to track client state + client_connected = False + client_sent_data = False + client_received_echo = None + + try: + print("šŸš€ Starting server...") + + async with trio.open_nursery() as nursery: + # Start server listener + success = await listener.listen(listen_addr, nursery) + assert success, "Failed to start server listener" + + # Get server address + server_addrs = listener.get_addrs() + server_addr = multiaddr.Multiaddr( + f"{server_addrs[0]}/p2p/{ID.from_pubkey(server_key.public_key)}" + ) + print(f"šŸ”§ SERVER: Listening on {server_addr}") + + # Give server a moment to be ready + await trio.sleep(0.1) + + print("šŸš€ Starting client...") + + # Create client transport + client_transport = QUICTransport(client_key.private_key, client_config) + client_transport.set_background_nursery(nursery) + + try: + # Connect to server + print(f"šŸ“ž CLIENT: Connecting to {server_addr}") + connection = await client_transport.dial(server_addr) + client_connected = True + print("āœ… CLIENT: Connected to server") + + # Open a stream + print("šŸ“¤ CLIENT: Opening stream...") + stream = await connection.open_stream() + print(f"āœ… CLIENT: Stream opened with ID: {stream.stream_id}") + + # Send test data + test_message = "Hello QUIC Server!" + print(f"šŸ“Ø CLIENT: Sending message: '{test_message}'") + await stream.write(test_message.encode()) + client_sent_data = True + print("āœ… CLIENT: Message sent") + + # Read echo response + print("šŸ“– CLIENT: Waiting for echo response...") + response_data = await stream.read(1024) + + if response_data: + client_received_echo = response_data.decode( + "utf-8", errors="ignore" + ) + print(f"šŸ“¬ CLIENT: Received echo: '{client_received_echo}'") + else: + print("āŒ CLIENT: No echo response received") + + print("šŸ”’ CLIENT: Closing connection") + await connection.close() + print("šŸ”’ CLIENT: Connection closed") + + print("šŸ”’ CLIENT: Closing transport") + await client_transport.close() + print("šŸ”’ CLIENT: Transport closed") + + except Exception as e: + print(f"āŒ CLIENT: Error: {e}") + import traceback + + traceback.print_exc() + + finally: + await client_transport.close() + print("šŸ”’ CLIENT: Transport closed") + + # Give everything time to complete + await trio.sleep(0.5) + + # Cancel nursery to stop server + nursery.cancel_scope.cancel() + + finally: + # Cleanup + if not listener._closed: + await listener.close() + await server_transport.close() + + # Verify the flow worked + print("\nšŸ“Š TEST RESULTS:") + print(f" Server connection established: {server_connection_established}") + print(f" Client connected: {client_connected}") + print(f" Client sent data: {client_sent_data}") + print(f" Server received data: '{server_received_data}'") + print(f" Echo sent by server: {echo_sent}") + print(f" Client received echo: '{client_received_echo}'") + + # Test assertions + assert server_connection_established, "Server connection handler was not called" + assert client_connected, "Client failed to connect" + assert client_sent_data, "Client failed to send data" + assert server_received_data == "Hello QUIC Server!", ( + f"Server received wrong data: '{server_received_data}'" + ) + assert echo_sent, "Server failed to send echo" + assert client_received_echo == "ECHO: Hello QUIC Server!", ( + f"Client received wrong echo: '{client_received_echo}'" + ) + + print("āœ… BASIC ECHO TEST PASSED!") + + @pytest.mark.trio + async def test_server_accept_stream_timeout( + self, server_key, client_key, server_config, client_config + ): + """Test what happens when server accept_stream times out.""" + print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===") + + server_transport = QUICTransport(server_key.private_key, server_config) + + accept_stream_called = False + accept_stream_timeout = False + + async def timeout_test_handler(connection: QUICConnection) -> None: + """Handler that tests accept_stream timeout.""" + nonlocal accept_stream_called, accept_stream_timeout + + print("šŸ”— SERVER: Connection established, testing accept_stream timeout") + accept_stream_called = True + + try: + print("šŸ“” SERVER: Calling accept_stream with 2 second timeout...") + stream = await connection.accept_stream(timeout=2.0) + print(f"āœ… SERVER: accept_stream returned: {stream}") + + except Exception as e: + print(f"ā° SERVER: accept_stream timed out or failed: {e}") + accept_stream_timeout = True + + listener = server_transport.create_listener(timeout_test_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + client_connected = False + + try: + async with trio.open_nursery() as nursery: + # Start server + server_transport.set_background_nursery(nursery) + success = await listener.listen(listen_addr, nursery) + assert success + + server_addr = multiaddr.Multiaddr( + f"{listener.get_addrs()[0]}/p2p/{ID.from_pubkey(server_key.public_key)}" + ) + print(f"šŸ”§ SERVER: Listening on {server_addr}") + + # Create client but DON'T open a stream + async with trio.open_nursery() as client_nursery: + client_transport = QUICTransport( + client_key.private_key, client_config + ) + client_transport.set_background_nursery(client_nursery) + + try: + print("šŸ“ž CLIENT: Connecting (but NOT opening stream)...") + connection = await client_transport.dial(server_addr) + client_connected = True + print("āœ… CLIENT: Connected (no stream opened)") + + # Wait for server timeout + await trio.sleep(3.0) + + await connection.close() + print("šŸ”’ CLIENT: Connection closed") + + finally: + await client_transport.close() + + nursery.cancel_scope.cancel() + + finally: + await listener.close() + await server_transport.close() + + print("\nšŸ“Š TIMEOUT TEST RESULTS:") + print(f" Client connected: {client_connected}") + print(f" accept_stream called: {accept_stream_called}") + print(f" accept_stream timeout: {accept_stream_timeout}") + + assert client_connected, "Client should have connected" + assert accept_stream_called, "accept_stream should have been called" + assert accept_stream_timeout, ( + "accept_stream should have timed out when no stream was opened" + ) + + print("āœ… TIMEOUT TEST PASSED!") + + +@pytest.mark.trio +async def test_yamux_stress_ping(): + STREAM_COUNT = 100 + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + latencies = [] + failures = [] + + # === Server Setup === + server_host = new_host(listen_addrs=[listen_addr]) + + async def handle_ping(stream: INetStream) -> None: + try: + while True: + payload = await stream.read(PING_LENGTH) + if not payload: + break + await stream.write(payload) + except Exception: + await stream.reset() + + server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping) + + async with server_host.run(listen_addrs=[listen_addr]): + # Give server time to start + await trio.sleep(0.1) + + # === Client Setup === + destination = str(server_host.get_addrs()[0]) + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + + client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + client_host = new_host(listen_addrs=[client_listen_addr]) + + async with client_host.run(listen_addrs=[client_listen_addr]): + await client_host.connect(info) + + async def ping_stream(i: int): + stream = None + try: + start = trio.current_time() + stream = await client_host.new_stream( + info.peer_id, [PING_PROTOCOL_ID] + ) + + await stream.write(b"\x01" * PING_LENGTH) + + with trio.fail_after(5): + response = await stream.read(PING_LENGTH) + + if response == b"\x01" * PING_LENGTH: + latency_ms = int((trio.current_time() - start) * 1000) + latencies.append(latency_ms) + print(f"[Ping #{i}] Latency: {latency_ms} ms") + await stream.close() + except Exception as e: + print(f"[Ping #{i}] Failed: {e}") + failures.append(i) + if stream: + await stream.reset() + + async with trio.open_nursery() as nursery: + for i in range(STREAM_COUNT): + nursery.start_soon(ping_stream, i) + + # === Result Summary === + print("\nšŸ“Š Ping Stress Test Summary") + print(f"Total Streams Launched: {STREAM_COUNT}") + print(f"Successful Pings: {len(latencies)}") + print(f"Failed Pings: {len(failures)}") + if failures: + print(f"āŒ Failed stream indices: {failures}") + + # === Assertions === + assert len(latencies) == STREAM_COUNT, ( + f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}" + ) + assert all(isinstance(x, int) and x >= 0 for x in latencies), ( + "Invalid latencies" + ) + + avg_latency = sum(latencies) / len(latencies) + print(f"āœ… Average Latency: {avg_latency:.2f} ms") + assert avg_latency < 1000 diff --git a/tests/core/transport/quic/test_listener.py b/tests/core/transport/quic/test_listener.py new file mode 100644 index 00000000..840f7218 --- /dev/null +++ b/tests/core/transport/quic/test_listener.py @@ -0,0 +1,150 @@ +from unittest.mock import AsyncMock + +import pytest +from multiaddr.multiaddr import Multiaddr +import trio + +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.transport.quic.exceptions import ( + QUICListenError, +) +from libp2p.transport.quic.listener import QUICListener +from libp2p.transport.quic.transport import ( + QUICTransport, + QUICTransportConfig, +) +from libp2p.transport.quic.utils import ( + create_quic_multiaddr, +) + + +class TestQUICListener: + """Test suite for QUIC listener functionality.""" + + @pytest.fixture + def private_key(self): + """Generate test private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def transport_config(self): + """Generate test transport configuration.""" + return QUICTransportConfig(idle_timeout=10.0) + + @pytest.fixture + def transport(self, private_key, transport_config): + """Create test transport instance.""" + return QUICTransport(private_key, transport_config) + + @pytest.fixture + def connection_handler(self): + """Mock connection handler.""" + return AsyncMock() + + @pytest.fixture + def listener(self, transport, connection_handler): + """Create test listener.""" + return transport.create_listener(connection_handler) + + def test_listener_creation(self, transport, connection_handler): + """Test listener creation.""" + listener = transport.create_listener(connection_handler) + + assert isinstance(listener, QUICListener) + assert listener._transport == transport + assert listener._handler == connection_handler + assert not listener._listening + assert not listener._closed + + @pytest.mark.trio + async def test_listener_invalid_multiaddr(self, listener: QUICListener): + """Test listener with invalid multiaddr.""" + async with trio.open_nursery() as nursery: + invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + + with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"): + await listener.listen(invalid_addr, nursery) + + @pytest.mark.trio + async def test_listener_basic_lifecycle(self, listener: QUICListener): + """Test basic listener lifecycle.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random + + async with trio.open_nursery() as nursery: + # Start listening + success = await listener.listen(listen_addr, nursery) + assert success + assert listener.is_listening() + + # Check bound addresses + addrs = listener.get_addrs() + assert len(addrs) == 1 + + # Check stats + stats = listener.get_stats() + assert stats["is_listening"] is True + assert stats["active_connections"] == 0 + assert stats["pending_connections"] == 0 + + # Sender Cancel Signal + nursery.cancel_scope.cancel() + + await listener.close() + assert not listener.is_listening() + + @pytest.mark.trio + async def test_listener_double_listen(self, listener: QUICListener): + """Test that double listen raises error.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + await trio.sleep(0.01) + + addrs = listener.get_addrs() + assert len(addrs) > 0 + async with trio.open_nursery() as nursery2: + with pytest.raises(QUICListenError, match="Already listening"): + await listener.listen(listen_addr, nursery2) + nursery2.cancel_scope.cancel() + + nursery.cancel_scope.cancel() + finally: + await listener.close() + + @pytest.mark.trio + async def test_listener_port_binding(self, listener: QUICListener): + """Test listener port binding and cleanup.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + await trio.sleep(0.5) + + addrs = listener.get_addrs() + assert len(addrs) > 0 + + nursery.cancel_scope.cancel() + finally: + await listener.close() + + # By the time we get here, the listener and its tasks have been fully + # shut down, allowing the nursery to exit without hanging. + print("TEST COMPLETED SUCCESSFULLY.") + + @pytest.mark.trio + async def test_listener_stats_tracking(self, listener): + """Test listener statistics tracking.""" + initial_stats = listener.get_stats() + + # All counters should start at 0 + assert initial_stats["connections_accepted"] == 0 + assert initial_stats["connections_rejected"] == 0 + assert initial_stats["bytes_received"] == 0 + assert initial_stats["packets_processed"] == 0 diff --git a/tests/core/transport/quic/test_transport.py b/tests/core/transport/quic/test_transport.py new file mode 100644 index 00000000..f9d65d8a --- /dev/null +++ b/tests/core/transport/quic/test_transport.py @@ -0,0 +1,123 @@ +from unittest.mock import ( + Mock, +) + +import pytest + +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.crypto.keys import PrivateKey +from libp2p.transport.quic.exceptions import ( + QUICDialError, + QUICListenError, +) +from libp2p.transport.quic.transport import ( + QUICTransport, + QUICTransportConfig, +) + + +class TestQUICTransport: + """Test suite for QUIC transport using trio.""" + + @pytest.fixture + def private_key(self): + """Generate test private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def transport_config(self): + """Generate test transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, enable_draft29=True, enable_v1=True + ) + + @pytest.fixture + def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig): + """Create test transport instance.""" + return QUICTransport(private_key, transport_config) + + def test_transport_initialization(self, transport): + """Test transport initialization.""" + assert transport._private_key is not None + assert transport._peer_id is not None + assert not transport._closed + assert len(transport._quic_configs) >= 1 + + def test_supported_protocols(self, transport): + """Test supported protocol identifiers.""" + protocols = transport.protocols() + # TODO: Update when quic-v1 compatible + # assert "quic-v1" in protocols + assert "quic" in protocols # draft-29 + + def test_can_dial_quic_addresses(self, transport: QUICTransport): + """Test multiaddr compatibility checking.""" + import multiaddr + + # Valid QUIC addresses + valid_addrs = [ + # TODO: Update Multiaddr package to accept quic-v1 + multiaddr.Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + multiaddr.Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + multiaddr.Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + ] + + for addr in valid_addrs: + assert transport.can_dial(addr) + + # Invalid addresses + invalid_addrs = [ + multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/4001"), + multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001"), + multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/ws"), + ] + + for addr in invalid_addrs: + assert not transport.can_dial(addr) + + @pytest.mark.trio + async def test_transport_lifecycle(self, transport): + """Test transport lifecycle management using trio.""" + assert not transport._closed + + await transport.close() + assert transport._closed + + # Should be safe to close multiple times + await transport.close() + + @pytest.mark.trio + async def test_dial_closed_transport(self, transport: QUICTransport) -> None: + """Test dialing with closed transport raises error.""" + import multiaddr + + await transport.close() + + with pytest.raises(QUICDialError, match="Transport is closed"): + await transport.dial( + multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + ) + + def test_create_listener_closed_transport(self, transport: QUICTransport) -> None: + """Test creating listener with closed transport raises error.""" + transport._closed = True + + with pytest.raises(QUICListenError, match="Transport is closed"): + transport.create_listener(Mock()) diff --git a/tests/core/transport/quic/test_utils.py b/tests/core/transport/quic/test_utils.py new file mode 100644 index 00000000..900c5c7e --- /dev/null +++ b/tests/core/transport/quic/test_utils.py @@ -0,0 +1,321 @@ +""" +Test suite for QUIC multiaddr utilities. +Focused tests covering essential functionality required for QUIC transport. +""" + +import pytest +from multiaddr import Multiaddr + +from libp2p.custom_types import TProtocol +from libp2p.transport.quic.exceptions import ( + QUICInvalidMultiaddrError, + QUICUnsupportedVersionError, +) +from libp2p.transport.quic.utils import ( + create_quic_multiaddr, + get_alpn_protocols, + is_quic_multiaddr, + multiaddr_to_quic_version, + normalize_quic_multiaddr, + quic_multiaddr_to_endpoint, + quic_version_to_wire_format, +) + + +class TestIsQuicMultiaddr: + """Test QUIC multiaddr detection.""" + + def test_valid_quic_v1_multiaddrs(self): + """Test valid QUIC v1 multiaddrs are detected.""" + valid_addrs = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip4/192.168.1.1/udp/8080/quic-v1", + "/ip6/::1/udp/4001/quic-v1", + "/ip6/2001:db8::1/udp/5000/quic-v1", + ] + + for addr_str in valid_addrs: + maddr = Multiaddr(addr_str) + assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" + + def test_valid_quic_draft29_multiaddrs(self): + """Test valid QUIC draft-29 multiaddrs are detected.""" + valid_addrs = [ + "/ip4/127.0.0.1/udp/4001/quic", + "/ip4/10.0.0.1/udp/9000/quic", + "/ip6/::1/udp/4001/quic", + "/ip6/fe80::1/udp/6000/quic", + ] + + for addr_str in valid_addrs: + maddr = Multiaddr(addr_str) + assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" + + def test_invalid_multiaddrs(self): + """Test non-QUIC multiaddrs are not detected.""" + invalid_addrs = [ + "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC + "/ip4/127.0.0.1/udp/4001", # UDP without QUIC + "/ip4/127.0.0.1/udp/4001/ws", # WebSocket + "/ip4/127.0.0.1/quic-v1", # Missing UDP + "/udp/4001/quic-v1", # Missing IP + "/dns4/example.com/tcp/443/tls", # Completely different + ] + + for addr_str in invalid_addrs: + maddr = Multiaddr(addr_str) + assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC" + + +class TestQuicMultiaddrToEndpoint: + """Test endpoint extraction from QUIC multiaddrs.""" + + def test_ipv4_extraction(self): + """Test IPv4 host/port extraction.""" + test_cases = [ + ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)), + ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)), + ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)), + ] + + for addr_str, expected in test_cases: + maddr = Multiaddr(addr_str) + result = quic_multiaddr_to_endpoint(maddr) + assert result == expected, f"Failed for {addr_str}" + + def test_ipv6_extraction(self): + """Test IPv6 host/port extraction.""" + test_cases = [ + ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)), + ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)), + ] + + for addr_str, expected in test_cases: + maddr = Multiaddr(addr_str) + result = quic_multiaddr_to_endpoint(maddr) + assert result == expected, f"Failed for {addr_str}" + + def test_invalid_multiaddr_raises_error(self): + """Test invalid multiaddrs raise appropriate errors.""" + invalid_addrs = [ + "/ip4/127.0.0.1/tcp/4001", # Not QUIC + "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol + ] + + for addr_str in invalid_addrs: + maddr = Multiaddr(addr_str) + with pytest.raises(QUICInvalidMultiaddrError): + quic_multiaddr_to_endpoint(maddr) + + +class TestMultiaddrToQuicVersion: + """Test QUIC version extraction.""" + + def test_quic_v1_detection(self): + """Test QUIC v1 version detection.""" + addrs = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip6/::1/udp/5000/quic-v1", + ] + + for addr_str in addrs: + maddr = Multiaddr(addr_str) + version = multiaddr_to_quic_version(maddr) + assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}" + + def test_quic_draft29_detection(self): + """Test QUIC draft-29 version detection.""" + addrs = [ + "/ip4/127.0.0.1/udp/4001/quic", + "/ip6/::1/udp/5000/quic", + ] + + for addr_str in addrs: + maddr = Multiaddr(addr_str) + version = multiaddr_to_quic_version(maddr) + assert version == "quic", f"Should detect quic for {addr_str}" + + def test_non_quic_raises_error(self): + """Test non-QUIC multiaddrs raise error.""" + maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + with pytest.raises(QUICInvalidMultiaddrError): + multiaddr_to_quic_version(maddr) + + +class TestCreateQuicMultiaddr: + """Test QUIC multiaddr creation.""" + + def test_ipv4_creation(self): + """Test IPv4 QUIC multiaddr creation.""" + test_cases = [ + ("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"), + ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"), + ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"), + ] + + for host, port, version, expected in test_cases: + result = create_quic_multiaddr(host, port, version) + assert str(result) == expected + + def test_ipv6_creation(self): + """Test IPv6 QUIC multiaddr creation.""" + test_cases = [ + ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"), + ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"), + ] + + for host, port, version, expected in test_cases: + result = create_quic_multiaddr(host, port, version) + assert str(result) == expected + + def test_default_version(self): + """Test default version is quic-v1.""" + result = create_quic_multiaddr("127.0.0.1", 4001) + expected = "/ip4/127.0.0.1/udp/4001/quic-v1" + assert str(result) == expected + + def test_invalid_inputs_raise_errors(self): + """Test invalid inputs raise appropriate errors.""" + # Invalid IP + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("invalid-ip", 4001) + + # Invalid port + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", 70000) + + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", -1) + + # Invalid version + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", 4001, "invalid-version") + + +class TestQuicVersionToWireFormat: + """Test QUIC version to wire format conversion.""" + + def test_supported_versions(self): + """Test supported version conversions.""" + test_cases = [ + ("quic-v1", 0x00000001), # RFC 9000 + ("quic", 0xFF00001D), # draft-29 + ] + + for version, expected_wire in test_cases: + result = quic_version_to_wire_format(TProtocol(version)) + assert result == expected_wire, f"Failed for version {version}" + + def test_unsupported_version_raises_error(self): + """Test unsupported versions raise error.""" + with pytest.raises(QUICUnsupportedVersionError): + quic_version_to_wire_format(TProtocol("unsupported-version")) + + +class TestGetAlpnProtocols: + """Test ALPN protocol retrieval.""" + + def test_returns_libp2p_protocols(self): + """Test returns expected libp2p ALPN protocols.""" + protocols = get_alpn_protocols() + assert protocols == ["libp2p"] + assert isinstance(protocols, list) + + def test_returns_copy(self): + """Test returns a copy, not the original list.""" + protocols1 = get_alpn_protocols() + protocols2 = get_alpn_protocols() + + # Modify one list + protocols1.append("test") + + # Other list should be unchanged + assert protocols2 == ["libp2p"] + + +class TestNormalizeQuicMultiaddr: + """Test QUIC multiaddr normalization.""" + + def test_already_normalized(self): + """Test already normalized multiaddrs pass through.""" + addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" + maddr = Multiaddr(addr_str) + + result = normalize_quic_multiaddr(maddr) + assert str(result) == addr_str + + def test_normalize_different_versions(self): + """Test normalization works for different QUIC versions.""" + test_cases = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip4/127.0.0.1/udp/4001/quic", + "/ip6/::1/udp/5000/quic-v1", + ] + + for addr_str in test_cases: + maddr = Multiaddr(addr_str) + result = normalize_quic_multiaddr(maddr) + + # Should be valid QUIC multiaddr + assert is_quic_multiaddr(result) + + # Should be parseable + host, port = quic_multiaddr_to_endpoint(result) + version = multiaddr_to_quic_version(result) + + # Should match original + orig_host, orig_port = quic_multiaddr_to_endpoint(maddr) + orig_version = multiaddr_to_quic_version(maddr) + + assert host == orig_host + assert port == orig_port + assert version == orig_version + + def test_non_quic_raises_error(self): + """Test non-QUIC multiaddrs raise error.""" + maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + with pytest.raises(QUICInvalidMultiaddrError): + normalize_quic_multiaddr(maddr) + + +class TestIntegration: + """Integration tests for utility functions working together.""" + + def test_round_trip_conversion(self): + """Test creating and parsing multiaddrs works correctly.""" + test_cases = [ + ("127.0.0.1", 4001, "quic-v1"), + ("::1", 5000, "quic"), + ("192.168.1.100", 8080, "quic-v1"), + ] + + for host, port, version in test_cases: + # Create multiaddr + maddr = create_quic_multiaddr(host, port, version) + + # Should be detected as QUIC + assert is_quic_multiaddr(maddr) + + # Should extract original values + extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr) + extracted_version = multiaddr_to_quic_version(maddr) + + assert extracted_host == host + assert extracted_port == port + assert extracted_version == version + + # Should normalize to same value + normalized = normalize_quic_multiaddr(maddr) + assert str(normalized) == str(maddr) + + def test_wire_format_integration(self): + """Test wire format conversion works with version detection.""" + addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" + maddr = Multiaddr(addr_str) + + # Extract version and convert to wire format + version = multiaddr_to_quic_version(maddr) + wire_format = quic_version_to_wire_format(version) + + # Should be QUIC v1 wire format + assert wire_format == 0x00000001 diff --git a/tests/core/transport/test_transport_registry.py b/tests/core/transport/test_transport_registry.py new file mode 100644 index 00000000..ff2fb234 --- /dev/null +++ b/tests/core/transport/test_transport_registry.py @@ -0,0 +1,324 @@ +""" +Tests for the transport registry functionality. +""" + +from multiaddr import Multiaddr + +from libp2p.abc import IListener, IRawConnection, ITransport +from libp2p.custom_types import THandler +from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.transport_registry import ( + TransportRegistry, + create_transport_for_multiaddr, + get_supported_transport_protocols, + get_transport_registry, + register_transport, +) +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.transport import WebsocketTransport + + +class TestTransportRegistry: + """Test the TransportRegistry class.""" + + def test_init(self): + """Test registry initialization.""" + registry = TransportRegistry() + assert isinstance(registry, TransportRegistry) + + # Check that default transports are registered + supported = registry.get_supported_protocols() + assert "tcp" in supported + assert "ws" in supported + + def test_register_transport(self): + """Test transport registration.""" + registry = TransportRegistry() + + # Register a custom transport + class CustomTransport(ITransport): + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("CustomTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "CustomTransport create_listener not implemented" + ) + + registry.register_transport("custom", CustomTransport) + assert registry.get_transport("custom") == CustomTransport + + def test_get_transport(self): + """Test getting registered transports.""" + registry = TransportRegistry() + + # Test existing transports + assert registry.get_transport("tcp") == TCP + assert registry.get_transport("ws") == WebsocketTransport + + # Test non-existent transport + assert registry.get_transport("nonexistent") is None + + def test_get_supported_protocols(self): + """Test getting supported protocols.""" + registry = TransportRegistry() + protocols = registry.get_supported_protocols() + + assert isinstance(protocols, list) + assert "tcp" in protocols + assert "ws" in protocols + + def test_create_transport_tcp(self): + """Test creating TCP transport.""" + registry = TransportRegistry() + upgrader = TransportUpgrader({}, {}) + + transport = registry.create_transport("tcp", upgrader) + assert isinstance(transport, TCP) + + def test_create_transport_websocket(self): + """Test creating WebSocket transport.""" + registry = TransportRegistry() + upgrader = TransportUpgrader({}, {}) + + transport = registry.create_transport("ws", upgrader) + assert isinstance(transport, WebsocketTransport) + + def test_create_transport_invalid_protocol(self): + """Test creating transport with invalid protocol.""" + registry = TransportRegistry() + upgrader = TransportUpgrader({}, {}) + + transport = registry.create_transport("invalid", upgrader) + assert transport is None + + def test_create_transport_websocket_no_upgrader(self): + """Test that WebSocket transport requires upgrader.""" + registry = TransportRegistry() + + # This should fail gracefully and return None + transport = registry.create_transport("ws", None) + assert transport is None + + +class TestGlobalRegistry: + """Test the global registry functions.""" + + def test_get_transport_registry(self): + """Test getting the global registry.""" + registry = get_transport_registry() + assert isinstance(registry, TransportRegistry) + + def test_register_transport_global(self): + """Test registering transport globally.""" + + class GlobalCustomTransport(ITransport): + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("GlobalCustomTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "GlobalCustomTransport create_listener not implemented" + ) + + # Register globally + register_transport("global_custom", GlobalCustomTransport) + + # Check that it's available + registry = get_transport_registry() + assert registry.get_transport("global_custom") == GlobalCustomTransport + + def test_get_supported_transport_protocols_global(self): + """Test getting supported protocols from global registry.""" + protocols = get_supported_transport_protocols() + assert isinstance(protocols, list) + assert "tcp" in protocols + assert "ws" in protocols + + +class TestTransportFactory: + """Test the transport factory functions.""" + + def test_create_transport_for_multiaddr_tcp(self): + """Test creating transport for TCP multiaddr.""" + upgrader = TransportUpgrader({}, {}) + + # TCP multiaddr + maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080") + transport = create_transport_for_multiaddr(maddr, upgrader) + + assert transport is not None + assert isinstance(transport, TCP) + + def test_create_transport_for_multiaddr_websocket(self): + """Test creating transport for WebSocket multiaddr.""" + upgrader = TransportUpgrader({}, {}) + + # WebSocket multiaddr + maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + transport = create_transport_for_multiaddr(maddr, upgrader) + + assert transport is not None + assert isinstance(transport, WebsocketTransport) + + def test_create_transport_for_multiaddr_websocket_secure(self): + """Test creating transport for WebSocket multiaddr.""" + upgrader = TransportUpgrader({}, {}) + + # WebSocket multiaddr + maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + transport = create_transport_for_multiaddr(maddr, upgrader) + + assert transport is not None + assert isinstance(transport, WebsocketTransport) + + def test_create_transport_for_multiaddr_ipv6(self): + """Test creating transport for IPv6 multiaddr.""" + upgrader = TransportUpgrader({}, {}) + + # IPv6 WebSocket multiaddr + maddr = Multiaddr("/ip6/::1/tcp/8080/ws") + transport = create_transport_for_multiaddr(maddr, upgrader) + + assert transport is not None + assert isinstance(transport, WebsocketTransport) + + def test_create_transport_for_multiaddr_dns(self): + """Test creating transport for DNS multiaddr.""" + upgrader = TransportUpgrader({}, {}) + + # DNS WebSocket multiaddr + maddr = Multiaddr("/dns4/example.com/tcp/443/ws") + transport = create_transport_for_multiaddr(maddr, upgrader) + + assert transport is not None + assert isinstance(transport, WebsocketTransport) + + def test_create_transport_for_multiaddr_unknown(self): + """Test creating transport for unknown multiaddr.""" + upgrader = TransportUpgrader({}, {}) + + # Unknown multiaddr + maddr = Multiaddr("/ip4/127.0.0.1/udp/8080") + transport = create_transport_for_multiaddr(maddr, upgrader) + + assert transport is None + + def test_create_transport_for_multiaddr_with_upgrader(self): + """Test creating transport with upgrader.""" + upgrader = TransportUpgrader({}, {}) + + # This should work for both TCP and WebSocket with upgrader + maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080") + transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader) + assert transport_tcp is not None + + maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader) + assert transport_ws is not None + + +class TestTransportInterfaceCompliance: + """Test that all transports implement the required interface.""" + + def test_tcp_implements_itransport(self): + """Test that TCP transport implements ITransport.""" + transport = TCP() + assert isinstance(transport, ITransport) + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + assert callable(transport.dial) + assert callable(transport.create_listener) + + def test_websocket_implements_itransport(self): + """Test that WebSocket transport implements ITransport.""" + upgrader = TransportUpgrader({}, {}) + transport = WebsocketTransport(upgrader) + assert isinstance(transport, ITransport) + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + assert callable(transport.dial) + assert callable(transport.create_listener) + + +class TestErrorHandling: + """Test error handling in the transport registry.""" + + def test_create_transport_with_exception(self): + """Test handling of transport creation exceptions.""" + registry = TransportRegistry() + upgrader = TransportUpgrader({}, {}) + + # Register a transport that raises an exception + class ExceptionTransport(ITransport): + def __init__(self, *args, **kwargs): + raise RuntimeError("Transport creation failed") + + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("ExceptionTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "ExceptionTransport create_listener not implemented" + ) + + registry.register_transport("exception", ExceptionTransport) + + # Should handle exception gracefully and return None + transport = registry.create_transport("exception", upgrader) + assert transport is None + + def test_invalid_multiaddr_handling(self): + """Test handling of invalid multiaddrs.""" + upgrader = TransportUpgrader({}, {}) + + # Test with a multiaddr that has an unsupported transport protocol + # This should be handled gracefully by our transport registry + # udp is not a supported transport + maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") + transport = create_transport_for_multiaddr(maddr, upgrader) + + assert transport is None + + +class TestIntegration: + """Test integration scenarios.""" + + def test_multiple_transport_types(self): + """Test using multiple transport types in the same registry.""" + registry = TransportRegistry() + upgrader = TransportUpgrader({}, {}) + + # Create different transport types + tcp_transport = registry.create_transport("tcp", upgrader) + ws_transport = registry.create_transport("ws", upgrader) + + # All should be different types + assert isinstance(tcp_transport, TCP) + assert isinstance(ws_transport, WebsocketTransport) + + # All should be different instances + assert tcp_transport is not ws_transport + + def test_transport_registry_persistence(self): + """Test that transport registry persists across calls.""" + registry1 = get_transport_registry() + registry2 = get_transport_registry() + + # Should be the same instance + assert registry1 is registry2 + + # Register a transport in one + class PersistentTransport(ITransport): + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("PersistentTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "PersistentTransport create_listener not implemented" + ) + + registry1.register_transport("persistent", PersistentTransport) + + # Should be available in the other + assert registry2.get_transport("persistent") == PersistentTransport diff --git a/tests/core/transport/test_upgrader.py b/tests/core/transport/test_upgrader.py new file mode 100644 index 00000000..8535a039 --- /dev/null +++ b/tests/core/transport/test_upgrader.py @@ -0,0 +1,27 @@ +import pytest + +from libp2p.custom_types import ( + TMuxerOptions, + TSecurityOptions, +) +from libp2p.transport.upgrader import ( + TransportUpgrader, +) + + +@pytest.mark.trio +async def test_transport_upgrader_security_and_muxer_initialization(): + """Test TransportUpgrader initializes security and muxer multistreams correctly.""" + secure_transports: TSecurityOptions = {} + muxer_transports: TMuxerOptions = {} + negotiate_timeout = 15 + + upgrader = TransportUpgrader( + secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout + ) + + # Verify security multistream initialization + assert upgrader.security_multistream.transports == secure_transports + # Verify muxer multistream initialization and timeout + assert upgrader.muxer_multistream.transports == muxer_transports + assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout diff --git a/tests/core/transport/test_websocket.py b/tests/core/transport/test_websocket.py new file mode 100644 index 00000000..6c1e249d --- /dev/null +++ b/tests/core/transport/test_websocket.py @@ -0,0 +1,1631 @@ +# Import exceptiongroup for Python 3.11+ +import builtins +from collections.abc import Sequence +import logging +from typing import Any + +import pytest + +if hasattr(builtins, "ExceptionGroup"): + ExceptionGroup = builtins.ExceptionGroup +else: + # Fallback for older Python versions + ExceptionGroup = Exception +from multiaddr import Multiaddr +import trio + +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.host.basic_host import BasicHost +from libp2p.network.swarm import Swarm +from libp2p.peer.id import ID +from libp2p.peer.peerstore import PeerStore +from libp2p.security.insecure.transport import InsecureTransport +from libp2p.stream_muxer.yamux.yamux import Yamux +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.multiaddr_utils import ( + is_valid_websocket_multiaddr, + parse_websocket_multiaddr, +) +from libp2p.transport.websocket.transport import WebsocketTransport + +logger = logging.getLogger(__name__) + +PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0" + + +async def make_host( + listen_addrs: Sequence[Multiaddr] | None = None, +) -> tuple[BasicHost, Any | None]: + # Identity + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + peer_store = PeerStore() + peer_store.add_key_pair(peer_id, key_pair) + + # Upgrader + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + + # Transport + Swarm + Host + transport = WebsocketTransport(upgrader) + swarm = Swarm(peer_id, peer_store, upgrader, transport) + host = BasicHost(swarm) + + # Optionally run/listen + ctx = None + if listen_addrs: + ctx = host.run(listen_addrs) + await ctx.__aenter__() + + return host, ctx + + +def create_upgrader(): + """Helper function to create a transport upgrader""" + key_pair = create_new_key_pair() + return TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + + +# 2. Listener Basic Functionality Tests +@pytest.mark.trio +async def test_listener_basic_listen(): + """Test basic listen functionality""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test listening on IPv4 + ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + + # Test that listener can be created and has required methods + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + + # Test that listener can handle the address + assert ma.value_for_protocol("ip4") == "127.0.0.1" + assert ma.value_for_protocol("tcp") == "0" + + # Test that listener can be closed + await listener.close() + + +@pytest.mark.trio +async def test_listener_port_0_handling(): + """Test listening on port 0 gets actual port""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + + # Test that the address can be parsed correctly + port_str = ma.value_for_protocol("tcp") + assert port_str == "0" + + # Test that listener can be closed + await listener.close() + + +@pytest.mark.trio +async def test_listener_any_interface(): + """Test listening on 0.0.0.0""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + + # Test that the address can be parsed correctly + host = ma.value_for_protocol("ip4") + assert host == "0.0.0.0" + + # Test that listener can be closed + await listener.close() + + +@pytest.mark.trio +async def test_listener_address_preservation(): + """Test that p2p IDs are preserved in addresses""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Create address with p2p ID + p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF" + ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + + # Test that p2p ID is preserved in the address + addr_str = str(ma) + assert p2p_id in addr_str + + # Test that listener can be closed + await listener.close() + + +# 3. Dial Basic Functionality Tests +@pytest.mark.trio +async def test_dial_basic(): + """Test basic dial functionality""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test that transport can parse addresses for dialing + ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + + # Test that the address can be parsed correctly + host = ma.value_for_protocol("ip4") + port = ma.value_for_protocol("tcp") + assert host == "127.0.0.1" + assert port == "8080" + + # Test that transport has the required methods + assert hasattr(transport, "dial") + assert callable(transport.dial) + + +@pytest.mark.trio +async def test_dial_with_p2p_id(): + """Test dialing with p2p ID suffix""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF" + ma = Multiaddr(f"/ip4/127.0.0.1/tcp/8080/ws/p2p/{p2p_id}") + + # Test that p2p ID is preserved in the address + addr_str = str(ma) + assert p2p_id in addr_str + + # Test that transport can handle addresses with p2p IDs + assert hasattr(transport, "dial") + assert callable(transport.dial) + + +@pytest.mark.trio +async def test_dial_port_0_resolution(): + """Test dialing to resolved port 0 addresses""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test that transport can handle port 0 addresses + ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") + + # Test that the address can be parsed correctly + port_str = ma.value_for_protocol("tcp") + assert port_str == "0" + + # Test that transport has the required methods + assert hasattr(transport, "dial") + assert callable(transport.dial) + + +# 4. Address Validation Tests (CRITICAL) +def test_address_validation_ipv4(): + """Test IPv4 address validation""" + # upgrader = create_upgrader() # Not used in this test + + # Valid IPv4 WebSocket addresses + valid_addresses = [ + "/ip4/127.0.0.1/tcp/8080/ws", + "/ip4/0.0.0.0/tcp/0/ws", + "/ip4/192.168.1.1/tcp/443/ws", + ] + + # Test valid addresses can be parsed + for addr_str in valid_addresses: + ma = Multiaddr(addr_str) + # Should not raise exception when creating transport address + transport_addr = str(ma) + assert "/ws" in transport_addr + + # Test that transport can handle addresses with p2p IDs + p2p_addr = Multiaddr( + "/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw" + ) + # Should not raise exception when creating transport address + transport_addr = str(p2p_addr) + assert "/ws" in transport_addr + + +def test_address_validation_ipv6(): + """Test IPv6 address validation""" + # upgrader = create_upgrader() # Not used in this test + + # Valid IPv6 WebSocket addresses + valid_addresses = [ + "/ip6/::1/tcp/8080/ws", + "/ip6/2001:db8::1/tcp/443/ws", + ] + + # Test valid addresses can be parsed + for addr_str in valid_addresses: + ma = Multiaddr(addr_str) + transport_addr = str(ma) + assert "/ws" in transport_addr + + +def test_address_validation_dns(): + """Test DNS address validation""" + # upgrader = create_upgrader() # Not used in this test + + # Valid DNS WebSocket addresses + valid_addresses = [ + "/dns4/example.com/tcp/80/ws", + "/dns6/example.com/tcp/443/ws", + "/dnsaddr/example.com/tcp/8080/ws", + ] + + # Test valid addresses can be parsed + for addr_str in valid_addresses: + ma = Multiaddr(addr_str) + transport_addr = str(ma) + assert "/ws" in transport_addr + + +def test_address_validation_mixed(): + """Test mixed address validation""" + # upgrader = create_upgrader() # Not used in this test + + # Mixed valid and invalid addresses + addresses = [ + "/ip4/127.0.0.1/tcp/8080/ws", # Valid + "/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws) + "/ip6/::1/tcp/8080/ws", # Valid + "/ip4/127.0.0.1/ws", # Invalid (no tcp) + "/dns4/example.com/tcp/80/ws", # Valid + ] + + # Convert to Multiaddr objects + multiaddrs = [Multiaddr(addr) for addr in addresses] + + # Test that valid addresses can be processed + valid_count = 0 + for ma in multiaddrs: + try: + # Try to extract transport part + addr_text = str(ma) + if "/ws" in addr_text and "/tcp/" in addr_text: + valid_count += 1 + except Exception: + pass + + assert valid_count == 3 # Should have 3 valid addresses + + +# 5. Error Handling Tests +@pytest.mark.trio +async def test_dial_invalid_address(): + """Test dialing invalid addresses""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test dialing non-WebSocket addresses + invalid_addresses = [ + Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws + Multiaddr("/ip4/127.0.0.1/ws"), # No tcp + ] + + for ma in invalid_addresses: + with pytest.raises(Exception): + await transport.dial(ma) + + +@pytest.mark.trio +async def test_listen_invalid_address(): + """Test listening on invalid addresses""" + # upgrader = create_upgrader() # Not used in this test + + # Test listening on non-WebSocket addresses + invalid_addresses = [ + Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws + Multiaddr("/ip4/127.0.0.1/ws"), # No tcp + ] + + # Test that invalid addresses are properly identified + for ma in invalid_addresses: + # Test that the address parsing works correctly + if "/ws" in str(ma) and "tcp" not in str(ma): + # This should be invalid + assert "tcp" not in str(ma) + elif "/ws" not in str(ma): + # This should be invalid + assert "/ws" not in str(ma) + + +@pytest.mark.trio +async def test_listen_port_in_use(): + """Test listening on port that's in use""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test that transport can handle port conflicts + ma1 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + ma2 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + + # Test that both addresses can be parsed + assert ma1.value_for_protocol("tcp") == "8080" + assert ma2.value_for_protocol("tcp") == "8080" + + # Test that transport can handle these addresses + assert hasattr(transport, "create_listener") + assert callable(transport.create_listener) + + +# 6. Connection Lifecycle Tests +@pytest.mark.trio +async def test_connection_close(): + """Test connection closing""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test that transport has required methods + assert hasattr(transport, "dial") + assert callable(transport.dial) + + # Test that listener can be created and closed + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + assert hasattr(listener, "close") + assert callable(listener.close) + + # Test that listener can be closed + await listener.close() + + +@pytest.mark.trio +async def test_multiple_connections(): + """Test multiple concurrent connections""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test that transport can handle multiple addresses + addresses = [ + Multiaddr("/ip4/127.0.0.1/tcp/8080/ws"), + Multiaddr("/ip4/127.0.0.1/tcp/8081/ws"), + Multiaddr("/ip4/127.0.0.1/tcp/8082/ws"), + ] + + # Test that all addresses can be parsed + for addr in addresses: + host = addr.value_for_protocol("ip4") + port = addr.value_for_protocol("tcp") + assert host == "127.0.0.1" + assert port in ["8080", "8081", "8082"] + + # Test that transport has required methods + assert hasattr(transport, "dial") + assert callable(transport.dial) + + +# Original test (kept for compatibility) +@pytest.mark.trio +async def test_websocket_dial_and_listen(): + """Test basic WebSocket dial and listen functionality with real data transfer""" + # Test that WebSocket transport can handle basic operations + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test that transport can create listeners + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + assert listener is not None + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + + # Test that transport can handle WebSocket addresses + ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") + assert ma.value_for_protocol("ip4") == "127.0.0.1" + assert ma.value_for_protocol("tcp") == "0" + assert "ws" in str(ma) + + # Test that transport has dial method + assert hasattr(transport, "dial") + assert callable(transport.dial) + + # Test that transport can handle WebSocket multiaddrs + ws_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + assert ws_addr.value_for_protocol("ip4") == "127.0.0.1" + assert ws_addr.value_for_protocol("tcp") == "8080" + assert "ws" in str(ws_addr) + + # Cleanup + await listener.close() + + +@pytest.mark.trio +async def test_websocket_transport_basic(): + """Test basic WebSocket transport functionality without full libp2p stack""" + # Create WebSocket transport + key_pair = create_new_key_pair() + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + transport = WebsocketTransport(upgrader) + + assert transport is not None + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + assert listener is not None + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + + valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") + assert valid_addr.value_for_protocol("ip4") == "127.0.0.1" + assert valid_addr.value_for_protocol("tcp") == "0" + assert "ws" in str(valid_addr) + + await listener.close() + + +@pytest.mark.trio +async def test_websocket_simple_connection(): + """Test WebSocket transport creation and basic functionality without real conn""" + # Create WebSocket transport + key_pair = create_new_key_pair() + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + transport = WebsocketTransport(upgrader) + + assert transport is not None + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + + async def simple_handler(conn): + await conn.close() + + listener = transport.create_listener(simple_handler) + assert listener is not None + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + + test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") + assert test_addr.value_for_protocol("ip4") == "127.0.0.1" + assert test_addr.value_for_protocol("tcp") == "0" + assert "ws" in str(test_addr) + + await listener.close() + + +@pytest.mark.trio +async def test_websocket_real_connection(): + """Test WebSocket transport creation and basic functionality""" + # Create WebSocket transport + key_pair = create_new_key_pair() + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + transport = WebsocketTransport(upgrader) + + assert transport is not None + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + + async def handler(conn): + await conn.close() + + listener = transport.create_listener(handler) + assert listener is not None + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + + await listener.close() + + +@pytest.mark.trio +async def test_websocket_with_tcp_fallback(): + """Test WebSocket functionality using TCP transport as fallback""" + from tests.utils.factories import host_pair_factory + + async with host_pair_factory() as (host_a, host_b): + assert len(host_a.get_network().connections) > 0 + assert len(host_b.get_network().connections) > 0 + + test_protocol = TProtocol("/test/protocol/1.0.0") + received_data = None + + async def test_handler(stream): + nonlocal received_data + received_data = await stream.read(1024) + await stream.write(b"Response from TCP") + await stream.close() + + host_a.set_stream_handler(test_protocol, test_handler) + stream = await host_b.new_stream(host_a.get_id(), [test_protocol]) + + test_data = b"TCP protocol test" + await stream.write(test_data) + response = await stream.read(1024) + + assert received_data == test_data + assert response == b"Response from TCP" + + await stream.close() + + +@pytest.mark.trio +async def test_websocket_data_exchange(): + """Test WebSocket transport with actual data exchange between two hosts""" + from libp2p import create_yamux_muxer_option, new_host + from libp2p.crypto.secp256k1 import create_new_key_pair + from libp2p.custom_types import TProtocol + from libp2p.peer.peerinfo import info_from_p2p_addr + from libp2p.security.insecure.transport import ( + PLAINTEXT_PROTOCOL_ID, + InsecureTransport, + ) + + # Create two hosts with plaintext security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + + # Host A (listener) + security_options_a = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], + ) + + # Host B (dialer) + security_options_b = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], # WebSocket transport + ) + + # Test data + test_data = b"Hello WebSocket Data Exchange!" + received_data = None + + # Set up handler on host A + test_protocol = TProtocol("/test/websocket/data/1.0.0") + + async def data_handler(stream): + nonlocal received_data + received_data = await stream.read(len(test_data)) + await stream.write(received_data) # Echo back + await stream.close() + + host_a.set_stream_handler(test_protocol, data_handler) + + # Start both hosts + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert len(listen_addrs) > 0 + + # Find the WebSocket address + ws_addr = None + for addr in listen_addrs: + if "/ws" in str(addr): + ws_addr = addr + break + + assert ws_addr is not None, "No WebSocket listen address found" + + # Connect host B to host A + peer_info = info_from_p2p_addr(ws_addr) + await host_b.connect(peer_info) + + # Create stream and test data exchange + stream = await host_b.new_stream(host_a.get_id(), [test_protocol]) + await stream.write(test_data) + response = await stream.read(len(test_data)) + await stream.close() + + # Verify data exchange + assert received_data == test_data, f"Expected {test_data}, got {received_data}" + assert response == test_data, f"Expected echo {test_data}, got {response}" + + +@pytest.mark.trio +async def test_websocket_host_pair_data_exchange(): + """ + Test WebSocket host pair with actual data exchange using host_pair_factory + pattern. + """ + from libp2p import create_yamux_muxer_option, new_host + from libp2p.crypto.secp256k1 import create_new_key_pair + from libp2p.custom_types import TProtocol + from libp2p.peer.peerinfo import info_from_p2p_addr + from libp2p.security.insecure.transport import ( + PLAINTEXT_PROTOCOL_ID, + InsecureTransport, + ) + + # Create two hosts with WebSocket transport and plaintext security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + + # Host A (listener) - WebSocket transport + security_options_a = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], + ) + + # Host B (dialer) - WebSocket transport + security_options_b = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], # WebSocket transport + ) + + # Test data + test_data = b"Hello WebSocket Host Pair Data Exchange!" + received_data = None + + # Set up handler on host A + test_protocol = TProtocol("/test/websocket/hostpair/1.0.0") + + async def data_handler(stream): + nonlocal received_data + received_data = await stream.read(len(test_data)) + await stream.write(received_data) # Echo back + await stream.close() + + host_a.set_stream_handler(test_protocol, data_handler) + + # Start both hosts and connect them (following host_pair_factory pattern) + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]), + host_b.run(listen_addrs=[]), + ): + # Connect the hosts using the same pattern as host_pair_factory + # Get host A's listen address and create peer info + listen_addrs = host_a.get_addrs() + assert len(listen_addrs) > 0 + + # Find the WebSocket address + ws_addr = None + for addr in listen_addrs: + if "/ws" in str(addr): + ws_addr = addr + break + + assert ws_addr is not None, "No WebSocket listen address found" + + # Connect host B to host A + peer_info = info_from_p2p_addr(ws_addr) + await host_b.connect(peer_info) + + # Allow time for connection to establish (following host_pair_factory pattern) + await trio.sleep(0.1) + + # Verify connection is established + assert len(host_a.get_network().connections) > 0 + assert len(host_b.get_network().connections) > 0 + + # Test data exchange + stream = await host_b.new_stream(host_a.get_id(), [test_protocol]) + await stream.write(test_data) + response = await stream.read(len(test_data)) + await stream.close() + + # Verify data exchange + assert received_data == test_data, f"Expected {test_data}, got {received_data}" + assert response == test_data, f"Expected echo {test_data}, got {response}" + + +@pytest.mark.trio +async def test_wss_host_pair_data_exchange(): + """Test WSS host pair with actual data exchange using host_pair_factory pattern""" + import ssl + + from libp2p import create_yamux_muxer_option, new_host + from libp2p.crypto.secp256k1 import create_new_key_pair + from libp2p.custom_types import TProtocol + from libp2p.peer.peerinfo import info_from_p2p_addr + from libp2p.security.insecure.transport import ( + PLAINTEXT_PROTOCOL_ID, + InsecureTransport, + ) + + # Create TLS contexts for WSS (separate for client and server) + # For testing, we need to create a self-signed certificate + try: + import datetime + import ipaddress + import os + import tempfile + + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + + # Generate private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + # Create certificate + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), # type: ignore + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"), # type: ignore + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), # type: ignore + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test"), # type: ignore + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), # type: ignore + ] + ) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after( + datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1) + ) + .add_extension( + x509.SubjectAlternativeName( + [ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ] + ), + critical=False, + ) + .sign(private_key, hashes.SHA256()) + ) + + # Create temporary files for cert and key + cert_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".crt") + key_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".key") + + # Write certificate and key to files + cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) + key_file.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + cert_file.close() + key_file.close() + + # Server context for listener (Host A) + server_tls_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + server_tls_context.load_cert_chain(cert_file.name, key_file.name) + + # Client context for dialer (Host B) + client_tls_context = ssl.create_default_context() + client_tls_context.check_hostname = False + client_tls_context.verify_mode = ssl.CERT_NONE + + # Clean up temp files after use + def cleanup_certs(): + try: + os.unlink(cert_file.name) + os.unlink(key_file.name) + except Exception: + pass + + except ImportError: + pytest.skip("cryptography package required for WSS tests") + except Exception as e: + pytest.skip(f"Failed to create test certificates: {e}") + + # Create two hosts with WSS transport and plaintext security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + + # Host A (listener) - WSS transport with server TLS config + security_options_a = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], + tls_server_config=server_tls_context, + ) + + # Host B (dialer) - WSS transport with client TLS config + security_options_b = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], # WebSocket transport + tls_client_config=client_tls_context, + ) + + # Test data + test_data = b"Hello WSS Host Pair Data Exchange!" + received_data = None + + # Set up handler on host A + test_protocol = TProtocol("/test/wss/hostpair/1.0.0") + + async def data_handler(stream): + nonlocal received_data + received_data = await stream.read(len(test_data)) + await stream.write(received_data) # Echo back + await stream.close() + + host_a.set_stream_handler(test_protocol, data_handler) + + # Start both hosts and connect them (following host_pair_factory pattern) + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")]), + host_b.run(listen_addrs=[]), + ): + # Connect the hosts using the same pattern as host_pair_factory + # Get host A's listen address and create peer info + listen_addrs = host_a.get_addrs() + assert len(listen_addrs) > 0 + + # Find the WSS address + wss_addr = None + for addr in listen_addrs: + if "/wss" in str(addr): + wss_addr = addr + break + + assert wss_addr is not None, "No WSS listen address found" + + # Connect host B to host A + peer_info = info_from_p2p_addr(wss_addr) + await host_b.connect(peer_info) + + # Allow time for connection to establish (following host_pair_factory pattern) + await trio.sleep(0.1) + + # Verify connection is established + assert len(host_a.get_network().connections) > 0 + assert len(host_b.get_network().connections) > 0 + + # Test data exchange + stream = await host_b.new_stream(host_a.get_id(), [test_protocol]) + await stream.write(test_data) + response = await stream.read(len(test_data)) + await stream.close() + + # Verify data exchange + assert received_data == test_data, f"Expected {test_data}, got {received_data}" + assert response == test_data, f"Expected echo {test_data}, got {response}" + + +@pytest.mark.trio +async def test_websocket_transport_interface(): + """Test WebSocket transport interface compliance""" + key_pair = create_new_key_pair() + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + + transport = WebsocketTransport(upgrader) + + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + assert callable(transport.dial) + assert callable(transport.create_listener) + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + + test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + host = test_addr.value_for_protocol("ip4") + port = test_addr.value_for_protocol("tcp") + assert host == "127.0.0.1" + assert port == "8080" + + await listener.close() + + +# ============================================================================ +# WSS (WebSocket Secure) Tests +# ============================================================================ + + +def test_wss_multiaddr_validation(): + """Test WSS multiaddr validation and parsing.""" + # Valid WSS multiaddrs + valid_wss_addresses = [ + "/ip4/127.0.0.1/tcp/8080/wss", + "/ip6/::1/tcp/8080/wss", + "/dns/localhost/tcp/8080/wss", + "/ip4/127.0.0.1/tcp/8080/tls/ws", + "/ip6/::1/tcp/8080/tls/ws", + ] + + # Invalid WSS multiaddrs + invalid_wss_addresses = [ + "/ip4/127.0.0.1/tcp/8080/ws", # Regular WS, not WSS + "/ip4/127.0.0.1/tcp/8080", # No WebSocket protocol + "/ip4/127.0.0.1/wss", # No TCP + ] + + # Test valid WSS addresses + for addr_str in valid_wss_addresses: + ma = Multiaddr(addr_str) + assert is_valid_websocket_multiaddr(ma), f"Address {addr_str} should be valid" + + # Test parsing + parsed = parse_websocket_multiaddr(ma) + assert parsed.is_wss, f"Address {addr_str} should be parsed as WSS" + + # Test invalid addresses + for addr_str in invalid_wss_addresses: + ma = Multiaddr(addr_str) + if "/ws" in addr_str and "/wss" not in addr_str and "/tls" not in addr_str: + # Regular WS should be valid but not WSS + assert is_valid_websocket_multiaddr(ma), ( + f"Address {addr_str} should be valid" + ) + parsed = parse_websocket_multiaddr(ma) + assert not parsed.is_wss, f"Address {addr_str} should not be parsed as WSS" + else: + # Invalid addresses should fail validation + assert not is_valid_websocket_multiaddr(ma), ( + f"Address {addr_str} should be invalid" + ) + + +def test_wss_multiaddr_parsing(): + """Test WSS multiaddr parsing functionality.""" + # Test /wss format + wss_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss") + parsed = parse_websocket_multiaddr(wss_ma) + assert parsed.is_wss + assert parsed.sni is None + assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1" + assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080" + + # Test /tls/ws format + tls_ws_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/tls/ws") + parsed = parse_websocket_multiaddr(tls_ws_ma) + assert parsed.is_wss + assert parsed.sni is None + assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1" + assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080" + + # Test regular /ws format + ws_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + parsed = parse_websocket_multiaddr(ws_ma) + assert not parsed.is_wss + assert parsed.sni is None + + +@pytest.mark.trio +async def test_wss_transport_creation(): + """Test WSS transport creation with TLS configuration.""" + import ssl + + # Create TLS contexts + client_ssl_context = ssl.create_default_context() + server_ssl_context = ssl.create_default_context() + server_ssl_context.check_hostname = False + server_ssl_context.verify_mode = ssl.CERT_NONE + + upgrader = create_upgrader() + + # Test creating WSS transport with TLS configs + wss_transport = WebsocketTransport( + upgrader, + tls_client_config=client_ssl_context, + tls_server_config=server_ssl_context, + ) + + assert wss_transport is not None + assert hasattr(wss_transport, "dial") + assert hasattr(wss_transport, "create_listener") + assert wss_transport._tls_client_config is not None + assert wss_transport._tls_server_config is not None + + +@pytest.mark.trio +async def test_wss_transport_without_tls_config(): + """Test WSS transport creation without TLS configuration.""" + upgrader = create_upgrader() + + # Test creating WSS transport without TLS configs (should still work) + wss_transport = WebsocketTransport(upgrader) + + assert wss_transport is not None + assert hasattr(wss_transport, "dial") + assert hasattr(wss_transport, "create_listener") + assert wss_transport._tls_client_config is None + assert wss_transport._tls_server_config is None + + +@pytest.mark.trio +async def test_wss_dial_parsing(): + """Test WSS dial functionality with multiaddr parsing.""" + # upgrader = create_upgrader() # Not used in this test + # transport = WebsocketTransport(upgrader) # Not used in this test + + # Test WSS multiaddr parsing in dial + wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss") + + # Test that the transport can parse WSS addresses + # (We can't actually dial without a server, but we can test parsing) + try: + parsed = parse_websocket_multiaddr(wss_maddr) + assert parsed.is_wss + assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1" + assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080" + except Exception as e: + pytest.fail(f"WSS multiaddr parsing failed: {e}") + + +@pytest.mark.trio +async def test_wss_listen_parsing(): + """Test WSS listen functionality with multiaddr parsing.""" + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test WSS multiaddr parsing in listen + wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + + # Test that the transport can parse WSS addresses + try: + parsed = parse_websocket_multiaddr(wss_maddr) + assert parsed.is_wss + assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1" + assert parsed.rest_multiaddr.value_for_protocol("tcp") == "0" + except Exception as e: + pytest.fail(f"WSS multiaddr parsing failed: {e}") + + await listener.close() + + +@pytest.mark.trio +async def test_wss_listen_without_tls_config(): + """Test WSS listen without TLS configuration should fail.""" + from libp2p.transport.websocket.transport import WebsocketTransport + + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) # No TLS config + + wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + + # This should raise an error when TLS config is not provided + try: + nursery = trio.lowlevel.current_task().parent_nursery + if nursery is None: + pytest.fail("No parent nursery available for test") + # Type assertion to help the type checker understand nursery is not None + assert nursery is not None + await listener.listen(wss_maddr, nursery) + pytest.fail("WSS listen without TLS config should have failed") + except ValueError as e: + assert "without TLS configuration" in str(e) + except Exception as e: + pytest.fail(f"Unexpected error: {e}") + + await listener.close() + + +@pytest.mark.trio +async def test_wss_listen_with_tls_config(): + """Test WSS listen with TLS configuration.""" + import ssl + + # Create server TLS context + server_ssl_context = ssl.create_default_context() + server_ssl_context.check_hostname = False + server_ssl_context.verify_mode = ssl.CERT_NONE + + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader, tls_server_config=server_ssl_context) + + wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + + # This should not raise an error when TLS config is provided + # Note: We can't actually start listening without proper certificates, + # but we can test that the validation passes + try: + parsed = parse_websocket_multiaddr(wss_maddr) + assert parsed.is_wss + assert transport._tls_server_config is not None + except Exception as e: + pytest.fail(f"WSS listen with TLS config failed: {e}") + + await listener.close() + + +def test_wss_transport_registry(): + """Test WSS support in transport registry.""" + from libp2p.transport.transport_registry import ( + create_transport_for_multiaddr, + get_supported_transport_protocols, + ) + + # Test that WSS is supported + supported = get_supported_transport_protocols() + assert "ws" in supported + assert "wss" in supported + + # Test transport creation for WSS multiaddrs + upgrader = create_upgrader() + + # Test WS multiaddr + ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + ws_transport = create_transport_for_multiaddr(ws_maddr, upgrader) + assert ws_transport is not None + assert isinstance(ws_transport, WebsocketTransport) + + # Test WSS multiaddr + wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss") + wss_transport = create_transport_for_multiaddr(wss_maddr, upgrader) + assert wss_transport is not None + assert isinstance(wss_transport, WebsocketTransport) + + # Test TLS/WS multiaddr + tls_ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/tls/ws") + tls_ws_transport = create_transport_for_multiaddr(tls_ws_maddr, upgrader) + assert tls_ws_transport is not None + assert isinstance(tls_ws_transport, WebsocketTransport) + + +def test_wss_multiaddr_formats(): + """Test different WSS multiaddr formats.""" + # Test various WSS formats + wss_formats = [ + "/ip4/127.0.0.1/tcp/8080/wss", + "/ip6/::1/tcp/8080/wss", + "/dns/localhost/tcp/8080/wss", + "/ip4/127.0.0.1/tcp/8080/tls/ws", + "/ip6/::1/tcp/8080/tls/ws", + "/dns/example.com/tcp/443/tls/ws", + ] + + for addr_str in wss_formats: + ma = Multiaddr(addr_str) + + # Should be valid WebSocket multiaddr + assert is_valid_websocket_multiaddr(ma), f"Address {addr_str} should be valid" + + # Should parse as WSS + parsed = parse_websocket_multiaddr(ma) + assert parsed.is_wss, f"Address {addr_str} should be parsed as WSS" + + # Should have correct base multiaddr + assert parsed.rest_multiaddr.value_for_protocol("tcp") is not None + + +def test_wss_vs_ws_distinction(): + """Test that WSS and WS are properly distinguished.""" + # WS addresses should not be WSS + ws_addresses = [ + "/ip4/127.0.0.1/tcp/8080/ws", + "/ip6/::1/tcp/8080/ws", + "/dns/localhost/tcp/8080/ws", + ] + + for addr_str in ws_addresses: + ma = Multiaddr(addr_str) + parsed = parse_websocket_multiaddr(ma) + assert not parsed.is_wss, f"Address {addr_str} should not be WSS" + + # WSS addresses should be WSS + wss_addresses = [ + "/ip4/127.0.0.1/tcp/8080/wss", + "/ip4/127.0.0.1/tcp/8080/tls/ws", + ] + + for addr_str in wss_addresses: + ma = Multiaddr(addr_str) + parsed = parse_websocket_multiaddr(ma) + assert parsed.is_wss, f"Address {addr_str} should be WSS" + + +@pytest.mark.trio +async def test_wss_connection_handling(): + """Test WSS connection handling with security flag.""" + # upgrader = create_upgrader() # Not used in this test + # transport = WebsocketTransport(upgrader) # Not used in this test + + # Test that WSS connections are marked as secure + wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss") + parsed = parse_websocket_multiaddr(wss_maddr) + assert parsed.is_wss + + # Test that WS connections are not marked as secure + ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + parsed = parse_websocket_multiaddr(ws_maddr) + assert not parsed.is_wss + + +def test_wss_error_handling(): + """Test WSS error handling for invalid configurations.""" + # upgrader = create_upgrader() # Not used in this test + + # Test invalid multiaddr formats + invalid_addresses = [ + "/ip4/127.0.0.1/tcp/8080", # No WebSocket protocol + "/ip4/127.0.0.1/wss", # No TCP + "/tcp/8080/wss", # No network protocol + ] + + for addr_str in invalid_addresses: + ma = Multiaddr(addr_str) + assert not is_valid_websocket_multiaddr(ma), ( + f"Address {addr_str} should be invalid" + ) + + # Should raise ValueError when parsing invalid addresses + with pytest.raises(ValueError): + parse_websocket_multiaddr(ma) + + +@pytest.mark.trio +async def test_handshake_timeout(): + """Test WebSocket handshake timeout functionality.""" + upgrader = create_upgrader() + + # Test creating transport with custom handshake timeout + transport = WebsocketTransport(upgrader, handshake_timeout=0.1) # 100ms timeout + assert transport._handshake_timeout == 0.1 + + # Test that the timeout is passed to the listener + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + # Type assertion to access private attribute for testing + assert hasattr(listener, "_handshake_timeout") + assert getattr(listener, "_handshake_timeout") == 0.1 + + +@pytest.mark.trio +async def test_handshake_timeout_creation(): + """Test handshake timeout in transport creation.""" + upgrader = create_upgrader() + + # Test creating transport with handshake timeout via create_transport + from libp2p.transport import create_transport + + transport = create_transport("ws", upgrader, handshake_timeout=5.0) + # Type assertion to access private attribute for testing + assert hasattr(transport, "_handshake_timeout") + assert getattr(transport, "_handshake_timeout") == 5.0 + + # Test default timeout + transport_default = create_transport("ws", upgrader) + assert hasattr(transport_default, "_handshake_timeout") + assert getattr(transport_default, "_handshake_timeout") == 15.0 + + +@pytest.mark.trio +async def test_connection_state_tracking(): + """Test WebSocket connection state tracking.""" + from libp2p.transport.websocket.connection import P2PWebSocketConnection + + # Create a mock WebSocket connection + class MockWebSocketConnection: + async def send_message(self, data: bytes) -> None: + pass + + async def get_message(self) -> bytes: + return b"test message" + + async def aclose(self) -> None: + pass + + mock_ws = MockWebSocketConnection() + conn = P2PWebSocketConnection(mock_ws, is_secure=True) + + # Test initial state + state = conn.conn_state() + assert state["transport"] == "websocket" + assert state["secure"] is True + assert state["bytes_read"] == 0 + assert state["bytes_written"] == 0 + assert state["total_bytes"] == 0 + assert state["connection_duration"] >= 0 + + # Test byte tracking (we can't actually read/write with mock, but we can test + # the method) + # The actual byte tracking will be tested in integration tests + assert hasattr(conn, "_bytes_read") + assert hasattr(conn, "_bytes_written") + assert hasattr(conn, "_connection_start_time") + + +@pytest.mark.trio +async def test_concurrent_close_handling(): + """Test concurrent close handling similar to Go implementation.""" + from libp2p.transport.websocket.connection import P2PWebSocketConnection + + # Create a mock WebSocket connection that tracks close calls + class MockWebSocketConnection: + def __init__(self): + self.close_calls = 0 + self.closed = False + + async def send_message(self, data: bytes) -> None: + if self.closed: + raise Exception("Connection closed") + pass + + async def get_message(self) -> bytes: + if self.closed: + raise Exception("Connection closed") + return b"test message" + + async def aclose(self) -> None: + self.close_calls += 1 + self.closed = True + + mock_ws = MockWebSocketConnection() + conn = P2PWebSocketConnection(mock_ws, is_secure=False) + + # Test that multiple close calls are handled gracefully + await conn.close() + await conn.close() # Second close should not raise an error + + # The mock should only be closed once + assert mock_ws.close_calls == 1 + assert mock_ws.closed is True + + +@pytest.mark.trio +async def test_zero_byte_write_handling(): + """Test zero-byte write handling similar to Go implementation.""" + from libp2p.transport.websocket.connection import P2PWebSocketConnection + + # Create a mock WebSocket connection that tracks write calls + class MockWebSocketConnection: + def __init__(self): + self.write_calls = [] + + async def send_message(self, data: bytes) -> None: + self.write_calls.append(len(data)) + + async def get_message(self) -> bytes: + return b"test message" + + async def aclose(self) -> None: + pass + + mock_ws = MockWebSocketConnection() + conn = P2PWebSocketConnection(mock_ws, is_secure=False) + + # Test zero-byte write + await conn.write(b"") + assert 0 in mock_ws.write_calls + + # Test normal write + await conn.write(b"hello") + assert 5 in mock_ws.write_calls + + # Test multiple zero-byte writes + for _ in range(10): + await conn.write(b"") + + # Should have 11 zero-byte writes total (1 initial + 10 in loop) + zero_byte_writes = [call for call in mock_ws.write_calls if call == 0] + assert len(zero_byte_writes) == 11 + + +@pytest.mark.trio +async def test_websocket_transport_protocols(): + """Test that WebSocket transport reports correct protocols.""" + # upgrader = create_upgrader() # Not used in this test + # transport = WebsocketTransport(upgrader) # Not used in this test + + # Test that the transport can handle both WS and WSS protocols + ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") + wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss") + + # Both should be valid WebSocket multiaddrs + assert is_valid_websocket_multiaddr(ws_maddr) + assert is_valid_websocket_multiaddr(wss_maddr) + + # Both should be parseable + ws_parsed = parse_websocket_multiaddr(ws_maddr) + wss_parsed = parse_websocket_multiaddr(wss_maddr) + + assert not ws_parsed.is_wss + assert wss_parsed.is_wss + + +@pytest.mark.trio +async def test_websocket_listener_addr_format(): + """Test WebSocket listener address format similar to Go implementation.""" + upgrader = create_upgrader() + + # Test WS listener + transport_ws = WebsocketTransport(upgrader) + + async def dummy_handler_ws(conn): + await trio.sleep(0) + + listener_ws = transport_ws.create_listener(dummy_handler_ws) + # Type assertion to access private attribute for testing + assert hasattr(listener_ws, "_handshake_timeout") + assert getattr(listener_ws, "_handshake_timeout") == 15.0 # Default timeout + + # Test WSS listener with TLS config + import ssl + + tls_config = ssl.create_default_context() + transport_wss = WebsocketTransport(upgrader, tls_server_config=tls_config) + + async def dummy_handler_wss(conn): + await trio.sleep(0) + + listener_wss = transport_wss.create_listener(dummy_handler_wss) + # Type assertion to access private attributes for testing + assert hasattr(listener_wss, "_tls_config") + assert getattr(listener_wss, "_tls_config") is not None + assert hasattr(listener_wss, "_handshake_timeout") + assert getattr(listener_wss, "_handshake_timeout") == 15.0 + + +@pytest.mark.trio +async def test_sni_resolution_limitation(): + """ + Test SNI resolution limitation - Python multiaddr library doesn't support + SNI protocol. + """ + upgrader = create_upgrader() + transport = WebsocketTransport(upgrader) + + # Test that WSS addresses are returned unchanged (SNI resolution not supported) + wss_maddr = Multiaddr("/dns/example.com/tcp/1234/wss") + resolved = transport.resolve(wss_maddr) + assert len(resolved) == 1 + assert resolved[0] == wss_maddr + + # Test that non-WSS addresses are returned unchanged + ws_maddr = Multiaddr("/dns/example.com/tcp/1234/ws") + resolved = transport.resolve(ws_maddr) + assert len(resolved) == 1 + assert resolved[0] == ws_maddr + + # Test that IP addresses are returned unchanged + ip_maddr = Multiaddr("/ip4/127.0.0.1/tcp/1234/wss") + resolved = transport.resolve(ip_maddr) + assert len(resolved) == 1 + assert resolved[0] == ip_maddr + + +@pytest.mark.trio +async def test_websocket_transport_can_dial(): + """Test WebSocket transport CanDial functionality similar to Go implementation.""" + # upgrader = create_upgrader() # Not used in this test + # transport = WebsocketTransport(upgrader) # Not used in this test + + # Test valid WebSocket addresses that should be dialable + valid_addresses = [ + "/ip4/127.0.0.1/tcp/5555/ws", + "/ip4/127.0.0.1/tcp/5555/wss", + "/ip4/127.0.0.1/tcp/5555/tls/ws", + # Note: SNI addresses not supported by Python multiaddr library + ] + + for addr_str in valid_addresses: + maddr = Multiaddr(addr_str) + # All these should be valid WebSocket multiaddrs + assert is_valid_websocket_multiaddr(maddr), ( + f"Address {addr_str} should be valid" + ) + + # Test invalid addresses that should not be dialable + invalid_addresses = [ + "/ip4/127.0.0.1/tcp/5555", # No WebSocket protocol + "/ip4/127.0.0.1/udp/5555/ws", # Wrong transport protocol + ] + + for addr_str in invalid_addresses: + maddr = Multiaddr(addr_str) + # These should not be valid WebSocket multiaddrs + assert not is_valid_websocket_multiaddr(maddr), ( + f"Address {addr_str} should be invalid" + ) diff --git a/tests/core/transport/test_websocket_p2p.py b/tests/core/transport/test_websocket_p2p.py new file mode 100644 index 00000000..2744bb34 --- /dev/null +++ b/tests/core/transport/test_websocket_p2p.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python3 +""" +Python-to-Python WebSocket peer-to-peer tests. + +This module tests real WebSocket communication between two Python libp2p hosts, +including both WS and WSS (WebSocket Secure) scenarios. +""" + +import pytest +from multiaddr import Multiaddr + +from libp2p import create_yamux_muxer_option, new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.crypto.x25519 import create_new_key_pair as create_new_x25519_key_pair +from libp2p.custom_types import TProtocol +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, +) +from libp2p.transport.websocket.multiaddr_utils import ( + is_valid_websocket_multiaddr, + parse_websocket_multiaddr, +) + +PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0") +PING_LENGTH = 32 + + +@pytest.mark.trio +async def test_websocket_p2p_plaintext(): + """Test Python-to-Python WebSocket communication with plaintext security.""" + # Create two hosts with plaintext security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + + # Host A (listener) - use only plaintext security + security_options_a = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], + ) + + # Host B (dialer) - use only plaintext security + security_options_b = { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport + ) + + # Test data + test_data = b"Hello WebSocket P2P!" + received_data = None + + # Set up ping handler on host A + async def ping_handler(stream): + nonlocal received_data + received_data = await stream.read(len(test_data)) + await stream.write(received_data) # Echo back + await stream.close() + + host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler) + + # Start both hosts + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert len(listen_addrs) > 0 + + # Find the WebSocket address + ws_addr = None + for addr in listen_addrs: + if "/ws" in str(addr): + ws_addr = addr + break + + assert ws_addr is not None, "No WebSocket listen address found" + assert is_valid_websocket_multiaddr(ws_addr), "Invalid WebSocket multiaddr" + + # Parse the WebSocket multiaddr + parsed = parse_websocket_multiaddr(ws_addr) + assert not parsed.is_wss, "Should be plain WebSocket, not WSS" + assert parsed.sni is None, "SNI should be None for plain WebSocket" + + # Connect host B to host A + from libp2p.peer.peerinfo import info_from_p2p_addr + + peer_info = info_from_p2p_addr(ws_addr) + await host_b.connect(peer_info) + + # Create stream and test communication + stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID]) + await stream.write(test_data) + response = await stream.read(len(test_data)) + await stream.close() + + # Verify communication + assert received_data == test_data, f"Expected {test_data}, got {received_data}" + assert response == test_data, f"Expected echo {test_data}, got {response}" + + +@pytest.mark.trio +async def test_websocket_p2p_noise(): + """Test Python-to-Python WebSocket communication with Noise security.""" + # Create two hosts with Noise security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + noise_key_pair_a = create_new_x25519_key_pair() + noise_key_pair_b = create_new_x25519_key_pair() + + # Host A (listener) + security_options_a = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_a, + noise_privkey=noise_key_pair_a.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], + ) + + # Host B (dialer) + security_options_b = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_b, + noise_privkey=noise_key_pair_b.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport + ) + + # Test data + test_data = b"Hello WebSocket P2P with Noise!" + received_data = None + + # Set up ping handler on host A + async def ping_handler(stream): + nonlocal received_data + received_data = await stream.read(len(test_data)) + await stream.write(received_data) # Echo back + await stream.close() + + host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler) + + # Start both hosts + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert len(listen_addrs) > 0 + + # Find the WebSocket address + ws_addr = None + for addr in listen_addrs: + if "/ws" in str(addr): + ws_addr = addr + break + + assert ws_addr is not None, "No WebSocket listen address found" + assert is_valid_websocket_multiaddr(ws_addr), "Invalid WebSocket multiaddr" + + # Parse the WebSocket multiaddr + parsed = parse_websocket_multiaddr(ws_addr) + assert not parsed.is_wss, "Should be plain WebSocket, not WSS" + assert parsed.sni is None, "SNI should be None for plain WebSocket" + + # Connect host B to host A + from libp2p.peer.peerinfo import info_from_p2p_addr + + peer_info = info_from_p2p_addr(ws_addr) + await host_b.connect(peer_info) + + # Create stream and test communication + stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID]) + await stream.write(test_data) + response = await stream.read(len(test_data)) + await stream.close() + + # Verify communication + assert received_data == test_data, f"Expected {test_data}, got {received_data}" + assert response == test_data, f"Expected echo {test_data}, got {response}" + + +@pytest.mark.trio +async def test_websocket_p2p_libp2p_ping(): + """Test Python-to-Python WebSocket communication using libp2p ping protocol.""" + # Create two hosts with Noise security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + noise_key_pair_a = create_new_x25519_key_pair() + noise_key_pair_b = create_new_x25519_key_pair() + + # Host A (listener) + security_options_a = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_a, + noise_privkey=noise_key_pair_a.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], + ) + + # Host B (dialer) + security_options_b = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_b, + noise_privkey=noise_key_pair_b.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport + ) + + # Set up ping handler on host A (standard libp2p ping protocol) + async def ping_handler(stream): + # Read ping data (32 bytes) + ping_data = await stream.read(PING_LENGTH) + # Echo back the same data (pong) + await stream.write(ping_data) + await stream.close() + + host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler) + + # Start both hosts + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert len(listen_addrs) > 0 + + # Find the WebSocket address + ws_addr = None + for addr in listen_addrs: + if "/ws" in str(addr): + ws_addr = addr + break + + assert ws_addr is not None, "No WebSocket listen address found" + + # Connect host B to host A + from libp2p.peer.peerinfo import info_from_p2p_addr + + peer_info = info_from_p2p_addr(ws_addr) + await host_b.connect(peer_info) + + # Create stream and test libp2p ping protocol + stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID]) + + # Send ping (32 bytes as per libp2p ping protocol) + ping_data = b"\x01" * PING_LENGTH + await stream.write(ping_data) + + # Receive pong (should be same 32 bytes) + pong_data = await stream.read(PING_LENGTH) + await stream.close() + + # Verify ping-pong + assert pong_data == ping_data, ( + f"Expected ping {ping_data}, got pong {pong_data}" + ) + + +@pytest.mark.trio +async def test_websocket_p2p_multiple_streams(): + """ + Test Python-to-Python WebSocket communication with multiple concurrent + streams. + """ + # Create two hosts with Noise security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + noise_key_pair_a = create_new_x25519_key_pair() + noise_key_pair_b = create_new_x25519_key_pair() + + # Host A (listener) + security_options_a = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_a, + noise_privkey=noise_key_pair_a.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], + ) + + # Host B (dialer) + security_options_b = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_b, + noise_privkey=noise_key_pair_b.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport + ) + + # Test protocol + test_protocol = TProtocol("/test/multiple/streams/1.0.0") + received_data = [] + + # Set up handler on host A + async def test_handler(stream): + data = await stream.read(1024) + received_data.append(data) + await stream.write(data) # Echo back + await stream.close() + + host_a.set_stream_handler(test_protocol, test_handler) + + # Start both hosts + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + ws_addr = None + for addr in listen_addrs: + if "/ws" in str(addr): + ws_addr = addr + break + + assert ws_addr is not None, "No WebSocket listen address found" + + # Connect host B to host A + from libp2p.peer.peerinfo import info_from_p2p_addr + + peer_info = info_from_p2p_addr(ws_addr) + await host_b.connect(peer_info) + + # Create multiple concurrent streams + num_streams = 5 + test_data_list = [f"Stream {i} data".encode() for i in range(num_streams)] + + async def create_stream_and_test(stream_id: int, data: bytes): + stream = await host_b.new_stream(host_a.get_id(), [test_protocol]) + await stream.write(data) + response = await stream.read(len(data)) + await stream.close() + return response + + # Run all streams concurrently + tasks = [ + create_stream_and_test(i, test_data_list[i]) for i in range(num_streams) + ] + responses = [] + for task in tasks: + responses.append(await task) + + # Verify all communications + assert len(received_data) == num_streams, ( + f"Expected {num_streams} received messages, got {len(received_data)}" + ) + for i, (sent, received, response) in enumerate( + zip(test_data_list, received_data, responses) + ): + assert received == sent, f"Stream {i}: Expected {sent}, got {received}" + assert response == sent, f"Stream {i}: Expected echo {sent}, got {response}" + + +@pytest.mark.trio +async def test_websocket_p2p_connection_state(): + """Test WebSocket connection state tracking and metadata.""" + # Create two hosts with Noise security + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + noise_key_pair_a = create_new_x25519_key_pair() + noise_key_pair_b = create_new_x25519_key_pair() + + # Host A (listener) + security_options_a = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_a, + noise_privkey=noise_key_pair_a.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options_a, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], + ) + + # Host B (dialer) + security_options_b = { + NOISE_PROTOCOL_ID: NoiseTransport( + libp2p_keypair=key_pair_b, + noise_privkey=noise_key_pair_b.private_key, + early_data=None, + with_noise_pipes=False, + ) + } + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options_b, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport + ) + + # Set up handler on host A + async def test_handler(stream): + # Read some data + await stream.read(1024) + # Write some data back + await stream.write(b"Response data") + await stream.close() + + host_a.set_stream_handler(PING_PROTOCOL_ID, test_handler) + + # Start both hosts + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + ws_addr = None + for addr in listen_addrs: + if "/ws" in str(addr): + ws_addr = addr + break + + assert ws_addr is not None, "No WebSocket listen address found" + + # Connect host B to host A + from libp2p.peer.peerinfo import info_from_p2p_addr + + peer_info = info_from_p2p_addr(ws_addr) + await host_b.connect(peer_info) + + # Create stream and test communication + stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID]) + await stream.write(b"Test data for connection state") + response = await stream.read(1024) + await stream.close() + + # Verify response + assert response == b"Response data", f"Expected 'Response data', got {response}" + + # Test connection state (if available) + # Note: This tests the connection state tracking we implemented + connections = host_b.get_network().connections + assert len(connections) > 0, "Should have at least one connection" + + # Get the connection to host A + conn_to_a = None + for peer_id, conn_list in connections.items(): + if peer_id == host_a.get_id(): + # connections maps peer_id to list of connections, get the first one + conn_to_a = conn_list[0] if conn_list else None + break + + assert conn_to_a is not None, "Should have connection to host A" + + # Test that the connection has the expected properties + assert hasattr(conn_to_a, "muxed_conn"), "Connection should have muxed_conn" + assert hasattr(conn_to_a.muxed_conn, "secured_conn"), ( + "Muxed connection should have underlying secured_conn" + ) + + # If the underlying connection is our WebSocket connection, test its state + # Type assertion to access private attribute for testing + underlying_conn = getattr(conn_to_a.muxed_conn, "secured_conn") + if hasattr(underlying_conn, "conn_state"): + state = underlying_conn.conn_state() + assert "connection_start_time" in state, ( + "Connection state should include start time" + ) + assert "bytes_read" in state, "Connection state should include bytes read" + assert "bytes_written" in state, ( + "Connection state should include bytes written" + ) + assert state["bytes_read"] > 0, "Should have read some bytes" + assert state["bytes_written"] > 0, "Should have written some bytes" diff --git a/tests/examples/test_examples_bind_address.py b/tests/examples/test_examples_bind_address.py new file mode 100644 index 00000000..c0dd9de3 --- /dev/null +++ b/tests/examples/test_examples_bind_address.py @@ -0,0 +1,117 @@ +""" +Tests to verify that all examples use the new address paradigm consistently +""" + +from pathlib import Path + + +class TestExamplesAddressParadigm: + """Test suite to verify all examples use the new address paradigm consistently""" + + def get_example_files(self): + """Get all Python files in the examples directory""" + examples_dir = Path("examples") + return list(examples_dir.rglob("*.py")) + + def check_file_for_wildcard_binding(self, filepath): + """Check if a file contains 0.0.0.0 binding""" + with open(filepath, encoding="utf-8") as f: + content = f.read() + + # Check for various forms of wildcard binding + wildcard_patterns = [ + "0.0.0.0", + "/ip4/0.0.0.0/", + ] + + found_wildcards = [] + for line_num, line in enumerate(content.splitlines(), 1): + for pattern in wildcard_patterns: + if pattern in line and not line.strip().startswith("#"): + found_wildcards.append((line_num, line.strip())) + + return found_wildcards + + def test_examples_use_address_paradigm(self): + """Test that examples use the new address paradigm functions""" + example_files = self.get_example_files() + + # Files that should use the new paradigm + networking_examples = [ + "echo/echo.py", + "chat/chat.py", + "ping/ping.py", + "bootstrap/bootstrap.py", + "pubsub/pubsub.py", + "identify/identify.py", + ] + + paradigm_functions = [ + "get_available_interfaces", + "get_optimal_binding_address", + ] + + for filename in networking_examples: + filepath = None + for example_file in example_files: + if filename in str(example_file): + filepath = example_file + break + + if filepath is None: + continue + + with open(filepath, encoding="utf-8") as f: + content = f.read() + + # Check that the file uses the new paradigm functions + for func in paradigm_functions: + assert func in content, ( + f"{filepath} should use {func} from the new address paradigm" + ) + + def test_wildcard_available_as_feature(self): + """Test that wildcard is available as a feature when needed""" + example_files = self.get_example_files() + + # Check that network_discover.py demonstrates wildcard usage + network_discover_file = None + for example_file in example_files: + if "network_discover.py" in str(example_file): + network_discover_file = example_file + break + + if network_discover_file: + with open(network_discover_file, encoding="utf-8") as f: + content = f.read() + + # Should demonstrate wildcard expansion + assert "0.0.0.0" in content, ( + f"{network_discover_file} should demonstrate wildcard usage" + ) + assert "expand_wildcard_address" in content, ( + f"{network_discover_file} should use expand_wildcard_address" + ) + + def test_doc_examples_use_paradigm(self): + """Test that documentation examples use the new address paradigm""" + doc_examples_dir = Path("examples/doc-examples") + if not doc_examples_dir.exists(): + return + + doc_example_files = list(doc_examples_dir.glob("*.py")) + + paradigm_functions = [ + "get_available_interfaces", + "get_optimal_binding_address", + ] + + for filepath in doc_example_files: + with open(filepath, encoding="utf-8") as f: + content = f.read() + + # Check that doc examples use the new paradigm + for func in paradigm_functions: + assert func in content, ( + f"Documentation example {filepath} should use {func}" + ) diff --git a/tests/examples/test_quic_echo_example.py b/tests/examples/test_quic_echo_example.py new file mode 100644 index 00000000..fc843f4b --- /dev/null +++ b/tests/examples/test_quic_echo_example.py @@ -0,0 +1,6 @@ +def test_echo_quic_example(): + """Test that the QUIC echo example can be imported and has required functions.""" + from examples.echo import echo_quic + + assert hasattr(echo_quic, "main") + assert hasattr(echo_quic, "run") diff --git a/tests/interop/__init__.py b/tests/interop/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/interop/js_libp2p/js_node/src/package.json b/tests/interop/js_libp2p/js_node/src/package.json new file mode 100644 index 00000000..d1e17d28 --- /dev/null +++ b/tests/interop/js_libp2p/js_node/src/package.json @@ -0,0 +1,21 @@ +{ + "name": "src", + "version": "1.0.0", + "main": "ping.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "description": "", + "dependencies": { + "@chainsafe/libp2p-noise": "^9.0.0", + "@chainsafe/libp2p-yamux": "^5.0.1", + "@libp2p/ping": "^2.0.36", + "@libp2p/plaintext": "^2.0.29", + "@libp2p/websockets": "^9.2.18", + "libp2p": "^2.9.0", + "multiaddr": "^10.0.1" + } +} diff --git a/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs b/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs new file mode 100644 index 00000000..3951fc02 --- /dev/null +++ b/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs @@ -0,0 +1,122 @@ +import { createLibp2p } from 'libp2p' +import { webSockets } from '@libp2p/websockets' +import { ping } from '@libp2p/ping' +import { noise } from '@chainsafe/libp2p-noise' +import { plaintext } from '@libp2p/plaintext' +import { yamux } from '@chainsafe/libp2p-yamux' +// import { identify } from '@libp2p/identify' // Commented out for compatibility + +// Configuration from environment (with defaults for compatibility) +const TRANSPORT = process.env.transport || 'ws' +const SECURITY = process.env.security || 'noise' +const MUXER = process.env.muxer || 'yamux' +const IP = process.env.ip || '0.0.0.0' + +async function main() { + console.log(`šŸ”§ Configuration: transport=${TRANSPORT}, security=${SECURITY}, muxer=${MUXER}`) + + // Build options following the proven pattern from test-plans-fork + const options = { + start: true, + connectionGater: { + denyDialMultiaddr: async () => false + }, + connectionMonitor: { + enabled: false + }, + services: { + ping: ping() + } + } + + // Transport configuration (following get-libp2p.ts pattern) + switch (TRANSPORT) { + case 'ws': + options.transports = [webSockets()] + options.addresses = { + listen: [`/ip4/${IP}/tcp/0/ws`] + } + break + case 'wss': + process.env.NODE_TLS_REJECT_UNAUTHORIZED = '0' + options.transports = [webSockets()] + options.addresses = { + listen: [`/ip4/${IP}/tcp/0/wss`] + } + break + default: + throw new Error(`Unknown transport: ${TRANSPORT}`) + } + + // Security configuration + switch (SECURITY) { + case 'noise': + options.connectionEncryption = [noise()] + break + case 'plaintext': + options.connectionEncryption = [plaintext()] + break + default: + throw new Error(`Unknown security: ${SECURITY}`) + } + + // Muxer configuration + switch (MUXER) { + case 'yamux': + options.streamMuxers = [yamux()] + break + default: + throw new Error(`Unknown muxer: ${MUXER}`) + } + + console.log('šŸ”§ Creating libp2p node with proven interop configuration...') + const node = await createLibp2p(options) + + await node.start() + + console.log(node.peerId.toString()) + for (const addr of node.getMultiaddrs()) { + console.log(addr.toString()) + } + + // Debug: Print supported protocols + console.log('DEBUG: Supported protocols:') + if (node.services && node.services.registrar) { + const protocols = node.services.registrar.getProtocols() + for (const protocol of protocols) { + console.log('DEBUG: Protocol:', protocol) + } + } + + // Debug: Print connection encryption protocols + console.log('DEBUG: Connection encryption protocols:') + try { + if (node.components && node.components.connectionEncryption) { + for (const encrypter of node.components.connectionEncryption) { + console.log('DEBUG: Encrypter:', encrypter.protocol) + } + } + } catch (e) { + console.log('DEBUG: Could not access connectionEncryption:', e.message) + } + + // Debug: Print stream muxer protocols + console.log('DEBUG: Stream muxer protocols:') + try { + if (node.components && node.components.streamMuxers) { + for (const muxer of node.components.streamMuxers) { + console.log('DEBUG: Muxer:', muxer.protocol) + } + } + } catch (e) { + console.log('DEBUG: Could not access streamMuxers:', e.message) + } + + // Keep the process alive + await new Promise(() => {}) +} + +main().catch(err => { + console.error(err) + process.exit(1) +}) diff --git a/tests/interop/nim_libp2p/.gitignore b/tests/interop/nim_libp2p/.gitignore new file mode 100644 index 00000000..7bcc01ea --- /dev/null +++ b/tests/interop/nim_libp2p/.gitignore @@ -0,0 +1,8 @@ +nimble.develop +nimble.paths + +*.nimble +nim-libp2p/ + +nim_echo_server +config.nims diff --git a/tests/interop/nim_libp2p/conftest.py b/tests/interop/nim_libp2p/conftest.py new file mode 100644 index 00000000..5765a09d --- /dev/null +++ b/tests/interop/nim_libp2p/conftest.py @@ -0,0 +1,119 @@ +import fcntl +import logging +from pathlib import Path +import shutil +import subprocess +import time + +import pytest + +logger = logging.getLogger(__name__) + + +def check_nim_available(): + """Check if nim compiler is available.""" + return shutil.which("nim") is not None and shutil.which("nimble") is not None + + +def check_nim_binary_built(): + """Check if nim echo server binary is built.""" + current_dir = Path(__file__).parent + binary_path = current_dir / "nim_echo_server" + return binary_path.exists() and binary_path.stat().st_size > 0 + + +def run_nim_setup_with_lock(): + """Run nim setup with file locking to prevent parallel execution.""" + current_dir = Path(__file__).parent + lock_file = current_dir / ".setup_lock" + setup_script = current_dir / "scripts" / "setup_nim_echo.sh" + + if not setup_script.exists(): + raise RuntimeError(f"Setup script not found: {setup_script}") + + # Try to acquire lock + try: + with open(lock_file, "w") as f: + # Non-blocking lock attempt + fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + + # Double-check binary doesn't exist (another worker might have built it) + if check_nim_binary_built(): + logger.info("Binary already exists, skipping setup") + return + + logger.info("Acquired setup lock, running nim-libp2p setup...") + + # Make setup script executable and run it + setup_script.chmod(0o755) + result = subprocess.run( + [str(setup_script)], + cwd=current_dir, + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + + if result.returncode != 0: + raise RuntimeError( + f"Setup failed (exit {result.returncode}):\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}" + ) + + # Verify binary was built + if not check_nim_binary_built(): + raise RuntimeError("nim_echo_server binary not found after setup") + + logger.info("nim-libp2p setup completed successfully") + + except BlockingIOError: + # Another worker is running setup, wait for it to complete + logger.info("Another worker is running setup, waiting...") + + # Wait for setup to complete (check every 2 seconds, max 5 minutes) + for _ in range(150): # 150 * 2 = 300 seconds = 5 minutes + if check_nim_binary_built(): + logger.info("Setup completed by another worker") + return + time.sleep(2) + + raise TimeoutError("Timed out waiting for setup to complete") + + finally: + # Clean up lock file + try: + lock_file.unlink(missing_ok=True) + except Exception: + pass + + +@pytest.fixture(scope="function") # Changed to function scope +def nim_echo_binary(): + """Get nim echo server binary path.""" + current_dir = Path(__file__).parent + binary_path = current_dir / "nim_echo_server" + + if not binary_path.exists(): + pytest.skip( + "nim_echo_server binary not found. " + "Run setup script: ./scripts/setup_nim_echo.sh" + ) + + return binary_path + + +@pytest.fixture +async def nim_server(nim_echo_binary): + """Start and stop nim echo server for tests.""" + # Import here to avoid circular imports + # pyrefly: ignore + from test_echo_interop import NimEchoServer + + server = NimEchoServer(nim_echo_binary) + + try: + peer_id, listen_addr = await server.start() + yield server, peer_id, listen_addr + finally: + await server.stop() diff --git a/tests/interop/nim_libp2p/nim_echo_server.nim b/tests/interop/nim_libp2p/nim_echo_server.nim new file mode 100644 index 00000000..a4f581d9 --- /dev/null +++ b/tests/interop/nim_libp2p/nim_echo_server.nim @@ -0,0 +1,108 @@ +{.used.} + +import chronos +import stew/byteutils +import libp2p + +## +# Simple Echo Protocol Implementation for py-libp2p Interop Testing +## +const EchoCodec = "/echo/1.0.0" + +type EchoProto = ref object of LPProtocol + +proc new(T: typedesc[EchoProto]): T = + proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} = + try: + echo "Echo server: Received connection from ", conn.peerId + + # Read and echo messages in a loop + while not conn.atEof: + try: + # Read length-prefixed message using nim-libp2p's readLp + let message = await conn.readLp(1024 * 1024) # Max 1MB + if message.len == 0: + echo "Echo server: Empty message, closing connection" + break + + let messageStr = string.fromBytes(message) + echo "Echo server: Received (", message.len, " bytes): ", messageStr + + # Echo back using writeLp + await conn.writeLp(message) + echo "Echo server: Echoed message back" + + except CatchableError as e: + echo "Echo server: Error processing message: ", e.msg + break + + except CancelledError as e: + echo "Echo server: Connection cancelled" + raise e + except CatchableError as e: + echo "Echo server: Exception in handler: ", e.msg + finally: + echo "Echo server: Connection closed" + await conn.close() + + return T.new(codecs = @[EchoCodec], handler = handle) + +## +# Create QUIC-enabled switch +## +proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch = + var switch = SwitchBuilder + .new() + .withRng(rng) + .withAddress(ma) + .withQuicTransport() + .build() + result = switch + +## +# Main server +## +proc main() {.async.} = + let + rng = newRng() + localAddr = MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet() + echoProto = EchoProto.new() + + echo "=== Nim Echo Server for py-libp2p Interop ===" + + # Create switch + let switch = createSwitch(localAddr, rng) + switch.mount(echoProto) + + # Start server + await switch.start() + + # Print connection info + echo "Peer ID: ", $switch.peerInfo.peerId + echo "Listening on:" + for addr in switch.peerInfo.addrs: + echo " ", $addr, "/p2p/", $switch.peerInfo.peerId + echo "Protocol: ", EchoCodec + echo "Ready for py-libp2p connections!" + echo "" + + # Keep running + try: + await sleepAsync(100.hours) + except CancelledError: + echo "Shutting down..." + finally: + await switch.stop() + +# Graceful shutdown handler +proc signalHandler() {.noconv.} = + echo "\nShutdown signal received" + quit(0) + +when isMainModule: + setControlCHook(signalHandler) + try: + waitFor(main()) + except CatchableError as e: + echo "Error: ", e.msg + quit(1) diff --git a/tests/interop/nim_libp2p/scripts/setup_nim_echo.sh b/tests/interop/nim_libp2p/scripts/setup_nim_echo.sh new file mode 100755 index 00000000..f80b2d27 --- /dev/null +++ b/tests/interop/nim_libp2p/scripts/setup_nim_echo.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +# tests/interop/nim_libp2p/scripts/setup_nim_echo.sh +# Cache-aware setup that skips installation if packages exist + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="${SCRIPT_DIR}/.." + +# Colors +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +main() { + log_info "Setting up nim echo server for interop testing..." + + # Check if nim is available + if ! command -v nim &> /dev/null || ! command -v nimble &> /dev/null; then + log_error "Nim not found. Please install nim first." + exit 1 + fi + + cd "${PROJECT_DIR}" + + # Create logs directory + mkdir -p logs + + # Check if binary already exists + if [[ -f "nim_echo_server" ]]; then + log_info "nim_echo_server already exists, skipping build" + return 0 + fi + + # Check if libp2p is already installed (cache-aware) + if nimble list -i | grep -q "libp2p"; then + log_info "libp2p already installed, skipping installation" + else + log_info "Installing nim-libp2p globally..." + nimble install -y libp2p + fi + + log_info "Building nim echo server..." + # Compile the echo server + nim c \ + -d:release \ + -d:chronicles_log_level=INFO \ + -d:libp2p_quic_support \ + -d:chronos_event_loop=iocp \ + -d:ssl \ + --opt:speed \ + --mm:orc \ + --verbosity:1 \ + -o:nim_echo_server \ + nim_echo_server.nim + + # Verify binary was created + if [[ -f "nim_echo_server" ]]; then + log_info "āœ… nim_echo_server built successfully" + log_info "Binary size: $(ls -lh nim_echo_server | awk '{print $5}')" + else + log_error "āŒ Failed to build nim_echo_server" + exit 1 + fi + + log_info "šŸŽ‰ Setup complete!" +} + +main "$@" diff --git a/tests/interop/nim_libp2p/test_echo_interop.py b/tests/interop/nim_libp2p/test_echo_interop.py new file mode 100644 index 00000000..8e2b3e33 --- /dev/null +++ b/tests/interop/nim_libp2p/test_echo_interop.py @@ -0,0 +1,195 @@ +import logging +from pathlib import Path +import subprocess +import time + +import pytest +import multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.utils.varint import encode_varint_prefixed, read_varint_prefixed_bytes + +# Configuration +PROTOCOL_ID = TProtocol("/echo/1.0.0") +TEST_TIMEOUT = 30 +SERVER_START_TIMEOUT = 10.0 + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class NimEchoServer: + """Simple nim echo server manager.""" + + def __init__(self, binary_path: Path): + self.binary_path = binary_path + self.process: None | subprocess.Popen = None + self.peer_id = None + self.listen_addr = None + + async def start(self): + """Start nim echo server and get connection info.""" + logger.info(f"Starting nim echo server: {self.binary_path}") + + self.process = subprocess.Popen( + [str(self.binary_path)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + ) + + # Parse output for connection info + start_time = time.time() + while time.time() - start_time < SERVER_START_TIMEOUT: + if self.process and self.process.poll() and self.process.stdout: + output = self.process.stdout.read() + raise RuntimeError(f"Server exited early: {output}") + + reader = self.process.stdout if self.process else None + if reader: + line = reader.readline().strip() + if not line: + continue + + logger.info(f"Server: {line}") + + if line.startswith("Peer ID:"): + self.peer_id = line.split(":", 1)[1].strip() + + elif "/quic-v1/p2p/" in line and self.peer_id: + if line.strip().startswith("/"): + self.listen_addr = line.strip() + logger.info(f"Server ready: {self.listen_addr}") + return self.peer_id, self.listen_addr + + await self.stop() + raise TimeoutError(f"Server failed to start within {SERVER_START_TIMEOUT}s") + + async def stop(self): + """Stop the server.""" + if self.process: + logger.info("Stopping nim echo server...") + try: + self.process.terminate() + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + self.process = None + + +async def run_echo_test(server_addr: str, messages: list[str]): + """Test echo protocol against nim server with proper timeout handling.""" + # Create py-libp2p QUIC client with shorter timeouts + + host = new_host( + enable_quic=True, + key_pair=create_new_key_pair(), + ) + + listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/udp/0/quic-v1") + responses = [] + + try: + async with host.run(listen_addrs=[listen_addr]): + logger.info(f"Connecting to nim server: {server_addr}") + + # Connect to nim server + maddr = multiaddr.Multiaddr(server_addr) + info = info_from_p2p_addr(maddr) + await host.connect(info) + + # Create stream + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + logger.info("Stream created") + + # Test each message + for i, message in enumerate(messages, 1): + logger.info(f"Testing message {i}: {message}") + + # Send with varint length prefix + data = message.encode("utf-8") + prefixed_data = encode_varint_prefixed(data) + await stream.write(prefixed_data) + + # Read response + response_data = await read_varint_prefixed_bytes(stream) + response = response_data.decode("utf-8") + + logger.info(f"Got echo: {response}") + responses.append(response) + + # Verify echo + assert message == response, ( + f"Echo failed: sent {message!r}, got {response!r}" + ) + + await stream.close() + logger.info("āœ… All messages echoed correctly") + + finally: + await host.close() + + return responses + + +@pytest.mark.trio +@pytest.mark.timeout(TEST_TIMEOUT) +async def test_basic_echo_interop(nim_server): + """Test basic echo functionality between py-libp2p and nim-libp2p.""" + server, peer_id, listen_addr = nim_server + + test_messages = [ + "Hello from py-libp2p!", + "QUIC transport working", + "Echo test successful!", + "Unicode: ƑoĆ«l, 测试, Ψυχή", + ] + + logger.info(f"Testing against nim server: {peer_id}") + + # Run test with timeout + with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup + responses = await run_echo_test(listen_addr, test_messages) + + # Verify all messages echoed correctly + assert len(responses) == len(test_messages) + for sent, received in zip(test_messages, responses): + assert sent == received + + logger.info("āœ… Basic echo interop test passed!") + + +@pytest.mark.trio +@pytest.mark.timeout(TEST_TIMEOUT) +async def test_large_message_echo(nim_server): + """Test echo with larger messages.""" + server, peer_id, listen_addr = nim_server + + large_messages = [ + "x" * 1024, + "y" * 5000, + ] + + logger.info("Testing large message echo...") + + # Run test with timeout + with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup + responses = await run_echo_test(listen_addr, large_messages) + + assert len(responses) == len(large_messages) + for sent, received in zip(large_messages, responses): + assert sent == received + + logger.info("āœ… Large message echo test passed!") + + +if __name__ == "__main__": + # Run tests directly + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/interop/test_js_ws_ping.py b/tests/interop/test_js_ws_ping.py new file mode 100644 index 00000000..35819a86 --- /dev/null +++ b/tests/interop/test_js_ws_ping.py @@ -0,0 +1,127 @@ +import os +import signal +import subprocess + +import pytest +from multiaddr import Multiaddr +import trio +from trio.lowlevel import open_process + +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.host.basic_host import BasicHost +from libp2p.network.exceptions import SwarmException +from libp2p.network.swarm import Swarm +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo +from libp2p.peer.peerstore import PeerStore +from libp2p.security.insecure.transport import InsecureTransport +from libp2p.stream_muxer.yamux.yamux import Yamux +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.transport import WebsocketTransport + +PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0" + + +@pytest.mark.trio +async def test_ping_with_js_node(): + # Skip this test due to JavaScript dependency issues + pytest.skip("Skipping JS interop test due to dependency issues") + js_node_dir = os.path.join(os.path.dirname(__file__), "js_libp2p", "js_node", "src") + script_name = "./ws_ping_node.mjs" + + try: + subprocess.run( + ["npm", "install"], + cwd=js_node_dir, + check=True, + capture_output=True, + text=True, + ) + except (subprocess.CalledProcessError, FileNotFoundError) as e: + pytest.fail(f"Failed to run 'npm install': {e}") + + # Launch the JS libp2p node (long-running) + proc = await open_process( + ["node", script_name], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=js_node_dir, + ) + assert proc.stdout is not None, "stdout pipe missing" + assert proc.stderr is not None, "stderr pipe missing" + stdout = proc.stdout + stderr = proc.stderr + + try: + # Read first two lines (PeerID and multiaddr) + buffer = b"" + with trio.fail_after(30): + while buffer.count(b"\n") < 2: + chunk = await stdout.receive_some(1024) + if not chunk: + break + buffer += chunk + + lines = [line for line in buffer.decode().splitlines() if line.strip()] + if len(lines) < 2: + stderr_output = await stderr.receive_some(2048) + stderr_output = stderr_output.decode() + pytest.fail( + "JS node did not produce expected PeerID and multiaddr.\n" + f"Stdout: {buffer.decode()!r}\n" + f"Stderr: {stderr_output!r}" + ) + peer_id_line, addr_line = lines[0], lines[1] + peer_id = ID.from_base58(peer_id_line) + maddr = Multiaddr(addr_line) + + # Debug: Print what we're trying to connect to + print(f"JS Node Peer ID: {peer_id_line}") + print(f"JS Node Address: {addr_line}") + print(f"All JS Node lines: {lines}") + + # Set up Python host + key_pair = create_new_key_pair() + py_peer_id = ID.from_pubkey(key_pair.public_key) + peer_store = PeerStore() + peer_store.add_key_pair(py_peer_id, key_pair) + + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + transport = WebsocketTransport(upgrader) + swarm = Swarm(py_peer_id, peer_store, upgrader, transport) + host = BasicHost(swarm) + + # Connect to JS node + peer_info = PeerInfo(peer_id, [maddr]) + + print(f"Python trying to connect to: {peer_info}") + + # Use the host as a context manager + async with host.run(listen_addrs=[]): + await trio.sleep(1) + + try: + await host.connect(peer_info) + except SwarmException as e: + underlying_error = e.__cause__ + pytest.fail( + "Connection failed with SwarmException.\n" + f"THE REAL ERROR IS: {underlying_error!r}\n" + ) + + assert host.get_network().connections.get(peer_id) is not None + + # Ping protocol + stream = await host.new_stream(peer_id, [TProtocol("/ipfs/ping/1.0.0")]) + await stream.write(b"ping") + data = await stream.read(4) + assert data == b"pong" + finally: + proc.send_signal(signal.SIGTERM) + await trio.sleep(0) diff --git a/tests/utils/test_default_bind_address.py b/tests/utils/test_default_bind_address.py new file mode 100644 index 00000000..b0598b5a --- /dev/null +++ b/tests/utils/test_default_bind_address.py @@ -0,0 +1,206 @@ +""" +Tests for the new address paradigm with wildcard support as a feature +""" + +import pytest +from multiaddr import Multiaddr + +from libp2p import new_host +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + get_wildcard_address, +) + + +class TestAddressParadigm: + """ + Test suite for verifying the new address paradigm: + - get_available_interfaces() returns all available interfaces + - get_optimal_binding_address() returns optimal address for examples + - get_wildcard_address() provides wildcard as a feature when needed + """ + + def test_wildcard_address_function(self): + """Test that get_wildcard_address() provides wildcard as a feature""" + port = 8000 + addr = get_wildcard_address(port) + + # Should return wildcard address when explicitly requested + assert "0.0.0.0" in str(addr) + addr_str = str(addr) + assert "/ip4/" in addr_str + assert f"/tcp/{port}" in addr_str + + def test_optimal_binding_address_selection(self): + """Test that optimal binding address uses good heuristics""" + port = 8000 + addr = get_optimal_binding_address(port) + + # Should return a valid IP address (could be loopback or local network) + addr_str = str(addr) + assert "/ip4/" in addr_str + assert f"/tcp/{port}" in addr_str + + # Should be from available interfaces + available_interfaces = get_available_interfaces(port) + assert addr in available_interfaces + + def test_available_interfaces_includes_loopback(self): + """Test that available interfaces always includes loopback address""" + port = 8000 + interfaces = get_available_interfaces(port) + + # Should have at least one interface + assert len(interfaces) > 0 + + # Should include loopback address + loopback_found = any("127.0.0.1" in str(addr) for addr in interfaces) + assert loopback_found, "Loopback address not found in available interfaces" + + # Available interfaces should not include wildcard by default + # (wildcard is available as a feature through get_wildcard_address()) + wildcard_found = any("0.0.0.0" in str(addr) for addr in interfaces) + assert not wildcard_found, ( + "Wildcard should not be in default available interfaces" + ) + + def test_host_default_listen_address(self): + """Test that new hosts use secure default addresses""" + # Create a host with a specific port + port = 8000 + listen_addr = Multiaddr(f"/ip4/127.0.0.1/tcp/{port}") + host = new_host(listen_addrs=[listen_addr]) + + # Verify the host configuration + assert host is not None + # Note: We can't test actual binding without running the host, + # but we've verified the address format is correct + + def test_paradigm_consistency(self): + """Test that the address paradigm is consistent""" + port = 8000 + + # get_optimal_binding_address should return a valid address + optimal_addr = get_optimal_binding_address(port) + assert "/ip4/" in str(optimal_addr) + assert f"/tcp/{port}" in str(optimal_addr) + + # get_wildcard_address should return wildcard when explicitly needed + wildcard_addr = get_wildcard_address(port) + assert "0.0.0.0" in str(wildcard_addr) + assert f"/tcp/{port}" in str(wildcard_addr) + + # Both should be valid Multiaddr objects + assert isinstance(optimal_addr, Multiaddr) + assert isinstance(wildcard_addr, Multiaddr) + + @pytest.mark.parametrize("protocol", ["tcp", "udp"]) + def test_different_protocols_support(self, protocol): + """Test that different protocols are supported by the paradigm""" + port = 8000 + + # Test optimal address with different protocols + optimal_addr = get_optimal_binding_address(port, protocol=protocol) + assert protocol in str(optimal_addr) + assert f"/{protocol}/{port}" in str(optimal_addr) + + # Test wildcard address with different protocols + wildcard_addr = get_wildcard_address(port, protocol=protocol) + assert "0.0.0.0" in str(wildcard_addr) + assert protocol in str(wildcard_addr) + assert f"/{protocol}/{port}" in str(wildcard_addr) + + # Test available interfaces with different protocols + interfaces = get_available_interfaces(port, protocol=protocol) + assert len(interfaces) > 0 + for addr in interfaces: + assert protocol in str(addr) + + def test_wildcard_available_as_feature(self): + """Test that wildcard binding is available as a feature when needed""" + port = 8000 + + # Wildcard should be available through get_wildcard_address() + wildcard_addr = get_wildcard_address(port) + assert "0.0.0.0" in str(wildcard_addr) + + # But should not be in default available interfaces + interfaces = get_available_interfaces(port) + wildcard_in_interfaces = any("0.0.0.0" in str(addr) for addr in interfaces) + assert not wildcard_in_interfaces, ( + "Wildcard should not be in default interfaces" + ) + + # Optimal address should not be wildcard by default + optimal = get_optimal_binding_address(port) + assert "0.0.0.0" not in str(optimal), ( + "Optimal address should not be wildcard by default" + ) + + def test_loopback_is_always_available(self): + """Test that loopback address is always available as an option""" + port = 8000 + interfaces = get_available_interfaces(port) + + # Loopback should always be available + loopback_addrs = [addr for addr in interfaces if "127.0.0.1" in str(addr)] + assert len(loopback_addrs) > 0, "Loopback address should always be available" + + # At least one loopback address should have the correct port + loopback_with_port = [ + addr for addr in loopback_addrs if f"/tcp/{port}" in str(addr) + ] + assert len(loopback_with_port) > 0, ( + f"Loopback address with port {port} should be available" + ) + + def test_optimal_address_selection_behavior(self): + """Test that optimal address selection works correctly""" + port = 8000 + interfaces = get_available_interfaces(port) + optimal = get_optimal_binding_address(port) + + # Should return one of the available interfaces + optimal_str = str(optimal) + interface_strs = [str(addr) for addr in interfaces] + assert optimal_str in interface_strs, ( + f"Optimal address {optimal_str} should be in available interfaces" + ) + + # Should prefer non-loopback when available, fallback to loopback + non_loopback_interfaces = [ + addr for addr in interfaces if "127.0.0.1" not in str(addr) + ] + if non_loopback_interfaces: + # Should prefer non-loopback when available + assert "127.0.0.1" not in str(optimal), ( + "Should prefer non-loopback when available" + ) + else: + # Should use loopback when no other interfaces available + assert "127.0.0.1" in str(optimal), ( + "Should use loopback when no other interfaces available" + ) + + def test_address_paradigm_completeness(self): + """Test that the address paradigm provides all necessary functionality""" + port = 8000 + + # Test that we get interface options + interfaces = get_available_interfaces(port) + assert len(interfaces) >= 1, "Should have at least one interface" + + # Test that loopback is always included + has_loopback = any("127.0.0.1" in str(addr) for addr in interfaces) + assert has_loopback, "Loopback should always be available" + + # Test that wildcard is available as a feature + wildcard_addr = get_wildcard_address(port) + assert "0.0.0.0" in str(wildcard_addr) + + # Test optimal selection + optimal = get_optimal_binding_address(port) + assert optimal in interfaces, ( + "Optimal address should be from available interfaces" + )