diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index ef963f80..0658d2b3 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -36,10 +36,48 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} + + - name: Install Nim for interop testing + if: matrix.toxenv == 'interop' + run: | + echo "Installing Nim for nim-libp2p interop testing..." + curl -sSf https://nim-lang.org/choosenim/init.sh | sh -s -- -y --firstInstall + echo "$HOME/.nimble/bin" >> $GITHUB_PATH + echo "$HOME/.choosenim/toolchains/nim-stable/bin" >> $GITHUB_PATH + + - name: Cache nimble packages + if: matrix.toxenv == 'interop' + uses: actions/cache@v4 + with: + path: | + ~/.nimble + ~/.choosenim/toolchains/*/lib + key: ${{ runner.os }}-nimble-${{ hashFiles('**/nim_echo_server.nim') }} + restore-keys: | + ${{ runner.os }}-nimble- + + - name: Build nim interop binaries + if: matrix.toxenv == 'interop' + run: | + export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH" + cd tests/interop/nim_libp2p + ./scripts/setup_nim_echo.sh + - run: | python -m pip install --upgrade pip python -m pip install tox - - run: | + + - name: Run Tests or Generate Docs + run: | + if [[ "${{ matrix.toxenv }}" == 'docs' ]]; then + export TOXENV=docs + else + export TOXENV=py${{ matrix.python }}-${{ matrix.toxenv }} + fi + # Set PATH for nim commands during tox + if [[ "${{ matrix.toxenv }}" == 'interop' ]]; then + export PATH="$HOME/.nimble/bin:$HOME/.choosenim/toolchains/nim-stable/bin:$PATH" + fi python -m tox run -r windows: diff --git a/.gitignore b/.gitignore index e17714b5..88676f10 100644 --- a/.gitignore +++ b/.gitignore @@ -178,7 +178,12 @@ env.bak/ #lockfiles uv.lock poetry.lock + +# JavaScript interop test files tests/interop/js_libp2p/js_node/node_modules/ tests/interop/js_libp2p/js_node/package-lock.json tests/interop/js_libp2p/js_node/src/node_modules/ tests/interop/js_libp2p/js_node/src/package-lock.json + +# Sphinx documentation build +_build/ diff --git a/README.md b/README.md index 77166429..f87fbea6 100644 --- a/README.md +++ b/README.md @@ -61,12 +61,12 @@ ______________________________________________________________________ ### Discovery -| **Discovery** | **Status** | **Source** | -| -------------------- | :--------: | :--------------------------------------------------------------------------------: | -| **`bootstrap`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) | -| **`random-walk`** | 🌱 | | -| **`mdns-discovery`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) | -| **`rendezvous`** | 🌱 | | +| **Discovery** | **Status** | **Source** | +| -------------------- | :--------: | :----------------------------------------------------------------------------------: | +| **`bootstrap`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/bootstrap) | +| **`random-walk`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/random_walk) | +| **`mdns-discovery`** | āœ… | [source](https://github.com/libp2p/py-libp2p/tree/main/libp2p/discovery/mdns) | +| **`rendezvous`** | 🌱 | | ______________________________________________________________________ diff --git a/docs/examples.echo_quic.rst b/docs/examples.echo_quic.rst new file mode 100644 index 00000000..0e3313df --- /dev/null +++ b/docs/examples.echo_quic.rst @@ -0,0 +1,43 @@ +QUIC Echo Demo +============== + +This example demonstrates a simple ``echo`` protocol using **QUIC transport**. + +QUIC provides built-in TLS security and stream multiplexing over UDP, making it an excellent transport choice for libp2p applications. + +.. code-block:: console + + $ python -m pip install libp2p + Collecting libp2p + ... + Successfully installed libp2p-x.x.x + $ echo-quic-demo + Run this from the same folder in another console: + + echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmAsbxRR1HiGJRNVPQLNMeNsBCsXT3rDjoYBQzgzNpM5mJ + + Waiting for incoming connection... + +Copy the line that starts with ``echo-quic-demo -p 8001``, open a new terminal in the same +folder and paste it in: + +.. code-block:: console + + $ echo-quic-demo -d /ip4/127.0.0.1/udp/8000/quic-v1/p2p/16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu + + I am 16Uiu2HAmE3N7KauPTmHddYPsbMcBp2C6XAmprELX3YcFEN9iXiBu + STARTING CLIENT CONNECTION PROCESS + CLIENT CONNECTED TO SERVER + Sent: hi, there! + Got: ECHO: hi, there! + +**Key differences from TCP Echo:** + +- Uses UDP instead of TCP: ``/udp/8000`` instead of ``/tcp/8000`` +- Includes QUIC protocol identifier: ``/quic-v1`` in the multiaddr +- Built-in TLS security (no separate security transport needed) +- Native stream multiplexing over a single QUIC connection + +.. literalinclude:: ../examples/echo/echo_quic.py + :language: python + :linenos: diff --git a/docs/examples.multiple_connections.rst b/docs/examples.multiple_connections.rst new file mode 100644 index 00000000..85ab8f2d --- /dev/null +++ b/docs/examples.multiple_connections.rst @@ -0,0 +1,194 @@ +Multiple Connections Per Peer +============================= + +This example demonstrates how to use the multiple connections per peer feature in py-libp2p. + +Overview +-------- + +The multiple connections per peer feature allows a libp2p node to maintain multiple network connections to the same peer. This provides several benefits: + +- **Improved reliability**: If one connection fails, others remain available +- **Better performance**: Load can be distributed across multiple connections +- **Enhanced throughput**: Multiple streams can be created in parallel +- **Fault tolerance**: Redundant connections provide backup paths + +Configuration +------------- + +The feature is configured through the `ConnectionConfig` class: + +.. code-block:: python + + from libp2p.network.swarm import ConnectionConfig + + # Default configuration + config = ConnectionConfig() + print(f"Max connections per peer: {config.max_connections_per_peer}") + print(f"Load balancing strategy: {config.load_balancing_strategy}") + + # Custom configuration + custom_config = ConnectionConfig( + max_connections_per_peer=5, + connection_timeout=60.0, + load_balancing_strategy="least_loaded" + ) + +Load Balancing Strategies +------------------------- + +Two load balancing strategies are available: + +**Round Robin** (default) + Cycles through connections in order, distributing load evenly. + +**Least Loaded** + Selects the connection with the fewest active streams. + +API Usage +--------- + +The new API provides direct access to multiple connections: + +.. code-block:: python + + from libp2p import new_swarm + + # Create swarm with multiple connections support + swarm = new_swarm() + + # Dial a peer - returns list of connections + connections = await swarm.dial_peer(peer_id) + print(f"Established {len(connections)} connections") + + # Get all connections to a peer + peer_connections = swarm.get_connections(peer_id) + + # Get all connections (across all peers) + all_connections = swarm.get_connections() + + # Get the complete connections map + connections_map = swarm.get_connections_map() + + # Backward compatibility - get single connection + single_conn = swarm.get_connection(peer_id) + +Backward Compatibility +---------------------- + +Existing code continues to work through backward compatibility features: + +.. code-block:: python + + # Legacy 1:1 mapping (returns first connection for each peer) + legacy_connections = swarm.connections_legacy + + # Single connection access (returns first available connection) + conn = swarm.get_connection(peer_id) + +Example +------- + +A complete working example is available in the `examples/doc-examples/multiple_connections_example.py` file. + +Production Configuration +------------------------- + +For production use, consider these settings: + +**RetryConfig Parameters** + +The `RetryConfig` class controls connection retry behavior with exponential backoff: + +- **max_retries**: Maximum number of retry attempts before giving up (default: 3) +- **initial_delay**: Initial delay in seconds before the first retry (default: 0.1s) +- **max_delay**: Maximum delay cap to prevent excessive wait times (default: 30.0s) +- **backoff_multiplier**: Exponential backoff multiplier - each retry multiplies delay by this factor (default: 2.0) +- **jitter_factor**: Random jitter (0.0-1.0) to prevent synchronized retries (default: 0.1) + +**ConnectionConfig Parameters** + +The `ConnectionConfig` class manages multi-connection behavior: + +- **max_connections_per_peer**: Maximum connections allowed to a single peer (default: 3) +- **connection_timeout**: Timeout for establishing new connections in seconds (default: 30.0s) +- **load_balancing_strategy**: Strategy for distributing streams ("round_robin" or "least_loaded") + +**Load Balancing Strategies Explained** + +- **round_robin**: Cycles through connections in order, distributing load evenly. Simple and predictable. +- **least_loaded**: Selects the connection with the fewest active streams. Better for performance but more complex. + +.. code-block:: python + + from libp2p.network.swarm import ConnectionConfig, RetryConfig + + # Production-ready configuration + retry_config = RetryConfig( + max_retries=3, # Maximum retry attempts before giving up + initial_delay=0.1, # Start with 100ms delay + max_delay=30.0, # Cap exponential backoff at 30 seconds + backoff_multiplier=2.0, # Double delay each retry (100ms -> 200ms -> 400ms) + jitter_factor=0.1 # Add 10% random jitter to prevent thundering herd + ) + + connection_config = ConnectionConfig( + max_connections_per_peer=3, # Allow up to 3 connections per peer + connection_timeout=30.0, # 30 second timeout for new connections + load_balancing_strategy="round_robin" # Simple, predictable load distribution + ) + + swarm = new_swarm( + retry_config=retry_config, + connection_config=connection_config + ) + +**How RetryConfig Works in Practice** + +With the configuration above, connection retries follow this pattern: + +1. **Attempt 1**: Immediate connection attempt +2. **Attempt 2**: Wait 100ms ± 10ms jitter, then retry +3. **Attempt 3**: Wait 200ms ± 20ms jitter, then retry +4. **Attempt 4**: Wait 400ms ± 40ms jitter, then retry +5. **Attempt 5**: Wait 800ms ± 80ms jitter, then retry +6. **Attempt 6**: Wait 1.6s ± 160ms jitter, then retry +7. **Attempt 7**: Wait 3.2s ± 320ms jitter, then retry +8. **Attempt 8**: Wait 6.4s ± 640ms jitter, then retry +9. **Attempt 9**: Wait 12.8s ± 1.28s jitter, then retry +10. **Attempt 10**: Wait 25.6s ± 2.56s jitter, then retry +11. **Attempt 11**: Wait 30.0s (capped) ± 3.0s jitter, then retry +12. **Attempt 12**: Wait 30.0s (capped) ± 3.0s jitter, then retry +13. **Give up**: After 12 retries (3 initial + 9 retries), connection fails + +The jitter prevents multiple clients from retrying simultaneously, reducing server load. + +**Parameter Tuning Guidelines** + +**For Development/Testing:** +- Use lower `max_retries` (1-2) and shorter delays for faster feedback +- Example: `RetryConfig(max_retries=2, initial_delay=0.01, max_delay=0.1)` + +**For Production:** +- Use moderate `max_retries` (3-5) with reasonable delays for reliability +- Example: `RetryConfig(max_retries=5, initial_delay=0.1, max_delay=60.0)` + +**For High-Latency Networks:** +- Use higher `max_retries` (5-10) with longer delays +- Example: `RetryConfig(max_retries=8, initial_delay=0.5, max_delay=120.0)` + +**For Load Balancing:** +- Use `round_robin` for simple, predictable behavior +- Use `least_loaded` when you need optimal performance and can handle complexity + +Architecture +------------ + +The implementation follows the same architectural patterns as the Go and JavaScript reference implementations: + +- **Core data structure**: `dict[ID, list[INetConn]]` for 1:many mapping +- **API consistency**: Methods like `get_connections()` match reference implementations +- **Load balancing**: Integrated at the API level for optimal performance +- **Backward compatibility**: Maintains existing interfaces for gradual migration + +This design ensures consistency across libp2p implementations while providing the benefits of multiple connections per peer. diff --git a/docs/examples.rst b/docs/examples.rst index b8ba44d7..9f149ad0 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -9,9 +9,11 @@ Examples examples.identify_push examples.chat examples.echo + examples.echo_quic examples.ping examples.pubsub examples.circuit_relay examples.kademlia examples.mDNS examples.random_walk + examples.multiple_connections diff --git a/docs/getting_started.rst b/docs/getting_started.rst index a8303ce0..b5de85bc 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -28,6 +28,11 @@ For Python, the most common transport is TCP. Here's how to set up a basic TCP t .. literalinclude:: ../examples/doc-examples/example_transport.py :language: python +Also, QUIC is a modern transport protocol that provides built-in TLS security and stream multiplexing over UDP: + +.. literalinclude:: ../examples/doc-examples/example_quic_transport.py + :language: python + Connection Encryption ^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/libp2p.transport.quic.rst b/docs/libp2p.transport.quic.rst new file mode 100644 index 00000000..b7b4b561 --- /dev/null +++ b/docs/libp2p.transport.quic.rst @@ -0,0 +1,77 @@ +libp2p.transport.quic package +============================= + +Submodules +---------- + +libp2p.transport.quic.config module +----------------------------------- + +.. automodule:: libp2p.transport.quic.config + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.connection module +--------------------------------------- + +.. automodule:: libp2p.transport.quic.connection + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.exceptions module +--------------------------------------- + +.. automodule:: libp2p.transport.quic.exceptions + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.listener module +------------------------------------- + +.. automodule:: libp2p.transport.quic.listener + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.security module +------------------------------------- + +.. automodule:: libp2p.transport.quic.security + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.stream module +----------------------------------- + +.. automodule:: libp2p.transport.quic.stream + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.transport module +-------------------------------------- + +.. automodule:: libp2p.transport.quic.transport + :members: + :undoc-members: + :show-inheritance: + +libp2p.transport.quic.utils module +---------------------------------- + +.. automodule:: libp2p.transport.quic.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.transport.quic + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.transport.rst b/docs/libp2p.transport.rst index 0d92c48f..2a468143 100644 --- a/docs/libp2p.transport.rst +++ b/docs/libp2p.transport.rst @@ -9,6 +9,11 @@ Subpackages libp2p.transport.tcp +.. toctree:: + :maxdepth: 4 + + libp2p.transport.quic + Submodules ---------- diff --git a/examples/advanced/network_discover.py b/examples/advanced/network_discover.py new file mode 100644 index 00000000..87b44ddf --- /dev/null +++ b/examples/advanced/network_discover.py @@ -0,0 +1,63 @@ +""" +Advanced demonstration of Thin Waist address handling. + +Run: + python -m examples.advanced.network_discovery +""" + +from __future__ import annotations + +from multiaddr import Multiaddr + +try: + from libp2p.utils.address_validation import ( + expand_wildcard_address, + get_available_interfaces, + get_optimal_binding_address, + ) +except ImportError: + # Fallbacks if utilities are missing + def get_available_interfaces(port: int, protocol: str = "tcp"): + return [Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")] + + def expand_wildcard_address(addr: Multiaddr, port: int | None = None): + if port is None: + return [addr] + addr_str = str(addr).rsplit("/", 1)[0] + return [Multiaddr(addr_str + f"/{port}")] + + def get_optimal_binding_address(port: int, protocol: str = "tcp"): + return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + + +def main() -> None: + port = 8080 + interfaces = get_available_interfaces(port) + print(f"Discovered interfaces for port {port}:") + for a in interfaces: + print(f" - {a}") + + wildcard_v4 = Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + expanded_v4 = expand_wildcard_address(wildcard_v4) + print("\nExpanded IPv4 wildcard:") + for a in expanded_v4: + print(f" - {a}") + + wildcard_v6 = Multiaddr(f"/ip6/::/tcp/{port}") + expanded_v6 = expand_wildcard_address(wildcard_v6) + print("\nExpanded IPv6 wildcard:") + for a in expanded_v6: + print(f" - {a}") + + print("\nOptimal binding address heuristic result:") + print(f" -> {get_optimal_binding_address(port)}") + + override_port = 9000 + overridden = expand_wildcard_address(wildcard_v4, port=override_port) + print(f"\nPort override expansion to {override_port}:") + for a in overridden: + print(f" - {a}") + + +if __name__ == "__main__": + main() diff --git a/examples/doc-examples/example_quic_transport.py b/examples/doc-examples/example_quic_transport.py new file mode 100644 index 00000000..da2f5395 --- /dev/null +++ b/examples/doc-examples/example_quic_transport.py @@ -0,0 +1,35 @@ +import secrets + +import multiaddr +import trio + +from libp2p import ( + new_host, +) +from libp2p.crypto.secp256k1 import ( + create_new_key_pair, +) + + +async def main(): + # Create a key pair for the host + secret = secrets.token_bytes(32) + key_pair = create_new_key_pair(secret) + + # Create a host with the key pair + host = new_host(key_pair=key_pair, enable_quic=True) + + # Configure the listening address + port = 8000 + listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic-v1") + + # Start the host + async with host.run(listen_addrs=[listen_addr]): + print("libp2p has started with QUIC transport") + print("libp2p is listening on:", host.get_addrs()) + # Keep the host running + await trio.sleep_forever() + + +# Run the async function +trio.run(main) diff --git a/examples/doc-examples/multiple_connections_example.py b/examples/doc-examples/multiple_connections_example.py new file mode 100644 index 00000000..f0738283 --- /dev/null +++ b/examples/doc-examples/multiple_connections_example.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +Example demonstrating multiple connections per peer support in libp2p. + +This example shows how to: +1. Configure multiple connections per peer +2. Use different load balancing strategies +3. Access multiple connections through the new API +4. Maintain backward compatibility +""" + +import logging + +import trio + +from libp2p import new_swarm +from libp2p.network.swarm import ConnectionConfig, RetryConfig + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def example_basic_multiple_connections() -> None: + """Example of basic multiple connections per peer usage.""" + logger.info("Creating swarm with multiple connections support...") + + # Create swarm with default configuration + swarm = new_swarm() + default_connection = ConnectionConfig() + + logger.info(f"Swarm created with peer ID: {swarm.get_peer_id()}") + logger.info( + f"Connection config: max_connections_per_peer=" + f"{default_connection.max_connections_per_peer}" + ) + + await swarm.close() + logger.info("Basic multiple connections example completed") + + +async def example_custom_connection_config() -> None: + """Example of custom connection configuration.""" + logger.info("Creating swarm with custom connection configuration...") + + # Custom connection configuration for high-performance scenarios + connection_config = ConnectionConfig( + max_connections_per_peer=5, # More connections per peer + connection_timeout=60.0, # Longer timeout + load_balancing_strategy="least_loaded", # Use least loaded strategy + ) + + # Create swarm with custom connection config + swarm = new_swarm(connection_config=connection_config) + + logger.info("Custom connection config applied:") + logger.info( + f" Max connections per peer: {connection_config.max_connections_per_peer}" + ) + logger.info(f" Connection timeout: {connection_config.connection_timeout}s") + logger.info( + f" Load balancing strategy: {connection_config.load_balancing_strategy}" + ) + + await swarm.close() + logger.info("Custom connection config example completed") + + +async def example_multiple_connections_api() -> None: + """Example of using the new multiple connections API.""" + logger.info("Demonstrating multiple connections API...") + + connection_config = ConnectionConfig( + max_connections_per_peer=3, load_balancing_strategy="round_robin" + ) + + swarm = new_swarm(connection_config=connection_config) + + logger.info("Multiple connections API features:") + logger.info(" - dial_peer() returns list[INetConn]") + logger.info(" - get_connections(peer_id) returns list[INetConn]") + logger.info(" - get_connections_map() returns dict[ID, list[INetConn]]") + logger.info( + " - get_connection(peer_id) returns INetConn | None (backward compatibility)" + ) + + await swarm.close() + logger.info("Multiple connections API example completed") + + +async def example_backward_compatibility() -> None: + """Example of backward compatibility features.""" + logger.info("Demonstrating backward compatibility...") + + swarm = new_swarm() + + logger.info("Backward compatibility features:") + logger.info(" - connections_legacy property provides 1:1 mapping") + logger.info(" - get_connection() method for single connection access") + logger.info(" - Existing code continues to work") + + await swarm.close() + logger.info("Backward compatibility example completed") + + +async def example_production_ready_config() -> None: + """Example of production-ready configuration.""" + logger.info("Creating swarm with production-ready configuration...") + + # Production-ready retry configuration + retry_config = RetryConfig( + max_retries=3, # Reasonable retry limit + initial_delay=0.1, # Quick initial retry + max_delay=30.0, # Cap exponential backoff + backoff_multiplier=2.0, # Standard exponential backoff + jitter_factor=0.1, # Small jitter to prevent thundering herd + ) + + # Production-ready connection configuration + connection_config = ConnectionConfig( + max_connections_per_peer=3, # Balance between performance and resource usage + connection_timeout=30.0, # Reasonable timeout + load_balancing_strategy="round_robin", # Simple, predictable strategy + ) + + # Create swarm with production config + swarm = new_swarm(retry_config=retry_config, connection_config=connection_config) + + logger.info("Production-ready configuration applied:") + logger.info( + f" Retry: {retry_config.max_retries} retries, " + f"{retry_config.max_delay}s max delay" + ) + logger.info(f" Connections: {connection_config.max_connections_per_peer} per peer") + logger.info(f" Load balancing: {connection_config.load_balancing_strategy}") + + await swarm.close() + logger.info("Production-ready configuration example completed") + + +async def main() -> None: + """Run all examples.""" + logger.info("Multiple Connections Per Peer Examples") + logger.info("=" * 50) + + try: + await example_basic_multiple_connections() + logger.info("-" * 30) + + await example_custom_connection_config() + logger.info("-" * 30) + + await example_multiple_connections_api() + logger.info("-" * 30) + + await example_backward_compatibility() + logger.info("-" * 30) + + await example_production_ready_config() + logger.info("-" * 30) + + logger.info("All examples completed successfully!") + + except Exception as e: + logger.error(f"Example failed: {e}") + raise + + +if __name__ == "__main__": + trio.run(main) diff --git a/examples/echo/echo.py b/examples/echo/echo.py index 126a7da2..19e98377 100644 --- a/examples/echo/echo.py +++ b/examples/echo/echo.py @@ -1,4 +1,6 @@ import argparse +import random +import secrets import multiaddr import trio @@ -12,40 +14,54 @@ from libp2p.crypto.secp256k1 import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.network.stream.exceptions import ( + StreamEOF, +) from libp2p.network.stream.net_stream import ( INetStream, ) from libp2p.peer.peerinfo import ( info_from_p2p_addr, ) +from libp2p.utils.address_validation import ( + find_free_port, + get_available_interfaces, +) PROTOCOL_ID = TProtocol("/echo/1.0.0") MAX_READ_LEN = 2**32 - 1 async def _echo_stream_handler(stream: INetStream) -> None: - # Wait until EOF - msg = await stream.read(MAX_READ_LEN) - await stream.write(msg) - await stream.close() + try: + peer_id = stream.muxed_conn.peer_id + print(f"Received connection from {peer_id}") + # Wait until EOF + msg = await stream.read(MAX_READ_LEN) + print(f"Echoing message: {msg.decode('utf-8')}") + await stream.write(msg) + except StreamEOF: + print("Stream closed by remote peer.") + except Exception as e: + print(f"Error in echo handler: {e}") + finally: + await stream.close() async def run(port: int, destination: str, seed: int | None = None) -> None: - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") + if port <= 0: + port = find_free_port() + listen_addr = get_available_interfaces(port) if seed: - import random - random.seed(seed) secret_number = random.getrandbits(32 * 8) secret = secret_number.to_bytes(length=32, byteorder="big") else: - import secrets - secret = secrets.token_bytes(32) host = new_host(key_pair=create_new_key_pair(secret)) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with host.run(listen_addrs=listen_addr), trio.open_nursery() as nursery: # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) @@ -54,10 +70,15 @@ async def run(port: int, destination: str, seed: int | None = None) -> None: if not destination: # its the server host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + # Print all listen addresses with peer ID (JS parity) + print("Listener ready, listening on:\n") + peer_id = host.get_id().to_string() + for addr in listen_addr: + print(f"{addr}/p2p/{peer_id}") + print( - "Run this from the same folder in another console:\n\n" - f"echo-demo " - f"-d {host.get_addrs()[0]}\n" + "\nRun this from the same folder in another console:\n\n" + f"echo-demo -d {host.get_addrs()[0]}\n" ) print("Waiting for incoming connections...") await trio.sleep_forever() diff --git a/examples/echo/echo_quic.py b/examples/echo/echo_quic.py new file mode 100644 index 00000000..248aed9f --- /dev/null +++ b/examples/echo/echo_quic.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +""" +QUIC Echo Example - Fixed version with proper client/server separation + +This program demonstrates a simple echo protocol using QUIC transport where a peer +listens for connections and copies back any input received on a stream. + +Fixed to properly separate client and server modes - clients don't start listeners. +""" + +import argparse +import logging + +from multiaddr import Multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.network.stream.net_stream import INetStream +from libp2p.peer.peerinfo import info_from_p2p_addr + +PROTOCOL_ID = TProtocol("/echo/1.0.0") + + +async def _echo_stream_handler(stream: INetStream) -> None: + try: + msg = await stream.read() + await stream.write(msg) + await stream.close() + except Exception as e: + print(f"Echo handler error: {e}") + try: + await stream.close() + except: # noqa: E722 + pass + + +async def run_server(port: int, seed: int | None = None) -> None: + """Run echo server with QUIC transport.""" + listen_addr = Multiaddr(f"/ip4/0.0.0.0/udp/{port}/quic") + + if seed: + import random + + random.seed(seed) + secret_number = random.getrandbits(32 * 8) + secret = secret_number.to_bytes(length=32, byteorder="big") + else: + import secrets + + secret = secrets.token_bytes(32) + + # Create host with QUIC transport + host = new_host( + enable_quic=True, + key_pair=create_new_key_pair(secret), + ) + + # Server mode: start listener + async with host.run(listen_addrs=[listen_addr]): + try: + print(f"I am {host.get_id().to_string()}") + host.set_stream_handler(PROTOCOL_ID, _echo_stream_handler) + + print( + "Run this from the same folder in another console:\n\n" + f"python3 ./examples/echo/echo_quic.py " + f"-d {host.get_addrs()[0]}\n" + ) + print("Waiting for incoming QUIC connections...") + await trio.sleep_forever() + except KeyboardInterrupt: + print("Closing server gracefully...") + await host.close() + return + + +async def run_client(destination: str, seed: int | None = None) -> None: + """Run echo client with QUIC transport.""" + if seed: + import random + + random.seed(seed) + secret_number = random.getrandbits(32 * 8) + secret = secret_number.to_bytes(length=32, byteorder="big") + else: + import secrets + + secret = secrets.token_bytes(32) + + # Create host with QUIC transport + host = new_host( + enable_quic=True, + key_pair=create_new_key_pair(secret), + ) + + # Client mode: NO listener, just connect + async with host.run(listen_addrs=[]): # Empty listen_addrs for client + print(f"I am {host.get_id().to_string()}") + + maddr = Multiaddr(destination) + info = info_from_p2p_addr(maddr) + + # Connect to server + print("STARTING CLIENT CONNECTION PROCESS") + await host.connect(info) + print("CLIENT CONNECTED TO SERVER") + + # Start a stream with the destination + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + + msg = b"hi, there!\n" + + await stream.write(msg) + response = await stream.read() + + print(f"Sent: {msg.decode('utf-8')}") + print(f"Got: {response.decode('utf-8')}") + await stream.close() + await host.disconnect(info.peer_id) + + +async def run(port: int, destination: str, seed: int | None = None) -> None: + """ + Run echo server or client with QUIC transport. + + Fixed version that properly separates client and server modes. + """ + if not destination: # Server mode + await run_server(port, seed) + else: # Client mode + await run_client(destination, seed) + + +def main() -> None: + """Main function - help text updated for QUIC.""" + description = """ + This program demonstrates a simple echo protocol using QUIC + transport where a peer listens for connections and copies back + any input received on a stream. + + QUIC provides built-in TLS security and stream multiplexing over UDP. + + To use it, first run 'echo-quic-demo -p ', where is + the UDP port number. Then, run another host with , + 'echo-quic-demo -d ' + where is the QUIC multiaddress of the previous listener host. + """ + + example_maddr = "/ip4/127.0.0.1/udp/8000/quic/p2p/QmQn4SwGkDZKkUEpBRBv" + + parser = argparse.ArgumentParser(description=description) + parser.add_argument("-p", "--port", default=0, type=int, help="UDP port number") + parser.add_argument( + "-d", + "--destination", + type=str, + help=f"destination multiaddr string, e.g. {example_maddr}", + ) + parser.add_argument( + "-s", + "--seed", + type=int, + help="provide a seed to the random number generator", + ) + args = parser.parse_args() + + try: + trio.run(run, args.port, args.destination, args.seed) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + logging.getLogger("aioquic").setLevel(logging.DEBUG) + main() diff --git a/examples/kademlia/kademlia.py b/examples/kademlia/kademlia.py index 5daa70d7..faaa66be 100644 --- a/examples/kademlia/kademlia.py +++ b/examples/kademlia/kademlia.py @@ -41,6 +41,7 @@ from libp2p.tools.async_service import ( from libp2p.tools.utils import ( info_from_p2p_addr, ) +from libp2p.utils.paths import get_script_dir, join_paths # Configure logging logging.basicConfig( @@ -53,8 +54,8 @@ logger = logging.getLogger("kademlia-example") # Configure DHT module loggers to inherit from the parent logger # This ensures all kademlia-example.* loggers use the same configuration # Get the directory where this script is located -SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -SERVER_ADDR_LOG = os.path.join(SCRIPT_DIR, "server_node_addr.txt") +SCRIPT_DIR = get_script_dir(__file__) +SERVER_ADDR_LOG = join_paths(SCRIPT_DIR, "server_node_addr.txt") # Set the level for all child loggers for module in [ diff --git a/examples/pubsub/pubsub.py b/examples/pubsub/pubsub.py index 1ab6d650..41545658 100644 --- a/examples/pubsub/pubsub.py +++ b/examples/pubsub/pubsub.py @@ -1,6 +1,5 @@ import argparse import logging -import socket import base58 import multiaddr @@ -31,6 +30,9 @@ from libp2p.stream_muxer.mplex.mplex import ( from libp2p.tools.async_service.trio_service import ( background_trio_service, ) +from libp2p.utils.address_validation import ( + find_free_port, +) # Configure logging logging.basicConfig( @@ -77,13 +79,6 @@ async def publish_loop(pubsub, topic, termination_event): await trio.sleep(1) # Avoid tight loop on error -def find_free_port(): - """Find a free port on localhost.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # Bind to a free port provided by the OS - return s.getsockname()[1] - - async def monitor_peer_topics(pubsub, nursery, termination_event): """ Monitor for new topics that peers are subscribed to and diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 91d60ae5..3679409f 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,3 +1,11 @@ +"""Libp2p Python implementation.""" + +import logging + +from libp2p.transport.quic.utils import is_quic_multiaddr +from typing import Any +from libp2p.transport.quic.transport import QUICTransport +from libp2p.transport.quic.config import QUICTransportConfig from collections.abc import ( Mapping, Sequence, @@ -6,15 +14,12 @@ from importlib.metadata import version as __version from typing import ( Literal, Optional, - Type, - cast, ) import multiaddr from libp2p.abc import ( IHost, - IMuxedConn, INetworkService, IPeerRouting, IPeerStore, @@ -33,9 +38,6 @@ from libp2p.custom_types import ( TProtocol, TSecurityOptions, ) -from libp2p.discovery.mdns.mdns import ( - MDNSDiscovery, -) from libp2p.host.basic_host import ( BasicHost, ) @@ -45,27 +47,34 @@ from libp2p.host.routed_host import ( from libp2p.network.swarm import ( Swarm, ) +from libp2p.network.config import ( + ConnectionConfig, + RetryConfig +) from libp2p.peer.id import ( ID, ) from libp2p.peer.peerstore import ( PeerStore, + create_signed_peer_record, ) from libp2p.security.insecure.transport import ( PLAINTEXT_PROTOCOL_ID, InsecureTransport, ) -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, +) import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import ( MPLEX_PROTOCOL_ID, Mplex, ) from libp2p.stream_muxer.yamux.yamux import ( + PROTOCOL_ID as YAMUX_PROTOCOL_ID, Yamux, ) -from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID from libp2p.transport.tcp.tcp import ( TCP, ) @@ -91,7 +100,7 @@ MUXER_YAMUX = "YAMUX" MUXER_MPLEX = "MPLEX" DEFAULT_NEGOTIATE_TIMEOUT = 5 - +logger = logging.getLogger(__name__) def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None: """ @@ -160,7 +169,6 @@ def get_default_muxer_options() -> TMuxerOptions: else: # YAMUX is default return create_yamux_muxer_option() - def new_swarm( key_pair: KeyPair | None = None, muxer_opt: TMuxerOptions | None = None, @@ -168,6 +176,9 @@ def new_swarm( peerstore_opt: IPeerStore | None = None, muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, + enable_quic: bool = False, + retry_config: Optional["RetryConfig"] = None, + connection_config: ConnectionConfig | QUICTransportConfig | None = None, ) -> INetworkService: """ Create a swarm instance based on the parameters. @@ -178,6 +189,8 @@ def new_swarm( :param peerstore_opt: optional peerstore :param muxer_preference: optional explicit muxer preference :param listen_addrs: optional list of multiaddrs to listen on + :param enable_quic: enable quic for transport + :param quic_transport_opt: options for transport :return: return a default swarm instance Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer @@ -190,8 +203,6 @@ def new_swarm( id_opt = generate_peer_id_from(key_pair) - - # Generate X25519 keypair for Noise noise_key_pair = create_new_x25519_key_pair() @@ -255,36 +266,18 @@ def new_swarm( else: transport = transport_maybe - # Use given muxer preference if provided, otherwise use global default - if muxer_preference is not None: - temp_pref = muxer_preference.upper() - if temp_pref not in [MUXER_YAMUX, MUXER_MPLEX]: - raise ValueError( - f"Unknown muxer: {muxer_preference}. Use 'YAMUX' or 'MPLEX'." - ) - active_preference = temp_pref - else: - active_preference = DEFAULT_MUXER - - # Use provided muxer options if given, otherwise create based on preference - if muxer_opt is not None: - muxer_transports_by_protocol = muxer_opt - else: - if active_preference == MUXER_MPLEX: - muxer_transports_by_protocol = create_mplex_muxer_option() - else: # YAMUX is default - muxer_transports_by_protocol = create_yamux_muxer_option() - - upgrader = TransportUpgrader( - secure_transports_by_protocol=secure_transports_by_protocol, - muxer_transports_by_protocol=muxer_transports_by_protocol, - ) - peerstore = peerstore_opt or PeerStore() # Store our key pair in peerstore peerstore.add_key_pair(id_opt, key_pair) - return Swarm(id_opt, peerstore, upgrader, transport) + return Swarm( + id_opt, + peerstore, + upgrader, + transport, + retry_config=retry_config, + connection_config=connection_config + ) def new_host( @@ -298,6 +291,8 @@ def new_host( enable_mDNS: bool = False, bootstrap: list[str] | None = None, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, + enable_quic: bool = False, + quic_transport_opt: QUICTransportConfig | None = None, ) -> IHost: """ Create a new libp2p host based on the given parameters. @@ -311,19 +306,33 @@ def new_host( :param listen_addrs: optional list of multiaddrs to listen on :param enable_mDNS: whether to enable mDNS discovery :param bootstrap: optional list of bootstrap peer addresses as strings + :param enable_quic: optinal choice to use QUIC for transport + :param transport_opt: optional configuration for quic transport :return: return a host instance """ + + if not enable_quic and quic_transport_opt is not None: + logger.warning(f"QUIC config provided but QUIC not enabled, ignoring QUIC config") + swarm = new_swarm( + enable_quic=enable_quic, key_pair=key_pair, muxer_opt=muxer_opt, sec_opt=sec_opt, peerstore_opt=peerstore_opt, muxer_preference=muxer_preference, listen_addrs=listen_addrs, + connection_config=quic_transport_opt if enable_quic else None ) if disc_opt is not None: return RoutedHost(swarm, disc_opt, enable_mDNS, bootstrap) - return BasicHost(network=swarm,enable_mDNS=enable_mDNS , bootstrap=bootstrap, negotitate_timeout=negotiate_timeout) + return BasicHost( + network=swarm, + enable_mDNS=enable_mDNS, + bootstrap=bootstrap, + negotitate_timeout=negotiate_timeout + ) + __version__ = __version("libp2p") diff --git a/libp2p/abc.py b/libp2p/abc.py index 90ad6a45..964c7454 100644 --- a/libp2p/abc.py +++ b/libp2p/abc.py @@ -970,6 +970,14 @@ class IPeerStore( # --------CERTIFIED-ADDR-BOOK---------- + @abstractmethod + def get_local_record(self) -> Optional["Envelope"]: + """Get the local-peer-record wrapped in Envelope""" + + @abstractmethod + def set_local_record(self, envelope: "Envelope") -> None: + """Set the local-peer-record wrapped in Envelope""" + @abstractmethod def consume_peer_record(self, envelope: "Envelope", ttl: int) -> bool: """ @@ -1404,15 +1412,16 @@ class INetwork(ABC): ---------- peerstore : IPeerStore The peer store for managing peer information. - connections : dict[ID, INetConn] - A mapping of peer IDs to network connections. + connections : dict[ID, list[INetConn]] + A mapping of peer IDs to lists of network connections + (multiple connections per peer). listeners : dict[str, IListener] A mapping of listener identifiers to listener instances. """ peerstore: IPeerStore - connections: dict[ID, INetConn] + connections: dict[ID, list[INetConn]] listeners: dict[str, IListener] @abstractmethod @@ -1428,9 +1437,56 @@ class INetwork(ABC): """ @abstractmethod - async def dial_peer(self, peer_id: ID) -> INetConn: + def get_connections(self, peer_id: ID | None = None) -> list[INetConn]: """ - Create a connection to the specified peer. + Get connections for peer (like JS getConnections, Go ConnsToPeer). + + Parameters + ---------- + peer_id : ID | None + The peer ID to get connections for. If None, returns all connections. + + Returns + ------- + list[INetConn] + List of connections to the specified peer, or all connections + if peer_id is None. + + """ + + @abstractmethod + def get_connections_map(self) -> dict[ID, list[INetConn]]: + """ + Get all connections map (like JS getConnectionsMap). + + Returns + ------- + dict[ID, list[INetConn]] + The complete mapping of peer IDs to their connection lists. + + """ + + @abstractmethod + def get_connection(self, peer_id: ID) -> INetConn | None: + """ + Get single connection for backward compatibility. + + Parameters + ---------- + peer_id : ID + The peer ID to get a connection for. + + Returns + ------- + INetConn | None + The first available connection, or None if no connections exist. + + """ + + @abstractmethod + async def dial_peer(self, peer_id: ID) -> list[INetConn]: + """ + Create connections to the specified peer with load balancing. Parameters ---------- @@ -1439,8 +1495,8 @@ class INetwork(ABC): Returns ------- - INetConn - The network connection instance to the specified peer. + list[INetConn] + List of established connections to the peer. Raises ------ diff --git a/libp2p/custom_types.py b/libp2p/custom_types.py index 0b844133..d8e1a1d9 100644 --- a/libp2p/custom_types.py +++ b/libp2p/custom_types.py @@ -5,17 +5,17 @@ from collections.abc import ( ) from typing import TYPE_CHECKING, NewType, Union, cast +from libp2p.transport.quic.stream import QUICStream + if TYPE_CHECKING: - from libp2p.abc import ( - IMuxedConn, - INetStream, - ISecureTransport, - ) + from libp2p.abc import IMuxedConn, IMuxedStream, INetStream, ISecureTransport + from libp2p.transport.quic.connection import QUICConnection else: IMuxedConn = cast(type, object) INetStream = cast(type, object) ISecureTransport = cast(type, object) - + IMuxedStream = cast(type, object) + QUICConnection = cast(type, object) from libp2p.io.abc import ( ReadWriteCloser, @@ -37,3 +37,6 @@ SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] UnsubscribeFn = Callable[[], Awaitable[None]] +TQUICStreamHandlerFn = Callable[[QUICStream], Awaitable[None]] +TQUICConnHandlerFn = Callable[[QUICConnection], Awaitable[None]] +MessageID = NewType("MessageID", str) diff --git a/libp2p/discovery/bootstrap/bootstrap.py b/libp2p/discovery/bootstrap/bootstrap.py index 222a88a1..63985242 100644 --- a/libp2p/discovery/bootstrap/bootstrap.py +++ b/libp2p/discovery/bootstrap/bootstrap.py @@ -2,15 +2,20 @@ import logging from multiaddr import Multiaddr from multiaddr.resolvers import DNSResolver +import trio from libp2p.abc import ID, INetworkService, PeerInfo from libp2p.discovery.bootstrap.utils import validate_bootstrap_addresses from libp2p.discovery.events.peerDiscovery import peerDiscovery +from libp2p.network.exceptions import SwarmException from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.peer.peerstore import PERMANENT_ADDR_TTL logger = logging.getLogger("libp2p.discovery.bootstrap") resolver = DNSResolver() +DEFAULT_CONNECTION_TIMEOUT = 10 + class BootstrapDiscovery: """ @@ -19,68 +24,147 @@ class BootstrapDiscovery: """ def __init__(self, swarm: INetworkService, bootstrap_addrs: list[str]): + """ + Initialize BootstrapDiscovery. + + Args: + swarm: The network service (swarm) instance + bootstrap_addrs: List of bootstrap peer multiaddresses + + """ self.swarm = swarm self.peerstore = swarm.peerstore self.bootstrap_addrs = bootstrap_addrs or [] self.discovered_peers: set[str] = set() + self.connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT async def start(self) -> None: - """Process bootstrap addresses and emit peer discovery events.""" - logger.debug( + """Process bootstrap addresses and emit peer discovery events in parallel.""" + logger.info( f"Starting bootstrap discovery with " f"{len(self.bootstrap_addrs)} bootstrap addresses" ) + # Show all bootstrap addresses being processed + for i, addr in enumerate(self.bootstrap_addrs): + logger.debug(f"{i + 1}. {addr}") + # Validate and filter bootstrap addresses self.bootstrap_addrs = validate_bootstrap_addresses(self.bootstrap_addrs) + logger.info(f"Valid addresses after validation: {len(self.bootstrap_addrs)}") - for addr_str in self.bootstrap_addrs: - try: - await self._process_bootstrap_addr(addr_str) - except Exception as e: - logger.debug(f"Failed to process bootstrap address {addr_str}: {e}") + # Use Trio nursery for PARALLEL address processing + try: + async with trio.open_nursery() as nursery: + logger.debug( + f"Starting {len(self.bootstrap_addrs)} parallel address " + f"processing tasks" + ) + + # Start all bootstrap address processing tasks in parallel + for addr_str in self.bootstrap_addrs: + logger.debug(f"Starting parallel task for: {addr_str}") + nursery.start_soon(self._process_bootstrap_addr, addr_str) + + # The nursery will wait for all address processing tasks to complete + logger.debug( + "Nursery active - waiting for address processing tasks to complete" + ) + + except trio.Cancelled: + logger.debug("Bootstrap address processing cancelled - cleaning up tasks") + raise + except Exception as e: + logger.error(f"Bootstrap address processing failed: {e}") + raise + + logger.info("Bootstrap discovery startup complete - all tasks finished") def stop(self) -> None: """Clean up bootstrap discovery resources.""" - logger.debug("Stopping bootstrap discovery") + logger.info("Stopping bootstrap discovery and cleaning up tasks") + + # Clear discovered peers self.discovered_peers.clear() + logger.debug("Bootstrap discovery cleanup completed") + async def _process_bootstrap_addr(self, addr_str: str) -> None: """Convert string address to PeerInfo and add to peerstore.""" try: - multiaddr = Multiaddr(addr_str) + try: + multiaddr = Multiaddr(addr_str) + except Exception as e: + logger.debug(f"Invalid multiaddr format '{addr_str}': {e}") + return + + if self.is_dns_addr(multiaddr): + resolved_addrs = await resolver.resolve(multiaddr) + if resolved_addrs is None: + logger.warning(f"DNS resolution returned None for: {addr_str}") + return + + peer_id_str = multiaddr.get_peer_id() + if peer_id_str is None: + logger.warning(f"Missing peer ID in DNS address: {addr_str}") + return + peer_id = ID.from_base58(peer_id_str) + addrs = [addr for addr in resolved_addrs] + if not addrs: + logger.warning(f"No addresses resolved for DNS address: {addr_str}") + return + peer_info = PeerInfo(peer_id, addrs) + await self.add_addr(peer_info) + else: + peer_info = info_from_p2p_addr(multiaddr) + await self.add_addr(peer_info) except Exception as e: - logger.debug(f"Invalid multiaddr format '{addr_str}': {e}") - return - if self.is_dns_addr(multiaddr): - resolved_addrs = await resolver.resolve(multiaddr) - peer_id_str = multiaddr.get_peer_id() - if peer_id_str is None: - logger.warning(f"Missing peer ID in DNS address: {addr_str}") - return - peer_id = ID.from_base58(peer_id_str) - addrs = [addr for addr in resolved_addrs] - if not addrs: - logger.warning(f"No addresses resolved for DNS address: {addr_str}") - return - peer_info = PeerInfo(peer_id, addrs) - self.add_addr(peer_info) - else: - self.add_addr(info_from_p2p_addr(multiaddr)) + logger.warning(f"Failed to process bootstrap address {addr_str}: {e}") def is_dns_addr(self, addr: Multiaddr) -> bool: """Check if the address is a DNS address.""" return any(protocol.name == "dnsaddr" for protocol in addr.protocols()) - def add_addr(self, peer_info: PeerInfo) -> None: - """Add a peer to the peerstore and emit discovery event.""" + async def add_addr(self, peer_info: PeerInfo) -> None: + """ + Add a peer to the peerstore, emit discovery event, + and attempt connection in parallel. + """ + logger.debug( + f"Adding peer {peer_info.peer_id} with {len(peer_info.addrs)} addresses" + ) + # Skip if it's our own peer if peer_info.peer_id == self.swarm.get_peer_id(): logger.debug(f"Skipping own peer ID: {peer_info.peer_id}") return - # Always add addresses to peerstore (allows multiple addresses for same peer) - self.peerstore.add_addrs(peer_info.peer_id, peer_info.addrs, 10) + # Filter addresses to only include IPv4+TCP (only supported protocol) + ipv4_tcp_addrs = [] + filtered_out_addrs = [] + + for addr in peer_info.addrs: + if self._is_ipv4_tcp_addr(addr): + ipv4_tcp_addrs.append(addr) + else: + filtered_out_addrs.append(addr) + + # Log filtering results + logger.debug( + f"Address filtering for {peer_info.peer_id}: " + f"{len(ipv4_tcp_addrs)} IPv4+TCP, {len(filtered_out_addrs)} filtered" + ) + + # Skip peer if no IPv4+TCP addresses available + if not ipv4_tcp_addrs: + logger.warning( + f"āŒ No IPv4+TCP addresses for {peer_info.peer_id} - " + f"skipping connection attempts" + ) + return + + # Add only IPv4+TCP addresses to peerstore + self.peerstore.add_addrs(peer_info.peer_id, ipv4_tcp_addrs, PERMANENT_ADDR_TTL) # Only emit discovery event if this is the first time we see this peer peer_id_str = str(peer_info.peer_id) @@ -89,6 +173,140 @@ class BootstrapDiscovery: self.discovered_peers.add(peer_id_str) # Emit peer discovery event peerDiscovery.emit_peer_discovered(peer_info) - logger.debug(f"Peer discovered: {peer_info.peer_id}") + logger.info(f"Peer discovered: {peer_info.peer_id}") + + # Connect to peer (parallel across different bootstrap addresses) + logger.debug("Connecting to discovered peer...") + await self._connect_to_peer(peer_info.peer_id) + else: - logger.debug(f"Additional addresses added for peer: {peer_info.peer_id}") + logger.debug( + f"Additional addresses added for existing peer: {peer_info.peer_id}" + ) + # Even for existing peers, try to connect if not already connected + if peer_info.peer_id not in self.swarm.connections: + logger.debug("Connecting to existing peer...") + await self._connect_to_peer(peer_info.peer_id) + + async def _connect_to_peer(self, peer_id: ID) -> None: + """ + Attempt to establish a connection to a peer with timeout. + + Uses swarm.dial_peer to connect using addresses stored in peerstore. + Times out after self.connection_timeout seconds to prevent hanging. + """ + logger.debug(f"Connection attempt for peer: {peer_id}") + + # Pre-connection validation: Check if already connected + if peer_id in self.swarm.connections: + logger.debug( + f"Already connected to {peer_id} - skipping connection attempt" + ) + return + + # Check available addresses before attempting connection + available_addrs = self.peerstore.addrs(peer_id) + logger.debug(f"Connecting to {peer_id} ({len(available_addrs)} addresses)") + + if not available_addrs: + logger.error(f"āŒ No addresses available for {peer_id} - cannot connect") + return + + # Record start time for connection attempt monitoring + connection_start_time = trio.current_time() + + try: + with trio.move_on_after(self.connection_timeout): + # Log connection attempt + logger.debug( + f"Attempting connection to {peer_id} using " + f"{len(available_addrs)} addresses" + ) + + # Use swarm.dial_peer to connect using stored addresses + await self.swarm.dial_peer(peer_id) + + # Calculate connection time + connection_time = trio.current_time() - connection_start_time + + # Post-connection validation: Verify connection was actually established + if peer_id in self.swarm.connections: + logger.info( + f"āœ… Connected to {peer_id} (took {connection_time:.2f}s)" + ) + + else: + logger.warning( + f"Dial succeeded but connection not found for {peer_id}" + ) + except trio.TooSlowError: + logger.warning( + f"āŒ Connection to {peer_id} timed out after {self.connection_timeout}s" + ) + except SwarmException as e: + # Calculate failed connection time + failed_connection_time = trio.current_time() - connection_start_time + + # Enhanced error logging + error_msg = str(e) + if "no addresses established a successful connection" in error_msg: + logger.warning( + f"āŒ Failed to connect to {peer_id} after trying all " + f"{len(available_addrs)} addresses " + f"(took {failed_connection_time:.2f}s)" + ) + # Log individual address failures if this is a MultiError + if ( + e.__cause__ is not None + and hasattr(e.__cause__, "exceptions") + and getattr(e.__cause__, "exceptions", None) is not None + ): + exceptions_list = getattr(e.__cause__, "exceptions") + logger.debug("šŸ“‹ Individual address failure details:") + for i, addr_exception in enumerate(exceptions_list, 1): + logger.debug(f"Address {i}: {addr_exception}") + # Also log the actual address that failed + if i <= len(available_addrs): + logger.debug(f"Failed address: {available_addrs[i - 1]}") + else: + logger.warning("No detailed exception information available") + else: + logger.warning( + f"āŒ Failed to connect to {peer_id}: {e} " + f"(took {failed_connection_time:.2f}s)" + ) + + except Exception as e: + # Handle unexpected errors that aren't swarm-specific + failed_connection_time = trio.current_time() - connection_start_time + logger.error( + f"āŒ Unexpected error connecting to {peer_id}: " + f"{e} (took {failed_connection_time:.2f}s)" + ) + # Don't re-raise to prevent killing the nursery and other parallel tasks + + def _is_ipv4_tcp_addr(self, addr: Multiaddr) -> bool: + """ + Check if address is IPv4 with TCP protocol only. + + Filters out IPv6, UDP, QUIC, WebSocket, and other unsupported protocols. + Only IPv4+TCP addresses are supported by the current transport. + """ + try: + protocols = addr.protocols() + + # Must have IPv4 protocol + has_ipv4 = any(p.name == "ip4" for p in protocols) + if not has_ipv4: + return False + + # Must have TCP protocol + has_tcp = any(p.name == "tcp" for p in protocols) + if not has_tcp: + return False + + return True + + except Exception: + # If we can't parse the address, don't use it + return False diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index b40b0128..e370a3de 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -43,6 +43,7 @@ from libp2p.peer.id import ( from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.protocol_muxer.exceptions import ( MultiselectClientError, MultiselectError, @@ -110,6 +111,14 @@ class BasicHost(IHost): if bootstrap: self.bootstrap = BootstrapDiscovery(network, bootstrap) + # Cache a signed-record if the local-node in the PeerStore + envelope = create_signed_peer_record( + self.get_id(), + self.get_addrs(), + self.get_private_key(), + ) + self.get_peerstore().set_local_record(envelope) + def get_id(self) -> ID: """ :return: peer_id of host @@ -288,6 +297,11 @@ class BasicHost(IHost): protocol, handler = await self.multiselect.negotiate( MultiselectCommunicator(net_stream), self.negotiate_timeout ) + if protocol is None: + await net_stream.reset() + raise StreamFailure( + "Failed to negotiate protocol: no protocol selected" + ) except MultiselectError as error: peer_id = net_stream.muxed_conn.peer_id logger.debug( @@ -329,7 +343,7 @@ class BasicHost(IHost): :param peer_id: ID of the peer to check :return: True if peer has an active connection, False otherwise """ - return peer_id in self._network.connections + return len(self._network.get_connections(peer_id)) > 0 def get_peer_connection_info(self, peer_id: ID) -> INetConn | None: """ @@ -338,4 +352,4 @@ class BasicHost(IHost): :param peer_id: ID of the peer to get info for :return: Connection object if peer is connected, None otherwise """ - return self._network.connections.get(peer_id) + return self._network.get_connection(peer_id) diff --git a/libp2p/identity/identify/identify.py b/libp2p/identity/identify/identify.py index b2811ff9..146fbd2d 100644 --- a/libp2p/identity/identify/identify.py +++ b/libp2p/identity/identify/identify.py @@ -15,8 +15,7 @@ from libp2p.custom_types import ( from libp2p.network.stream.exceptions import ( StreamClosed, ) -from libp2p.peer.envelope import seal_record -from libp2p.peer.peer_record import PeerRecord +from libp2p.peer.peerstore import env_to_send_in_RPC from libp2p.utils import ( decode_varint_with_size, get_agent_version, @@ -66,9 +65,7 @@ def _mk_identify_protobuf( protocols = tuple(str(p) for p in host.get_mux().get_protocols() if p is not None) # Create a signed peer-record for the remote peer - record = PeerRecord(host.get_id(), host.get_addrs()) - envelope = seal_record(record, host.get_private_key()) - protobuf = envelope.marshal_envelope() + envelope_bytes, _ = env_to_send_in_RPC(host) observed_addr = observed_multiaddr.to_bytes() if observed_multiaddr else b"" return Identify( @@ -78,7 +75,7 @@ def _mk_identify_protobuf( listen_addrs=map(_multiaddr_to_bytes, laddrs), observed_addr=observed_addr, protocols=protocols, - signedPeerRecord=protobuf, + signedPeerRecord=envelope_bytes, ) diff --git a/libp2p/kad_dht/kad_dht.py b/libp2p/kad_dht/kad_dht.py index 097b6c48..0d05aaf8 100644 --- a/libp2p/kad_dht/kad_dht.py +++ b/libp2p/kad_dht/kad_dht.py @@ -22,15 +22,18 @@ from libp2p.abc import ( IHost, ) from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.network.stream.net_stream import ( INetStream, ) +from libp2p.peer.envelope import Envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from libp2p.tools.async_service import ( Service, ) @@ -234,6 +237,9 @@ class KadDHT(Service): await self.add_peer(peer_id) logger.debug(f"Added peer {peer_id} to routing table") + closer_peer_envelope: Envelope | None = None + provider_peer_envelope: Envelope | None = None + try: # Read varint-prefixed length for the message length_prefix = b"" @@ -274,6 +280,14 @@ class KadDHT(Service): ) logger.debug(f"Found {len(closest_peers)} peers close to target") + # Consume the source signed_peer_record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + # Build response message with protobuf response = Message() response.type = Message.MessageType.FIND_NODE @@ -298,6 +312,21 @@ class KadDHT(Service): except Exception: pass + # Add the signed-peer-record for each peer in the peer-proto + # if cached in the peerstore + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + + # Create sender_signed_peer_record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -312,6 +341,14 @@ class KadDHT(Service): key = message.key logger.debug(f"Received ADD_PROVIDER for key {key.hex()}") + # Consume the source signed-peer-record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + # Extract provider information for provider_proto in message.providerPeers: try: @@ -338,6 +375,17 @@ class KadDHT(Service): logger.debug( f"Added provider {provider_id} for key {key.hex()}" ) + + # Process the signed-records of provider if sent + if not maybe_consume_signed_record( + provider_proto, self.host + ): + logger.error( + "Received an invalid-signed-record," + "dropping the stream" + ) + await stream.close() + return except Exception as e: logger.warning(f"Failed to process provider info: {e}") @@ -346,6 +394,10 @@ class KadDHT(Service): response.type = Message.MessageType.ADD_PROVIDER response.key = key + # Add sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) await stream.write(response_bytes) @@ -357,6 +409,14 @@ class KadDHT(Service): key = message.key logger.debug(f"Received GET_PROVIDERS request for key {key.hex()}") + # Consume the source signed_peer_record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + # Find providers for the key providers = self.provider_store.get_providers(key) logger.debug( @@ -368,12 +428,28 @@ class KadDHT(Service): response.type = Message.MessageType.GET_PROVIDERS response.key = key + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Add provider information to response for provider_info in providers: provider_proto = response.providerPeers.add() provider_proto.id = provider_info.peer_id.to_bytes() provider_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add provider signed-records if cached + provider_peer_envelope = ( + self.host.get_peerstore().get_peer_record( + provider_info.peer_id + ) + ) + + if provider_peer_envelope is not None: + provider_proto.signedRecord = ( + provider_peer_envelope.marshal_envelope() + ) + # Add addresses if available for addr in provider_info.addrs: provider_proto.addrs.append(addr.to_bytes()) @@ -397,6 +473,16 @@ class KadDHT(Service): peer_proto.id = peer.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add the signed-records of closest_peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer) @@ -417,6 +503,14 @@ class KadDHT(Service): key = message.key logger.debug(f"Received GET_VALUE request for key {key.hex()}") + # Consume the sender_signed_peer_record + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + value = self.value_store.get(key) if value: logger.debug(f"Found value for key {key.hex()}") @@ -431,6 +525,10 @@ class KadDHT(Service): response.record.value = value response.record.timeReceived = str(time.time()) + # Create sender_signed_peer_record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) @@ -444,6 +542,10 @@ class KadDHT(Service): response.type = Message.MessageType.GET_VALUE response.key = key + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Add closest peers to key closest_peers = self.routing_table.find_local_closest_peers( key, 20 @@ -462,6 +564,16 @@ class KadDHT(Service): peer_proto.id = peer.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add signed-records of closer-peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer) + ) + + if closer_peer_envelope is not None: + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer) @@ -484,6 +596,15 @@ class KadDHT(Service): key = message.record.key value = message.record.value success = False + + # Consume the source signed_peer_record if sent + if not maybe_consume_signed_record(message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + await stream.close() + return + try: if not (key and value): raise ValueError( @@ -504,6 +625,12 @@ class KadDHT(Service): response.type = Message.MessageType.PUT_VALUE if success: response.key = key + + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + + # Serialize and send response response_bytes = response.SerializeToString() await stream.write(varint.encode(len(response_bytes))) await stream.write(response_bytes) diff --git a/libp2p/kad_dht/pb/kademlia.proto b/libp2p/kad_dht/pb/kademlia.proto index fd198d28..7c3e5bad 100644 --- a/libp2p/kad_dht/pb/kademlia.proto +++ b/libp2p/kad_dht/pb/kademlia.proto @@ -27,6 +27,7 @@ message Message { bytes id = 1; repeated bytes addrs = 2; ConnectionType connection = 3; + optional bytes signedRecord = 4; // Envelope(PeerRecord) encoded } MessageType type = 1; @@ -35,4 +36,6 @@ message Message { Record record = 3; repeated Peer closerPeers = 8; repeated Peer providerPeers = 9; + + optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded } diff --git a/libp2p/kad_dht/pb/kademlia_pb2.py b/libp2p/kad_dht/pb/kademlia_pb2.py index 781333bf..ac23169c 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.py +++ b/libp2p/kad_dht/pb/kademlia_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: libp2p/kad_dht/pb/kademlia.proto +# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,21 +14,21 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xca\x03\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x1aN\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n libp2p/kad_dht/pb/kademlia.proto\":\n\x06Record\x12\x0b\n\x03key\x18\x01 \x01(\x0c\x12\r\n\x05value\x18\x02 \x01(\x0c\x12\x14\n\x0ctimeReceived\x18\x05 \x01(\t\"\xa2\x04\n\x07Message\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.Message.MessageType\x12\x17\n\x0f\x63lusterLevelRaw\x18\n \x01(\x05\x12\x0b\n\x03key\x18\x02 \x01(\x0c\x12\x17\n\x06record\x18\x03 \x01(\x0b\x32\x07.Record\x12\"\n\x0b\x63loserPeers\x18\x08 \x03(\x0b\x32\r.Message.Peer\x12$\n\rproviderPeers\x18\t \x03(\x0b\x32\r.Message.Peer\x12\x19\n\x0csenderRecord\x18\x0b \x01(\x0cH\x00\x88\x01\x01\x1az\n\x04Peer\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\r\n\x05\x61\x64\x64rs\x18\x02 \x03(\x0c\x12+\n\nconnection\x18\x03 \x01(\x0e\x32\x17.Message.ConnectionType\x12\x19\n\x0csignedRecord\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x0f\n\r_signedRecord\"i\n\x0bMessageType\x12\r\n\tPUT_VALUE\x10\x00\x12\r\n\tGET_VALUE\x10\x01\x12\x10\n\x0c\x41\x44\x44_PROVIDER\x10\x02\x12\x11\n\rGET_PROVIDERS\x10\x03\x12\r\n\tFIND_NODE\x10\x04\x12\x08\n\x04PING\x10\x05\"W\n\x0e\x43onnectionType\x12\x11\n\rNOT_CONNECTED\x10\x00\x12\r\n\tCONNECTED\x10\x01\x12\x0f\n\x0b\x43\x41N_CONNECT\x10\x02\x12\x12\n\x0e\x43\x41NNOT_CONNECT\x10\x03\x42\x0f\n\r_senderRecordb\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.kad_dht.pb.kademlia_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _RECORD._serialized_start=36 - _RECORD._serialized_end=94 - _MESSAGE._serialized_start=97 - _MESSAGE._serialized_end=555 - _MESSAGE_PEER._serialized_start=281 - _MESSAGE_PEER._serialized_end=359 - _MESSAGE_MESSAGETYPE._serialized_start=361 - _MESSAGE_MESSAGETYPE._serialized_end=466 - _MESSAGE_CONNECTIONTYPE._serialized_start=468 - _MESSAGE_CONNECTIONTYPE._serialized_end=555 + _globals['_RECORD']._serialized_start=36 + _globals['_RECORD']._serialized_end=94 + _globals['_MESSAGE']._serialized_start=97 + _globals['_MESSAGE']._serialized_end=643 + _globals['_MESSAGE_PEER']._serialized_start=308 + _globals['_MESSAGE_PEER']._serialized_end=430 + _globals['_MESSAGE_MESSAGETYPE']._serialized_start=432 + _globals['_MESSAGE_MESSAGETYPE']._serialized_end=537 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_start=539 + _globals['_MESSAGE_CONNECTIONTYPE']._serialized_end=626 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/kad_dht/pb/kademlia_pb2.pyi b/libp2p/kad_dht/pb/kademlia_pb2.pyi index c8f16db2..6d80d77d 100644 --- a/libp2p/kad_dht/pb/kademlia_pb2.pyi +++ b/libp2p/kad_dht/pb/kademlia_pb2.pyi @@ -1,133 +1,70 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -""" +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing +DESCRIPTOR: _descriptor.FileDescriptor -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions +class Record(_message.Message): + __slots__ = ("key", "value", "timeReceived") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + TIMERECEIVED_FIELD_NUMBER: _ClassVar[int] + key: bytes + value: bytes + timeReceived: str + def __init__(self, key: _Optional[bytes] = ..., value: _Optional[bytes] = ..., timeReceived: _Optional[str] = ...) -> None: ... -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor - -@typing.final -class Record(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - TIMERECEIVED_FIELD_NUMBER: builtins.int - key: builtins.bytes - value: builtins.bytes - timeReceived: builtins.str - def __init__( - self, - *, - key: builtins.bytes = ..., - value: builtins.bytes = ..., - timeReceived: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "timeReceived", b"timeReceived", "value", b"value"]) -> None: ... - -global___Record = Record - -@typing.final -class Message(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _MessageType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _MessageTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._MessageType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - PUT_VALUE: Message._MessageType.ValueType # 0 - GET_VALUE: Message._MessageType.ValueType # 1 - ADD_PROVIDER: Message._MessageType.ValueType # 2 - GET_PROVIDERS: Message._MessageType.ValueType # 3 - FIND_NODE: Message._MessageType.ValueType # 4 - PING: Message._MessageType.ValueType # 5 - - class MessageType(_MessageType, metaclass=_MessageTypeEnumTypeWrapper): ... - PUT_VALUE: Message.MessageType.ValueType # 0 - GET_VALUE: Message.MessageType.ValueType # 1 - ADD_PROVIDER: Message.MessageType.ValueType # 2 - GET_PROVIDERS: Message.MessageType.ValueType # 3 - FIND_NODE: Message.MessageType.ValueType # 4 - PING: Message.MessageType.ValueType # 5 - - class _ConnectionType: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _ConnectionTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Message._ConnectionType.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NOT_CONNECTED: Message._ConnectionType.ValueType # 0 - CONNECTED: Message._ConnectionType.ValueType # 1 - CAN_CONNECT: Message._ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message._ConnectionType.ValueType # 3 - - class ConnectionType(_ConnectionType, metaclass=_ConnectionTypeEnumTypeWrapper): ... - NOT_CONNECTED: Message.ConnectionType.ValueType # 0 - CONNECTED: Message.ConnectionType.ValueType # 1 - CAN_CONNECT: Message.ConnectionType.ValueType # 2 - CANNOT_CONNECT: Message.ConnectionType.ValueType # 3 - - @typing.final - class Peer(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - ID_FIELD_NUMBER: builtins.int - ADDRS_FIELD_NUMBER: builtins.int - CONNECTION_FIELD_NUMBER: builtins.int - id: builtins.bytes - connection: global___Message.ConnectionType.ValueType - @property - def addrs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... - def __init__( - self, - *, - id: builtins.bytes = ..., - addrs: collections.abc.Iterable[builtins.bytes] | None = ..., - connection: global___Message.ConnectionType.ValueType = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["addrs", b"addrs", "connection", b"connection", "id", b"id"]) -> None: ... - - TYPE_FIELD_NUMBER: builtins.int - CLUSTERLEVELRAW_FIELD_NUMBER: builtins.int - KEY_FIELD_NUMBER: builtins.int - RECORD_FIELD_NUMBER: builtins.int - CLOSERPEERS_FIELD_NUMBER: builtins.int - PROVIDERPEERS_FIELD_NUMBER: builtins.int - type: global___Message.MessageType.ValueType - clusterLevelRaw: builtins.int - key: builtins.bytes - @property - def record(self) -> global___Record: ... - @property - def closerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - @property - def providerPeers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message.Peer]: ... - def __init__( - self, - *, - type: global___Message.MessageType.ValueType = ..., - clusterLevelRaw: builtins.int = ..., - key: builtins.bytes = ..., - record: global___Record | None = ..., - closerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - providerPeers: collections.abc.Iterable[global___Message.Peer] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["record", b"record"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["closerPeers", b"closerPeers", "clusterLevelRaw", b"clusterLevelRaw", "key", b"key", "providerPeers", b"providerPeers", "record", b"record", "type", b"type"]) -> None: ... - -global___Message = Message +class Message(_message.Message): + __slots__ = ("type", "clusterLevelRaw", "key", "record", "closerPeers", "providerPeers", "senderRecord") + class MessageType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + PUT_VALUE: _ClassVar[Message.MessageType] + GET_VALUE: _ClassVar[Message.MessageType] + ADD_PROVIDER: _ClassVar[Message.MessageType] + GET_PROVIDERS: _ClassVar[Message.MessageType] + FIND_NODE: _ClassVar[Message.MessageType] + PING: _ClassVar[Message.MessageType] + PUT_VALUE: Message.MessageType + GET_VALUE: Message.MessageType + ADD_PROVIDER: Message.MessageType + GET_PROVIDERS: Message.MessageType + FIND_NODE: Message.MessageType + PING: Message.MessageType + class ConnectionType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NOT_CONNECTED: _ClassVar[Message.ConnectionType] + CONNECTED: _ClassVar[Message.ConnectionType] + CAN_CONNECT: _ClassVar[Message.ConnectionType] + CANNOT_CONNECT: _ClassVar[Message.ConnectionType] + NOT_CONNECTED: Message.ConnectionType + CONNECTED: Message.ConnectionType + CAN_CONNECT: Message.ConnectionType + CANNOT_CONNECT: Message.ConnectionType + class Peer(_message.Message): + __slots__ = ("id", "addrs", "connection", "signedRecord") + ID_FIELD_NUMBER: _ClassVar[int] + ADDRS_FIELD_NUMBER: _ClassVar[int] + CONNECTION_FIELD_NUMBER: _ClassVar[int] + SIGNEDRECORD_FIELD_NUMBER: _ClassVar[int] + id: bytes + addrs: _containers.RepeatedScalarFieldContainer[bytes] + connection: Message.ConnectionType + signedRecord: bytes + def __init__(self, id: _Optional[bytes] = ..., addrs: _Optional[_Iterable[bytes]] = ..., connection: _Optional[_Union[Message.ConnectionType, str]] = ..., signedRecord: _Optional[bytes] = ...) -> None: ... + TYPE_FIELD_NUMBER: _ClassVar[int] + CLUSTERLEVELRAW_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + RECORD_FIELD_NUMBER: _ClassVar[int] + CLOSERPEERS_FIELD_NUMBER: _ClassVar[int] + PROVIDERPEERS_FIELD_NUMBER: _ClassVar[int] + SENDERRECORD_FIELD_NUMBER: _ClassVar[int] + type: Message.MessageType + clusterLevelRaw: int + key: bytes + record: Record + closerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + providerPeers: _containers.RepeatedCompositeFieldContainer[Message.Peer] + senderRecord: bytes + def __init__(self, type: _Optional[_Union[Message.MessageType, str]] = ..., clusterLevelRaw: _Optional[int] = ..., key: _Optional[bytes] = ..., record: _Optional[_Union[Record, _Mapping]] = ..., closerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., providerPeers: _Optional[_Iterable[_Union[Message.Peer, _Mapping]]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore diff --git a/libp2p/kad_dht/peer_routing.py b/libp2p/kad_dht/peer_routing.py index c4a066f7..f5313cb6 100644 --- a/libp2p/kad_dht/peer_routing.py +++ b/libp2p/kad_dht/peer_routing.py @@ -15,12 +15,14 @@ from libp2p.abc import ( INetStream, IPeerRouting, ) +from libp2p.peer.envelope import Envelope from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( ALPHA, @@ -33,6 +35,7 @@ from .routing_table import ( RoutingTable, ) from .utils import ( + maybe_consume_signed_record, sort_peer_ids_by_distance, ) @@ -255,6 +258,10 @@ class PeerRouting(IPeerRouting): find_node_msg.type = Message.MessageType.FIND_NODE find_node_msg.key = target_key # Set target key directly as bytes + # Create sender_signed_peer_record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + find_node_msg.senderRecord = envelope_bytes + # Serialize and send the protobuf message with varint length prefix proto_bytes = find_node_msg.SerializeToString() logger.debug( @@ -299,7 +306,22 @@ class PeerRouting(IPeerRouting): # Process closest peers from response if response_msg.type == Message.MessageType.FIND_NODE: + # Consume the sender_signed_peer_record + if not maybe_consume_signed_record(response_msg, self.host, peer): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] + for peer_data in response_msg.closerPeers: + # Consume the received closer_peers signed-records, peer-id is + # sent with the peer-data + if not maybe_consume_signed_record(peer_data, self.host): + logger.error( + "Received an invalid-signed-record,ignoring the response" + ) + return [] + new_peer_id = ID(peer_data.id) if new_peer_id not in results: results.append(new_peer_id) @@ -332,6 +354,7 @@ class PeerRouting(IPeerRouting): """ try: # Read message length + peer_id = stream.muxed_conn.peer_id length_bytes = await stream.read(4) if not length_bytes: return @@ -345,10 +368,18 @@ class PeerRouting(IPeerRouting): # Parse protobuf message kad_message = Message() + closer_peer_envelope: Envelope | None = None try: kad_message.ParseFromString(message_bytes) if kad_message.type == Message.MessageType.FIND_NODE: + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(kad_message, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, dropping the stream" + ) + return + # Get target key directly from protobuf message target_key = kad_message.key @@ -361,12 +392,26 @@ class PeerRouting(IPeerRouting): response = Message() response.type = Message.MessageType.FIND_NODE + # Create sender_signed_peer_record for the response + envelope_bytes, _ = env_to_send_in_RPC(self.host) + response.senderRecord = envelope_bytes + # Add peer information to response for peer_id in closest_peers: peer_proto = response.closerPeers.add() peer_proto.id = peer_id.to_bytes() peer_proto.connection = Message.ConnectionType.CAN_CONNECT + # Add the signed-records of closest_peers if cached + closer_peer_envelope = ( + self.host.get_peerstore().get_peer_record(peer_id) + ) + + if isinstance(closer_peer_envelope, Envelope): + peer_proto.signedRecord = ( + closer_peer_envelope.marshal_envelope() + ) + # Add addresses if available try: addrs = self.host.get_peerstore().addrs(peer_id) diff --git a/libp2p/kad_dht/provider_store.py b/libp2p/kad_dht/provider_store.py index 5c34f0c7..77bb464f 100644 --- a/libp2p/kad_dht/provider_store.py +++ b/libp2p/kad_dht/provider_store.py @@ -22,12 +22,14 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( ALPHA, @@ -240,11 +242,18 @@ class ProviderStore: message.type = Message.MessageType.ADD_PROVIDER message.key = key + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Add our provider info provider = message.providerPeers.add() provider.id = self.local_peer_id.to_bytes() provider.addrs.extend(addrs) + # Add the provider's signed-peer-record + provider.signedRecord = envelope_bytes + # Serialize and send the message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -276,10 +285,15 @@ class ProviderStore: response = Message() response.ParseFromString(response_bytes) - # Check response type - response.type == Message.MessageType.ADD_PROVIDER - if response.type: - result = True + if response.type == Message.MessageType.ADD_PROVIDER: + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + result = False + else: + result = True except Exception as e: logger.warning(f"Error sending ADD_PROVIDER to {peer_id}: {e}") @@ -380,6 +394,10 @@ class ProviderStore: message.type = Message.MessageType.GET_PROVIDERS message.key = key + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Serialize and send the message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -414,10 +432,26 @@ class ProviderStore: if response.type != Message.MessageType.GET_PROVIDERS: return [] + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return [] + # Extract provider information providers = [] for provider_proto in response.providerPeers: try: + # Consume the provider's signed-peer-record if sent, peer-id + # already sent with the provider-proto + if not maybe_consume_signed_record(provider_proto, self.host): + logger.error( + "Received an invalid-signed-record, " + "ignoring the response" + ) + return [] + # Create peer ID from bytes provider_id = ID(provider_proto.id) @@ -431,6 +465,7 @@ class ProviderStore: # Create PeerInfo and add to result providers.append(PeerInfo(provider_id, addrs)) + except Exception as e: logger.warning(f"Failed to parse provider info: {e}") diff --git a/libp2p/kad_dht/utils.py b/libp2p/kad_dht/utils.py index 61158320..fe768723 100644 --- a/libp2p/kad_dht/utils.py +++ b/libp2p/kad_dht/utils.py @@ -2,13 +2,93 @@ Utility functions for Kademlia DHT implementation. """ +import logging + import base58 import multihash +from libp2p.abc import IHost +from libp2p.peer.envelope import consume_envelope from libp2p.peer.id import ( ID, ) +from .pb.kademlia_pb2 import ( + Message, +) + +logger = logging.getLogger("kademlia-example.utils") + + +def maybe_consume_signed_record( + msg: Message | Message.Peer, host: IHost, peer_id: ID | None = None +) -> bool: + """ + Attempt to parse and store a signed-peer-record (Envelope) received during + DHT communication. If the record is invalid, the peer-id does not match, or + updating the peerstore fails, the function logs an error and returns False. + + Parameters + ---------- + msg : Message | Message.Peer + The protobuf message received during DHT communication. Can either be a + top-level `Message` containing `senderRecord` or a `Message.Peer` + containing `signedRecord`. + host : IHost + The local host instance, providing access to the peerstore for storing + verified peer records. + peer_id : ID | None, optional + The expected peer ID for record validation. If provided, the peer ID + inside the record must match this value. + + Returns + ------- + bool + True if a valid signed peer record was successfully consumed and stored, + False otherwise. + + """ + if isinstance(msg, Message): + if msg.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, record = consume_envelope( + msg.senderRecord, + "libp2p-peer-record", + ) + if not (isinstance(peer_id, ID) and record.peer_id == peer_id): + return False + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Failed to update the Certified-Addr-Book") + return False + except Exception as e: + logger.error("Failed to update the Certified-Addr-Book: %s", e) + return False + else: + if msg.HasField("signedRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, record = consume_envelope( + msg.signedRecord, + "libp2p-peer-record", + ) + if not record.peer_id.to_bytes() == msg.id: + return False + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Failed to update the Certified-Addr-Book") + return False + except Exception as e: + logger.error( + "Failed to update the Certified-Addr-Book: %s", + e, + ) + return False + return True + def create_key_from_binary(binary_data: bytes) -> bytes: """ diff --git a/libp2p/kad_dht/value_store.py b/libp2p/kad_dht/value_store.py index b79425fd..2002965f 100644 --- a/libp2p/kad_dht/value_store.py +++ b/libp2p/kad_dht/value_store.py @@ -15,9 +15,11 @@ from libp2p.abc import ( from libp2p.custom_types import ( TProtocol, ) +from libp2p.kad_dht.utils import maybe_consume_signed_record from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .common import ( DEFAULT_TTL, @@ -110,6 +112,10 @@ class ValueStore: message = Message() message.type = Message.MessageType.PUT_VALUE + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Set message fields message.key = key message.record.key = key @@ -155,7 +161,13 @@ class ValueStore: # Check if response is valid if response.type == Message.MessageType.PUT_VALUE: - if response.key: + # Consume the sender's signed-peer-record if sent + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return False + if response.key == key: result = True return result @@ -231,6 +243,10 @@ class ValueStore: message.type = Message.MessageType.GET_VALUE message.key = key + # Create sender's signed-peer-record + envelope_bytes, _ = env_to_send_in_RPC(self.host) + message.senderRecord = envelope_bytes + # Serialize and send the protobuf message proto_bytes = message.SerializeToString() await stream.write(varint.encode(len(proto_bytes))) @@ -275,6 +291,13 @@ class ValueStore: and response.HasField("record") and response.record.value ): + # Consume the sender's signed-peer-record + if not maybe_consume_signed_record(response, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the response" + ) + return None + logger.debug( f"Received value for key {key.hex()} from peer {peer_id}" ) diff --git a/libp2p/network/config.py b/libp2p/network/config.py new file mode 100644 index 00000000..e0fad33c --- /dev/null +++ b/libp2p/network/config.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass + + +@dataclass +class RetryConfig: + """ + Configuration for retry logic with exponential backoff. + + This configuration controls how connection attempts are retried when they fail. + The retry mechanism uses exponential backoff with jitter to prevent thundering + herd problems in distributed systems. + + Attributes: + max_retries: Maximum number of retry attempts before giving up. + Default: 3 attempts + initial_delay: Initial delay in seconds before the first retry. + Default: 0.1 seconds (100ms) + max_delay: Maximum delay cap in seconds to prevent excessive wait times. + Default: 30.0 seconds + backoff_multiplier: Multiplier for exponential backoff (each retry multiplies + the delay by this factor). Default: 2.0 (doubles each time) + jitter_factor: Random jitter factor (0.0-1.0) to add randomness to delays + and prevent synchronized retries. Default: 0.1 (10% jitter) + + """ + + max_retries: int = 3 + initial_delay: float = 0.1 + max_delay: float = 30.0 + backoff_multiplier: float = 2.0 + jitter_factor: float = 0.1 + + +@dataclass +class ConnectionConfig: + """ + Configuration for multi-connection support. + + This configuration controls how multiple connections per peer are managed, + including connection limits, timeouts, and load balancing strategies. + + Attributes: + max_connections_per_peer: Maximum number of connections allowed to a single + peer. Default: 3 connections + connection_timeout: Timeout in seconds for establishing new connections. + Default: 30.0 seconds + load_balancing_strategy: Strategy for distributing streams across connections. + Options: "round_robin" (default) or "least_loaded" + + """ + + max_connections_per_peer: int = 3 + connection_timeout: float = 30.0 + load_balancing_strategy: str = "round_robin" # or "least_loaded" + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + if not ( + self.load_balancing_strategy == "round_robin" + or self.load_balancing_strategy == "least_loaded" + ): + raise ValueError( + "Load balancing strategy can only be 'round_robin' or 'least_loaded'" + ) + + if self.max_connections_per_peer < 1: + raise ValueError("Max connection per peer should be atleast 1") + + if self.connection_timeout < 0: + raise ValueError("Connection timeout should be positive") diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index b54fdda4..49daab9c 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -17,6 +17,7 @@ from libp2p.stream_muxer.exceptions import ( MuxedStreamError, MuxedStreamReset, ) +from libp2p.transport.quic.exceptions import QUICStreamClosedError, QUICStreamResetError from .exceptions import ( StreamClosed, @@ -170,7 +171,7 @@ class NetStream(INetStream): elif self.__stream_state == StreamState.OPEN: self.__stream_state = StreamState.CLOSE_READ raise StreamEOF() from error - except MuxedStreamReset as error: + except (MuxedStreamReset, QUICStreamClosedError, QUICStreamResetError) as error: async with self._state_lock: if self.__stream_state in [ StreamState.OPEN, @@ -199,7 +200,12 @@ class NetStream(INetStream): try: await self.muxed_stream.write(data) - except (MuxedStreamClosed, MuxedStreamError) as error: + except ( + MuxedStreamClosed, + MuxedStreamError, + QUICStreamClosedError, + QUICStreamResetError, + ) as error: async with self._state_lock: if self.__stream_state == StreamState.OPEN: self.__stream_state = StreamState.CLOSE_WRITE diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index b9f366c3..01079a1c 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -3,6 +3,8 @@ from collections.abc import ( Callable, ) import logging +import random +from typing import cast from multiaddr import ( Multiaddr, @@ -25,6 +27,7 @@ from libp2p.custom_types import ( from libp2p.io.abc import ( ReadWriteCloser, ) +from libp2p.network.config import ConnectionConfig, RetryConfig from libp2p.peer.id import ( ID, ) @@ -39,6 +42,9 @@ from libp2p.transport.exceptions import ( OpenConnectionError, SecurityUpgradeFailure, ) +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.transport import QUICTransport from libp2p.transport.upgrader import ( TransportUpgrader, ) @@ -71,9 +77,7 @@ class Swarm(Service, INetworkService): peerstore: IPeerStore upgrader: TransportUpgrader transport: ITransport - # TODO: Connection and `peer_id` are 1-1 mapping in our implementation, - # whereas in Go one `peer_id` may point to multiple connections. - connections: dict[ID, INetConn] + connections: dict[ID, list[INetConn]] listeners: dict[str, IListener] common_stream_handler: StreamHandlerFn listener_nursery: trio.Nursery | None @@ -81,18 +85,31 @@ class Swarm(Service, INetworkService): notifees: list[INotifee] + # Enhanced: New configuration + retry_config: RetryConfig + connection_config: ConnectionConfig | QUICTransportConfig + _round_robin_index: dict[ID, int] + def __init__( self, peer_id: ID, peerstore: IPeerStore, upgrader: TransportUpgrader, transport: ITransport, + retry_config: RetryConfig | None = None, + connection_config: ConnectionConfig | QUICTransportConfig | None = None, ): self.self_id = peer_id self.peerstore = peerstore self.upgrader = upgrader self.transport = transport - self.connections = dict() + + # Enhanced: Initialize retry and connection configuration + self.retry_config = retry_config or RetryConfig() + self.connection_config = connection_config or ConnectionConfig() + + # Enhanced: Initialize connections as 1:many mapping + self.connections = {} self.listeners = dict() # Create Notifee array @@ -103,11 +120,19 @@ class Swarm(Service, INetworkService): self.listener_nursery = None self.event_listener_nursery_created = trio.Event() + # Load balancing state + self._round_robin_index = {} + async def run(self) -> None: async with trio.open_nursery() as nursery: # Create a nursery for listener tasks. self.listener_nursery = nursery self.event_listener_nursery_created.set() + + if isinstance(self.transport, QUICTransport): + self.transport.set_background_nursery(nursery) + self.transport.set_swarm(self) + try: await self.manager.wait_finished() finally: @@ -122,18 +147,74 @@ class Swarm(Service, INetworkService): def set_stream_handler(self, stream_handler: StreamHandlerFn) -> None: self.common_stream_handler = stream_handler - async def dial_peer(self, peer_id: ID) -> INetConn: + def get_connections(self, peer_id: ID | None = None) -> list[INetConn]: """ - Try to create a connection to peer_id. + Get connections for peer (like JS getConnections, Go ConnsToPeer). + + Parameters + ---------- + peer_id : ID | None + The peer ID to get connections for. If None, returns all connections. + + Returns + ------- + list[INetConn] + List of connections to the specified peer, or all connections + if peer_id is None. + + """ + if peer_id is not None: + return self.connections.get(peer_id, []) + + # Return all connections from all peers + all_conns = [] + for conns in self.connections.values(): + all_conns.extend(conns) + return all_conns + + def get_connections_map(self) -> dict[ID, list[INetConn]]: + """ + Get all connections map (like JS getConnectionsMap). + + Returns + ------- + dict[ID, list[INetConn]] + The complete mapping of peer IDs to their connection lists. + + """ + return self.connections.copy() + + def get_connection(self, peer_id: ID) -> INetConn | None: + """ + Get single connection for backward compatibility. + + Parameters + ---------- + peer_id : ID + The peer ID to get a connection for. + + Returns + ------- + INetConn | None + The first available connection, or None if no connections exist. + + """ + conns = self.get_connections(peer_id) + return conns[0] if conns else None + + async def dial_peer(self, peer_id: ID) -> list[INetConn]: + """ + Try to create connections to peer_id with enhanced retry logic. :param peer_id: peer if we want to dial :raises SwarmException: raised when an error occurs - :return: muxed connection + :return: list of muxed connections """ - if peer_id in self.connections: - # If muxed connection already exists for peer_id, - # set muxed connection equal to existing muxed connection - return self.connections[peer_id] + # Check if we already have connections + existing_connections = self.get_connections(peer_id) + if existing_connections: + logger.debug(f"Reusing existing connections to peer {peer_id}") + return existing_connections logger.debug("attempting to dial peer %s", peer_id) @@ -146,12 +227,19 @@ class Swarm(Service, INetworkService): if not addrs: raise SwarmException(f"No known addresses to peer {peer_id}") + connections = [] exceptions: list[SwarmException] = [] - # Try all known addresses + # Enhanced: Try all known addresses with retry logic for multiaddr in addrs: try: - return await self.dial_addr(multiaddr, peer_id) + connection = await self._dial_with_retry(multiaddr, peer_id) + connections.append(connection) + + # Limit number of connections per peer + if len(connections) >= self.connection_config.max_connections_per_peer: + break + except SwarmException as e: exceptions.append(e) logger.debug( @@ -161,15 +249,73 @@ class Swarm(Service, INetworkService): exc_info=e, ) - # Tried all addresses, raising exception. - raise SwarmException( - f"unable to connect to {peer_id}, no addresses established a successful " - "connection (with exceptions)" - ) from MultiError(exceptions) + if not connections: + # Tried all addresses, raising exception. + raise SwarmException( + f"unable to connect to {peer_id}, no addresses established a " + "successful connection (with exceptions)" + ) from MultiError(exceptions) - async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn: + return connections + + async def _dial_with_retry(self, addr: Multiaddr, peer_id: ID) -> INetConn: """ - Try to create a connection to peer_id with addr. + Enhanced: Dial with retry logic and exponential backoff. + + :param addr: the address to dial + :param peer_id: the peer we want to connect to + :raises SwarmException: raised when all retry attempts fail + :return: network connection + """ + last_exception = None + + for attempt in range(self.retry_config.max_retries + 1): + try: + return await self._dial_addr_single_attempt(addr, peer_id) + except Exception as e: + last_exception = e + if attempt < self.retry_config.max_retries: + delay = self._calculate_backoff_delay(attempt) + logger.debug( + f"Connection attempt {attempt + 1} failed, " + f"retrying in {delay:.2f}s: {e}" + ) + await trio.sleep(delay) + else: + logger.debug(f"All {self.retry_config.max_retries} attempts failed") + + # Convert the last exception to SwarmException for consistency + if last_exception is not None: + if isinstance(last_exception, SwarmException): + raise last_exception + else: + raise SwarmException( + f"Failed to connect after {self.retry_config.max_retries} attempts" + ) from last_exception + + # This should never be reached, but mypy requires it + raise SwarmException("Unexpected error in retry logic") + + def _calculate_backoff_delay(self, attempt: int) -> float: + """ + Enhanced: Calculate backoff delay with jitter to prevent thundering herd. + + :param attempt: the current attempt number (0-based) + :return: delay in seconds + """ + delay = min( + self.retry_config.initial_delay + * (self.retry_config.backoff_multiplier**attempt), + self.retry_config.max_delay, + ) + + # Add jitter to prevent synchronized retries + jitter = delay * self.retry_config.jitter_factor + return delay + random.uniform(-jitter, jitter) + + async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetConn: + """ + Enhanced: Single attempt to dial an address (extracted from original dial_addr). :param addr: the address we want to connect with :param peer_id: the peer we want to connect to @@ -179,6 +325,7 @@ class Swarm(Service, INetworkService): # Dial peer (connection to peer does not yet exist) # Transport dials peer (gets back a raw conn) try: + addr = Multiaddr(f"{addr}/p2p/{peer_id}") raw_conn = await self.transport.dial(addr) except OpenConnectionError as error: logger.debug("fail to dial peer %s over base transport", peer_id) @@ -186,6 +333,15 @@ class Swarm(Service, INetworkService): f"fail to open connection to peer {peer_id}" ) from error + if isinstance(self.transport, QUICTransport) and isinstance( + raw_conn, IMuxedConn + ): + logger.info( + "Skipping upgrade for QUIC, QUIC connections are already multiplexed" + ) + swarm_conn = await self.add_conn(raw_conn) + return swarm_conn + logger.debug("dialed peer %s over base transport", peer_id) # Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure @@ -211,24 +367,103 @@ class Swarm(Service, INetworkService): logger.debug("upgraded mux for peer %s", peer_id) swarm_conn = await self.add_conn(muxed_conn) - logger.debug("successfully dialed peer %s", peer_id) - return swarm_conn + async def dial_addr(self, addr: Multiaddr, peer_id: ID) -> INetConn: + """ + Enhanced: Try to create a connection to peer_id with addr using retry logic. + + :param addr: the address we want to connect with + :param peer_id: the peer we want to connect to + :raises SwarmException: raised when an error occurs + :return: network connection + """ + return await self._dial_with_retry(addr, peer_id) + async def new_stream(self, peer_id: ID) -> INetStream: """ + Enhanced: Create a new stream with load balancing across multiple connections. + :param peer_id: peer_id of destination :raises SwarmException: raised when an error occurs :return: net stream instance """ logger.debug("attempting to open a stream to peer %s", peer_id) + # Get existing connections or dial new ones + connections = self.get_connections(peer_id) + if not connections: + connections = await self.dial_peer(peer_id) - swarm_conn = await self.dial_peer(peer_id) + # Load balancing strategy at interface level + connection = self._select_connection(connections, peer_id) - net_stream = await swarm_conn.new_stream() - logger.debug("successfully opened a stream to peer %s", peer_id) - return net_stream + if isinstance(self.transport, QUICTransport) and connection is not None: + conn = cast(SwarmConn, connection) + return await conn.new_stream() + + try: + net_stream = await connection.new_stream() + logger.debug("successfully opened a stream to peer %s", peer_id) + return net_stream + except Exception as e: + logger.debug(f"Failed to create stream on connection: {e}") + # Try other connections if available + for other_conn in connections: + if other_conn != connection: + try: + net_stream = await other_conn.new_stream() + logger.debug( + f"Successfully opened a stream to peer {peer_id} " + "using alternative connection" + ) + return net_stream + except Exception: + continue + + # All connections failed, raise exception + raise SwarmException(f"Failed to create stream to peer {peer_id}") from e + + def _select_connection(self, connections: list[INetConn], peer_id: ID) -> INetConn: + """ + Select connection based on load balancing strategy. + + Parameters + ---------- + connections : list[INetConn] + List of available connections. + peer_id : ID + The peer ID for round-robin tracking. + strategy : str + Load balancing strategy ("round_robin", "least_loaded", etc.). + + Returns + ------- + INetConn + Selected connection. + + """ + if not connections: + raise ValueError("No connections available") + + strategy = self.connection_config.load_balancing_strategy + + if strategy == "round_robin": + # Simple round-robin selection + if peer_id not in self._round_robin_index: + self._round_robin_index[peer_id] = 0 + + index = self._round_robin_index[peer_id] % len(connections) + self._round_robin_index[peer_id] += 1 + return connections[index] + + elif strategy == "least_loaded": + # Find connection with least streams + return min(connections, key=lambda c: len(c.get_streams())) + + else: + # Default to first connection + return connections[0] async def listen(self, *multiaddrs: Multiaddr) -> bool: """ @@ -248,17 +483,35 @@ class Swarm(Service, INetworkService): """ logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}") # We need to wait until `self.listener_nursery` is created. + logger.debug("Starting to listen") await self.event_listener_nursery_created.wait() + success_count = 0 for maddr in multiaddrs: logger.debug(f"Swarm.listen processing multiaddr: {maddr}") if str(maddr) in self.listeners: - logger.debug(f"Swarm.listen: listener already exists for {maddr}") - return True + success_count += 1 + continue async def conn_handler( read_write_closer: ReadWriteCloser, maddr: Multiaddr = maddr ) -> None: + # No need to upgrade QUIC Connection + if isinstance(self.transport, QUICTransport): + try: + quic_conn = cast(QUICConnection, read_write_closer) + await self.add_conn(quic_conn) + peer_id = quic_conn.peer_id + logger.debug( + f"successfully opened quic connection to peer {peer_id}" + ) + # NOTE: This is a intentional barrier to prevent from the + # handler exiting and closing the connection. + await self.manager.wait_finished() + except Exception: + await read_write_closer.close() + return + raw_conn = RawConnection(read_write_closer, False) # Per, https://discuss.libp2p.io/t/multistream-security/130, we first @@ -309,13 +562,14 @@ class Swarm(Service, INetworkService): # Call notifiers since event occurred await self.notify_listen(maddr) - return True + success_count += 1 + logger.debug("successfully started listening on: %s", maddr) except OSError: # Failed. Continue looping. logger.debug("fail to listen on: %s", maddr) - # No maddr succeeded - return False + # Return true if at least one address succeeded + return success_count > 0 async def close(self) -> None: """ @@ -328,9 +582,9 @@ class Swarm(Service, INetworkService): # Perform alternative cleanup if the manager isn't initialized # Close all connections manually if hasattr(self, "connections"): - for conn_id in list(self.connections.keys()): - conn = self.connections[conn_id] - await conn.close() + for peer_id, conns in list(self.connections.items()): + for conn in conns: + await conn.close() # Clear connection tracking dictionary self.connections.clear() @@ -360,12 +614,28 @@ class Swarm(Service, INetworkService): logger.debug("swarm successfully closed") async def close_peer(self, peer_id: ID) -> None: - if peer_id not in self.connections: + """ + Close all connections to the specified peer. + + Parameters + ---------- + peer_id : ID + The peer ID to close connections for. + + """ + connections = self.get_connections(peer_id) + if not connections: return - connection = self.connections[peer_id] - # NOTE: `connection.close` will delete `peer_id` from `self.connections` - # and `notify_disconnected` for us. - await connection.close() + + # Close all connections + for connection in connections: + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing connection to {peer_id}: {e}") + + # Remove from connections dict + self.connections.pop(peer_id, None) logger.debug("successfully close the connection to peer %s", peer_id) @@ -379,26 +649,77 @@ class Swarm(Service, INetworkService): muxed_conn, self, ) - + logger.debug("Swarm::add_conn | starting muxed connection") self.manager.run_task(muxed_conn.start) await muxed_conn.event_started.wait() + logger.debug("Swarm::add_conn | starting swarm connection") self.manager.run_task(swarm_conn.start) await swarm_conn.event_started.wait() - # Store muxed_conn with peer id - self.connections[muxed_conn.peer_id] = swarm_conn + + # Add to connections dict with deduplication + peer_id = muxed_conn.peer_id + if peer_id not in self.connections: + self.connections[peer_id] = [] + + # Check for duplicate connections by comparing the underlying muxed connection + for existing_conn in self.connections[peer_id]: + if existing_conn.muxed_conn == muxed_conn: + logger.debug(f"Connection already exists for peer {peer_id}") + # existing_conn is a SwarmConn since it's stored in the connections list + return existing_conn # type: ignore[return-value] + + self.connections[peer_id].append(swarm_conn) + + # Trim if we exceed max connections + max_conns = self.connection_config.max_connections_per_peer + if len(self.connections[peer_id]) > max_conns: + self._trim_connections(peer_id) + # Call notifiers since event occurred await self.notify_connected(swarm_conn) return swarm_conn + def _trim_connections(self, peer_id: ID) -> None: + """ + Remove oldest connections when limit is exceeded. + """ + connections = self.connections[peer_id] + if len(connections) <= self.connection_config.max_connections_per_peer: + return + + # Sort by creation time and remove oldest + # For now, just keep the most recent connections + max_conns = self.connection_config.max_connections_per_peer + connections_to_remove = connections[:-max_conns] + + for conn in connections_to_remove: + logger.debug(f"Trimming old connection for peer {peer_id}") + trio.lowlevel.spawn_system_task(self._close_connection_async, conn) + + # Keep only the most recent connections + max_conns = self.connection_config.max_connections_per_peer + self.connections[peer_id] = connections[-max_conns:] + + async def _close_connection_async(self, connection: INetConn) -> None: + """Close a connection asynchronously.""" + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing connection: {e}") + def remove_conn(self, swarm_conn: SwarmConn) -> None: """ Simply remove the connection from Swarm's records, without closing the connection. """ peer_id = swarm_conn.muxed_conn.peer_id - if peer_id not in self.connections: - return - del self.connections[peer_id] + + if peer_id in self.connections: + self.connections[peer_id] = [ + conn for conn in self.connections[peer_id] if conn != swarm_conn + ] + if not self.connections[peer_id]: + del self.connections[peer_id] # Notifee @@ -444,3 +765,21 @@ class Swarm(Service, INetworkService): async with trio.open_nursery() as nursery: for notifee in self.notifees: nursery.start_soon(notifier, notifee) + + # Backward compatibility properties + @property + def connections_legacy(self) -> dict[ID, INetConn]: + """ + Legacy 1:1 mapping for backward compatibility. + + Returns + ------- + dict[ID, INetConn] + Legacy mapping with only the first connection per peer. + + """ + legacy_conns = {} + for peer_id, conns in self.connections.items(): + if conns: + legacy_conns[peer_id] = conns[0] + return legacy_conns diff --git a/libp2p/peer/envelope.py b/libp2p/peer/envelope.py index e93a8280..f8bf9f43 100644 --- a/libp2p/peer/envelope.py +++ b/libp2p/peer/envelope.py @@ -1,5 +1,7 @@ from typing import Any, cast +import multiaddr + from libp2p.crypto.ed25519 import Ed25519PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.rsa import RSAPublicKey @@ -131,6 +133,9 @@ class Envelope: ) return False + def _env_addrs_set(self) -> set[multiaddr.Multiaddr]: + return {b for b in self.record().addrs} + def pub_key_to_protobuf(pub_key: PublicKey) -> cryto_pb.PublicKey: """ diff --git a/libp2p/peer/peerstore.py b/libp2p/peer/peerstore.py index 043aaf0d..ddf1af1f 100644 --- a/libp2p/peer/peerstore.py +++ b/libp2p/peer/peerstore.py @@ -16,6 +16,7 @@ import trio from trio import MemoryReceiveChannel, MemorySendChannel from libp2p.abc import ( + IHost, IPeerStore, ) from libp2p.crypto.keys import ( @@ -23,7 +24,8 @@ from libp2p.crypto.keys import ( PrivateKey, PublicKey, ) -from libp2p.peer.envelope import Envelope +from libp2p.peer.envelope import Envelope, seal_record +from libp2p.peer.peer_record import PeerRecord from .id import ( ID, @@ -39,6 +41,86 @@ from .peerinfo import ( PERMANENT_ADDR_TTL = 0 +def create_signed_peer_record( + peer_id: ID, addrs: list[Multiaddr], pvt_key: PrivateKey +) -> Envelope: + """Creates a signed_peer_record wrapped in an Envelope""" + record = PeerRecord(peer_id, addrs) + envelope = seal_record(record, pvt_key) + return envelope + + +def env_to_send_in_RPC(host: IHost) -> tuple[bytes, bool]: + """ + Return the signed peer record (Envelope) to be sent in an RPC. + + This function checks whether the host already has a cached signed peer record + (SPR). If one exists and its addresses match the host's current listen + addresses, the cached envelope is reused. Otherwise, a new signed peer record + is created, cached, and returned. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + tuple[bytes, bool] + A 2-tuple where the first element is the serialized envelope (bytes) + for the signed peer record, and the second element is a boolean flag + indicating whether a new record was created (True) or an existing cached + one was reused (False). + + """ + listen_addrs_set = {addr for addr in host.get_addrs()} + local_env = host.get_peerstore().get_local_record() + + if local_env is None: + # No cached SPR yet -> create one + return issue_and_cache_local_record(host), True + else: + record_addrs_set = local_env._env_addrs_set() + if record_addrs_set == listen_addrs_set: + # Perfect match -> reuse cached envelope + return local_env.marshal_envelope(), False + else: + # Addresses changed -> issue a new SPR and cache it + return issue_and_cache_local_record(host), True + + +def issue_and_cache_local_record(host: IHost) -> bytes: + """ + Create and cache a new signed peer record (Envelope) for the host. + + This function generates a new signed peer record from the host’s peer ID, + listen addresses, and private key. The resulting envelope is stored in + the peerstore as the local record for future reuse. + + Parameters + ---------- + host : IHost + The local host instance, providing access to peer ID, listen addresses, + private key, and the peerstore. + + Returns + ------- + bytes + The serialized envelope (bytes) representing the newly created signed + peer record. + + """ + env = create_signed_peer_record( + host.get_id(), + host.get_addrs(), + host.get_private_key(), + ) + # Cache it for next time use + host.get_peerstore().set_local_record(env) + return env.marshal_envelope() + + class PeerRecordState: envelope: Envelope seq: int @@ -55,8 +137,17 @@ class PeerStore(IPeerStore): self.peer_data_map = defaultdict(PeerData) self.addr_update_channels: dict[ID, MemorySendChannel[Multiaddr]] = {} self.peer_record_map: dict[ID, PeerRecordState] = {} + self.local_peer_record: Envelope | None = None self.max_records = max_records + def get_local_record(self) -> Envelope | None: + """Get the local-signed-record wrapped in Envelope""" + return self.local_peer_record + + def set_local_record(self, envelope: Envelope) -> None: + """Set the local-signed-record wrapped in Envelope""" + self.local_peer_record = envelope + def peer_info(self, peer_id: ID) -> PeerInfo: """ :param peer_id: peer ID to get info for diff --git a/libp2p/protocol_muxer/multiselect_communicator.py b/libp2p/protocol_muxer/multiselect_communicator.py index 98a8129c..dff5b339 100644 --- a/libp2p/protocol_muxer/multiselect_communicator.py +++ b/libp2p/protocol_muxer/multiselect_communicator.py @@ -1,3 +1,5 @@ +from builtins import AssertionError + from libp2p.abc import ( IMultiselectCommunicator, ) @@ -36,7 +38,8 @@ class MultiselectCommunicator(IMultiselectCommunicator): msg_bytes = encode_delim(msg_str.encode()) try: await self.read_writer.write(msg_bytes) - except IOException as error: + # Handle for connection close during ongoing negotiation in QUIC + except (IOException, AssertionError, ValueError) as error: raise MultiselectCommunicatorError( "fail to write to multiselect communicator" ) from error diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 3e0d454f..8167581d 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -15,6 +15,7 @@ from libp2p.custom_types import ( from libp2p.peer.id import ( ID, ) +from libp2p.peer.peerstore import env_to_send_in_RPC from .exceptions import ( PubsubRouterError, @@ -103,6 +104,11 @@ class FloodSub(IPubsubRouter): ) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) + # Add the senderRecord of the peer in the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) + rpc_msg.senderRecord = envelope_bytes + logger.debug("publishing message %s", pubsub_msg) if self.pubsub is None: diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index c345c138..45c6cd81 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -1,6 +1,3 @@ -from ast import ( - literal_eval, -) from collections import ( defaultdict, ) @@ -22,6 +19,7 @@ from libp2p.abc import ( IPubsubRouter, ) from libp2p.custom_types import ( + MessageID, TProtocol, ) from libp2p.peer.id import ( @@ -34,10 +32,12 @@ from libp2p.peer.peerinfo import ( ) from libp2p.peer.peerstore import ( PERMANENT_ADDR_TTL, + env_to_send_in_RPC, ) from libp2p.pubsub import ( floodsub, ) +from libp2p.pubsub.utils import maybe_consume_signed_record from libp2p.tools.async_service import ( Service, ) @@ -54,6 +54,10 @@ from .pb import ( from .pubsub import ( Pubsub, ) +from .utils import ( + parse_message_id_safe, + safe_parse_message_id, +) PROTOCOL_ID = TProtocol("/meshsub/1.0.0") PROTOCOL_ID_V11 = TProtocol("/meshsub/1.1.0") @@ -226,6 +230,12 @@ class GossipSub(IPubsubRouter, Service): :param rpc: RPC message :param sender_peer_id: id of the peer who sent the message """ + # Process the senderRecord if sent + if isinstance(self.pubsub, Pubsub): + if not maybe_consume_signed_record(rpc, self.pubsub.host, sender_peer_id): + logger.error("Received an invalid-signed-record, ignoring the message") + return + control_message = rpc.control # Relay each rpc control message to the appropriate handler @@ -253,6 +263,11 @@ class GossipSub(IPubsubRouter, Service): ) rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg]) + # Add the senderRecord of the peer in the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) + rpc_msg.senderRecord = envelope_bytes + logger.debug("publishing message %s", pubsub_msg) for peer_id in peers_gen: @@ -781,8 +796,8 @@ class GossipSub(IPubsubRouter, Service): # Add all unknown message ids (ids that appear in ihave_msg but not in # seen_seqnos) to list of messages we want to request - msg_ids_wanted: list[str] = [ - msg_id + msg_ids_wanted: list[MessageID] = [ + parse_message_id_safe(msg_id) for msg_id in ihave_msg.messageIDs if msg_id not in seen_seqnos_and_peers ] @@ -798,9 +813,9 @@ class GossipSub(IPubsubRouter, Service): Forwards all request messages that are present in mcache to the requesting peer. """ - # FIXME: Update type of message ID - # FIXME: Find a better way to parse the msg ids - msg_ids: list[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs] + msg_ids: list[tuple[bytes, bytes]] = [ + safe_parse_message_id(msg) for msg in iwant_msg.messageIDs + ] msgs_to_forward: list[rpc_pb2.Message] = [] for msg_id_iwant in msg_ids: # Check if the wanted message ID is present in mcache @@ -818,6 +833,13 @@ class GossipSub(IPubsubRouter, Service): # 1) Package these messages into a single packet packet: rpc_pb2.RPC = rpc_pb2.RPC() + # Here the an RPC message is being created and published in response + # to the iwant control msg, so we will send a freshly created senderRecord + # with the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) + packet.senderRecord = envelope_bytes + packet.publish.extend(msgs_to_forward) if self.pubsub is None: @@ -973,6 +995,12 @@ class GossipSub(IPubsubRouter, Service): raise NoPubsubAttached # Add control message to packet packet: rpc_pb2.RPC = rpc_pb2.RPC() + + # Add the sender's peer-record in the RPC msg + if isinstance(self.pubsub, Pubsub): + envelope_bytes, _ = env_to_send_in_RPC(self.pubsub.host) + packet.senderRecord = envelope_bytes + packet.control.CopyFrom(control_msg) # Get stream for peer from pubsub diff --git a/libp2p/pubsub/pb/rpc.proto b/libp2p/pubsub/pb/rpc.proto index 7abce0d6..d24db281 100644 --- a/libp2p/pubsub/pb/rpc.proto +++ b/libp2p/pubsub/pb/rpc.proto @@ -14,6 +14,7 @@ message RPC { } optional ControlMessage control = 3; + optional bytes senderRecord = 4; } message Message { diff --git a/libp2p/pubsub/pb/rpc_pb2.py b/libp2p/pubsub/pb/rpc_pb2.py index 30f0281b..e4a35745 100644 --- a/libp2p/pubsub/pb/rpc_pb2.py +++ b/libp2p/pubsub/pb/rpc_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: libp2p/pubsub/pb/rpc.proto +# Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,39 +14,39 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xb4\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1alibp2p/pubsub/pb/rpc.proto\x12\tpubsub.pb\"\xca\x01\n\x03RPC\x12-\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x16.pubsub.pb.RPC.SubOpts\x12#\n\x07publish\x18\x02 \x03(\x0b\x32\x12.pubsub.pb.Message\x12*\n\x07\x63ontrol\x18\x03 \x01(\x0b\x32\x19.pubsub.pb.ControlMessage\x12\x14\n\x0csenderRecord\x18\x04 \x01(\x0c\x1a-\n\x07SubOpts\x12\x11\n\tsubscribe\x18\x01 \x01(\x08\x12\x0f\n\x07topicid\x18\x02 \x01(\t\"i\n\x07Message\x12\x0f\n\x07\x66rom_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\r\n\x05seqno\x18\x03 \x01(\x0c\x12\x10\n\x08topicIDs\x18\x04 \x03(\t\x12\x11\n\tsignature\x18\x05 \x01(\x0c\x12\x0b\n\x03key\x18\x06 \x01(\x0c\"\xb0\x01\n\x0e\x43ontrolMessage\x12&\n\x05ihave\x18\x01 \x03(\x0b\x32\x17.pubsub.pb.ControlIHave\x12&\n\x05iwant\x18\x02 \x03(\x0b\x32\x17.pubsub.pb.ControlIWant\x12&\n\x05graft\x18\x03 \x03(\x0b\x32\x17.pubsub.pb.ControlGraft\x12&\n\x05prune\x18\x04 \x03(\x0b\x32\x17.pubsub.pb.ControlPrune\"3\n\x0c\x43ontrolIHave\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\x12\n\nmessageIDs\x18\x02 \x03(\t\"\"\n\x0c\x43ontrolIWant\x12\x12\n\nmessageIDs\x18\x01 \x03(\t\"\x1f\n\x0c\x43ontrolGraft\x12\x0f\n\x07topicID\x18\x01 \x01(\t\"T\n\x0c\x43ontrolPrune\x12\x0f\n\x07topicID\x18\x01 \x01(\t\x12\"\n\x05peers\x18\x02 \x03(\x0b\x32\x13.pubsub.pb.PeerInfo\x12\x0f\n\x07\x62\x61\x63koff\x18\x03 \x01(\x04\"4\n\x08PeerInfo\x12\x0e\n\x06peerID\x18\x01 \x01(\x0c\x12\x18\n\x10signedPeerRecord\x18\x02 \x01(\x0c\"\x87\x03\n\x0fTopicDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04\x61uth\x18\x02 \x01(\x0b\x32#.pubsub.pb.TopicDescriptor.AuthOpts\x12/\n\x03\x65nc\x18\x03 \x01(\x0b\x32\".pubsub.pb.TopicDescriptor.EncOpts\x1a|\n\x08\x41uthOpts\x12:\n\x04mode\x18\x01 \x01(\x0e\x32,.pubsub.pb.TopicDescriptor.AuthOpts.AuthMode\x12\x0c\n\x04keys\x18\x02 \x03(\x0c\"&\n\x08\x41uthMode\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03KEY\x10\x01\x12\x07\n\x03WOT\x10\x02\x1a\x83\x01\n\x07\x45ncOpts\x12\x38\n\x04mode\x18\x01 \x01(\x0e\x32*.pubsub.pb.TopicDescriptor.EncOpts.EncMode\x12\x11\n\tkeyHashes\x18\x02 \x03(\x0c\"+\n\x07\x45ncMode\x12\x08\n\x04NONE\x10\x00\x12\r\n\tSHAREDKEY\x10\x01\x12\x07\n\x03WOT\x10\x02') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.pubsub.pb.rpc_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _RPC._serialized_start=42 - _RPC._serialized_end=222 - _RPC_SUBOPTS._serialized_start=177 - _RPC_SUBOPTS._serialized_end=222 - _MESSAGE._serialized_start=224 - _MESSAGE._serialized_end=329 - _CONTROLMESSAGE._serialized_start=332 - _CONTROLMESSAGE._serialized_end=508 - _CONTROLIHAVE._serialized_start=510 - _CONTROLIHAVE._serialized_end=561 - _CONTROLIWANT._serialized_start=563 - _CONTROLIWANT._serialized_end=597 - _CONTROLGRAFT._serialized_start=599 - _CONTROLGRAFT._serialized_end=630 - _CONTROLPRUNE._serialized_start=632 - _CONTROLPRUNE._serialized_end=716 - _PEERINFO._serialized_start=718 - _PEERINFO._serialized_end=770 - _TOPICDESCRIPTOR._serialized_start=773 - _TOPICDESCRIPTOR._serialized_end=1164 - _TOPICDESCRIPTOR_AUTHOPTS._serialized_start=906 - _TOPICDESCRIPTOR_AUTHOPTS._serialized_end=1030 - _TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_start=992 - _TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE._serialized_end=1030 - _TOPICDESCRIPTOR_ENCOPTS._serialized_start=1033 - _TOPICDESCRIPTOR_ENCOPTS._serialized_end=1164 - _TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_start=1121 - _TOPICDESCRIPTOR_ENCOPTS_ENCMODE._serialized_end=1164 + _globals['_RPC']._serialized_start=42 + _globals['_RPC']._serialized_end=244 + _globals['_RPC_SUBOPTS']._serialized_start=199 + _globals['_RPC_SUBOPTS']._serialized_end=244 + _globals['_MESSAGE']._serialized_start=246 + _globals['_MESSAGE']._serialized_end=351 + _globals['_CONTROLMESSAGE']._serialized_start=354 + _globals['_CONTROLMESSAGE']._serialized_end=530 + _globals['_CONTROLIHAVE']._serialized_start=532 + _globals['_CONTROLIHAVE']._serialized_end=583 + _globals['_CONTROLIWANT']._serialized_start=585 + _globals['_CONTROLIWANT']._serialized_end=619 + _globals['_CONTROLGRAFT']._serialized_start=621 + _globals['_CONTROLGRAFT']._serialized_end=652 + _globals['_CONTROLPRUNE']._serialized_start=654 + _globals['_CONTROLPRUNE']._serialized_end=738 + _globals['_PEERINFO']._serialized_start=740 + _globals['_PEERINFO']._serialized_end=792 + _globals['_TOPICDESCRIPTOR']._serialized_start=795 + _globals['_TOPICDESCRIPTOR']._serialized_end=1186 + _globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_start=928 + _globals['_TOPICDESCRIPTOR_AUTHOPTS']._serialized_end=1052 + _globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_start=1014 + _globals['_TOPICDESCRIPTOR_AUTHOPTS_AUTHMODE']._serialized_end=1052 + _globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_start=1055 + _globals['_TOPICDESCRIPTOR_ENCOPTS']._serialized_end=1186 + _globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_start=1143 + _globals['_TOPICDESCRIPTOR_ENCOPTS_ENCMODE']._serialized_end=1186 # @@protoc_insertion_point(module_scope) diff --git a/libp2p/pubsub/pb/rpc_pb2.pyi b/libp2p/pubsub/pb/rpc_pb2.pyi index 88738e2e..2609fd11 100644 --- a/libp2p/pubsub/pb/rpc_pb2.pyi +++ b/libp2p/pubsub/pb/rpc_pb2.pyi @@ -1,323 +1,132 @@ -""" -@generated by mypy-protobuf. Do not edit manually! -isort:skip_file -Modified from https://github.com/libp2p/go-libp2p-pubsub/blob/master/pb/rpc.proto""" +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union -import builtins -import collections.abc -import google.protobuf.descriptor -import google.protobuf.internal.containers -import google.protobuf.internal.enum_type_wrapper -import google.protobuf.message -import sys -import typing +DESCRIPTOR: _descriptor.FileDescriptor -if sys.version_info >= (3, 10): - import typing as typing_extensions -else: - import typing_extensions +class RPC(_message.Message): + __slots__ = ("subscriptions", "publish", "control", "senderRecord") + class SubOpts(_message.Message): + __slots__ = ("subscribe", "topicid") + SUBSCRIBE_FIELD_NUMBER: _ClassVar[int] + TOPICID_FIELD_NUMBER: _ClassVar[int] + subscribe: bool + topicid: str + def __init__(self, subscribe: bool = ..., topicid: _Optional[str] = ...) -> None: ... + SUBSCRIPTIONS_FIELD_NUMBER: _ClassVar[int] + PUBLISH_FIELD_NUMBER: _ClassVar[int] + CONTROL_FIELD_NUMBER: _ClassVar[int] + SENDERRECORD_FIELD_NUMBER: _ClassVar[int] + subscriptions: _containers.RepeatedCompositeFieldContainer[RPC.SubOpts] + publish: _containers.RepeatedCompositeFieldContainer[Message] + control: ControlMessage + senderRecord: bytes + def __init__(self, subscriptions: _Optional[_Iterable[_Union[RPC.SubOpts, _Mapping]]] = ..., publish: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., control: _Optional[_Union[ControlMessage, _Mapping]] = ..., senderRecord: _Optional[bytes] = ...) -> None: ... # type: ignore -DESCRIPTOR: google.protobuf.descriptor.FileDescriptor +class Message(_message.Message): + __slots__ = ("from_id", "data", "seqno", "topicIDs", "signature", "key") + FROM_ID_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + SEQNO_FIELD_NUMBER: _ClassVar[int] + TOPICIDS_FIELD_NUMBER: _ClassVar[int] + SIGNATURE_FIELD_NUMBER: _ClassVar[int] + KEY_FIELD_NUMBER: _ClassVar[int] + from_id: bytes + data: bytes + seqno: bytes + topicIDs: _containers.RepeatedScalarFieldContainer[str] + signature: bytes + key: bytes + def __init__(self, from_id: _Optional[bytes] = ..., data: _Optional[bytes] = ..., seqno: _Optional[bytes] = ..., topicIDs: _Optional[_Iterable[str]] = ..., signature: _Optional[bytes] = ..., key: _Optional[bytes] = ...) -> None: ... -@typing.final -class RPC(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor +class ControlMessage(_message.Message): + __slots__ = ("ihave", "iwant", "graft", "prune") + IHAVE_FIELD_NUMBER: _ClassVar[int] + IWANT_FIELD_NUMBER: _ClassVar[int] + GRAFT_FIELD_NUMBER: _ClassVar[int] + PRUNE_FIELD_NUMBER: _ClassVar[int] + ihave: _containers.RepeatedCompositeFieldContainer[ControlIHave] + iwant: _containers.RepeatedCompositeFieldContainer[ControlIWant] + graft: _containers.RepeatedCompositeFieldContainer[ControlGraft] + prune: _containers.RepeatedCompositeFieldContainer[ControlPrune] + def __init__(self, ihave: _Optional[_Iterable[_Union[ControlIHave, _Mapping]]] = ..., iwant: _Optional[_Iterable[_Union[ControlIWant, _Mapping]]] = ..., graft: _Optional[_Iterable[_Union[ControlGraft, _Mapping]]] = ..., prune: _Optional[_Iterable[_Union[ControlPrune, _Mapping]]] = ...) -> None: ... # type: ignore - @typing.final - class SubOpts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor +class ControlIHave(_message.Message): + __slots__ = ("topicID", "messageIDs") + TOPICID_FIELD_NUMBER: _ClassVar[int] + MESSAGEIDS_FIELD_NUMBER: _ClassVar[int] + topicID: str + messageIDs: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, topicID: _Optional[str] = ..., messageIDs: _Optional[_Iterable[str]] = ...) -> None: ... - SUBSCRIBE_FIELD_NUMBER: builtins.int - TOPICID_FIELD_NUMBER: builtins.int - subscribe: builtins.bool - """subscribe or unsubscribe""" - topicid: builtins.str - def __init__( - self, - *, - subscribe: builtins.bool | None = ..., - topicid: builtins.str | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["subscribe", b"subscribe", "topicid", b"topicid"]) -> None: ... +class ControlIWant(_message.Message): + __slots__ = ("messageIDs",) + MESSAGEIDS_FIELD_NUMBER: _ClassVar[int] + messageIDs: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, messageIDs: _Optional[_Iterable[str]] = ...) -> None: ... - SUBSCRIPTIONS_FIELD_NUMBER: builtins.int - PUBLISH_FIELD_NUMBER: builtins.int - CONTROL_FIELD_NUMBER: builtins.int - @property - def subscriptions(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___RPC.SubOpts]: ... - @property - def publish(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Message]: ... - @property - def control(self) -> global___ControlMessage: ... - def __init__( - self, - *, - subscriptions: collections.abc.Iterable[global___RPC.SubOpts] | None = ..., - publish: collections.abc.Iterable[global___Message] | None = ..., - control: global___ControlMessage | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["control", b"control"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["control", b"control", "publish", b"publish", "subscriptions", b"subscriptions"]) -> None: ... +class ControlGraft(_message.Message): + __slots__ = ("topicID",) + TOPICID_FIELD_NUMBER: _ClassVar[int] + topicID: str + def __init__(self, topicID: _Optional[str] = ...) -> None: ... -global___RPC = RPC +class ControlPrune(_message.Message): + __slots__ = ("topicID", "peers", "backoff") + TOPICID_FIELD_NUMBER: _ClassVar[int] + PEERS_FIELD_NUMBER: _ClassVar[int] + BACKOFF_FIELD_NUMBER: _ClassVar[int] + topicID: str + peers: _containers.RepeatedCompositeFieldContainer[PeerInfo] + backoff: int + def __init__(self, topicID: _Optional[str] = ..., peers: _Optional[_Iterable[_Union[PeerInfo, _Mapping]]] = ..., backoff: _Optional[int] = ...) -> None: ... # type: ignore -@typing.final -class Message(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor +class PeerInfo(_message.Message): + __slots__ = ("peerID", "signedPeerRecord") + PEERID_FIELD_NUMBER: _ClassVar[int] + SIGNEDPEERRECORD_FIELD_NUMBER: _ClassVar[int] + peerID: bytes + signedPeerRecord: bytes + def __init__(self, peerID: _Optional[bytes] = ..., signedPeerRecord: _Optional[bytes] = ...) -> None: ... - FROM_ID_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - SEQNO_FIELD_NUMBER: builtins.int - TOPICIDS_FIELD_NUMBER: builtins.int - SIGNATURE_FIELD_NUMBER: builtins.int - KEY_FIELD_NUMBER: builtins.int - from_id: builtins.bytes - data: builtins.bytes - seqno: builtins.bytes - signature: builtins.bytes - key: builtins.bytes - @property - def topicIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - def __init__( - self, - *, - from_id: builtins.bytes | None = ..., - data: builtins.bytes | None = ..., - seqno: builtins.bytes | None = ..., - topicIDs: collections.abc.Iterable[builtins.str] | None = ..., - signature: builtins.bytes | None = ..., - key: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "from_id", b"from_id", "key", b"key", "seqno", b"seqno", "signature", b"signature", "topicIDs", b"topicIDs"]) -> None: ... - -global___Message = Message - -@typing.final -class ControlMessage(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - IHAVE_FIELD_NUMBER: builtins.int - IWANT_FIELD_NUMBER: builtins.int - GRAFT_FIELD_NUMBER: builtins.int - PRUNE_FIELD_NUMBER: builtins.int - @property - def ihave(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIHave]: ... - @property - def iwant(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlIWant]: ... - @property - def graft(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlGraft]: ... - @property - def prune(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___ControlPrune]: ... - def __init__( - self, - *, - ihave: collections.abc.Iterable[global___ControlIHave] | None = ..., - iwant: collections.abc.Iterable[global___ControlIWant] | None = ..., - graft: collections.abc.Iterable[global___ControlGraft] | None = ..., - prune: collections.abc.Iterable[global___ControlPrune] | None = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["graft", b"graft", "ihave", b"ihave", "iwant", b"iwant", "prune", b"prune"]) -> None: ... - -global___ControlMessage = ControlMessage - -@typing.final -class ControlIHave(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TOPICID_FIELD_NUMBER: builtins.int - MESSAGEIDS_FIELD_NUMBER: builtins.int - topicID: builtins.str - @property - def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - def __init__( - self, - *, - topicID: builtins.str | None = ..., - messageIDs: collections.abc.Iterable[builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs", "topicID", b"topicID"]) -> None: ... - -global___ControlIHave = ControlIHave - -@typing.final -class ControlIWant(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - MESSAGEIDS_FIELD_NUMBER: builtins.int - @property - def messageIDs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - def __init__( - self, - *, - messageIDs: collections.abc.Iterable[builtins.str] | None = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["messageIDs", b"messageIDs"]) -> None: ... - -global___ControlIWant = ControlIWant - -@typing.final -class ControlGraft(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TOPICID_FIELD_NUMBER: builtins.int - topicID: builtins.str - def __init__( - self, - *, - topicID: builtins.str | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["topicID", b"topicID"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["topicID", b"topicID"]) -> None: ... - -global___ControlGraft = ControlGraft - -@typing.final -class ControlPrune(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TOPICID_FIELD_NUMBER: builtins.int - PEERS_FIELD_NUMBER: builtins.int - BACKOFF_FIELD_NUMBER: builtins.int - topicID: builtins.str - backoff: builtins.int - @property - def peers(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___PeerInfo]: ... - def __init__( - self, - *, - topicID: builtins.str | None = ..., - peers: collections.abc.Iterable[global___PeerInfo] | None = ..., - backoff: builtins.int | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["backoff", b"backoff", "topicID", b"topicID"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["backoff", b"backoff", "peers", b"peers", "topicID", b"topicID"]) -> None: ... - -global___ControlPrune = ControlPrune - -@typing.final -class PeerInfo(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - PEERID_FIELD_NUMBER: builtins.int - SIGNEDPEERRECORD_FIELD_NUMBER: builtins.int - peerID: builtins.bytes - signedPeerRecord: builtins.bytes - def __init__( - self, - *, - peerID: builtins.bytes | None = ..., - signedPeerRecord: builtins.bytes | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["peerID", b"peerID", "signedPeerRecord", b"signedPeerRecord"]) -> None: ... - -global___PeerInfo = PeerInfo - -@typing.final -class TopicDescriptor(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class AuthOpts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _AuthMode: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _AuthModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.AuthOpts._AuthMode.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NONE: TopicDescriptor.AuthOpts._AuthMode.ValueType # 0 - """no authentication, anyone can publish""" - KEY: TopicDescriptor.AuthOpts._AuthMode.ValueType # 1 - """only messages signed by keys in the topic descriptor are accepted""" - WOT: TopicDescriptor.AuthOpts._AuthMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - class AuthMode(_AuthMode, metaclass=_AuthModeEnumTypeWrapper): ... - NONE: TopicDescriptor.AuthOpts.AuthMode.ValueType # 0 - """no authentication, anyone can publish""" - KEY: TopicDescriptor.AuthOpts.AuthMode.ValueType # 1 - """only messages signed by keys in the topic descriptor are accepted""" - WOT: TopicDescriptor.AuthOpts.AuthMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - MODE_FIELD_NUMBER: builtins.int - KEYS_FIELD_NUMBER: builtins.int - mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType - @property - def keys(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: - """root keys to trust""" - - def __init__( - self, - *, - mode: global___TopicDescriptor.AuthOpts.AuthMode.ValueType | None = ..., - keys: collections.abc.Iterable[builtins.bytes] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["keys", b"keys", "mode", b"mode"]) -> None: ... - - @typing.final - class EncOpts(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class _EncMode: - ValueType = typing.NewType("ValueType", builtins.int) - V: typing_extensions.TypeAlias = ValueType - - class _EncModeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TopicDescriptor.EncOpts._EncMode.ValueType], builtins.type): - DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor - NONE: TopicDescriptor.EncOpts._EncMode.ValueType # 0 - """no encryption, anyone can read""" - SHAREDKEY: TopicDescriptor.EncOpts._EncMode.ValueType # 1 - """messages are encrypted with shared key""" - WOT: TopicDescriptor.EncOpts._EncMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - class EncMode(_EncMode, metaclass=_EncModeEnumTypeWrapper): ... - NONE: TopicDescriptor.EncOpts.EncMode.ValueType # 0 - """no encryption, anyone can read""" - SHAREDKEY: TopicDescriptor.EncOpts.EncMode.ValueType # 1 - """messages are encrypted with shared key""" - WOT: TopicDescriptor.EncOpts.EncMode.ValueType # 2 - """web of trust, certificates can allow publisher set to grow""" - - MODE_FIELD_NUMBER: builtins.int - KEYHASHES_FIELD_NUMBER: builtins.int - mode: global___TopicDescriptor.EncOpts.EncMode.ValueType - @property - def keyHashes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: - """the hashes of the shared keys used (salted)""" - - def __init__( - self, - *, - mode: global___TopicDescriptor.EncOpts.EncMode.ValueType | None = ..., - keyHashes: collections.abc.Iterable[builtins.bytes] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["mode", b"mode"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["keyHashes", b"keyHashes", "mode", b"mode"]) -> None: ... - - NAME_FIELD_NUMBER: builtins.int - AUTH_FIELD_NUMBER: builtins.int - ENC_FIELD_NUMBER: builtins.int - name: builtins.str - @property - def auth(self) -> global___TopicDescriptor.AuthOpts: ... - @property - def enc(self) -> global___TopicDescriptor.EncOpts: ... - def __init__( - self, - *, - name: builtins.str | None = ..., - auth: global___TopicDescriptor.AuthOpts | None = ..., - enc: global___TopicDescriptor.EncOpts | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["auth", b"auth", "enc", b"enc", "name", b"name"]) -> None: ... - -global___TopicDescriptor = TopicDescriptor +class TopicDescriptor(_message.Message): + __slots__ = ("name", "auth", "enc") + class AuthOpts(_message.Message): + __slots__ = ("mode", "keys") + class AuthMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NONE: _ClassVar[TopicDescriptor.AuthOpts.AuthMode] + KEY: _ClassVar[TopicDescriptor.AuthOpts.AuthMode] + WOT: _ClassVar[TopicDescriptor.AuthOpts.AuthMode] + NONE: TopicDescriptor.AuthOpts.AuthMode + KEY: TopicDescriptor.AuthOpts.AuthMode + WOT: TopicDescriptor.AuthOpts.AuthMode + MODE_FIELD_NUMBER: _ClassVar[int] + KEYS_FIELD_NUMBER: _ClassVar[int] + mode: TopicDescriptor.AuthOpts.AuthMode + keys: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, mode: _Optional[_Union[TopicDescriptor.AuthOpts.AuthMode, str]] = ..., keys: _Optional[_Iterable[bytes]] = ...) -> None: ... + class EncOpts(_message.Message): + __slots__ = ("mode", "keyHashes") + class EncMode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NONE: _ClassVar[TopicDescriptor.EncOpts.EncMode] + SHAREDKEY: _ClassVar[TopicDescriptor.EncOpts.EncMode] + WOT: _ClassVar[TopicDescriptor.EncOpts.EncMode] + NONE: TopicDescriptor.EncOpts.EncMode + SHAREDKEY: TopicDescriptor.EncOpts.EncMode + WOT: TopicDescriptor.EncOpts.EncMode + MODE_FIELD_NUMBER: _ClassVar[int] + KEYHASHES_FIELD_NUMBER: _ClassVar[int] + mode: TopicDescriptor.EncOpts.EncMode + keyHashes: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, mode: _Optional[_Union[TopicDescriptor.EncOpts.EncMode, str]] = ..., keyHashes: _Optional[_Iterable[bytes]] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + AUTH_FIELD_NUMBER: _ClassVar[int] + ENC_FIELD_NUMBER: _ClassVar[int] + name: str + auth: TopicDescriptor.AuthOpts + enc: TopicDescriptor.EncOpts + def __init__(self, name: _Optional[str] = ..., auth: _Optional[_Union[TopicDescriptor.AuthOpts, _Mapping]] = ..., enc: _Optional[_Union[TopicDescriptor.EncOpts, _Mapping]] = ...) -> None: ... # type: ignore diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 5641ec5d..2c605fc3 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -56,6 +56,8 @@ from libp2p.peer.id import ( from libp2p.peer.peerdata import ( PeerDataError, ) +from libp2p.peer.peerstore import env_to_send_in_RPC +from libp2p.pubsub.utils import maybe_consume_signed_record from libp2p.tools.async_service import ( Service, ) @@ -247,6 +249,10 @@ class Pubsub(Service, IPubsub): packet.subscriptions.extend( [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) + # Add the sender's signedRecord in the RPC message + envelope_bytes, _ = env_to_send_in_RPC(self.host) + packet.senderRecord = envelope_bytes + return packet async def continuously_read_stream(self, stream: INetStream) -> None: @@ -263,6 +269,14 @@ class Pubsub(Service, IPubsub): incoming: bytes = await read_varint_prefixed_bytes(stream) rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC() rpc_incoming.ParseFromString(incoming) + + # Process the sender's signed-record if sent + if not maybe_consume_signed_record(rpc_incoming, self.host, peer_id): + logger.error( + "Received an invalid-signed-record, ignoring the incoming msg" + ) + continue + if rpc_incoming.publish: # deal with RPC.publish for msg in rpc_incoming.publish: @@ -572,6 +586,9 @@ class Pubsub(Service, IPubsub): [rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)] ) + # Add the senderRecord of the peer in the RPC msg + envelope_bytes, _ = env_to_send_in_RPC(self.host) + packet.senderRecord = envelope_bytes # Send out subscribe message to all peers await self.message_all_peers(packet.SerializeToString()) @@ -604,6 +621,9 @@ class Pubsub(Service, IPubsub): packet.subscriptions.extend( [rpc_pb2.RPC.SubOpts(subscribe=False, topicid=topic_id)] ) + # Add the senderRecord of the peer in the RPC msg + envelope_bytes, _ = env_to_send_in_RPC(self.host) + packet.senderRecord = envelope_bytes # Send out unsubscribe message to all peers await self.message_all_peers(packet.SerializeToString()) diff --git a/libp2p/pubsub/utils.py b/libp2p/pubsub/utils.py new file mode 100644 index 00000000..6beaccc5 --- /dev/null +++ b/libp2p/pubsub/utils.py @@ -0,0 +1,80 @@ +import ast +import logging + +from libp2p.abc import IHost +from libp2p.custom_types import ( + MessageID, +) +from libp2p.peer.envelope import consume_envelope +from libp2p.peer.id import ID +from libp2p.pubsub.pb.rpc_pb2 import RPC + +logger = logging.getLogger("pubsub-example.utils") + + +def maybe_consume_signed_record(msg: RPC, host: IHost, peer_id: ID) -> bool: + """ + Attempt to parse and store a signed-peer-record (Envelope) received during + PubSub communication. If the record is invalid, the peer-id does not match, or + updating the peerstore fails, the function logs an error and returns False. + + Parameters + ---------- + msg : RPC + The protobuf message received during PubSub communication. + host : IHost + The local host instance, providing access to the peerstore for storing + verified peer records. + peer_id : ID | None, optional + The expected peer ID for record validation. If provided, the peer ID + inside the record must match this value. + + Returns + ------- + bool + True if a valid signed peer record was successfully consumed and stored, + False otherwise. + + """ + if msg.HasField("senderRecord"): + try: + # Convert the signed-peer-record(Envelope) from + # protobuf bytes + envelope, record = consume_envelope(msg.senderRecord, "libp2p-peer-record") + if not record.peer_id == peer_id: + return False + + # Use the default TTL of 2 hours (7200 seconds) + if not host.get_peerstore().consume_peer_record(envelope, 7200): + logger.error("Failed to update the Certified-Addr-Book") + return False + except Exception as e: + logger.error("Failed to update the Certified-Addr-Book: %s", e) + return False + return True + + +def parse_message_id_safe(msg_id_str: str) -> MessageID: + """Safely handle message ID as string.""" + return MessageID(msg_id_str) + + +def safe_parse_message_id(msg_id_str: str) -> tuple[bytes, bytes]: + """ + Safely parse message ID using ast.literal_eval with validation. + :param msg_id_str: String representation of message ID + :return: Tuple of (seqno, from_id) as bytes + :raises ValueError: If parsing fails + """ + try: + parsed = ast.literal_eval(msg_id_str) + if not isinstance(parsed, tuple) or len(parsed) != 2: + raise ValueError("Invalid message ID format") + + seqno, from_id = parsed + if not isinstance(seqno, bytes) or not isinstance(from_id, bytes): + raise ValueError("Message ID components must be bytes") + + return (seqno, from_id) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid message ID format: {e}") diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index a9c4b19c..ee8d4475 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -118,6 +118,8 @@ class SecurityMultistream(ABC): # Select protocol if non-initiator protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: - raise MultiselectError("fail to negotiate a security protocol") + raise MultiselectError( + "Failed to negotiate a security protocol: no protocol selected" + ) # Return transport from protocol return self.transports[protocol] diff --git a/libp2p/stream_muxer/muxer_multistream.py b/libp2p/stream_muxer/muxer_multistream.py index 322db912..ef90fac0 100644 --- a/libp2p/stream_muxer/muxer_multistream.py +++ b/libp2p/stream_muxer/muxer_multistream.py @@ -85,7 +85,9 @@ class MuxerMultistream: else: protocol, _ = await self.multiselect.negotiate(communicator) if protocol is None: - raise MultiselectError("fail to negotiate a stream muxer protocol") + raise MultiselectError( + "Fail to negotiate a stream muxer protocol: no protocol selected" + ) return self.transports[protocol] async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: diff --git a/libp2p/transport/quic/__init__.py b/libp2p/transport/quic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libp2p/transport/quic/config.py b/libp2p/transport/quic/config.py new file mode 100644 index 00000000..e0c87adf --- /dev/null +++ b/libp2p/transport/quic/config.py @@ -0,0 +1,345 @@ +""" +Configuration classes for QUIC transport. +""" + +from dataclasses import ( + dataclass, + field, +) +import ssl +from typing import Any, Literal, TypedDict + +from libp2p.custom_types import TProtocol +from libp2p.network.config import ConnectionConfig + + +class QUICTransportKwargs(TypedDict, total=False): + """Type definition for kwargs accepted by new_transport function.""" + + # Connection settings + idle_timeout: float + max_datagram_size: int + local_port: int | None + + # Protocol version support + enable_draft29: bool + enable_v1: bool + + # TLS settings + verify_mode: ssl.VerifyMode + alpn_protocols: list[str] + + # Performance settings + max_concurrent_streams: int + connection_window: int + stream_window: int + + # Logging and debugging + enable_qlog: bool + qlog_dir: str | None + + # Connection management + max_connections: int + connection_timeout: float + + # Protocol identifiers + PROTOCOL_QUIC_V1: TProtocol + PROTOCOL_QUIC_DRAFT29: TProtocol + + +@dataclass +class QUICTransportConfig(ConnectionConfig): + """Configuration for QUIC transport.""" + + # Connection settings + idle_timeout: float = 30.0 # Seconds before an idle connection is closed. + max_datagram_size: int = ( + 1200 # Maximum size of UDP datagrams to avoid IP fragmentation. + ) + local_port: int | None = ( + None # Local port to bind to. If None, a random port is chosen. + ) + + # Protocol version support + enable_draft29: bool = True # Enable QUIC draft-29 for compatibility + enable_v1: bool = True # Enable QUIC v1 (RFC 9000) + + # TLS settings + verify_mode: ssl.VerifyMode = ssl.CERT_NONE + alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"]) + + # Performance settings + max_concurrent_streams: int = 100 # Maximum concurrent streams per connection + connection_window: int = 1024 * 1024 # Connection flow control window + stream_window: int = 64 * 1024 # Stream flow control window + + # Logging and debugging + enable_qlog: bool = False # Enable QUIC logging + qlog_dir: str | None = None # Directory for QUIC logs + + # Connection management + max_connections: int = 1000 # Maximum number of connections + connection_timeout: float = 10.0 # Connection establishment timeout + + MAX_CONCURRENT_STREAMS: int = 1000 + """Maximum number of concurrent streams per connection.""" + + MAX_INCOMING_STREAMS: int = 1000 + """Maximum number of incoming streams per connection.""" + + CONNECTION_HANDSHAKE_TIMEOUT: float = 60.0 + """Timeout for connection handshake (seconds).""" + + MAX_OUTGOING_STREAMS: int = 1000 + """Maximum number of outgoing streams per connection.""" + + CONNECTION_CLOSE_TIMEOUT: int = 10 + """Timeout for opening new connection (seconds).""" + + # Stream timeouts + STREAM_OPEN_TIMEOUT: float = 5.0 + """Timeout for opening new streams (seconds).""" + + STREAM_ACCEPT_TIMEOUT: float = 30.0 + """Timeout for accepting incoming streams (seconds).""" + + STREAM_READ_TIMEOUT: float = 30.0 + """Default timeout for stream read operations (seconds).""" + + STREAM_WRITE_TIMEOUT: float = 30.0 + """Default timeout for stream write operations (seconds).""" + + STREAM_CLOSE_TIMEOUT: float = 10.0 + """Timeout for graceful stream close (seconds).""" + + # Flow control configuration + STREAM_FLOW_CONTROL_WINDOW: int = 1024 * 1024 # 1MB + """Per-stream flow control window size.""" + + CONNECTION_FLOW_CONTROL_WINDOW: int = 1536 * 1024 # 1.5MB + """Connection-wide flow control window size.""" + + # Buffer management + MAX_STREAM_RECEIVE_BUFFER: int = 2 * 1024 * 1024 # 2MB + """Maximum receive buffer size per stream.""" + + STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB + """Low watermark for stream receive buffer.""" + + STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB + """High watermark for stream receive buffer.""" + + # Stream lifecycle configuration + ENABLE_STREAM_RESET_ON_ERROR: bool = True + """Whether to automatically reset streams on errors.""" + + STREAM_RESET_ERROR_CODE: int = 1 + """Default error code for stream resets.""" + + ENABLE_STREAM_KEEP_ALIVE: bool = False + """Whether to enable stream keep-alive mechanisms.""" + + STREAM_KEEP_ALIVE_INTERVAL: float = 30.0 + """Interval for stream keep-alive pings (seconds).""" + + # Resource management + ENABLE_STREAM_RESOURCE_TRACKING: bool = True + """Whether to track stream resource usage.""" + + STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB + """Memory limit per individual stream.""" + + STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB + """Total memory limit for all streams per connection.""" + + # Concurrency and performance + ENABLE_STREAM_BATCHING: bool = True + """Whether to batch multiple stream operations.""" + + STREAM_BATCH_SIZE: int = 10 + """Number of streams to process in a batch.""" + + STREAM_PROCESSING_CONCURRENCY: int = 100 + """Maximum concurrent stream processing tasks.""" + + # Debugging and monitoring + ENABLE_STREAM_METRICS: bool = True + """Whether to collect stream metrics.""" + + ENABLE_STREAM_TIMELINE_TRACKING: bool = True + """Whether to track stream lifecycle timelines.""" + + STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0 + """Interval for collecting stream metrics (seconds).""" + + # Error handling configuration + STREAM_ERROR_RETRY_ATTEMPTS: int = 3 + """Number of retry attempts for recoverable stream errors.""" + + STREAM_ERROR_RETRY_DELAY: float = 1.0 + """Initial delay between stream error retries (seconds).""" + + STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0 + """Backoff factor for stream error retries.""" + + # Protocol identifiers matching go-libp2p + PROTOCOL_QUIC_V1: TProtocol = TProtocol("quic-v1") # RFC 9000 + PROTOCOL_QUIC_DRAFT29: TProtocol = TProtocol("quic") # draft-29 + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + if not (self.enable_draft29 or self.enable_v1): + raise ValueError("At least one QUIC version must be enabled") + + if self.idle_timeout <= 0: + raise ValueError("Idle timeout must be positive") + + if self.max_datagram_size < 1200: + raise ValueError("Max datagram size must be at least 1200 bytes") + + # Validate timeouts + timeout_fields = [ + "STREAM_OPEN_TIMEOUT", + "STREAM_ACCEPT_TIMEOUT", + "STREAM_READ_TIMEOUT", + "STREAM_WRITE_TIMEOUT", + "STREAM_CLOSE_TIMEOUT", + ] + for timeout_field in timeout_fields: + if getattr(self, timeout_field) <= 0: + raise ValueError(f"{timeout_field} must be positive") + + # Validate flow control windows + if self.STREAM_FLOW_CONTROL_WINDOW <= 0: + raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive") + + if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW: + raise ValueError( + "CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW" + ) + + # Validate buffer sizes + if self.MAX_STREAM_RECEIVE_BUFFER <= 0: + raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive") + + if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER: + raise ValueError( + "STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__( + "exceed MAX_STREAM_RECEIVE_BUFFER" + ) + ) + + if ( + self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK + >= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK + ): + raise ValueError( + "STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK" + ) + + # Validate memory limits + if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0: + raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive") + + if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0: + raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive") + + expected_stream_memory = ( + self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM + ) + if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2: + # Allow some headroom, but warn if configuration seems inconsistent + import logging + + logger = logging.getLogger(__name__) + logger.warning( + "Stream memory configuration may be inconsistent: " + f"{self.MAX_CONCURRENT_STREAMS} streams Ɨ" + "{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes " + "could exceed connection limit of" + f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes" + ) + + def get_stream_config_dict(self) -> dict[str, Any]: + """Get stream-specific configuration as dictionary.""" + stream_config = {} + for attr_name in dir(self): + if attr_name.startswith( + ("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW") + ): + stream_config[attr_name.lower()] = getattr(self, attr_name) + return stream_config + + +# Additional configuration classes for specific stream features + + +class QUICStreamFlowControlConfig: + """Configuration for QUIC stream flow control.""" + + def __init__( + self, + initial_window_size: int = 512 * 1024, + max_window_size: int = 2 * 1024 * 1024, + window_update_threshold: float = 0.5, + enable_auto_tuning: bool = True, + ): + self.initial_window_size = initial_window_size + self.max_window_size = max_window_size + self.window_update_threshold = window_update_threshold + self.enable_auto_tuning = enable_auto_tuning + + +def create_stream_config_for_use_case( + use_case: Literal[ + "high_throughput", "low_latency", "many_streams", "memory_constrained" + ], +) -> QUICTransportConfig: + """ + Create optimized stream configuration for specific use cases. + + Args: + use_case: One of "high_throughput", "low_latency", "many_streams"," + "memory_constrained" + + Returns: + Optimized QUICTransportConfig + + """ + base_config = QUICTransportConfig() + + if use_case == "high_throughput": + # Optimize for high throughput + base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB + base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB + base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB + base_config.STREAM_PROCESSING_CONCURRENCY = 200 + + elif use_case == "low_latency": + # Optimize for low latency + base_config.STREAM_OPEN_TIMEOUT = 1.0 + base_config.STREAM_READ_TIMEOUT = 5.0 + base_config.STREAM_WRITE_TIMEOUT = 5.0 + base_config.ENABLE_STREAM_BATCHING = False + base_config.STREAM_BATCH_SIZE = 1 + + elif use_case == "many_streams": + # Optimize for many concurrent streams + base_config.MAX_CONCURRENT_STREAMS = 5000 + base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB + base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB + base_config.STREAM_PROCESSING_CONCURRENCY = 500 + + elif use_case == "memory_constrained": + # Optimize for low memory usage + base_config.MAX_CONCURRENT_STREAMS = 100 + base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB + base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB + base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB + base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB + base_config.STREAM_PROCESSING_CONCURRENCY = 50 + + else: + raise ValueError(f"Unknown use case: {use_case}") + + return base_config diff --git a/libp2p/transport/quic/connection.py b/libp2p/transport/quic/connection.py new file mode 100644 index 00000000..428acd83 --- /dev/null +++ b/libp2p/transport/quic/connection.py @@ -0,0 +1,1487 @@ +""" +QUIC Connection implementation. +Manages bidirectional QUIC connections with integrated stream multiplexing. +""" + +from collections import defaultdict +from collections.abc import Awaitable, Callable +import logging +import socket +import time +from typing import TYPE_CHECKING, Any, Optional, cast + +from aioquic.quic import events +from aioquic.quic.connection import QuicConnection +from aioquic.quic.events import QuicEvent +from cryptography import x509 +import multiaddr +import trio + +from libp2p.abc import IMuxedConn, IRawConnection +from libp2p.custom_types import TQUICStreamHandlerFn +from libp2p.peer.id import ID +from libp2p.stream_muxer.exceptions import MuxedConnUnavailable + +from .exceptions import ( + QUICConnectionClosedError, + QUICConnectionError, + QUICConnectionTimeoutError, + QUICErrorContext, + QUICPeerVerificationError, + QUICStreamError, + QUICStreamLimitError, + QUICStreamTimeoutError, +) +from .stream import QUICStream, StreamDirection + +if TYPE_CHECKING: + from .security import QUICTLSConfigManager + from .transport import QUICTransport + +logger = logging.getLogger(__name__) + + +class QUICConnection(IRawConnection, IMuxedConn): + """ + QUIC connection implementing both raw connection and muxed connection interfaces. + + Uses aioquic's sans-IO core with trio for native async support. + QUIC natively provides stream multiplexing, so this connection acts as both + a raw connection (for transport layer) and muxed connection (for upper layers). + + Features: + - Native QUIC stream multiplexing + - Integrated libp2p TLS security with peer identity verification + - Resource-aware stream management + - Comprehensive error handling + - Flow control integration + - Connection migration support + - Performance monitoring + - COMPLETE connection ID management (fixes the original issue) + """ + + def __init__( + self, + quic_connection: QuicConnection, + remote_addr: tuple[str, int], + remote_peer_id: ID | None, + local_peer_id: ID, + is_initiator: bool, + maddr: multiaddr.Multiaddr, + transport: "QUICTransport", + security_manager: Optional["QUICTLSConfigManager"] = None, + resource_scope: Any | None = None, + listener_socket: trio.socket.SocketType | None = None, + ): + """ + Initialize QUIC connection with security integration. + + Args: + quic_connection: aioquic QuicConnection instance + remote_addr: Remote peer address + remote_peer_id: Remote peer ID (may be None initially) + local_peer_id: Local peer ID + is_initiator: Whether this is the connection initiator + maddr: Multiaddr for this connection + transport: Parent QUIC transport + security_manager: Security manager for TLS/certificate handling + resource_scope: Resource manager scope for tracking + listener_socket: Socket of listener to transmit data + + """ + self._quic = quic_connection + self._remote_addr = remote_addr + self._remote_peer_id = remote_peer_id + self._local_peer_id = local_peer_id + self.peer_id = remote_peer_id or local_peer_id + self._is_initiator = is_initiator + self._maddr = maddr + self._transport = transport + self._security_manager = security_manager + self._resource_scope = resource_scope + + # Trio networking - socket may be provided by listener + self._socket = listener_socket if listener_socket else None + self._owns_socket = listener_socket is None + self._connected_event = trio.Event() + self._closed_event = trio.Event() + + self._streams: dict[int, QUICStream] = {} + self._stream_cache: dict[int, QUICStream] = {} # Cache for frequent lookups + self._next_stream_id: int = self._calculate_initial_stream_id() + self._stream_handler: TQUICStreamHandlerFn | None = None + + # Single lock for all stream operations + self._stream_lock = trio.Lock() + + # Stream counting and limits + self._outbound_stream_count = 0 + self._inbound_stream_count = 0 + + # Stream acceptance for incoming streams + self._stream_accept_queue: list[QUICStream] = [] + self._stream_accept_event = trio.Event() + + # Connection state + self._closed: bool = False + self._established = False + self._started = False + self._handshake_completed = False + self._peer_verified = False + + # Security state + self._peer_certificate: x509.Certificate | None = None + self._handshake_events: list[events.HandshakeCompleted] = [] + + # Background task management + self._background_tasks_started = False + self._nursery: trio.Nursery | None = None + self._event_processing_task: Any | None = None + self.on_close: Callable[[], Awaitable[None]] | None = None + self.event_started = trio.Event() + + self._available_connection_ids: set[bytes] = set() + self._current_connection_id: bytes | None = None + self._retired_connection_ids: set[bytes] = set() + self._connection_id_sequence_numbers: set[int] = set() + + # Event processing control with batching + self._event_processing_active = False + self._event_batch: list[events.QuicEvent] = [] + self._event_batch_size = 10 + self._last_event_time = 0.0 + + # Set quic connection configuration + self.CONNECTION_CLOSE_TIMEOUT = transport._config.CONNECTION_CLOSE_TIMEOUT + self.MAX_INCOMING_STREAMS = transport._config.MAX_INCOMING_STREAMS + self.MAX_OUTGOING_STREAMS = transport._config.MAX_OUTGOING_STREAMS + self.CONNECTION_HANDSHAKE_TIMEOUT = ( + transport._config.CONNECTION_HANDSHAKE_TIMEOUT + ) + self.MAX_CONCURRENT_STREAMS = transport._config.MAX_CONCURRENT_STREAMS + + # Performance and monitoring + self._connection_start_time = time.time() + self._stats = { + "streams_opened": 0, + "streams_accepted": 0, + "streams_closed": 0, + "streams_reset": 0, + "bytes_sent": 0, + "bytes_received": 0, + "packets_sent": 0, + "packets_received": 0, + "connection_ids_issued": 0, + "connection_ids_retired": 0, + "connection_id_changes": 0, + } + + logger.debug( + f"Created QUIC connection to {remote_peer_id} " + f"(initiator: {is_initiator}, addr: {remote_addr}, " + "security: {security_manager is not None})" + ) + + def _calculate_initial_stream_id(self) -> int: + """ + Calculate the initial stream ID based on QUIC specification. + + QUIC stream IDs: + - Client-initiated bidirectional: 0, 4, 8, 12, ... + - Server-initiated bidirectional: 1, 5, 9, 13, ... + - Client-initiated unidirectional: 2, 6, 10, 14, ... + - Server-initiated unidirectional: 3, 7, 11, 15, ... + + For libp2p, we primarily use bidirectional streams. + """ + if self._is_initiator: + return 0 + else: + return 1 + + @property + def is_initiator(self) -> bool: # type: ignore + """Check if this connection is the initiator.""" + return self._is_initiator + + @property + def is_closed(self) -> bool: + """Check if connection is closed.""" + return self._closed + + @property + def is_established(self) -> bool: + """Check if connection is established (handshake completed).""" + return self._established and self._handshake_completed + + @property + def is_started(self) -> bool: + """Check if connection has been started.""" + return self._started + + @property + def is_peer_verified(self) -> bool: + """Check if peer identity has been verified.""" + return self._peer_verified + + def multiaddr(self) -> multiaddr.Multiaddr: + """Get the multiaddr for this connection.""" + return self._maddr + + def local_peer_id(self) -> ID: + """Get the local peer ID.""" + return self._local_peer_id + + def remote_peer_id(self) -> ID | None: + """Get the remote peer ID.""" + return self._remote_peer_id + + def get_connection_id_stats(self) -> dict[str, Any]: + """Get connection ID statistics and current state.""" + return { + "available_connection_ids": len(self._available_connection_ids), + "current_connection_id": self._current_connection_id.hex() + if self._current_connection_id + else None, + "retired_connection_ids": len(self._retired_connection_ids), + "connection_ids_issued": self._stats["connection_ids_issued"], + "connection_ids_retired": self._stats["connection_ids_retired"], + "connection_id_changes": self._stats["connection_id_changes"], + "available_cid_list": [cid.hex() for cid in self._available_connection_ids], + } + + def get_current_connection_id(self) -> bytes | None: + """Get the current connection ID.""" + return self._current_connection_id + + # Fast stream lookup with caching + def _get_stream_fast(self, stream_id: int) -> QUICStream | None: + """Get stream with caching for performance.""" + # Try cache first + stream = self._stream_cache.get(stream_id) + if stream is not None: + return stream + + # Fallback to main dict + stream = self._streams.get(stream_id) + if stream is not None: + self._stream_cache[stream_id] = stream + + return stream + + # Connection lifecycle methods + + async def start(self) -> None: + """ + Start the connection and its background tasks. + + This method implements the IMuxedConn.start() interface. + It should be called to begin processing connection events. + """ + if self._started: + logger.warning("Connection already started") + return + + if self._closed: + raise QUICConnectionError("Cannot start a closed connection") + + self._started = True + self.event_started.set() + logger.debug(f"Starting QUIC connection to {self._remote_peer_id}") + + try: + # If this is a client connection, we need to establish the connection + if self._is_initiator: + await self._initiate_connection() + else: + # For server connections, we're already connected via the listener + self._established = True + self._connected_event.set() + + logger.debug(f"QUIC connection to {self._remote_peer_id} started") + + except Exception as e: + logger.error(f"Failed to start connection: {e}") + raise QUICConnectionError(f"Connection start failed: {e}") from e + + async def _initiate_connection(self) -> None: + """Initiate client-side connection, reusing listener socket if available.""" + try: + with QUICErrorContext("connection_initiation", "connection"): + if not self._socket: + logger.debug("Creating new socket for outbound connection") + self._socket = trio.socket.socket( + family=socket.AF_INET, type=socket.SOCK_DGRAM + ) + + await self._socket.bind(("0.0.0.0", 0)) + + self._quic.connect(self._remote_addr, now=time.time()) + + # Send initial packet(s) + await self._transmit() + + logger.debug(f"Initiated QUIC connection to {self._remote_addr}") + + except Exception as e: + logger.error(f"Failed to initiate connection: {e}") + raise QUICConnectionError(f"Connection initiation failed: {e}") from e + + async def connect(self, nursery: trio.Nursery) -> None: + """ + Establish the QUIC connection using trio nursery for background tasks. + + Args: + nursery: Trio nursery for managing connection background tasks + + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + self._nursery = nursery + + try: + with QUICErrorContext("connection_establishment", "connection"): + # Start the connection if not already started + logger.debug("STARTING TO CONNECT") + if not self._started: + await self.start() + + # Start background event processing + if not self._background_tasks_started: + logger.debug("STARTING BACKGROUND TASK") + await self._start_background_tasks() + else: + logger.debug("BACKGROUND TASK ALREADY STARTED") + + # Wait for handshake completion with timeout + with trio.move_on_after( + self.CONNECTION_HANDSHAKE_TIMEOUT + ) as cancel_scope: + await self._connected_event.wait() + + if cancel_scope.cancelled_caught: + raise QUICConnectionTimeoutError( + "Connection handshake timed out after" + f"{self.CONNECTION_HANDSHAKE_TIMEOUT}s" + ) + + logger.debug( + "QUICConnection: Verifying peer identity with security manager" + ) + # Verify peer identity using security manager + peer_id = await self._verify_peer_identity_with_security() + + if peer_id: + self.peer_id = peer_id + + logger.debug(f"QUICConnection {id(self)}: Peer identity verified") + self._established = True + logger.debug(f"QUIC connection established with {self._remote_peer_id}") + + except Exception as e: + logger.error(f"Failed to establish connection: {e}") + await self.close() + raise + + async def _start_background_tasks(self) -> None: + """Start background tasks for connection management.""" + if self._background_tasks_started or not self._nursery: + return + + self._background_tasks_started = True + + if self._is_initiator: + self._nursery.start_soon(async_fn=self._client_packet_receiver) + + self._nursery.start_soon(async_fn=self._event_processing_loop) + self._nursery.start_soon(async_fn=self._periodic_maintenance) + + logger.debug("Started background tasks for QUIC connection") + + async def _event_processing_loop(self) -> None: + """Main event processing loop for the connection.""" + logger.debug( + f"Started QUIC event processing loop for connection id: {id(self)} " + f"and local peer id {str(self.local_peer_id())}" + ) + + try: + while not self._closed: + # Batch process events + await self._process_quic_events_batched() + + # Handle timer events + await self._handle_timer_events() + + # Transmit any pending data + await self._transmit() + + # Short sleep to prevent busy waiting + await trio.sleep(0.01) + + except Exception as e: + logger.error(f"Error in event processing loop: {e}") + await self._handle_connection_error(e) + finally: + logger.debug("QUIC event processing loop finished") + + async def _periodic_maintenance(self) -> None: + """Perform periodic connection maintenance.""" + try: + while not self._closed: + # Update connection statistics + self._update_stats() + + # Check for idle streams that can be cleaned up + await self._cleanup_idle_streams() + + if logger.isEnabledFor(logging.DEBUG): + cid_stats = self.get_connection_id_stats() + logger.debug(f"Connection ID stats: {cid_stats}") + + # Clean cache periodically + await self._cleanup_cache() + + # Sleep for maintenance interval + await trio.sleep(30.0) # 30 seconds + + except Exception as e: + logger.error(f"Error in periodic maintenance: {e}") + + async def _cleanup_cache(self) -> None: + """Clean up stream cache periodically to prevent memory leaks.""" + if len(self._stream_cache) > 100: # Arbitrary threshold + # Remove closed streams from cache + closed_stream_ids = [ + sid for sid, stream in self._stream_cache.items() if stream.is_closed() + ] + for sid in closed_stream_ids: + self._stream_cache.pop(sid, None) + + async def _client_packet_receiver(self) -> None: + """Receive packets for client connections.""" + logger.debug("Starting client packet receiver") + logger.debug("Started QUIC client packet receiver") + + try: + while not self._closed and self._socket: + try: + # Receive UDP packets + data, addr = await self._socket.recvfrom(65536) + logger.debug(f"Client received {len(data)} bytes from {addr}") + + # Feed packet to QUIC connection + self._quic.receive_datagram(data, addr, now=time.time()) + + # Batch process events + await self._process_quic_events_batched() + + # Send any response packets + await self._transmit() + + except trio.ClosedResourceError: + logger.debug("Client socket closed") + break + except Exception as e: + logger.error(f"Error receiving client packet: {e}") + await trio.sleep(0.01) + + except trio.Cancelled: + logger.debug("Client packet receiver cancelled") + raise + finally: + logger.debug("Client packet receiver terminated") + + # Security and identity methods + + async def _verify_peer_identity_with_security(self) -> ID | None: + """ + Verify peer identity using integrated security manager. + + Raises: + QUICPeerVerificationError: If peer verification fails + + """ + logger.debug("VERIFYING PEER IDENTITY") + if not self._security_manager: + logger.debug("No security manager available for peer verification") + return None + + try: + # Extract peer certificate from TLS handshake + await self._extract_peer_certificate() + + if not self._peer_certificate: + logger.debug("No peer certificate available for verification") + return None + + # Validate certificate format and accessibility + if not self._validate_peer_certificate(): + logger.debug("Validation Failed for peer cerificate") + raise QUICPeerVerificationError("Peer certificate validation failed") + + # Verify peer identity using security manager + verified_peer_id = self._security_manager.verify_peer_identity( + self._peer_certificate, + self._remote_peer_id, # Expected peer ID for outbound connections + ) + + # Update peer ID if it wasn't known (inbound connections) + if not self._remote_peer_id: + self._remote_peer_id = verified_peer_id + logger.debug(f"Discovered peer ID from certificate: {verified_peer_id}") + elif self._remote_peer_id != verified_peer_id: + raise QUICPeerVerificationError( + f"Peer ID mismatch: expected {self._remote_peer_id}, " + "got {verified_peer_id}" + ) + + self._peer_verified = True + logger.debug(f"Peer identity verified successfully: {verified_peer_id}") + + return verified_peer_id + + except QUICPeerVerificationError: + # Re-raise verification errors as-is + raise + except Exception as e: + # Wrap other errors in verification error + raise QUICPeerVerificationError(f"Peer verification failed: {e}") from e + + async def _extract_peer_certificate(self) -> None: + """Extract peer certificate from completed TLS handshake.""" + try: + # Get peer certificate from aioquic TLS context + if self._quic.tls: + tls_context = self._quic.tls + + if tls_context._peer_certificate: + # aioquic stores the peer certificate as cryptography + # x509.Certificate + self._peer_certificate = tls_context._peer_certificate + logger.debug( + f"Extracted peer certificate: {self._peer_certificate.subject}" + ) + else: + logger.debug("No peer certificate found in TLS context") + + else: + logger.debug("No TLS context available for certificate extraction") + + except Exception as e: + logger.warning(f"Failed to extract peer certificate: {e}") + + # Try alternative approach - check if certificate is in handshake events + try: + # Some versions of aioquic might expose certificate differently + config = self._quic.configuration + if hasattr(config, "certificate") and config.certificate: + # This would be the local certificate, not peer certificate + # but we can use it for debugging + logger.debug("Found local certificate in configuration") + + except Exception as inner_e: + logger.error( + f"Alternative certificate extraction also failed: {inner_e}" + ) + + async def get_peer_certificate(self) -> x509.Certificate | None: + """ + Get the peer's TLS certificate. + + Returns: + The peer's X.509 certificate, or None if not available + + """ + # If we don't have a certificate yet, try to extract it + if not self._peer_certificate and self._handshake_completed: + await self._extract_peer_certificate() + + return self._peer_certificate + + def _validate_peer_certificate(self) -> bool: + """ + Validate that the peer certificate is properly formatted and accessible. + + Returns: + True if certificate is valid and accessible, False otherwise + + """ + if not self._peer_certificate: + return False + + try: + # Basic validation - try to access certificate properties + subject = self._peer_certificate.subject + serial_number = self._peer_certificate.serial_number + + logger.debug( + f"Certificate validation - Subject: {subject}, Serial: {serial_number}" + ) + return True + + except Exception as e: + logger.error(f"Certificate validation failed: {e}") + return False + + def get_security_manager(self) -> Optional["QUICTLSConfigManager"]: + """Get the security manager for this connection.""" + return self._security_manager + + def get_security_info(self) -> dict[str, Any]: + """Get security-related information about the connection.""" + info: dict[str, bool | Any | None] = { + "peer_verified": self._peer_verified, + "handshake_complete": self._handshake_completed, + "peer_id": str(self._remote_peer_id) if self._remote_peer_id else None, + "local_peer_id": str(self._local_peer_id), + "is_initiator": self._is_initiator, + "has_certificate": self._peer_certificate is not None, + "security_manager_available": self._security_manager is not None, + } + + # Add certificate details if available + if self._peer_certificate: + try: + info.update( + { + "certificate_subject": str(self._peer_certificate.subject), + "certificate_issuer": str(self._peer_certificate.issuer), + "certificate_serial": str(self._peer_certificate.serial_number), + "certificate_not_before": ( + self._peer_certificate.not_valid_before.isoformat() + ), + "certificate_not_after": ( + self._peer_certificate.not_valid_after.isoformat() + ), + } + ) + except Exception as e: + info["certificate_error"] = str(e) + + # Add TLS context debug info + try: + if hasattr(self._quic, "tls") and self._quic.tls: + tls_info = { + "tls_context_available": True, + "tls_state": getattr(self._quic.tls, "state", None), + } + + # Check for peer certificate in TLS context + if hasattr(self._quic.tls, "_peer_certificate"): + tls_info["tls_peer_certificate_available"] = ( + self._quic.tls._peer_certificate is not None + ) + + info["tls_debug"] = tls_info + else: + info["tls_debug"] = {"tls_context_available": False} + + except Exception as e: + info["tls_debug"] = {"error": str(e)} + + return info + + # Stream management methods (IMuxedConn interface) + + async def open_stream(self, timeout: float = 5.0) -> QUICStream: + """ + Open a new outbound stream + + Args: + timeout: Timeout for stream creation + + Returns: + New QUIC stream + + Raises: + QUICStreamLimitError: Too many concurrent streams + QUICConnectionClosedError: Connection is closed + QUICStreamTimeoutError: Stream creation timed out + + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + if not self._started: + raise QUICConnectionError("Connection not started") + + # Use single lock for all stream operations + with trio.move_on_after(timeout): + async with self._stream_lock: + # Check stream limits inside lock + if self._outbound_stream_count >= self.MAX_OUTGOING_STREAMS: + raise QUICStreamLimitError( + "Maximum outbound streams " + f"({self.MAX_OUTGOING_STREAMS}) reached" + ) + + # Generate next stream ID + stream_id = self._next_stream_id + self._next_stream_id += 4 # Increment by 4 for bidirectional streams + + stream = QUICStream( + connection=self, + stream_id=stream_id, + direction=StreamDirection.OUTBOUND, + resource_scope=self._resource_scope, + remote_addr=self._remote_addr, + ) + + self._streams[stream_id] = stream + self._stream_cache[stream_id] = stream # Add to cache + + self._outbound_stream_count += 1 + self._stats["streams_opened"] += 1 + + logger.debug(f"Opened outbound QUIC stream {stream_id}") + return stream + + raise QUICStreamTimeoutError(f"Stream creation timed out after {timeout}s") + + async def accept_stream(self, timeout: float | None = None) -> QUICStream: + """ + Accept incoming stream. + + Args: + timeout: Optional timeout. If None, waits indefinitely. + + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + if timeout is not None: + with trio.move_on_after(timeout): + return await self._accept_stream_impl() + # Timeout occurred + if self._closed_event.is_set() or self._closed: + raise MuxedConnUnavailable("QUIC connection closed during timeout") + else: + raise QUICStreamTimeoutError( + f"Stream accept timed out after {timeout}s" + ) + else: + # No timeout - wait indefinitely + return await self._accept_stream_impl() + + async def _accept_stream_impl(self) -> QUICStream: + while True: + if self._closed: + raise MuxedConnUnavailable("QUIC connection is closed") + + # Use single lock for stream acceptance + async with self._stream_lock: + if self._stream_accept_queue: + stream = self._stream_accept_queue.pop(0) + logger.debug(f"Accepted inbound stream {stream.stream_id}") + return stream + + if self._closed: + raise MuxedConnUnavailable("Connection closed while accepting stream") + + # Wait for new streams indefinitely + await self._stream_accept_event.wait() + + raise QUICConnectionError("Error occurred while waiting to accept stream") + + def set_stream_handler(self, handler_function: TQUICStreamHandlerFn) -> None: + """ + Set handler for incoming streams. + + Args: + handler_function: Function to handle new incoming streams + + """ + self._stream_handler = handler_function + logger.debug("Set stream handler for incoming streams") + + def _remove_stream(self, stream_id: int) -> None: + """ + Remove stream from connection registry. + Called by stream cleanup process. + """ + if stream_id in self._streams: + stream = self._streams.pop(stream_id) + # Remove from cache too + self._stream_cache.pop(stream_id, None) + + # Update stream counts asynchronously + async def update_counts() -> None: + async with self._stream_lock: + if stream.direction == StreamDirection.OUTBOUND: + self._outbound_stream_count = max( + 0, self._outbound_stream_count - 1 + ) + else: + self._inbound_stream_count = max( + 0, self._inbound_stream_count - 1 + ) + self._stats["streams_closed"] += 1 + + # Schedule count update if we're in a trio context + if self._nursery: + self._nursery.start_soon(update_counts) + + logger.debug(f"Removed stream {stream_id} from connection") + + # Batched event processing to reduce overhead + async def _process_quic_events_batched(self) -> None: + """Process QUIC events in batches for better performance.""" + if self._event_processing_active: + return # Prevent recursion + + self._event_processing_active = True + + try: + current_time = time.time() + events_processed = 0 + + # Collect events into batch + while events_processed < self._event_batch_size: + event = self._quic.next_event() + if event is None: + break + + self._event_batch.append(event) + events_processed += 1 + + # Process batch if we have events or timeout + if self._event_batch and ( + len(self._event_batch) >= self._event_batch_size + or current_time - self._last_event_time > 0.01 # 10ms timeout + ): + await self._process_event_batch() + self._event_batch.clear() + self._last_event_time = current_time + + finally: + self._event_processing_active = False + + async def _process_event_batch(self) -> None: + """Process a batch of events efficiently.""" + if not self._event_batch: + return + + # Group events by type for batch processing where possible + events_by_type: defaultdict[str, list[QuicEvent]] = defaultdict(list) + for event in self._event_batch: + events_by_type[type(event).__name__].append(event) + + # Process events by type + for event_type, event_list in events_by_type.items(): + if event_type == type(events.StreamDataReceived).__name__: + await self._handle_stream_data_batch( + cast(list[events.StreamDataReceived], event_list) + ) + else: + # Process other events individually + for event in event_list: + await self._handle_quic_event(event) + + logger.debug(f"Processed batch of {len(self._event_batch)} events") + + async def _handle_stream_data_batch( + self, events_list: list[events.StreamDataReceived] + ) -> None: + """Handle stream data events in batch for better performance.""" + # Group by stream ID + events_by_stream: defaultdict[int, list[QuicEvent]] = defaultdict(list) + for event in events_list: + events_by_stream[event.stream_id].append(event) + + # Process each stream's events + for stream_id, stream_events in events_by_stream.items(): + stream = self._get_stream_fast(stream_id) # Use fast lookup + + if not stream: + if self._is_incoming_stream(stream_id): + try: + stream = await self._create_inbound_stream(stream_id) + except QUICStreamLimitError: + # Reset stream if we can't handle it + self._quic.reset_stream(stream_id, error_code=0x04) + await self._transmit() + continue + else: + logger.error( + f"Unexpected outbound stream {stream_id} in data event" + ) + continue + + # Process all events for this stream + for received_event in stream_events: + if hasattr(received_event, "data"): + self._stats["bytes_received"] += len(received_event.data) # type: ignore + + if hasattr(received_event, "end_stream"): + await stream.handle_data_received( + received_event.data, # type: ignore + received_event.end_stream, # type: ignore + ) + + async def _create_inbound_stream(self, stream_id: int) -> QUICStream: + """Create inbound stream with proper limit checking.""" + async with self._stream_lock: + # Double-check stream doesn't exist + existing_stream = self._streams.get(stream_id) + if existing_stream: + return existing_stream + + # Check limits + if self._inbound_stream_count >= self.MAX_INCOMING_STREAMS: + logger.warning(f"Rejecting inbound stream {stream_id}: limit reached") + raise QUICStreamLimitError("Too many inbound streams") + + # Create stream + stream = QUICStream( + connection=self, + stream_id=stream_id, + direction=StreamDirection.INBOUND, + resource_scope=self._resource_scope, + remote_addr=self._remote_addr, + ) + + self._streams[stream_id] = stream + self._stream_cache[stream_id] = stream # Add to cache + self._inbound_stream_count += 1 + self._stats["streams_accepted"] += 1 + + # Add to accept queue + self._stream_accept_queue.append(stream) + self._stream_accept_event.set() + + logger.debug(f"Created inbound stream {stream_id}") + return stream + + async def _process_quic_events(self) -> None: + """Process all pending QUIC events.""" + # Delegate to batched processing for better performance + await self._process_quic_events_batched() + + async def _handle_quic_event(self, event: events.QuicEvent) -> None: + """Handle a single QUIC event with COMPLETE event type coverage.""" + logger.debug(f"Handling QUIC event: {type(event).__name__}") + logger.debug(f"QUIC event: {type(event).__name__}") + + try: + if isinstance(event, events.ConnectionTerminated): + await self._handle_connection_terminated(event) + elif isinstance(event, events.HandshakeCompleted): + await self._handle_handshake_completed(event) + elif isinstance(event, events.StreamDataReceived): + await self._handle_stream_data(event) + elif isinstance(event, events.StreamReset): + await self._handle_stream_reset(event) + elif isinstance(event, events.DatagramFrameReceived): + await self._handle_datagram_received(event) + # *** NEW: Connection ID event handlers - CRITICAL FIX *** + elif isinstance(event, events.ConnectionIdIssued): + await self._handle_connection_id_issued(event) + elif isinstance(event, events.ConnectionIdRetired): + await self._handle_connection_id_retired(event) + # *** NEW: Additional event handlers for completeness *** + elif isinstance(event, events.PingAcknowledged): + await self._handle_ping_acknowledged(event) + elif isinstance(event, events.ProtocolNegotiated): + await self._handle_protocol_negotiated(event) + elif isinstance(event, events.StopSendingReceived): + await self._handle_stop_sending_received(event) + else: + logger.debug(f"Unhandled QUIC event type: {type(event).__name__}") + logger.debug(f"Unhandled QUIC event: {type(event).__name__}") + + except Exception as e: + logger.error(f"Error handling QUIC event {type(event).__name__}: {e}") + + async def _handle_connection_id_issued( + self, event: events.ConnectionIdIssued + ) -> None: + """ + Handle new connection ID issued by peer. + + This is the CRITICAL missing functionality that was causing your issue! + """ + logger.debug(f"šŸ†” NEW CONNECTION ID ISSUED: {event.connection_id.hex()}") + logger.debug(f"šŸ†” NEW CONNECTION ID ISSUED: {event.connection_id.hex()}") + + # Add to available connection IDs + self._available_connection_ids.add(event.connection_id) + + # If we don't have a current connection ID, use this one + if self._current_connection_id is None: + self._current_connection_id = event.connection_id + logger.debug( + f"šŸ†” Set current connection ID to: {event.connection_id.hex()}" + ) + logger.debug( + f"šŸ†” Set current connection ID to: {event.connection_id.hex()}" + ) + + # Update statistics + self._stats["connection_ids_issued"] += 1 + + logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}") + logger.debug(f"Available connection IDs: {len(self._available_connection_ids)}") + + async def _handle_connection_id_retired( + self, event: events.ConnectionIdRetired + ) -> None: + """ + Handle connection ID retirement. + + This handles when the peer tells us to stop using a connection ID. + """ + logger.debug(f"šŸ—‘ļø CONNECTION ID RETIRED: {event.connection_id.hex()}") + + # Remove from available IDs and add to retired set + self._available_connection_ids.discard(event.connection_id) + self._retired_connection_ids.add(event.connection_id) + + # If this was our current connection ID, switch to another + if self._current_connection_id == event.connection_id: + if self._available_connection_ids: + self._current_connection_id = next(iter(self._available_connection_ids)) + if self._current_connection_id: + logger.debug( + "Switching to new connection ID: " + f"{self._current_connection_id.hex()}" + ) + self._stats["connection_id_changes"] += 1 + else: + logger.warning("āš ļø No available connection IDs after retirement!") + else: + self._current_connection_id = None + logger.warning("āš ļø No available connection IDs after retirement!") + + # Update statistics + self._stats["connection_ids_retired"] += 1 + + async def _handle_ping_acknowledged(self, event: events.PingAcknowledged) -> None: + """Handle ping acknowledgment.""" + logger.debug(f"Ping acknowledged: uid={event.uid}") + + async def _handle_protocol_negotiated( + self, event: events.ProtocolNegotiated + ) -> None: + """Handle protocol negotiation completion.""" + logger.debug(f"Protocol negotiated: {event.alpn_protocol}") + + async def _handle_stop_sending_received( + self, event: events.StopSendingReceived + ) -> None: + """Handle stop sending request from peer.""" + logger.debug( + "Stop sending received: " + f"stream_id={event.stream_id}, error_code={event.error_code}" + ) + + # Use fast lookup + stream = self._get_stream_fast(event.stream_id) + if stream: + # Handle stop sending on the stream if method exists + await stream.handle_stop_sending(event.error_code) + + async def _handle_handshake_completed( + self, event: events.HandshakeCompleted + ) -> None: + """Handle handshake completion with security integration.""" + logger.debug("QUIC handshake completed") + self._handshake_completed = True + + # Store handshake event for security verification + self._handshake_events.append(event) + + # Try to extract certificate information after handshake + await self._extract_peer_certificate() + + logger.debug("āœ… Setting connected event") + self._connected_event.set() + + async def _handle_connection_terminated( + self, event: events.ConnectionTerminated + ) -> None: + """Handle connection termination.""" + logger.debug(f"QUIC connection terminated: {event.reason_phrase}") + + # Close all streams + for stream in list(self._streams.values()): + if event.error_code: + await stream.handle_reset(event.error_code) + else: + await stream.close() + + self._streams.clear() + self._stream_cache.clear() # Clear cache too + self._closed = True + self._closed_event.set() + + self._stream_accept_event.set() + logger.debug(f"Woke up pending accept_stream() calls, {id(self)}") + + await self._notify_parent_of_termination() + + async def _handle_stream_data(self, event: events.StreamDataReceived) -> None: + """Handle stream data events - create streams and add to accept queue.""" + stream_id = event.stream_id + self._stats["bytes_received"] += len(event.data) + + try: + # Use fast lookup + stream = self._get_stream_fast(stream_id) + + if not stream: + if self._is_incoming_stream(stream_id): + logger.debug(f"Creating new incoming stream {stream_id}") + stream = await self._create_inbound_stream(stream_id) + else: + logger.error( + f"Unexpected outbound stream {stream_id} in data event" + ) + return + + await stream.handle_data_received(event.data, event.end_stream) + + except Exception as e: + logger.error(f"Error handling stream data for stream {stream_id}: {e}") + logger.debug(f"āŒ STREAM_DATA: Error: {e}") + + async def _get_or_create_stream(self, stream_id: int) -> QUICStream: + """Get existing stream or create new inbound stream.""" + # Use fast lookup + stream = self._get_stream_fast(stream_id) + if stream: + return stream + + # Check if this is an incoming stream + is_incoming = self._is_incoming_stream(stream_id) + + if not is_incoming: + # This shouldn't happen - outbound streams should be created by open_stream + raise QUICStreamError( + f"Received data for unknown outbound stream {stream_id}" + ) + + # Create new inbound stream + return await self._create_inbound_stream(stream_id) + + def _is_incoming_stream(self, stream_id: int) -> bool: + """ + Determine if a stream ID represents an incoming stream. + + For bidirectional streams: + - Even IDs are client-initiated + - Odd IDs are server-initiated + """ + if self._is_initiator: + # We're the client, so odd stream IDs are incoming + return stream_id % 2 == 1 + else: + # We're the server, so even stream IDs are incoming + return stream_id % 2 == 0 + + async def _handle_stream_reset(self, event: events.StreamReset) -> None: + """Stream reset handling.""" + stream_id = event.stream_id + self._stats["streams_reset"] += 1 + + # Use fast lookup + stream = self._get_stream_fast(stream_id) + if stream: + try: + await stream.handle_reset(event.error_code) + logger.debug( + f"Handled reset for stream {stream_id}" + f"with error code {event.error_code}" + ) + except Exception as e: + logger.error(f"Error handling stream reset for {stream_id}: {e}") + # Force remove the stream + self._remove_stream(stream_id) + else: + logger.debug(f"Received reset for unknown stream {stream_id}") + + async def _handle_datagram_received( + self, event: events.DatagramFrameReceived + ) -> None: + """Handle datagram frame (if using QUIC datagrams).""" + logger.debug(f"Datagram frame received: size={len(event.data)}") + # For now, just log. Could be extended for custom datagram handling + + async def _handle_timer_events(self) -> None: + """Handle QUIC timer events.""" + timer = self._quic.get_timer() + if timer is not None: + now = time.time() + if timer <= now: + self._quic.handle_timer(now=now) + + # Network transmission + + async def _transmit(self) -> None: + """Transmit pending QUIC packets using available socket.""" + sock = self._socket + if not sock: + logger.debug("No socket to transmit") + return + + try: + current_time = time.time() + datagrams = self._quic.datagrams_to_send(now=current_time) + + # Batch stats updates + packet_count = 0 + total_bytes = 0 + + for data, addr in datagrams: + await sock.sendto(data, addr) + packet_count += 1 + total_bytes += len(data) + + # Update stats in batch + if packet_count > 0: + self._stats["packets_sent"] += packet_count + self._stats["bytes_sent"] += total_bytes + + except Exception as e: + logger.error(f"Transmission error: {e}") + await self._handle_connection_error(e) + + # Additional methods for stream data processing + async def _process_quic_event(self, event: events.QuicEvent) -> None: + """Process a single QUIC event.""" + await self._handle_quic_event(event) + + async def _transmit_pending_data(self) -> None: + """Transmit any pending data.""" + await self._transmit() + + # Error handling + + async def _handle_connection_error(self, error: Exception) -> None: + """Handle connection-level errors.""" + logger.error(f"Connection error: {error}") + + if not self._closed: + try: + await self.close() + except Exception as close_error: + logger.error(f"Error during connection close: {close_error}") + + # Connection close + + async def close(self) -> None: + """Connection close with proper stream cleanup.""" + if self._closed: + return + + self._closed = True + logger.debug(f"Closing QUIC connection to {self._remote_peer_id}") + + try: + # Close all streams gracefully + stream_close_tasks = [] + for stream in list(self._streams.values()): + if stream.can_write() or stream.can_read(): + stream_close_tasks.append(stream.close) + + if stream_close_tasks and self._nursery: + try: + # Close streams concurrently with timeout + with trio.move_on_after(self.CONNECTION_CLOSE_TIMEOUT): + async with trio.open_nursery() as close_nursery: + for task in stream_close_tasks: + close_nursery.start_soon(task) + except Exception as e: + logger.warning(f"Error during graceful stream close: {e}") + # Force reset remaining streams + for stream in self._streams.values(): + try: + await stream.reset(error_code=0) + except Exception: + pass + + if self.on_close: + await self.on_close() + + # Close QUIC connection + self._quic.close() + + if self._socket: + await self._transmit() # Send close frames + + # Close socket + if self._socket and self._owns_socket: + self._socket.close() + self._socket = None + + self._streams.clear() + self._stream_cache.clear() # Clear cache + self._closed_event.set() + + logger.debug(f"QUIC connection to {self._remote_peer_id} closed") + + except Exception as e: + logger.error(f"Error during connection close: {e}") + + async def _notify_parent_of_termination(self) -> None: + """ + Notify the parent listener/transport to remove this connection from tracking. + + This ensures that terminated connections are cleaned up from the + 'established connections' list. + """ + try: + if self._transport: + await self._transport._cleanup_terminated_connection(self) + logger.debug("Notified transport of connection termination") + return + + for listener in self._transport._listeners: + try: + await listener._remove_connection_by_object(self) + logger.debug( + "Found and notified listener of connection termination" + ) + return + except Exception: + continue + + # Method 4: Use connection ID if we have one (most reliable) + if self._current_connection_id: + await self._cleanup_by_connection_id(self._current_connection_id) + return + + logger.warning( + "Could not notify parent of connection termination - no" + f" parent reference found for conn host {self._quic.host_cid.hex()}" + ) + + except Exception as e: + logger.error(f"Error notifying parent of connection termination: {e}") + + async def _cleanup_by_connection_id(self, connection_id: bytes) -> None: + """Cleanup using connection ID as a fallback method.""" + try: + for listener in self._transport._listeners: + for tracked_cid, tracked_conn in list(listener._connections.items()): + if tracked_conn is self: + await listener._remove_connection(tracked_cid) + logger.debug(f"Removed connection {tracked_cid.hex()}") + return + + logger.debug("Fallback cleanup by connection ID completed") + except Exception as e: + logger.error(f"Error in fallback cleanup: {e}") + + # IRawConnection interface (for compatibility) + + def get_remote_address(self) -> tuple[str, int]: + return self._remote_addr + + async def write(self, data: bytes) -> None: + """ + Write data to the connection. + For QUIC, this creates a new stream for each write operation. + """ + if self._closed: + raise QUICConnectionClosedError("Connection is closed") + + stream = await self.open_stream() + try: + await stream.write(data) + await stream.close_write() + except Exception: + await stream.reset() + raise + + async def read(self, n: int | None = -1) -> bytes: + """ + Read data from the stream. + + Args: + n: Maximum number of bytes to read. -1 means read all available. + + Returns: + Data bytes read from the stream. + + Raises: + QUICStreamClosedError: If stream is closed for reading. + QUICStreamResetError: If stream was reset. + QUICStreamTimeoutError: If read timeout occurs. + + """ + # It's here for interface compatibility but should not be used + raise NotImplementedError( + "Use streams for reading data from QUIC connections. " + "Call accept_stream() or open_stream() instead." + ) + + # Utility and monitoring methods + + def get_stream_stats(self) -> dict[str, Any]: + """Get stream statistics for monitoring.""" + return { + "total_streams": len(self._streams), + "outbound_streams": self._outbound_stream_count, + "inbound_streams": self._inbound_stream_count, + "max_streams": self.MAX_CONCURRENT_STREAMS, + "stream_utilization": len(self._streams) / self.MAX_CONCURRENT_STREAMS, + "stats": self._stats.copy(), + "cache_size": len( + self._stream_cache + ), # Include cache metrics for monitoring + } + + def get_active_streams(self) -> list[QUICStream]: + """Get list of active streams.""" + return [stream for stream in self._streams.values() if not stream.is_closed()] + + def get_streams_by_protocol(self, protocol: str) -> list[QUICStream]: + """Get streams filtered by protocol.""" + return [ + stream + for stream in self._streams.values() + if hasattr(stream, "protocol") + and stream.protocol == protocol + and not stream.is_closed() + ] + + def _update_stats(self) -> None: + """Update connection statistics.""" + # Add any periodic stats updates here + pass + + async def _cleanup_idle_streams(self) -> None: + """Clean up idle streams that are no longer needed.""" + current_time = time.time() + streams_to_cleanup = [] + + for stream in self._streams.values(): + if stream.is_closed(): + # Check if stream has been closed for a while + if hasattr(stream, "_timeline") and stream._timeline.closed_at: + if current_time - stream._timeline.closed_at > 60: # 1 minute + streams_to_cleanup.append(stream.stream_id) + + for stream_id in streams_to_cleanup: + self._remove_stream(int(stream_id)) + + # String representation + + def __repr__(self) -> str: + current_cid: str | None = ( + self._current_connection_id.hex() if self._current_connection_id else None + ) + return ( + f"QUICConnection(peer={self._remote_peer_id}, " + f"addr={self._remote_addr}, " + f"initiator={self._is_initiator}, " + f"verified={self._peer_verified}, " + f"established={self._established}, " + f"streams={len(self._streams)}, " + f"current_cid={current_cid})" + ) + + def __str__(self) -> str: + return f"QUICConnection({self._remote_peer_id})" diff --git a/libp2p/transport/quic/exceptions.py b/libp2p/transport/quic/exceptions.py new file mode 100644 index 00000000..2df3dda5 --- /dev/null +++ b/libp2p/transport/quic/exceptions.py @@ -0,0 +1,391 @@ +""" +QUIC Transport exceptions +""" + +from typing import Any, Literal + + +class QUICError(Exception): + """Base exception for all QUIC transport errors.""" + + def __init__(self, message: str, error_code: int | None = None): + super().__init__(message) + self.error_code = error_code + + +# Transport-level exceptions + + +class QUICTransportError(QUICError): + """Base exception for QUIC transport operations.""" + + pass + + +class QUICDialError(QUICTransportError): + """Error occurred during QUIC connection establishment.""" + + pass + + +class QUICListenError(QUICTransportError): + """Error occurred during QUIC listener operations.""" + + pass + + +class QUICSecurityError(QUICTransportError): + """Error related to QUIC security/TLS operations.""" + + pass + + +# Connection-level exceptions + + +class QUICConnectionError(QUICError): + """Base exception for QUIC connection operations.""" + + pass + + +class QUICConnectionClosedError(QUICConnectionError): + """QUIC connection has been closed.""" + + pass + + +class QUICConnectionTimeoutError(QUICConnectionError): + """QUIC connection operation timed out.""" + + pass + + +class QUICHandshakeError(QUICConnectionError): + """Error during QUIC handshake process.""" + + pass + + +class QUICPeerVerificationError(QUICConnectionError): + """Error verifying peer identity during handshake.""" + + pass + + +# Stream-level exceptions + + +class QUICStreamError(QUICError): + """Base exception for QUIC stream operations.""" + + def __init__( + self, + message: str, + stream_id: str | None = None, + error_code: int | None = None, + ): + super().__init__(message, error_code) + self.stream_id = stream_id + + +class QUICStreamClosedError(QUICStreamError): + """Stream is closed and cannot be used for I/O operations.""" + + pass + + +class QUICStreamResetError(QUICStreamError): + """Stream was reset by local or remote peer.""" + + def __init__( + self, + message: str, + stream_id: str | None = None, + error_code: int | None = None, + reset_by_peer: bool = False, + ): + super().__init__(message, stream_id, error_code) + self.reset_by_peer = reset_by_peer + + +class QUICStreamTimeoutError(QUICStreamError): + """Stream operation timed out.""" + + pass + + +class QUICStreamBackpressureError(QUICStreamError): + """Stream write blocked due to flow control.""" + + pass + + +class QUICStreamLimitError(QUICStreamError): + """Stream limit reached (too many concurrent streams).""" + + pass + + +class QUICStreamStateError(QUICStreamError): + """Invalid operation for current stream state.""" + + def __init__( + self, + message: str, + stream_id: str | None = None, + current_state: str | None = None, + attempted_operation: str | None = None, + ): + super().__init__(message, stream_id) + self.current_state = current_state + self.attempted_operation = attempted_operation + + +# Flow control exceptions + + +class QUICFlowControlError(QUICError): + """Base exception for flow control related errors.""" + + pass + + +class QUICFlowControlViolationError(QUICFlowControlError): + """Flow control limits were violated.""" + + pass + + +class QUICFlowControlDeadlockError(QUICFlowControlError): + """Flow control deadlock detected.""" + + pass + + +# Resource management exceptions + + +class QUICResourceError(QUICError): + """Base exception for resource management errors.""" + + pass + + +class QUICMemoryLimitError(QUICResourceError): + """Memory limit exceeded.""" + + pass + + +class QUICConnectionLimitError(QUICResourceError): + """Connection limit exceeded.""" + + pass + + +# Multiaddr and addressing exceptions + + +class QUICAddressError(QUICError): + """Base exception for QUIC addressing errors.""" + + pass + + +class QUICInvalidMultiaddrError(QUICAddressError): + """Invalid multiaddr format for QUIC transport.""" + + pass + + +class QUICAddressResolutionError(QUICAddressError): + """Failed to resolve QUIC address.""" + + pass + + +class QUICProtocolError(QUICError): + """Base exception for QUIC protocol errors.""" + + pass + + +class QUICVersionNegotiationError(QUICProtocolError): + """QUIC version negotiation failed.""" + + pass + + +class QUICUnsupportedVersionError(QUICProtocolError): + """Unsupported QUIC version.""" + + pass + + +# Configuration exceptions + + +class QUICConfigurationError(QUICError): + """Base exception for QUIC configuration errors.""" + + pass + + +class QUICInvalidConfigError(QUICConfigurationError): + """Invalid QUIC configuration parameters.""" + + pass + + +class QUICCertificateError(QUICConfigurationError): + """Error with TLS certificate configuration.""" + + pass + + +def map_quic_error_code(error_code: int) -> str: + """ + Map QUIC error codes to human-readable descriptions. + Based on RFC 9000 Transport Error Codes. + """ + error_codes = { + 0x00: "NO_ERROR", + 0x01: "INTERNAL_ERROR", + 0x02: "CONNECTION_REFUSED", + 0x03: "FLOW_CONTROL_ERROR", + 0x04: "STREAM_LIMIT_ERROR", + 0x05: "STREAM_STATE_ERROR", + 0x06: "FINAL_SIZE_ERROR", + 0x07: "FRAME_ENCODING_ERROR", + 0x08: "TRANSPORT_PARAMETER_ERROR", + 0x09: "CONNECTION_ID_LIMIT_ERROR", + 0x0A: "PROTOCOL_VIOLATION", + 0x0B: "INVALID_TOKEN", + 0x0C: "APPLICATION_ERROR", + 0x0D: "CRYPTO_BUFFER_EXCEEDED", + 0x0E: "KEY_UPDATE_ERROR", + 0x0F: "AEAD_LIMIT_REACHED", + 0x10: "NO_VIABLE_PATH", + } + + return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}") + + +def create_stream_error( + error_type: str, + message: str, + stream_id: str | None = None, + error_code: int | None = None, +) -> QUICStreamError: + """ + Factory function to create appropriate stream error based on type. + + Args: + error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.) + message: Error message + stream_id: Stream identifier + error_code: QUIC error code + + Returns: + Appropriate QUICStreamError subclass + + """ + error_type = error_type.lower() + + if error_type in ("closed", "close"): + return QUICStreamClosedError(message, stream_id, error_code) + elif error_type == "reset": + return QUICStreamResetError(message, stream_id, error_code) + elif error_type == "timeout": + return QUICStreamTimeoutError(message, stream_id, error_code) + elif error_type in ("backpressure", "flow_control"): + return QUICStreamBackpressureError(message, stream_id, error_code) + elif error_type in ("limit", "stream_limit"): + return QUICStreamLimitError(message, stream_id, error_code) + elif error_type == "state": + return QUICStreamStateError(message, stream_id) + else: + return QUICStreamError(message, stream_id, error_code) + + +def create_connection_error( + error_type: str, message: str, error_code: int | None = None +) -> QUICConnectionError: + """ + Factory function to create appropriate connection error based on type. + + Args: + error_type: Type of error ("closed", "timeout", "handshake", etc.) + message: Error message + error_code: QUIC error code + + Returns: + Appropriate QUICConnectionError subclass + + """ + error_type = error_type.lower() + + if error_type in ("closed", "close"): + return QUICConnectionClosedError(message, error_code) + elif error_type == "timeout": + return QUICConnectionTimeoutError(message, error_code) + elif error_type == "handshake": + return QUICHandshakeError(message, error_code) + elif error_type in ("peer_verification", "verification"): + return QUICPeerVerificationError(message, error_code) + else: + return QUICConnectionError(message, error_code) + + +class QUICErrorContext: + """ + Context manager for handling QUIC errors with automatic error mapping. + Useful for converting low-level aioquic errors to py-libp2p QUIC errors. + """ + + def __init__(self, operation: str, component: str = "quic") -> None: + self.operation = operation + self.component = component + + def __enter__(self) -> "QUICErrorContext": + return self + + # TODO: Fix types for exc_type + def __exit__( + self, + exc_type: type[BaseException] | None | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> Literal[False]: + if exc_type is None: + return False + + if exc_val is None: + return False + + # Map common aioquic exceptions to our exceptions + if "ConnectionClosed" in str(exc_type): + raise QUICConnectionClosedError( + f"Connection closed during {self.operation}: {exc_val}" + ) from exc_val + elif "StreamReset" in str(exc_type): + raise QUICStreamResetError( + f"Stream reset during {self.operation}: {exc_val}" + ) from exc_val + elif "timeout" in str(exc_val).lower(): + if "stream" in self.component.lower(): + raise QUICStreamTimeoutError( + f"Timeout during {self.operation}: {exc_val}" + ) from exc_val + else: + raise QUICConnectionTimeoutError( + f"Timeout during {self.operation}: {exc_val}" + ) from exc_val + elif "flow control" in str(exc_val).lower(): + raise QUICStreamBackpressureError( + f"Flow control error during {self.operation}: {exc_val}" + ) from exc_val + + # Let other exceptions propagate + return False diff --git a/libp2p/transport/quic/listener.py b/libp2p/transport/quic/listener.py new file mode 100644 index 00000000..0e8e66ad --- /dev/null +++ b/libp2p/transport/quic/listener.py @@ -0,0 +1,1041 @@ +""" +QUIC Listener +""" + +import logging +import socket +import struct +import sys +import time +from typing import TYPE_CHECKING + +from aioquic.quic import events +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import QuicConnection +from aioquic.quic.packet import QuicPacketType +from multiaddr import Multiaddr +import trio + +from libp2p.abc import IListener +from libp2p.custom_types import ( + TProtocol, + TQUICConnHandlerFn, +) +from libp2p.transport.quic.security import ( + LIBP2P_TLS_EXTENSION_OID, + QUICTLSConfigManager, +) + +from .config import QUICTransportConfig +from .connection import QUICConnection +from .exceptions import QUICListenError +from .utils import ( + create_quic_multiaddr, + create_server_config_from_base, + custom_quic_version_to_wire_format, + is_quic_multiaddr, + multiaddr_to_quic_version, + quic_multiaddr_to_endpoint, +) + +if TYPE_CHECKING: + from .transport import QUICTransport + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class QUICPacketInfo: + """Information extracted from a QUIC packet header.""" + + def __init__( + self, + version: int, + destination_cid: bytes, + source_cid: bytes, + packet_type: QuicPacketType, + token: bytes | None = None, + ): + self.version = version + self.destination_cid = destination_cid + self.source_cid = source_cid + self.packet_type = packet_type + self.token = token + + +class QUICListener(IListener): + """ + QUIC Listener with connection ID handling and protocol negotiation. + """ + + def __init__( + self, + transport: "QUICTransport", + handler_function: TQUICConnHandlerFn, + quic_configs: dict[TProtocol, QuicConfiguration], + config: QUICTransportConfig, + security_manager: QUICTLSConfigManager | None = None, + ): + """Initialize enhanced QUIC listener.""" + self._transport = transport + self._handler = handler_function + self._quic_configs = quic_configs + self._config = config + self._security_manager = security_manager + + # Network components + self._socket: trio.socket.SocketType | None = None + self._bound_addresses: list[Multiaddr] = [] + + # Enhanced connection management with connection ID routing + self._connections: dict[ + bytes, QUICConnection + ] = {} # destination_cid -> connection + self._pending_connections: dict[ + bytes, QuicConnection + ] = {} # destination_cid -> quic_conn + self._addr_to_cid: dict[ + tuple[str, int], bytes + ] = {} # (host, port) -> destination_cid + self._cid_to_addr: dict[ + bytes, tuple[str, int] + ] = {} # destination_cid -> (host, port) + self._connection_lock = trio.Lock() + + # Version negotiation support + self._supported_versions = self._get_supported_versions() + + # Listener state + self._closed = False + self._listening = False + self._nursery: trio.Nursery | None = None + + # Performance tracking + self._stats = { + "connections_accepted": 0, + "connections_rejected": 0, + "version_negotiations": 0, + "bytes_received": 0, + "packets_processed": 0, + "invalid_packets": 0, + } + + def _get_supported_versions(self) -> set[int]: + """Get wire format versions for all supported QUIC configurations.""" + versions: set[int] = set() + for protocol in self._quic_configs: + try: + config = self._quic_configs[protocol] + wire_versions = config.supported_versions + for version in wire_versions: + versions.add(version) + except Exception as e: + logger.warning(f"Failed to get wire version for {protocol}: {e}") + return versions + + def parse_quic_packet(self, data: bytes) -> QUICPacketInfo | None: + """ + Parse QUIC packet header to extract connection IDs and version. + Based on RFC 9000 packet format. + """ + try: + if len(data) < 1: + return None + + # Read first byte to get packet type and flags + first_byte = data[0] + + # Check if this is a long header packet (version negotiation, initial, etc.) + is_long_header = (first_byte & 0x80) != 0 + + if not is_long_header: + cid_length = 8 # We are using standard CID length everywhere + + if len(data) < 1 + cid_length: + return None + + dest_cid = data[1 : 1 + cid_length] + + return QUICPacketInfo( + version=1, # Assume QUIC v1 for established connections + destination_cid=dest_cid, + source_cid=b"", # Not available in short header + packet_type=QuicPacketType.ONE_RTT, + token=b"", + ) + + # Long header packet parsing + offset = 1 + + # Extract version (4 bytes) + if len(data) < offset + 4: + return None + version = struct.unpack("!I", data[offset : offset + 4])[0] + offset += 4 + + # Extract destination connection ID length and value + if len(data) < offset + 1: + return None + dest_cid_len = data[offset] + offset += 1 + + if len(data) < offset + dest_cid_len: + return None + dest_cid = data[offset : offset + dest_cid_len] + offset += dest_cid_len + + # Extract source connection ID length and value + if len(data) < offset + 1: + return None + src_cid_len = data[offset] + offset += 1 + + if len(data) < offset + src_cid_len: + return None + src_cid = data[offset : offset + src_cid_len] + offset += src_cid_len + + # Determine packet type from first byte + packet_type_value = (first_byte & 0x30) >> 4 + + packet_value_to_type_mapping = { + 0: QuicPacketType.INITIAL, + 1: QuicPacketType.ZERO_RTT, + 2: QuicPacketType.HANDSHAKE, + 3: QuicPacketType.RETRY, + 4: QuicPacketType.VERSION_NEGOTIATION, + 5: QuicPacketType.ONE_RTT, + } + + # For Initial packets, extract token + token = b"" + if packet_type_value == 0: # Initial packet + if len(data) < offset + 1: + return None + # Token length is variable-length integer + token_len, token_len_bytes = self._decode_varint(data[offset:]) + offset += token_len_bytes + + if len(data) < offset + token_len: + return None + token = data[offset : offset + token_len] + + return QUICPacketInfo( + version=version, + destination_cid=dest_cid, + source_cid=src_cid, + packet_type=packet_value_to_type_mapping.get(packet_type_value) + or QuicPacketType.INITIAL, + token=token, + ) + + except Exception as e: + logger.debug(f"Failed to parse QUIC packet: {e}") + return None + + def _decode_varint(self, data: bytes) -> tuple[int, int]: + """Decode QUIC variable-length integer.""" + if len(data) < 1: + return 0, 0 + + first_byte = data[0] + length_bits = (first_byte & 0xC0) >> 6 + + if length_bits == 0: + return first_byte & 0x3F, 1 + elif length_bits == 1: + if len(data) < 2: + return 0, 0 + return ((first_byte & 0x3F) << 8) | data[1], 2 + elif length_bits == 2: + if len(data) < 4: + return 0, 0 + return ((first_byte & 0x3F) << 24) | (data[1] << 16) | ( + data[2] << 8 + ) | data[3], 4 + else: # length_bits == 3 + if len(data) < 8: + return 0, 0 + value = (first_byte & 0x3F) << 56 + for i in range(1, 8): + value |= data[i] << (8 * (7 - i)) + return value, 8 + + async def _process_packet(self, data: bytes, addr: tuple[str, int]) -> None: + """Process incoming QUIC packet with optimized routing.""" + try: + self._stats["packets_processed"] += 1 + self._stats["bytes_received"] += len(data) + + packet_info = self.parse_quic_packet(data) + if packet_info is None: + self._stats["invalid_packets"] += 1 + return + + dest_cid = packet_info.destination_cid + + # Single lock acquisition with all lookups + async with self._connection_lock: + connection_obj = self._connections.get(dest_cid) + pending_quic_conn = self._pending_connections.get(dest_cid) + + if not connection_obj and not pending_quic_conn: + if packet_info.packet_type == QuicPacketType.INITIAL: + pending_quic_conn = await self._handle_new_connection( + data, addr, packet_info + ) + else: + return + + # Process outside the lock + if connection_obj: + await self._handle_established_connection_packet( + connection_obj, data, addr, dest_cid + ) + elif pending_quic_conn: + await self._handle_pending_connection_packet( + pending_quic_conn, data, addr, dest_cid + ) + + except Exception as e: + logger.error(f"Error processing packet from {addr}: {e}") + + async def _handle_established_connection_packet( + self, + connection_obj: QUICConnection, + data: bytes, + addr: tuple[str, int], + dest_cid: bytes, + ) -> None: + """Handle packet for established connection WITHOUT holding connection lock.""" + try: + await self._route_to_connection(connection_obj, data, addr) + + except Exception as e: + logger.error(f"Error handling established connection packet: {e}") + + async def _handle_pending_connection_packet( + self, + quic_conn: QuicConnection, + data: bytes, + addr: tuple[str, int], + dest_cid: bytes, + ) -> None: + """Handle packet for pending connection WITHOUT holding connection lock.""" + try: + logger.debug(f"Handling packet for pending connection {dest_cid.hex()}") + logger.debug(f"Packet size: {len(data)} bytes from {addr}") + + # Feed data to QUIC connection + quic_conn.receive_datagram(data, addr, now=time.time()) + logger.debug("PENDING: Datagram received by QUIC connection") + + # Process events - this is crucial for handshake progression + logger.debug("Processing QUIC events...") + await self._process_quic_events(quic_conn, addr, dest_cid) + + # Send any outgoing packets + logger.debug("Transmitting response...") + await self._transmit_for_connection(quic_conn, addr) + + # Check if handshake completed (with minimal locking) + if quic_conn._handshake_complete: + logger.debug("PENDING: Handshake completed, promoting connection") + await self._promote_pending_connection(quic_conn, addr, dest_cid) + else: + logger.debug("Handshake still in progress") + + except Exception as e: + logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}") + + async def _send_version_negotiation( + self, addr: tuple[str, int], source_cid: bytes + ) -> None: + """Send version negotiation packet to client.""" + try: + self._stats["version_negotiations"] += 1 + + # Construct version negotiation packet + packet = bytearray() + + # First byte: long header (1) + unused bits (0111) + packet.append(0x80 | 0x70) + + # Version: 0 for version negotiation + packet.extend(struct.pack("!I", 0)) + + # Destination connection ID (echo source CID from client) + packet.append(len(source_cid)) + packet.extend(source_cid) + + # Source connection ID (empty for version negotiation) + packet.append(0) + + # Supported versions + for version in sorted(self._supported_versions): + packet.extend(struct.pack("!I", version)) + + # Send the packet + if self._socket: + await self._socket.sendto(bytes(packet), addr) + logger.debug( + f"Sent version negotiation to {addr} " + f"with versions {sorted(self._supported_versions)}" + ) + + except Exception as e: + logger.error(f"Failed to send version negotiation to {addr}: {e}") + + async def _handle_new_connection( + self, data: bytes, addr: tuple[str, int], packet_info: QUICPacketInfo + ) -> QuicConnection | None: + """Handle new connection with proper connection ID handling.""" + try: + logger.debug(f"Starting handshake for {addr}") + + # Find appropriate QUIC configuration + quic_config = None + + for protocol, config in self._quic_configs.items(): + wire_versions = custom_quic_version_to_wire_format(protocol) + if wire_versions == packet_info.version: + quic_config = config + break + + if not quic_config: + logger.error( + f"No configuration found for version 0x{packet_info.version:08x}" + ) + await self._send_version_negotiation(addr, packet_info.source_cid) + return None + + if not quic_config: + raise QUICListenError("Cannot determine QUIC configuration") + + # Create server-side QUIC configuration + server_config = create_server_config_from_base( + base_config=quic_config, + security_manager=self._security_manager, + transport_config=self._config, + ) + + # Validate certificate has libp2p extension + if server_config.certificate: + cert = server_config.certificate + has_libp2p_ext = False + for ext in cert.extensions: + if ext.oid == LIBP2P_TLS_EXTENSION_OID: + has_libp2p_ext = True + break + logger.debug(f"Certificate has libp2p extension: {has_libp2p_ext}") + + if not has_libp2p_ext: + logger.error("Certificate missing libp2p extension!") + + logger.debug( + f"Original destination CID: {packet_info.destination_cid.hex()}" + ) + + quic_conn = QuicConnection( + configuration=server_config, + original_destination_connection_id=packet_info.destination_cid, + ) + + quic_conn._replenish_connection_ids() + # Use the first host CID as our routing CID + if quic_conn._host_cids: + destination_cid = quic_conn._host_cids[0].cid + logger.debug(f"Using host CID as routing CID: {destination_cid.hex()}") + else: + # Fallback to random if no host CIDs generated + import secrets + + destination_cid = secrets.token_bytes(8) + logger.debug(f"Fallback to random CID: {destination_cid.hex()}") + + logger.debug(f"Generated {len(quic_conn._host_cids)} host CIDs for client") + + logger.debug( + f"QUIC connection created for destination CID {destination_cid.hex()}" + ) + + # Store connection mapping using our generated CID + self._pending_connections[destination_cid] = quic_conn + self._addr_to_cid[addr] = destination_cid + self._cid_to_addr[destination_cid] = addr + + # Process initial packet + quic_conn.receive_datagram(data, addr, now=time.time()) + if quic_conn.tls: + if self._security_manager: + try: + quic_conn.tls._request_client_certificate = True + logger.debug( + "request_client_certificate set to True in server TLS" + ) + except Exception as e: + logger.error(f"FAILED to apply request_client_certificate: {e}") + + # Process events and send response + await self._process_quic_events(quic_conn, addr, destination_cid) + await self._transmit_for_connection(quic_conn, addr) + + logger.debug( + f"Started handshake for new connection from {addr} " + f"(version: 0x{packet_info.version:08x}, cid: {destination_cid.hex()})" + ) + + return quic_conn + + except Exception as e: + logger.error(f"Error handling new connection from {addr}: {e}") + self._stats["connections_rejected"] += 1 + return None + + async def _handle_short_header_packet( + self, data: bytes, addr: tuple[str, int] + ) -> None: + """Handle short header packets for established connections.""" + try: + logger.debug(f" SHORT_HDR: Handling short header packet from {addr}") + + # First, try address-based lookup + dest_cid = self._addr_to_cid.get(addr) + if dest_cid and dest_cid in self._connections: + connection = self._connections[dest_cid] + await self._route_to_connection(connection, data, addr) + return + + # Fallback: try to extract CID from packet + if len(data) >= 9: # 1 byte header + 8 byte CID + potential_cid = data[1:9] + + if potential_cid in self._connections: + connection = self._connections[potential_cid] + + # Update mappings for future packets + self._addr_to_cid[addr] = potential_cid + self._cid_to_addr[potential_cid] = addr + + await self._route_to_connection(connection, data, addr) + return + + logger.debug(f"āŒ SHORT_HDR: No matching connection found for {addr}") + + except Exception as e: + logger.error(f"Error handling short header packet from {addr}: {e}") + + async def _route_to_connection( + self, connection: QUICConnection, data: bytes, addr: tuple[str, int] + ) -> None: + """Route packet to existing connection.""" + try: + # Feed data to the connection's QUIC instance + connection._quic.receive_datagram(data, addr, now=time.time()) + + # Process events and handle responses + await connection._process_quic_events() + await connection._transmit() + + except Exception as e: + logger.error(f"Error routing packet to connection {addr}: {e}") + # Remove problematic connection + await self._remove_connection_by_addr(addr) + + async def _handle_pending_connection( + self, + quic_conn: QuicConnection, + data: bytes, + addr: tuple[str, int], + dest_cid: bytes, + ) -> None: + """Handle packet for a pending (handshaking) connection.""" + try: + logger.debug(f"Handling packet for pending connection {dest_cid.hex()}") + + # Feed data to QUIC connection + quic_conn.receive_datagram(data, addr, now=time.time()) + + if quic_conn.tls: + logger.debug(f"TLS state after: {quic_conn.tls.state}") + + # Process events - this is crucial for handshake progression + await self._process_quic_events(quic_conn, addr, dest_cid) + + # Send any outgoing packets - this is where the response should be sent + await self._transmit_for_connection(quic_conn, addr) + + # Check if handshake completed + if quic_conn._handshake_complete: + logger.debug("PENDING: Handshake completed, promoting connection") + await self._promote_pending_connection(quic_conn, addr, dest_cid) + + except Exception as e: + logger.error(f"Error handling pending connection {dest_cid.hex()}: {e}") + + # Remove problematic pending connection + logger.error(f"Removing problematic connection {dest_cid.hex()}") + await self._remove_pending_connection(dest_cid) + + async def _process_quic_events( + self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes + ) -> None: + """Process QUIC events with enhanced debugging.""" + try: + events_processed = 0 + while True: + event = quic_conn.next_event() + if event is None: + break + + events_processed += 1 + logger.debug( + "QUIC EVENT: Processing event " + f"{events_processed}: {type(event).__name__}" + ) + + if isinstance(event, events.ConnectionTerminated): + logger.debug( + "QUIC EVENT: Connection terminated " + f"- code: {event.error_code}, reason: {event.reason_phrase}" + f"Connection {dest_cid.hex()} from {addr} " + f"terminated: {event.reason_phrase}" + ) + await self._remove_connection(dest_cid) + break + + elif isinstance(event, events.HandshakeCompleted): + logger.debug( + "QUIC EVENT: Handshake completed for connection " + f"{dest_cid.hex()}" + ) + logger.debug(f"Handshake completed for connection {dest_cid.hex()}") + await self._promote_pending_connection(quic_conn, addr, dest_cid) + + elif isinstance(event, events.StreamDataReceived): + logger.debug( + f"QUIC EVENT: Stream data received on stream {event.stream_id}" + ) + if dest_cid in self._connections: + connection = self._connections[dest_cid] + await connection._handle_stream_data(event) + + elif isinstance(event, events.StreamReset): + logger.debug( + f"QUIC EVENT: Stream reset on stream {event.stream_id}" + ) + if dest_cid in self._connections: + connection = self._connections[dest_cid] + await connection._handle_stream_reset(event) + + elif isinstance(event, events.ConnectionIdIssued): + logger.debug( + f"QUIC EVENT: Connection ID issued: {event.connection_id.hex()}" + ) + # Add new CID to the same address mapping + taddr = self._cid_to_addr.get(dest_cid) + if taddr: + # Don't overwrite, but this CID is also valid for this address + logger.debug( + f"QUIC EVENT: New CID {event.connection_id.hex()} " + f"available for {taddr}" + ) + + elif isinstance(event, events.ConnectionIdRetired): + logger.info(f"Connection ID retired: {event.connection_id.hex()}") + retired_cid = event.connection_id + if retired_cid in self._cid_to_addr: + addr = self._cid_to_addr[retired_cid] + del self._cid_to_addr[retired_cid] + # Only remove addr mapping if this was the active CID + if self._addr_to_cid.get(addr) == retired_cid: + del self._addr_to_cid[addr] + else: + logger.warning(f"Unhandled event type: {type(event).__name__}") + + except Exception as e: + logger.debug(f"āŒ EVENT: Error processing events: {e}") + + async def _promote_pending_connection( + self, quic_conn: QuicConnection, addr: tuple[str, int], dest_cid: bytes + ) -> None: + """Promote pending connection - avoid duplicate creation.""" + try: + self._pending_connections.pop(dest_cid, None) + + if dest_cid in self._connections: + logger.debug( + f"āš ļø Connection {dest_cid.hex()} already exists in _connections!" + ) + connection = self._connections[dest_cid] + else: + from .connection import QUICConnection + + host, port = addr + quic_version = "quic" + remote_maddr = create_quic_multiaddr(host, port, f"/{quic_version}") + + connection = QUICConnection( + quic_connection=quic_conn, + remote_addr=addr, + remote_peer_id=None, + local_peer_id=self._transport._peer_id, + is_initiator=False, + maddr=remote_maddr, + transport=self._transport, + security_manager=self._security_manager, + listener_socket=self._socket, + ) + + logger.debug(f"šŸ”„ Created NEW QUICConnection for {dest_cid.hex()}") + + self._connections[dest_cid] = connection + + self._addr_to_cid[addr] = dest_cid + self._cid_to_addr[dest_cid] = addr + + if self._nursery: + connection._nursery = self._nursery + await connection.connect(self._nursery) + logger.debug(f"Connection connected succesfully for {dest_cid.hex()}") + + if self._security_manager: + try: + peer_id = await connection._verify_peer_identity_with_security() + if peer_id: + connection.peer_id = peer_id + logger.info( + f"Security verification successful for {dest_cid.hex()}" + ) + except Exception as e: + logger.error( + f"Security verification failed for {dest_cid.hex()}: {e}" + ) + await connection.close() + return + + if self._nursery: + connection._nursery = self._nursery + await connection._start_background_tasks() + logger.debug( + f"Started background tasks for connection {dest_cid.hex()}" + ) + + try: + logger.debug(f"Invoking user callback {dest_cid.hex()}") + await self._handler(connection) + + except Exception as e: + logger.error(f"Error in user callback: {e}") + + self._stats["connections_accepted"] += 1 + logger.info(f"Enhanced connection {dest_cid.hex()} established from {addr}") + + except Exception as e: + logger.error(f"āŒ Error promoting connection {dest_cid.hex()}: {e}") + await self._remove_connection(dest_cid) + + async def _remove_connection(self, dest_cid: bytes) -> None: + """Remove connection by connection ID.""" + try: + # Remove connection + connection = self._connections.pop(dest_cid, None) + if connection: + await connection.close() + + # Clean up mappings + addr = self._cid_to_addr.pop(dest_cid, None) + if addr: + self._addr_to_cid.pop(addr, None) + + logger.debug(f"Removed connection {dest_cid.hex()}") + + except Exception as e: + logger.error(f"Error removing connection {dest_cid.hex()}: {e}") + + async def _remove_pending_connection(self, dest_cid: bytes) -> None: + """Remove pending connection by connection ID.""" + try: + self._pending_connections.pop(dest_cid, None) + addr = self._cid_to_addr.pop(dest_cid, None) + if addr: + self._addr_to_cid.pop(addr, None) + logger.debug(f"Removed pending connection {dest_cid.hex()}") + except Exception as e: + logger.error(f"Error removing pending connection {dest_cid.hex()}: {e}") + + async def _remove_connection_by_addr(self, addr: tuple[str, int]) -> None: + """Remove connection by address (fallback method).""" + dest_cid = self._addr_to_cid.get(addr) + if dest_cid: + await self._remove_connection(dest_cid) + + async def _transmit_for_connection( + self, quic_conn: QuicConnection, addr: tuple[str, int] + ) -> None: + """Enhanced transmission diagnostics to analyze datagram content.""" + try: + logger.debug(f" TRANSMIT: Starting transmission to {addr}") + + # Get current timestamp for timing + import time + + now = time.time() + + datagrams = quic_conn.datagrams_to_send(now=now) + logger.debug(f" TRANSMIT: Got {len(datagrams)} datagrams to send") + + if not datagrams: + logger.debug("āš ļø TRANSMIT: No datagrams to send") + return + + for i, (datagram, dest_addr) in enumerate(datagrams): + logger.debug(f" TRANSMIT: Analyzing datagram {i}") + logger.debug(f" TRANSMIT: Datagram size: {len(datagram)} bytes") + logger.debug(f" TRANSMIT: Destination: {dest_addr}") + logger.debug(f" TRANSMIT: Expected destination: {addr}") + + # Analyze datagram content + if len(datagram) > 0: + # QUIC packet format analysis + first_byte = datagram[0] + header_form = (first_byte & 0x80) >> 7 # Bit 7 + + # For long header packets (handshake), analyze further + if header_form == 1: # Long header + # CRYPTO frame type is 0x06 + crypto_frame_found = False + for offset in range(len(datagram)): + if datagram[offset] == 0x06: + crypto_frame_found = True + break + + if not crypto_frame_found: + logger.error("No CRYPTO frame found in datagram!") + # Look for other frame types + frame_types_found = set() + for offset in range(len(datagram)): + frame_type = datagram[offset] + if frame_type in [0x00, 0x01]: # PADDING/PING + frame_types_found.add("PADDING/PING") + elif frame_type == 0x02: # ACK + frame_types_found.add("ACK") + elif frame_type == 0x06: # CRYPTO + frame_types_found.add("CRYPTO") + + if self._socket: + try: + await self._socket.sendto(datagram, addr) + except Exception as send_error: + logger.error(f"Socket send failed: {send_error}") + else: + logger.error("No socket available!") + except Exception as e: + logger.debug(f"Transmission error: {e}") + + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + """Start listening on the given multiaddr with enhanced connection handling.""" + if self._listening: + raise QUICListenError("Already listening") + + if not is_quic_multiaddr(maddr): + raise QUICListenError(f"Invalid QUIC multiaddr: {maddr}") + + if self._transport._background_nursery: + active_nursery = self._transport._background_nursery + logger.debug("Using transport background nursery for listener") + elif nursery: + active_nursery = nursery + self._transport._background_nursery = nursery + logger.debug("Using provided nursery for listener") + else: + raise QUICListenError("No nursery available") + + try: + host, port = quic_multiaddr_to_endpoint(maddr) + + # Create and configure socket + self._socket = await self._create_socket(host, port) + self._nursery = active_nursery + + # Get the actual bound address + bound_host, bound_port = self._socket.getsockname() + quic_version = multiaddr_to_quic_version(maddr) + bound_maddr = create_quic_multiaddr(bound_host, bound_port, quic_version) + self._bound_addresses = [bound_maddr] + + self._listening = True + + # Start packet handling loop + active_nursery.start_soon(self._handle_incoming_packets) + + logger.info( + f"QUIC listener started on {bound_maddr} with connection ID support" + ) + return True + + except Exception as e: + await self.close() + raise QUICListenError(f"Failed to start listening: {e}") from e + + async def _create_socket(self, host: str, port: int) -> trio.socket.SocketType: + """Create and configure UDP socket.""" + try: + # Determine address family + try: + import ipaddress + + ip = ipaddress.ip_address(host) + family = socket.AF_INET if ip.version == 4 else socket.AF_INET6 + except ValueError: + family = socket.AF_INET + + # Create UDP socket + sock = trio.socket.socket(family=family, type=socket.SOCK_DGRAM) + + # Set socket options + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, "SO_REUSEPORT"): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + # Bind to address + await sock.bind((host, port)) + + logger.debug(f"Created and bound UDP socket to {host}:{port}") + return sock + + except Exception as e: + raise QUICListenError(f"Failed to create socket: {e}") from e + + async def _handle_incoming_packets(self) -> None: + """Handle incoming UDP packets with enhanced routing.""" + logger.debug("Started enhanced packet handling loop") + + try: + while self._listening and self._socket: + try: + # Receive UDP packet + data, addr = await self._socket.recvfrom(65536) + + # Process packet asynchronously + if self._nursery: + self._nursery.start_soon(self._process_packet, data, addr) + + except trio.ClosedResourceError: + logger.debug("Socket closed, exiting packet handler") + break + except Exception as e: + logger.error(f"Error receiving packet: {e}") + await trio.sleep(0.01) + except trio.Cancelled: + logger.info("Packet handling cancelled") + raise + finally: + logger.debug("Enhanced packet handling loop terminated") + + async def close(self) -> None: + """Close the listener and clean up resources.""" + if self._closed: + return + + self._closed = True + self._listening = False + + try: + # Close all connections + async with self._connection_lock: + for dest_cid in list(self._connections.keys()): + await self._remove_connection(dest_cid) + + for dest_cid in list(self._pending_connections.keys()): + await self._remove_pending_connection(dest_cid) + + # Close socket + if self._socket: + self._socket.close() + self._socket = None + + self._bound_addresses.clear() + + logger.info("QUIC listener closed") + + except Exception as e: + logger.error(f"Error closing listener: {e}") + + async def _remove_connection_by_object( + self, connection_obj: QUICConnection + ) -> None: + """Remove a connection by object reference.""" + try: + # Find the connection ID for this object + connection_cid = None + for cid, tracked_connection in self._connections.items(): + if tracked_connection is connection_obj: + connection_cid = cid + break + + if connection_cid: + await self._remove_connection(connection_cid) + logger.debug(f"Removed connection {connection_cid.hex()}") + else: + logger.warning("Connection object not found in tracking") + + except Exception as e: + logger.error(f"Error removing connection by object: {e}") + + def get_addresses(self) -> list[Multiaddr]: + """Get the bound addresses.""" + return self._bound_addresses.copy() + + async def _handle_new_established_connection( + self, connection: QUICConnection + ) -> None: + """Handle newly established connection by adding to swarm.""" + try: + logger.debug( + f"New QUIC connection established from {connection._remote_addr}" + ) + + if self._transport._swarm: + logger.debug("Adding QUIC connection directly to swarm") + await self._transport._swarm.add_conn(connection) + logger.debug("Successfully added QUIC connection to swarm") + else: + logger.error("No swarm available for QUIC connection") + await connection.close() + + except Exception as e: + logger.error(f"Error adding QUIC connection to swarm: {e}") + await connection.close() + + def get_addrs(self) -> tuple[Multiaddr]: + return tuple(self.get_addresses()) + + def is_listening(self) -> bool: + """ + Check if the listener is currently listening for connections. + + Returns: + bool: True if the listener is actively listening, False otherwise + + """ + return self._listening and not self._closed + + def get_stats(self) -> dict[str, int | bool]: + """ + Get listener statistics including the listening state. + + Returns: + dict: Statistics dictionary with current state information + + """ + stats = self._stats.copy() + stats["is_listening"] = self.is_listening() + stats["active_connections"] = len(self._connections) + stats["pending_connections"] = len(self._pending_connections) + return stats diff --git a/libp2p/transport/quic/security.py b/libp2p/transport/quic/security.py new file mode 100644 index 00000000..43ebfa37 --- /dev/null +++ b/libp2p/transport/quic/security.py @@ -0,0 +1,1165 @@ +""" +QUIC Security helpers implementation +""" + +from dataclasses import dataclass, field +import logging +import ssl +from typing import Any + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec, rsa +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.x509.base import Certificate +from cryptography.x509.extensions import Extension, UnrecognizedExtension +from cryptography.x509.oid import NameOID + +from libp2p.crypto.keys import PrivateKey, PublicKey +from libp2p.crypto.serialization import deserialize_public_key +from libp2p.peer.id import ID + +from .exceptions import ( + QUICCertificateError, + QUICPeerVerificationError, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# libp2p TLS Extension OID - Official libp2p specification +LIBP2P_TLS_EXTENSION_OID = x509.ObjectIdentifier("1.3.6.1.4.1.53594.1.1") + +# Certificate validity period +CERTIFICATE_VALIDITY_DAYS = 365 +CERTIFICATE_NOT_BEFORE_BUFFER = 3600 # 1 hour before now + + +@dataclass +@dataclass +class TLSConfig: + """TLS configuration for QUIC transport with libp2p extensions.""" + + certificate: x509.Certificate + private_key: ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey + peer_id: ID + + def get_certificate_der(self) -> bytes: + """Get certificate in DER format for external use.""" + return self.certificate.public_bytes(serialization.Encoding.DER) + + def get_private_key_der(self) -> bytes: + """Get private key in DER format for external use.""" + return self.private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + def get_certificate_pem(self) -> bytes: + """Get certificate in PEM format.""" + return self.certificate.public_bytes(serialization.Encoding.PEM) + + def get_private_key_pem(self) -> bytes: + """Get private key in PEM format.""" + return self.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +class LibP2PExtensionHandler: + """ + Handles libp2p-specific TLS extensions for peer identity verification. + + Based on libp2p TLS specification: + https://github.com/libp2p/specs/blob/master/tls/tls.md + """ + + @staticmethod + def create_signed_key_extension( + libp2p_private_key: PrivateKey, + cert_public_key: bytes, + ) -> bytes: + """ + Create the libp2p Public Key Extension with signed key proof. + + The extension contains: + 1. The libp2p public key + 2. A signature proving ownership of the private key + + Args: + libp2p_private_key: The libp2p identity private key + cert_public_key: The certificate's public key bytes + + Returns: + Encoded extension value + + """ + try: + # Get the libp2p public key + libp2p_public_key = libp2p_private_key.get_public_key() + + # Create the signature payload: "libp2p-tls-handshake:" + cert_public_key + signature_payload = b"libp2p-tls-handshake:" + cert_public_key + + # Sign the payload with the libp2p private key + signature = libp2p_private_key.sign(signature_payload) + + # Get the public key bytes + public_key_bytes = libp2p_public_key.serialize() + + # Create ASN.1 DER encoded structure (go-libp2p compatible) + return LibP2PExtensionHandler._create_asn1_der_extension( + public_key_bytes, signature + ) + + except Exception as e: + raise QUICCertificateError( + f"Failed to create signed key extension: {e}" + ) from e + + @staticmethod + def _create_asn1_der_extension(public_key_bytes: bytes, signature: bytes) -> bytes: + """ + Create ASN.1 DER encoded extension (go-libp2p compatible). + + Structure: + SEQUENCE { + publicKey OCTET STRING, + signature OCTET STRING + } + """ + # Encode public key as OCTET STRING + pubkey_octets = LibP2PExtensionHandler._encode_der_octet_string( + public_key_bytes + ) + + # Encode signature as OCTET STRING + sig_octets = LibP2PExtensionHandler._encode_der_octet_string(signature) + + # Combine into SEQUENCE + sequence_content = pubkey_octets + sig_octets + + # Encode as SEQUENCE + return LibP2PExtensionHandler._encode_der_sequence(sequence_content) + + @staticmethod + def _encode_der_length(length: int) -> bytes: + """Encode length in DER format.""" + if length < 128: + # Short form + return bytes([length]) + else: + # Long form + length_bytes = length.to_bytes( + (length.bit_length() + 7) // 8, byteorder="big" + ) + return bytes([0x80 | len(length_bytes)]) + length_bytes + + @staticmethod + def _encode_der_octet_string(data: bytes) -> bytes: + """Encode data as DER OCTET STRING.""" + return ( + bytes([0x04]) + LibP2PExtensionHandler._encode_der_length(len(data)) + data + ) + + @staticmethod + def _encode_der_sequence(data: bytes) -> bytes: + """Encode data as DER SEQUENCE.""" + return ( + bytes([0x30]) + LibP2PExtensionHandler._encode_der_length(len(data)) + data + ) + + @staticmethod + def parse_signed_key_extension( + extension: Extension[Any], + ) -> tuple[PublicKey, bytes]: + """ + Parse the libp2p Public Key Extension with support for all crypto types. + Handles both ASN.1 DER format (from go-libp2p) and simple binary format. + """ + try: + logger.debug(f"šŸ” Extension type: {type(extension)}") + logger.debug(f"šŸ” Extension.value type: {type(extension.value)}") + + # Extract the raw bytes from the extension + if isinstance(extension.value, UnrecognizedExtension): + raw_bytes = extension.value.value + logger.debug( + "šŸ” Extension is UnrecognizedExtension, using .value property" + ) + else: + raw_bytes = extension.value + logger.debug("šŸ” Extension.value is already bytes") + + logger.debug(f"šŸ” Total extension length: {len(raw_bytes)} bytes") + logger.debug(f"šŸ” Extension hex (first 50 bytes): {raw_bytes[:50].hex()}") + + if not isinstance(raw_bytes, bytes): + raise QUICCertificateError(f"Expected bytes, got {type(raw_bytes)}") + + # Check if this is ASN.1 DER encoded (from go-libp2p) + if len(raw_bytes) >= 4 and raw_bytes[0] == 0x30: + logger.debug("šŸ” Detected ASN.1 DER encoding") + return LibP2PExtensionHandler._parse_asn1_der_extension(raw_bytes) + else: + logger.debug("šŸ” Using simple binary format parsing") + return LibP2PExtensionHandler._parse_simple_binary_extension(raw_bytes) + + except Exception as e: + logger.debug(f"āŒ Extension parsing failed: {e}") + import traceback + + logger.debug(f"āŒ Traceback: {traceback.format_exc()}") + raise QUICCertificateError( + f"Failed to parse signed key extension: {e}" + ) from e + + @staticmethod + def _parse_asn1_der_extension(raw_bytes: bytes) -> tuple[PublicKey, bytes]: + """ + Parse ASN.1 DER encoded extension (go-libp2p format). + + The structure is typically: + SEQUENCE { + publicKey OCTET STRING, + signature OCTET STRING + } + """ + try: + offset = 0 + + # Parse SEQUENCE tag + if raw_bytes[offset] != 0x30: + raise QUICCertificateError( + f"Expected SEQUENCE tag (0x30), got {raw_bytes[offset]:02x}" + ) + offset += 1 + + # Parse SEQUENCE length + seq_length, length_bytes = LibP2PExtensionHandler._parse_der_length( + raw_bytes[offset:] + ) + offset += length_bytes + logger.debug(f"šŸ” SEQUENCE length: {seq_length} bytes") + + # Parse first OCTET STRING (public key) + if raw_bytes[offset] != 0x04: + raise QUICCertificateError( + f"Expected OCTET STRING tag (0x04), got {raw_bytes[offset]:02x}" + ) + offset += 1 + + pubkey_length, length_bytes = LibP2PExtensionHandler._parse_der_length( + raw_bytes[offset:] + ) + offset += length_bytes + logger.debug(f"šŸ” Public key length: {pubkey_length} bytes") + + if len(raw_bytes) < offset + pubkey_length: + raise QUICCertificateError("Extension too short for public key data") + + public_key_bytes = raw_bytes[offset : offset + pubkey_length] + offset += pubkey_length + + # Parse second OCTET STRING (signature) + if offset < len(raw_bytes) and raw_bytes[offset] == 0x04: + offset += 1 + sig_length, length_bytes = LibP2PExtensionHandler._parse_der_length( + raw_bytes[offset:] + ) + offset += length_bytes + logger.debug(f"šŸ” Signature length: {sig_length} bytes") + + if len(raw_bytes) < offset + sig_length: + raise QUICCertificateError("Extension too short for signature data") + + signature_data = raw_bytes[offset : offset + sig_length] + else: + # Signature might be the remaining bytes + signature_data = raw_bytes[offset:] + + logger.debug(f"šŸ” Public key data length: {len(public_key_bytes)} bytes") + logger.debug(f"šŸ” Signature data length: {len(signature_data)} bytes") + + # Deserialize the public key + public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes) + logger.debug(f"šŸ” Successfully deserialized public key: {type(public_key)}") + + # Extract signature based on key type + signature = LibP2PExtensionHandler._extract_signature_by_key_type( + public_key, signature_data + ) + + return public_key, signature + + except Exception as e: + raise QUICCertificateError( + f"Failed to parse ASN.1 DER extension: {e}" + ) from e + + @staticmethod + def _parse_der_length(data: bytes) -> tuple[int, int]: + """ + Parse DER length encoding. + Returns (length_value, bytes_consumed). + """ + if not data: + raise QUICCertificateError("No data for DER length") + + first_byte = data[0] + + # Short form (length < 128) + if first_byte < 0x80: + return first_byte, 1 + + # Long form + num_bytes = first_byte & 0x7F + if len(data) < 1 + num_bytes: + raise QUICCertificateError("Insufficient data for DER long form length") + + length = 0 + for i in range(1, num_bytes + 1): + length = (length << 8) | data[i] + + return length, 1 + num_bytes + + @staticmethod + def _parse_simple_binary_extension(raw_bytes: bytes) -> tuple[PublicKey, bytes]: + """ + Parse simple binary format extension (original py-libp2p format). + Format: [4-byte pubkey length][pubkey][4-byte sig length][signature] + """ + offset = 0 + + # Parse public key length and data + if len(raw_bytes) < 4: + raise QUICCertificateError("Extension too short for public key length") + + public_key_length = int.from_bytes( + raw_bytes[offset : offset + 4], byteorder="big" + ) + logger.debug(f"šŸ” Public key length: {public_key_length} bytes") + offset += 4 + + if len(raw_bytes) < offset + public_key_length: + raise QUICCertificateError("Extension too short for public key data") + + public_key_bytes = raw_bytes[offset : offset + public_key_length] + offset += public_key_length + + # Parse signature length and data + if len(raw_bytes) < offset + 4: + raise QUICCertificateError("Extension too short for signature length") + + signature_length = int.from_bytes( + raw_bytes[offset : offset + 4], byteorder="big" + ) + logger.debug(f"šŸ” Signature length: {signature_length} bytes") + offset += 4 + + if len(raw_bytes) < offset + signature_length: + raise QUICCertificateError("Extension too short for signature data") + + signature_data = raw_bytes[offset : offset + signature_length] + + # Deserialize the public key + public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes) + logger.debug(f"šŸ” Successfully deserialized public key: {type(public_key)}") + + # Extract signature based on key type + signature = LibP2PExtensionHandler._extract_signature_by_key_type( + public_key, signature_data + ) + + return public_key, signature + + @staticmethod + def _extract_signature_by_key_type( + public_key: PublicKey, signature_data: bytes + ) -> bytes: + """ + Extract the actual signature from signature_data based on the key type. + Different crypto libraries have different signature formats. + """ + if not hasattr(public_key, "get_type"): + logger.debug("āš ļø Public key has no get_type method, using signature as-is") + return signature_data + + key_type = public_key.get_type() + key_type_name = key_type.name if hasattr(key_type, "name") else str(key_type) + logger.debug(f"šŸ” Processing signature for key type: {key_type_name}") + + # Handle different key types + if key_type_name == "Ed25519": + return LibP2PExtensionHandler._extract_ed25519_signature(signature_data) + + elif key_type_name == "Secp256k1": + return LibP2PExtensionHandler._extract_secp256k1_signature(signature_data) + + elif key_type_name == "RSA": + return LibP2PExtensionHandler._extract_rsa_signature(signature_data) + + elif key_type_name in ["ECDSA", "ECC_P256"]: + return LibP2PExtensionHandler._extract_ecdsa_signature(signature_data) + + else: + logger.debug( + f"āš ļø Unknown key type {key_type_name}, using generic extraction" + ) + return LibP2PExtensionHandler._extract_generic_signature(signature_data) + + @staticmethod + def _extract_ed25519_signature(signature_data: bytes) -> bytes: + """Extract Ed25519 signature (must be exactly 64 bytes).""" + logger.debug("šŸ”§ Extracting Ed25519 signature") + + if len(signature_data) == 64: + logger.debug("āœ… Ed25519 signature is already 64 bytes") + return signature_data + + logger.debug( + f"āš ļø Ed25519 signature is {len(signature_data)} bytes, extracting 64 bytes" + ) + + # Look for the payload marker and extract signature before it + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index >= 64: + # The signature is likely the first 64 bytes before the payload + signature = signature_data[:64] + logger.debug("šŸ”§ Using first 64 bytes as Ed25519 signature") + return signature + + elif marker_index > 0 and marker_index == 64: + # Perfect case: signature is exactly before the marker + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + return signature + + else: + # Fallback: try to extract first 64 bytes + if len(signature_data) >= 64: + signature = signature_data[:64] + logger.debug("šŸ”§ Fallback: using first 64 bytes") + return signature + else: + logger.debug( + f"Cannot extract 64 bytes from {len(signature_data)} byte signature" + ) + return signature_data + + @staticmethod + def _extract_secp256k1_signature(signature_data: bytes) -> bytes: + """ + Extract Secp256k1 signature. Secp256k1 can use either DER-encoded + or raw format depending on the implementation. + """ + logger.debug("šŸ”§ Extracting Secp256k1 signature") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + + # Check if it's DER-encoded (starts with 0x30) + if len(signature) >= 2 and signature[0] == 0x30: + logger.debug("šŸ” Secp256k1 signature appears to be DER-encoded") + return LibP2PExtensionHandler._validate_der_signature(signature) + else: + logger.debug("šŸ” Secp256k1 signature appears to be raw format") + return signature + else: + # No marker found, check if the whole data is DER-encoded + if len(signature_data) >= 2 and signature_data[0] == 0x30: + logger.debug( + "šŸ” Secp256k1 signature appears to be DER-encoded (no marker)" + ) + return LibP2PExtensionHandler._validate_der_signature(signature_data) + else: + logger.debug("šŸ” Using Secp256k1 signature data as-is") + return signature_data + + @staticmethod + def _extract_rsa_signature(signature_data: bytes) -> bytes: + """ + Extract RSA signature. + RSA signatures are typically raw bytes with length matching the key size. + """ + logger.debug("šŸ”§ Extracting RSA signature") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug( + f"šŸ”§ Using {len(signature)} bytes before payload marker for RSA" + ) + return signature + else: + logger.debug("šŸ” Using RSA signature data as-is") + return signature_data + + @staticmethod + def _extract_ecdsa_signature(signature_data: bytes) -> bytes: + """ + Extract ECDSA signature (typically DER-encoded ASN.1). + ECDSA signatures start with 0x30 (ASN.1 SEQUENCE). + """ + logger.debug("šŸ”§ Extracting ECDSA signature") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + + # Validate DER encoding for ECDSA + if len(signature) >= 2 and signature[0] == 0x30: + return LibP2PExtensionHandler._validate_der_signature(signature) + else: + logger.debug( + "āš ļø ECDSA signature doesn't start with DER header, using as-is" + ) + return signature + else: + # Check if the whole data is DER-encoded + if len(signature_data) >= 2 and signature_data[0] == 0x30: + logger.debug("šŸ” ECDSA signature appears to be DER-encoded (no marker)") + return LibP2PExtensionHandler._validate_der_signature(signature_data) + else: + logger.debug("šŸ” Using ECDSA signature data as-is") + return signature_data + + @staticmethod + def _extract_generic_signature(signature_data: bytes) -> bytes: + """ + Generic signature extraction for unknown key types. + Tries to detect DER encoding or extract based on payload marker. + """ + logger.debug("šŸ”§ Extracting signature using generic method") + + # Look for payload marker to separate signature from payload + payload_marker = b"libp2p-tls-handshake:" + marker_index = signature_data.find(payload_marker) + + if marker_index > 0: + signature = signature_data[:marker_index] + logger.debug(f"šŸ”§ Using {len(signature)} bytes before payload marker") + + # Check if it's DER-encoded + if len(signature) >= 2 and signature[0] == 0x30: + return LibP2PExtensionHandler._validate_der_signature(signature) + else: + return signature + else: + # Check if the whole data is DER-encoded + if len(signature_data) >= 2 and signature_data[0] == 0x30: + logger.debug( + "šŸ” Generic signature appears to be DER-encoded (no marker)" + ) + return LibP2PExtensionHandler._validate_der_signature(signature_data) + else: + logger.debug("šŸ” Using signature data as-is") + return signature_data + + @staticmethod + def _validate_der_signature(signature: bytes) -> bytes: + """ + Validate and potentially fix DER-encoded signatures. + DER signatures have the format: 30 [length] ... + """ + if len(signature) < 2: + return signature + + if signature[0] != 0x30: + logger.debug("āš ļø Signature doesn't start with DER SEQUENCE tag") + return signature + + # Get the DER length + der_length = signature[1] + expected_total_length = der_length + 2 + + logger.debug( + f"šŸ” DER signature: length byte = {der_length}, " + f"expected total = {expected_total_length}, " + f"actual length = {len(signature)}" + ) + + if len(signature) == expected_total_length: + logger.debug("āœ… DER signature length is correct") + return signature + elif len(signature) > expected_total_length: + logger.debug( + "Truncating DER signature from " + f"{len(signature)} to {expected_total_length} bytes" + ) + return signature[:expected_total_length] + else: + logger.debug("DER signature is shorter than expected, using as-is") + return signature + + +class LibP2PKeyConverter: + """ + Converts between libp2p key formats and cryptography library formats. + Handles different key types: Ed25519, Secp256k1, RSA, ECDSA. + """ + + @staticmethod + def libp2p_to_tls_private_key( + libp2p_key: PrivateKey, + ) -> ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey: + """ + Convert libp2p private key to TLS-compatible private key. + + For certificate generation, we create a separate ephemeral key + rather than using the libp2p identity key directly. + """ + # For QUIC, we prefer ECDSA keys for smaller certificates + # Generate ephemeral P-256 key for certificate signing + private_key = ec.generate_private_key(ec.SECP256R1()) + return private_key + + @staticmethod + def serialize_public_key(public_key: PublicKey) -> bytes: + """Serialize libp2p public key to bytes.""" + return public_key.serialize() + + @staticmethod + def deserialize_public_key(key_bytes: bytes) -> PublicKey: + """ + Deserialize libp2p public key from protobuf bytes. + + Args: + key_bytes: Protobuf-serialized public key bytes + + Returns: + Deserialized PublicKey instance + + """ + try: + # Use the official libp2p deserialization function + return deserialize_public_key(key_bytes) + except Exception as e: + raise QUICCertificateError(f"Failed to deserialize public key: {e}") from e + + +class CertificateGenerator: + """ + Generates X.509 certificates with libp2p peer identity extensions. + Follows libp2p TLS specification for QUIC transport. + """ + + def __init__(self) -> None: + self.extension_handler = LibP2PExtensionHandler() + self.key_converter = LibP2PKeyConverter() + + def generate_certificate( + self, + libp2p_private_key: PrivateKey, + peer_id: ID, + validity_days: int = CERTIFICATE_VALIDITY_DAYS, + ) -> TLSConfig: + """ + Generate a TLS certificate with embedded libp2p peer identity. + Fixed to use datetime objects for validity periods. + + Args: + libp2p_private_key: The libp2p identity private key + peer_id: The libp2p peer ID + validity_days: Certificate validity period in days + + Returns: + TLSConfig with certificate and private key + + Raises: + QUICCertificateError: If certificate generation fails + + """ + try: + # Generate ephemeral private key for certificate + cert_private_key = self.key_converter.libp2p_to_tls_private_key( + libp2p_private_key + ) + cert_public_key = cert_private_key.public_key() + + # Get certificate public key bytes for extension + cert_public_key_bytes = cert_public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + # Create libp2p extension with signed key proof + extension_data = self.extension_handler.create_signed_key_extension( + libp2p_private_key, cert_public_key_bytes + ) + + from datetime import datetime, timedelta, timezone + + now = datetime.now(timezone.utc) + not_before = now - timedelta(minutes=1) + not_after = now + timedelta(days=validity_days) + + # Generate serial number + serial_number = int(now.timestamp()) + + certificate = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, peer_id.to_base58())] # type: ignore + ) + ) + .issuer_name( + x509.Name( + [x509.NameAttribute(NameOID.COMMON_NAME, peer_id.to_base58())] # type: ignore + ) + ) + .public_key(cert_public_key) + .serial_number(serial_number) + .not_valid_before(not_before) + .not_valid_after(not_after) + .add_extension( + x509.UnrecognizedExtension( + oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data + ), + critical=False, + ) + .sign(cert_private_key, hashes.SHA256()) + ) + + logger.info(f"Generated libp2p TLS certificate for peer {peer_id}") + logger.debug(f"Certificate valid from {not_before} to {not_after}") + + return TLSConfig( + certificate=certificate, private_key=cert_private_key, peer_id=peer_id + ) + + except Exception as e: + raise QUICCertificateError(f"Failed to generate certificate: {e}") from e + + +class PeerAuthenticator: + """ + Authenticates remote peers using libp2p TLS certificates. + Validates both TLS certificate integrity and libp2p peer identity. + """ + + def __init__(self) -> None: + self.extension_handler = LibP2PExtensionHandler() + + def verify_peer_certificate( + self, certificate: x509.Certificate, expected_peer_id: ID | None = None + ) -> ID: + """ + Verify a peer's TLS certificate and extract/validate peer identity. + + Args: + certificate: The peer's TLS certificate + expected_peer_id: Expected peer ID (for outbound connections) + + Returns: + The verified peer ID + + Raises: + QUICPeerVerificationError: If verification fails + + """ + try: + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + + if certificate.not_valid_after_utc < now: + raise QUICPeerVerificationError("Certificate has expired") + + if certificate.not_valid_before_utc > now: + raise QUICPeerVerificationError("Certificate not yet valid") + + # Extract libp2p extension + libp2p_extension = None + for extension in certificate.extensions: + if extension.oid == LIBP2P_TLS_EXTENSION_OID: + libp2p_extension = extension + break + + if not libp2p_extension: + raise QUICPeerVerificationError("Certificate missing libp2p extension") + + assert libp2p_extension.value is not None + logger.debug(f"Extension type: {type(libp2p_extension)}") + logger.debug(f"Extension value type: {type(libp2p_extension.value)}") + if hasattr(libp2p_extension.value, "__len__"): + logger.debug(f"Extension value length: {len(libp2p_extension.value)}") + logger.debug(f"Extension value: {libp2p_extension.value}") + # Parse the extension to get public key and signature + public_key, signature = self.extension_handler.parse_signed_key_extension( + libp2p_extension + ) + + # Get certificate public key for signature verification + cert_public_key_bytes = certificate.public_key().public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + # Verify the signature proves ownership of the libp2p private key + signature_payload = b"libp2p-tls-handshake:" + cert_public_key_bytes + + try: + public_key.verify(signature_payload, signature) + except Exception as e: + raise QUICPeerVerificationError( + f"Invalid signature in libp2p extension: {e}" + ) + + # Derive peer ID from public key + derived_peer_id = ID.from_pubkey(public_key) + + # Verify against expected peer ID if provided + if expected_peer_id and derived_peer_id != expected_peer_id: + logger.debug(f"Expected Peer id: {expected_peer_id}") + logger.debug(f"Derived Peer ID: {derived_peer_id}") + raise QUICPeerVerificationError( + f"Peer ID mismatch: expected {expected_peer_id}, " + f"got {derived_peer_id}" + ) + + logger.debug( + f"Successfully verified peer certificate for {derived_peer_id}" + ) + return derived_peer_id + + except QUICPeerVerificationError: + raise + except Exception as e: + raise QUICPeerVerificationError( + f"Certificate verification failed: {e}" + ) from e + + +@dataclass +class QUICTLSSecurityConfig: + """ + Type-safe TLS security configuration for QUIC transport. + """ + + # Core TLS components (required) + certificate: Certificate + private_key: EllipticCurvePrivateKey | RSAPrivateKey + + # Certificate chain (optional) + certificate_chain: list[Certificate] = field(default_factory=list) + + # ALPN protocols + alpn_protocols: list[str] = field(default_factory=lambda: ["libp2p"]) + + # TLS verification settings + verify_mode: ssl.VerifyMode = ssl.CERT_NONE + check_hostname: bool = False + request_client_certificate: bool = False + + # Optional peer ID for validation + peer_id: ID | None = None + + # Configuration metadata + is_client_config: bool = False + config_name: str | None = None + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + self._validate() + + def _validate(self) -> None: + """Validate the TLS configuration.""" + if self.certificate is None: + raise ValueError("Certificate is required") + + if self.private_key is None: + raise ValueError("Private key is required") + + if not isinstance(self.certificate, x509.Certificate): + raise TypeError( + f"Certificate must be x509.Certificate, got {type(self.certificate)}" + ) + + if not isinstance( + self.private_key, (ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey) + ): + raise TypeError( + f"Private key must be EC or RSA key, got {type(self.private_key)}" + ) + + if not self.alpn_protocols: + raise ValueError("At least one ALPN protocol is required") + + def validate_certificate_key_match(self) -> bool: + """ + Validate that the certificate and private key match. + + Returns: + True if certificate and private key match + + """ + try: + from cryptography.hazmat.primitives import serialization + + # Get public keys from both certificate and private key + cert_public_key = self.certificate.public_key() + private_public_key = self.private_key.public_key() + + # Compare their PEM representations + cert_pub_pem = cert_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + private_pub_pem = private_public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + return cert_pub_pem == private_pub_pem + + except Exception: + return False + + def has_libp2p_extension(self) -> bool: + """ + Check if the certificate has the required libp2p extension. + + Returns: + True if libp2p extension is present + + """ + try: + for ext in self.certificate.extensions: + if ext.oid == LIBP2P_TLS_EXTENSION_OID: + return True + return False + except Exception: + return False + + def is_certificate_valid(self) -> bool: + """ + Check if the certificate is currently valid (not expired). + + Returns: + True if certificate is valid + + """ + try: + from datetime import datetime, timezone + + now = datetime.now(timezone.utc) + not_before = self.certificate.not_valid_before_utc + not_after = self.certificate.not_valid_after_utc + + return not_before <= now <= not_after + except Exception: + return False + + def get_certificate_info(self) -> dict[Any, Any]: + """ + Get certificate information for debugging. + + Returns: + Dictionary with certificate details + + """ + try: + return { + "subject": str(self.certificate.subject), + "issuer": str(self.certificate.issuer), + "serial_number": self.certificate.serial_number, + "not_valid_before_utc": self.certificate.not_valid_before_utc, + "not_valid_after_utc": self.certificate.not_valid_after_utc, + "has_libp2p_extension": self.has_libp2p_extension(), + "is_valid": self.is_certificate_valid(), + "certificate_key_match": self.validate_certificate_key_match(), + } + except Exception as e: + return {"error": str(e)} + + def debug_config(self) -> None: + """logger.debug debugging information about this configuration.""" + logger.debug( + f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===" + ) + logger.debug(f"Is client config: {self.is_client_config}") + logger.debug(f"ALPN protocols: {self.alpn_protocols}") + logger.debug(f"Verify mode: {self.verify_mode}") + logger.debug(f"Check hostname: {self.check_hostname}") + logger.debug(f"Certificate chain length: {len(self.certificate_chain)}") + + cert_info: dict[Any, Any] = self.get_certificate_info() + for key, value in cert_info.items(): + logger.debug(f"Certificate {key}: {value}") + + logger.debug(f"Private key type: {type(self.private_key).__name__}") + if hasattr(self.private_key, "key_size"): + logger.debug(f"Private key size: {self.private_key.key_size}") + + +def create_server_tls_config( + certificate: Certificate, + private_key: EllipticCurvePrivateKey | RSAPrivateKey, + peer_id: ID | None = None, + **kwargs: Any, +) -> QUICTLSSecurityConfig: + """ + Create a server TLS configuration. + + Args: + certificate: X.509 certificate + private_key: Private key corresponding to certificate + peer_id: Optional peer ID for validation + kwargs: Additional configuration parameters + + Returns: + Server TLS configuration + + """ + return QUICTLSSecurityConfig( + certificate=certificate, + private_key=private_key, + peer_id=peer_id, + is_client_config=False, + config_name="server", + verify_mode=ssl.CERT_NONE, + check_hostname=False, + request_client_certificate=True, + **kwargs, + ) + + +def create_client_tls_config( + certificate: Certificate, + private_key: EllipticCurvePrivateKey | RSAPrivateKey, + peer_id: ID | None = None, + **kwargs: Any, +) -> QUICTLSSecurityConfig: + """ + Create a client TLS configuration. + + Args: + certificate: X.509 certificate + private_key: Private key corresponding to certificate + peer_id: Optional peer ID for validation + kwargs: Additional configuration parameters + + Returns: + Client TLS configuration + + """ + return QUICTLSSecurityConfig( + certificate=certificate, + private_key=private_key, + peer_id=peer_id, + is_client_config=True, + config_name="client", + verify_mode=ssl.CERT_NONE, + check_hostname=False, + **kwargs, + ) + + +class QUICTLSConfigManager: + """ + Manages TLS configuration for QUIC transport with libp2p security. + Integrates with aioquic's TLS configuration system. + """ + + def __init__(self, libp2p_private_key: PrivateKey, peer_id: ID) -> None: + self.libp2p_private_key = libp2p_private_key + self.peer_id = peer_id + self.certificate_generator = CertificateGenerator() + self.peer_authenticator = PeerAuthenticator() + + # Generate certificate for this peer + self.tls_config = self.certificate_generator.generate_certificate( + libp2p_private_key, peer_id + ) + + def create_server_config(self) -> QUICTLSSecurityConfig: + """ + Create server configuration using the new class-based approach. + + Returns: + QUICTLSSecurityConfig instance for server + + """ + config = create_server_tls_config( + certificate=self.tls_config.certificate, + private_key=self.tls_config.private_key, + peer_id=self.peer_id, + ) + + return config + + def create_client_config(self) -> QUICTLSSecurityConfig: + """ + Create client configuration using the new class-based approach. + + Returns: + QUICTLSSecurityConfig instance for client + + """ + config = create_client_tls_config( + certificate=self.tls_config.certificate, + private_key=self.tls_config.private_key, + peer_id=self.peer_id, + ) + + return config + + def verify_peer_identity( + self, peer_certificate: x509.Certificate, expected_peer_id: ID | None = None + ) -> ID: + """ + Verify remote peer's identity from their TLS certificate. + + Args: + peer_certificate: Remote peer's TLS certificate + expected_peer_id: Expected peer ID (for outbound connections) + + Returns: + Verified peer ID + + """ + return self.peer_authenticator.verify_peer_certificate( + peer_certificate, expected_peer_id + ) + + def get_local_peer_id(self) -> ID: + """Get the local peer ID.""" + return self.peer_id + + +# Factory function for creating QUIC security transport +def create_quic_security_transport( + libp2p_private_key: PrivateKey, peer_id: ID +) -> QUICTLSConfigManager: + """ + Factory function to create QUIC security transport. + + Args: + libp2p_private_key: The libp2p identity private key + peer_id: The libp2p peer ID + + Returns: + Configured QUIC TLS manager + + """ + return QUICTLSConfigManager(libp2p_private_key, peer_id) diff --git a/libp2p/transport/quic/stream.py b/libp2p/transport/quic/stream.py new file mode 100644 index 00000000..dac8925e --- /dev/null +++ b/libp2p/transport/quic/stream.py @@ -0,0 +1,656 @@ +""" +QUIC Stream implementation +Provides stream interface over QUIC's native multiplexing. +""" + +from enum import Enum +import logging +import time +from types import TracebackType +from typing import TYPE_CHECKING, Any, cast + +import trio + +from .exceptions import ( + QUICStreamBackpressureError, + QUICStreamClosedError, + QUICStreamResetError, + QUICStreamTimeoutError, +) + +if TYPE_CHECKING: + from libp2p.abc import IMuxedStream + from libp2p.custom_types import TProtocol + + from .connection import QUICConnection +else: + IMuxedStream = cast(type, object) + TProtocol = cast(type, object) + +logger = logging.getLogger(__name__) + + +class StreamState(Enum): + """Stream lifecycle states following libp2p patterns.""" + + OPEN = "open" + WRITE_CLOSED = "write_closed" + READ_CLOSED = "read_closed" + CLOSED = "closed" + RESET = "reset" + + +class StreamDirection(Enum): + """Stream direction for tracking initiator.""" + + INBOUND = "inbound" + OUTBOUND = "outbound" + + +class StreamTimeline: + """Track stream lifecycle events for debugging and monitoring.""" + + def __init__(self) -> None: + self.created_at = time.time() + self.opened_at: float | None = None + self.first_data_at: float | None = None + self.closed_at: float | None = None + self.reset_at: float | None = None + self.error_code: int | None = None + + def record_open(self) -> None: + self.opened_at = time.time() + + def record_first_data(self) -> None: + if self.first_data_at is None: + self.first_data_at = time.time() + + def record_close(self) -> None: + self.closed_at = time.time() + + def record_reset(self, error_code: int) -> None: + self.reset_at = time.time() + self.error_code = error_code + + +class QUICStream(IMuxedStream): + """ + QUIC Stream implementation following libp2p IMuxedStream interface. + + Based on patterns from go-libp2p and js-libp2p, this implementation: + - Leverages QUIC's native multiplexing and flow control + - Integrates with libp2p resource management + - Provides comprehensive error handling with QUIC-specific codes + - Supports bidirectional communication with independent close semantics + - Implements proper stream lifecycle management + """ + + def __init__( + self, + connection: "QUICConnection", + stream_id: int, + direction: StreamDirection, + remote_addr: tuple[str, int], + resource_scope: Any | None = None, + ): + """ + Initialize QUIC stream. + + Args: + connection: Parent QUIC connection + stream_id: QUIC stream identifier + direction: Stream direction (inbound/outbound) + resource_scope: Resource manager scope for memory accounting + remote_addr: Remote addr stream is connected to + + """ + self._connection = connection + self._stream_id = stream_id + self._direction = direction + self._resource_scope = resource_scope + + # libp2p interface compliance + self._protocol: TProtocol | None = None + self._metadata: dict[str, Any] = {} + self._remote_addr = remote_addr + + # Stream state management + self._state = StreamState.OPEN + self._state_lock = trio.Lock() + + # Flow control and buffering + self._receive_buffer = bytearray() + self._receive_buffer_lock = trio.Lock() + self._receive_event = trio.Event() + self._backpressure_event = trio.Event() + self._backpressure_event.set() # Initially no backpressure + + # Close/reset state + self._write_closed = False + self._read_closed = False + self._close_event = trio.Event() + self._reset_error_code: int | None = None + + # Lifecycle tracking + self._timeline = StreamTimeline() + self._timeline.record_open() + + # Resource accounting + self._memory_reserved = 0 + + # Stream constant configurations + self.READ_TIMEOUT = connection._transport._config.STREAM_READ_TIMEOUT + self.WRITE_TIMEOUT = connection._transport._config.STREAM_WRITE_TIMEOUT + self.FLOW_CONTROL_WINDOW_SIZE = ( + connection._transport._config.STREAM_FLOW_CONTROL_WINDOW + ) + self.MAX_RECEIVE_BUFFER_SIZE = ( + connection._transport._config.MAX_STREAM_RECEIVE_BUFFER + ) + + if self._resource_scope: + self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE) + + logger.debug( + f"Created QUIC stream {stream_id} " + f"({direction.value}, connection: {connection.remote_peer_id()})" + ) + + # Properties for libp2p interface compliance + + @property + def protocol(self) -> TProtocol | None: + """Get the protocol identifier for this stream.""" + return self._protocol + + @protocol.setter + def protocol(self, protocol_id: TProtocol) -> None: + """Set the protocol identifier for this stream.""" + self._protocol = protocol_id + self._metadata["protocol"] = protocol_id + logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}") + + @property + def stream_id(self) -> str: + """Get stream ID as string for libp2p compatibility.""" + return str(self._stream_id) + + @property + def muxed_conn(self) -> "QUICConnection": # type: ignore + """Get the parent muxed connection.""" + return self._connection + + @property + def state(self) -> StreamState: + """Get current stream state.""" + return self._state + + @property + def direction(self) -> StreamDirection: + """Get stream direction.""" + return self._direction + + @property + def is_initiator(self) -> bool: + """Check if this stream was locally initiated.""" + return self._direction == StreamDirection.OUTBOUND + + # Core stream operations + + async def read(self, n: int | None = None) -> bytes: + """ + Read data from the stream with QUIC flow control. + + Args: + n: Maximum number of bytes to read. If None or -1, read all available. + + Returns: + Data read from stream + + Raises: + QUICStreamClosedError: Stream is closed + QUICStreamResetError: Stream was reset + QUICStreamTimeoutError: Read timeout exceeded + + """ + if n is None: + n = -1 + + async with self._state_lock: + if self._state in (StreamState.CLOSED, StreamState.RESET): + raise QUICStreamClosedError(f"Stream {self.stream_id} is closed") + + if self._read_closed: + # Return any remaining buffered data, then EOF + async with self._receive_buffer_lock: + if self._receive_buffer: + data = self._extract_data_from_buffer(n) + self._timeline.record_first_data() + return data + return b"" + + # Wait for data with timeout + timeout = self.READ_TIMEOUT + try: + with trio.move_on_after(timeout) as cancel_scope: + while True: + async with self._receive_buffer_lock: + if self._receive_buffer: + data = self._extract_data_from_buffer(n) + self._timeline.record_first_data() + return data + + # Check if stream was closed while waiting + if self._read_closed: + return b"" + + # Wait for more data + await self._receive_event.wait() + self._receive_event = trio.Event() # Reset for next wait + + if cancel_scope.cancelled_caught: + raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}") + + return b"" + except QUICStreamResetError: + # Stream was reset while reading + raise + except Exception as e: + logger.error(f"Error reading from stream {self.stream_id}: {e}") + await self._handle_stream_error(e) + raise + + async def write(self, data: bytes) -> None: + """ + Write data to the stream with QUIC flow control. + + Args: + data: Data to write + + Raises: + QUICStreamClosedError: Stream is closed for writing + QUICStreamBackpressureError: Flow control window exhausted + QUICStreamResetError: Stream was reset + + """ + if not data: + return + + async with self._state_lock: + if self._state in (StreamState.CLOSED, StreamState.RESET): + raise QUICStreamClosedError(f"Stream {self.stream_id} is closed") + + if self._write_closed: + raise QUICStreamClosedError( + f"Stream {self.stream_id} write side is closed" + ) + + try: + # Handle flow control backpressure + await self._backpressure_event.wait() + + # Send data through QUIC connection + self._connection._quic.send_stream_data(self._stream_id, data) + await self._connection._transmit() + + self._timeline.record_first_data() + logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}") + + except Exception as e: + logger.error(f"Error writing to stream {self.stream_id}: {e}") + # Convert QUIC-specific errors + if "flow control" in str(e).lower(): + raise QUICStreamBackpressureError(f"Flow control limit reached: {e}") + await self._handle_stream_error(e) + raise + + async def close(self) -> None: + """ + Close the stream gracefully (both read and write sides). + + This implements proper close semantics where both sides + are closed and resources are cleaned up. + """ + async with self._state_lock: + if self._state in (StreamState.CLOSED, StreamState.RESET): + return + + logger.debug(f"Closing stream {self.stream_id}") + + # Close both sides + if not self._write_closed: + await self.close_write() + if not self._read_closed: + await self.close_read() + + # Update state and cleanup + async with self._state_lock: + self._state = StreamState.CLOSED + + await self._cleanup_resources() + self._timeline.record_close() + self._close_event.set() + + logger.debug(f"Stream {self.stream_id} closed") + + async def close_write(self) -> None: + """Close the write side of the stream.""" + if self._write_closed: + return + + try: + # Send FIN to close write side + self._connection._quic.send_stream_data( + self._stream_id, b"", end_stream=True + ) + await self._connection._transmit() + + self._write_closed = True + + async with self._state_lock: + if self._read_closed: + self._state = StreamState.CLOSED + else: + self._state = StreamState.WRITE_CLOSED + + logger.debug(f"Stream {self.stream_id} write side closed") + + except Exception as e: + logger.error(f"Error closing write side of stream {self.stream_id}: {e}") + + async def close_read(self) -> None: + """Close the read side of the stream.""" + if self._read_closed: + return + + try: + self._read_closed = True + + async with self._state_lock: + if self._write_closed: + self._state = StreamState.CLOSED + else: + self._state = StreamState.READ_CLOSED + + # Wake up any pending reads + self._receive_event.set() + + logger.debug(f"Stream {self.stream_id} read side closed") + + except Exception as e: + logger.error(f"Error closing read side of stream {self.stream_id}: {e}") + + async def reset(self, error_code: int = 0) -> None: + """ + Reset the stream with the given error code. + + Args: + error_code: QUIC error code for the reset + + """ + async with self._state_lock: + if self._state == StreamState.RESET: + return + + logger.debug( + f"Resetting stream {self.stream_id} with error code {error_code}" + ) + + self._state = StreamState.RESET + self._reset_error_code = error_code + + try: + # Send QUIC reset frame + self._connection._quic.reset_stream(self._stream_id, error_code) + await self._connection._transmit() + + except Exception as e: + logger.error(f"Error sending reset for stream {self.stream_id}: {e}") + finally: + # Always cleanup resources + await self._cleanup_resources() + self._timeline.record_reset(error_code) + self._close_event.set() + + def is_closed(self) -> bool: + """Check if stream is completely closed.""" + return self._state in (StreamState.CLOSED, StreamState.RESET) + + def is_reset(self) -> bool: + """Check if stream was reset.""" + return self._state == StreamState.RESET + + def can_read(self) -> bool: + """Check if stream can be read from.""" + return not self._read_closed and self._state not in ( + StreamState.CLOSED, + StreamState.RESET, + ) + + def can_write(self) -> bool: + """Check if stream can be written to.""" + return not self._write_closed and self._state not in ( + StreamState.CLOSED, + StreamState.RESET, + ) + + async def handle_data_received(self, data: bytes, end_stream: bool) -> None: + """ + Handle data received from the QUIC connection. + + Args: + data: Received data + end_stream: Whether this is the last data (FIN received) + + """ + if self._state == StreamState.RESET: + return + + if data: + async with self._receive_buffer_lock: + if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE: + logger.warning( + f"Stream {self.stream_id} receive buffer overflow, " + f"dropping {len(data)} bytes" + ) + return + + self._receive_buffer.extend(data) + self._timeline.record_first_data() + + # Notify waiting readers + self._receive_event.set() + + logger.debug(f"Stream {self.stream_id} received {len(data)} bytes") + + if end_stream: + self._read_closed = True + async with self._state_lock: + if self._write_closed: + self._state = StreamState.CLOSED + else: + self._state = StreamState.READ_CLOSED + + # Wake up readers to process remaining data and EOF + self._receive_event.set() + + logger.debug(f"Stream {self.stream_id} received FIN") + + async def handle_stop_sending(self, error_code: int) -> None: + """ + Handle STOP_SENDING frame from remote peer. + + When a STOP_SENDING frame is received, the peer is requesting that we + stop sending data on this stream. We respond by resetting the stream. + + Args: + error_code: Error code from the STOP_SENDING frame + + """ + logger.debug( + f"Stream {self.stream_id} handling STOP_SENDING (error_code={error_code})" + ) + + self._write_closed = True + + # Wake up any pending write operations + self._backpressure_event.set() + + async with self._state_lock: + if self.direction == StreamDirection.OUTBOUND: + self._state = StreamState.CLOSED + elif self._read_closed: + self._state = StreamState.CLOSED + else: + # Only write side closed - add WRITE_CLOSED state if needed + self._state = StreamState.WRITE_CLOSED + + # Send RESET_STREAM in response (QUIC protocol requirement) + try: + self._connection._quic.reset_stream(int(self.stream_id), error_code) + await self._connection._transmit() + logger.debug(f"Sent RESET_STREAM for stream {self.stream_id}") + except Exception as e: + logger.warning( + f"Could not send RESET_STREAM for stream {self.stream_id}: {e}" + ) + + async def handle_reset(self, error_code: int) -> None: + """ + Handle stream reset from remote peer. + + Args: + error_code: QUIC error code from reset frame + + """ + logger.debug( + f"Stream {self.stream_id} reset by peer with error code {error_code}" + ) + + async with self._state_lock: + self._state = StreamState.RESET + self._reset_error_code = error_code + + await self._cleanup_resources() + self._timeline.record_reset(error_code) + self._close_event.set() + + # Wake up any pending operations + self._receive_event.set() + self._backpressure_event.set() + + async def handle_flow_control_update(self, available_window: int) -> None: + """ + Handle flow control window updates. + + Args: + available_window: Available flow control window size + + """ + if available_window > 0: + self._backpressure_event.set() + logger.debug( + f"Stream {self.stream_id} flow control".__add__( + f"window updated: {available_window}" + ) + ) + else: + self._backpressure_event = trio.Event() # Reset to blocking state + logger.debug(f"Stream {self.stream_id} flow control window exhausted") + + def _extract_data_from_buffer(self, n: int) -> bytes: + """Extract data from receive buffer with specified limit.""" + if n == -1: + # Read all available data + data = bytes(self._receive_buffer) + self._receive_buffer.clear() + else: + # Read up to n bytes + data = bytes(self._receive_buffer[:n]) + self._receive_buffer = self._receive_buffer[n:] + + return data + + async def _handle_stream_error(self, error: Exception) -> None: + """Handle errors by resetting the stream.""" + logger.error(f"Stream {self.stream_id} error: {error}") + await self.reset(error_code=1) # Generic error code + + def _reserve_memory(self, size: int) -> None: + """Reserve memory with resource manager.""" + if self._resource_scope: + try: + self._resource_scope.reserve_memory(size) + self._memory_reserved += size + except Exception as e: + logger.warning( + f"Failed to reserve memory for stream {self.stream_id}: {e}" + ) + + def _release_memory(self, size: int) -> None: + """Release memory with resource manager.""" + if self._resource_scope and size > 0: + try: + self._resource_scope.release_memory(size) + self._memory_reserved = max(0, self._memory_reserved - size) + except Exception as e: + logger.warning( + f"Failed to release memory for stream {self.stream_id}: {e}" + ) + + async def _cleanup_resources(self) -> None: + """Clean up stream resources.""" + # Release all reserved memory + if self._memory_reserved > 0: + self._release_memory(self._memory_reserved) + + # Clear receive buffer + async with self._receive_buffer_lock: + self._receive_buffer.clear() + + # Remove from connection's stream registry + self._connection._remove_stream(self._stream_id) + + logger.debug(f"Stream {self.stream_id} resources cleaned up") + + # Abstact implementations + + def get_remote_address(self) -> tuple[str, int]: + return self._remote_addr + + async def __aenter__(self) -> "QUICStream": + """Enter the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the async context manager and close the stream.""" + logger.debug("Exiting the context and closing the stream") + await self.close() + + def set_deadline(self, ttl: int) -> bool: + """ + Set a deadline for the stream. QUIC does not support deadlines natively, + so this method always returns False to indicate the operation is unsupported. + + :param ttl: Time-to-live in seconds (ignored). + :return: False, as deadlines are not supported. + """ + raise NotImplementedError("QUIC does not support setting read deadlines") + + # String representation for debugging + + def __repr__(self) -> str: + return ( + f"QUICStream(id={self.stream_id}, " + f"state={self._state.value}, " + f"direction={self._direction.value}, " + f"protocol={self._protocol})" + ) + + def __str__(self) -> str: + return f"QUICStream({self.stream_id})" diff --git a/libp2p/transport/quic/transport.py b/libp2p/transport/quic/transport.py new file mode 100644 index 00000000..ef0df368 --- /dev/null +++ b/libp2p/transport/quic/transport.py @@ -0,0 +1,491 @@ +""" +QUIC Transport implementation +""" + +import copy +import logging +import ssl +from typing import TYPE_CHECKING, cast + +from aioquic.quic.configuration import ( + QuicConfiguration, +) +from aioquic.quic.connection import ( + QuicConnection as NativeQUICConnection, +) +from aioquic.quic.logger import QuicLogger +import multiaddr +import trio + +from libp2p.abc import ( + ITransport, +) +from libp2p.crypto.keys import ( + PrivateKey, +) +from libp2p.custom_types import TProtocol, TQUICConnHandlerFn +from libp2p.peer.id import ( + ID, +) +from libp2p.transport.quic.security import QUICTLSSecurityConfig +from libp2p.transport.quic.utils import ( + create_client_config_from_base, + create_server_config_from_base, + get_alpn_protocols, + is_quic_multiaddr, + multiaddr_to_quic_version, + quic_multiaddr_to_endpoint, + quic_version_to_wire_format, +) + +if TYPE_CHECKING: + from libp2p.network.swarm import Swarm +else: + Swarm = cast(type, object) + +from .config import ( + QUICTransportConfig, +) +from .connection import ( + QUICConnection, +) +from .exceptions import ( + QUICDialError, + QUICListenError, + QUICSecurityError, +) +from .listener import ( + QUICListener, +) +from .security import ( + QUICTLSConfigManager, + create_quic_security_transport, +) + +QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 +QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 + +logger = logging.getLogger(__name__) + + +class QUICTransport(ITransport): + """ + QUIC Stream implementation following libp2p IMuxedStream interface. + """ + + def __init__( + self, private_key: PrivateKey, config: QUICTransportConfig | None = None + ) -> None: + """ + Initialize QUIC transport with security integration. + + Args: + private_key: libp2p private key for identity and TLS cert generation + config: QUIC transport configuration options + + """ + self._private_key = private_key + self._peer_id = ID.from_pubkey(private_key.get_public_key()) + self._config = config or QUICTransportConfig() + + # Connection management + self._connections: dict[str, QUICConnection] = {} + self._listeners: list[QUICListener] = [] + + # Security manager for TLS integration + self._security_manager = create_quic_security_transport( + self._private_key, self._peer_id + ) + + # QUIC configurations for different versions + self._quic_configs: dict[TProtocol, QuicConfiguration] = {} + self._setup_quic_configurations() + + # Resource management + self._closed = False + self._nursery_manager = trio.CapacityLimiter(1) + self._background_nursery: trio.Nursery | None = None + + self._swarm: Swarm | None = None + + logger.debug( + f"Initialized QUIC transport with security for peer {self._peer_id}" + ) + + def set_background_nursery(self, nursery: trio.Nursery) -> None: + """Set the nursery to use for background tasks (called by swarm).""" + self._background_nursery = nursery + logger.debug("Transport background nursery set") + + def set_swarm(self, swarm: Swarm) -> None: + """Set the swarm for adding incoming connections.""" + self._swarm = swarm + + def _setup_quic_configurations(self) -> None: + """Setup QUIC configurations.""" + try: + # Get TLS configuration from security manager + server_tls_config = self._security_manager.create_server_config() + client_tls_config = self._security_manager.create_client_config() + + # Base server configuration + base_server_config = QuicConfiguration( + is_client=False, + alpn_protocols=get_alpn_protocols(), + verify_mode=self._config.verify_mode, + max_datagram_frame_size=self._config.max_datagram_size, + idle_timeout=self._config.idle_timeout, + ) + + # Base client configuration + base_client_config = QuicConfiguration( + is_client=True, + alpn_protocols=get_alpn_protocols(), + verify_mode=self._config.verify_mode, + max_datagram_frame_size=self._config.max_datagram_size, + idle_timeout=self._config.idle_timeout, + ) + + # Apply TLS configuration + self._apply_tls_configuration(base_server_config, server_tls_config) + self._apply_tls_configuration(base_client_config, client_tls_config) + + # QUIC v1 (RFC 9000) configurations + if self._config.enable_v1: + quic_v1_server_config = create_server_config_from_base( + base_server_config, self._security_manager, self._config + ) + quic_v1_server_config.supported_versions = [ + quic_version_to_wire_format(QUIC_V1_PROTOCOL) + ] + + quic_v1_client_config = create_client_config_from_base( + base_client_config, self._security_manager, self._config + ) + quic_v1_client_config.supported_versions = [ + quic_version_to_wire_format(QUIC_V1_PROTOCOL) + ] + + # Store both server and client configs for v1 + self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_server")] = ( + quic_v1_server_config + ) + self._quic_configs[TProtocol(f"{QUIC_V1_PROTOCOL}_client")] = ( + quic_v1_client_config + ) + + # QUIC draft-29 configurations for compatibility + if self._config.enable_draft29: + draft29_server_config: QuicConfiguration = copy.copy(base_server_config) + draft29_server_config.supported_versions = [ + quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL) + ] + + draft29_client_config = copy.copy(base_client_config) + draft29_client_config.supported_versions = [ + quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL) + ] + + self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_server")] = ( + draft29_server_config + ) + self._quic_configs[TProtocol(f"{QUIC_DRAFT29_PROTOCOL}_client")] = ( + draft29_client_config + ) + + logger.debug("QUIC configurations initialized with libp2p TLS security") + + except Exception as e: + raise QUICSecurityError( + f"Failed to setup QUIC TLS configurations: {e}" + ) from e + + def _apply_tls_configuration( + self, config: QuicConfiguration, tls_config: QUICTLSSecurityConfig + ) -> None: + """ + Apply TLS configuration to a QUIC configuration using aioquic's actual API. + + Args: + config: QuicConfiguration to update + tls_config: TLS configuration dictionary from security manager + + """ + try: + config.certificate = tls_config.certificate + config.private_key = tls_config.private_key + config.certificate_chain = tls_config.certificate_chain + config.alpn_protocols = tls_config.alpn_protocols + config.verify_mode = ssl.CERT_NONE + + logger.debug("Successfully applied TLS configuration to QUIC config") + + except Exception as e: + raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e + + async def dial( + self, + maddr: multiaddr.Multiaddr, + ) -> QUICConnection: + """ + Dial a remote peer using QUIC transport with security verification. + + Args: + maddr: Multiaddr of the remote peer (e.g., /ip4/1.2.3.4/udp/4001/quic-v1) + peer_id: Expected peer ID for verification + nursery: Nursery to execute the background tasks + + Returns: + Raw connection interface to the remote peer + + Raises: + QUICDialError: If dialing fails + QUICSecurityError: If security verification fails + + """ + if self._closed: + raise QUICDialError("Transport is closed") + + if not is_quic_multiaddr(maddr): + raise QUICDialError(f"Invalid QUIC multiaddr: {maddr}") + + try: + # Extract connection details from multiaddr + host, port = quic_multiaddr_to_endpoint(maddr) + remote_peer_id = maddr.get_peer_id() + if remote_peer_id is not None: + remote_peer_id = ID.from_base58(remote_peer_id) + + if remote_peer_id is None: + logger.error("Unable to derive peer id from multiaddr") + raise QUICDialError("Unable to derive peer id from multiaddr") + quic_version = multiaddr_to_quic_version(maddr) + + # Get appropriate QUIC client configuration + config_key = TProtocol(f"{quic_version}_client") + logger.debug("config_key", config_key, self._quic_configs.keys()) + config = self._quic_configs.get(config_key) + if not config: + raise QUICDialError(f"Unsupported QUIC version: {quic_version}") + + config.is_client = True + config.quic_logger = QuicLogger() + + # Ensure client certificate is properly set for mutual authentication + if not config.certificate or not config.private_key: + logger.warning( + "Client config missing certificate - applying TLS config" + ) + client_tls_config = self._security_manager.create_client_config() + self._apply_tls_configuration(config, client_tls_config) + + # Debug log to verify certificate is present + logger.info( + f"Dialing QUIC connection to {host}:{port} (version: {{quic_version}})" + ) + + logger.debug("Starting QUIC Connection") + # Create QUIC connection using aioquic's sans-IO core + native_quic_connection = NativeQUICConnection(configuration=config) + + # Create trio-based QUIC connection wrapper with security + connection = QUICConnection( + quic_connection=native_quic_connection, + remote_addr=(host, port), + remote_peer_id=remote_peer_id, + local_peer_id=self._peer_id, + is_initiator=True, + maddr=maddr, + transport=self, + security_manager=self._security_manager, + ) + logger.debug("QUIC Connection Created") + + if self._background_nursery is None: + logger.error("No nursery set to execute background tasks") + raise QUICDialError("No nursery found to execute tasks") + + await connection.connect(self._background_nursery) + + # Store connection for management + conn_id = f"{host}:{port}" + self._connections[conn_id] = connection + + return connection + + except Exception as e: + logger.error(f"Failed to dial QUIC connection to {maddr}: {e}") + raise QUICDialError(f"Dial failed: {e}") from e + + async def _verify_peer_identity( + self, connection: QUICConnection, expected_peer_id: ID + ) -> None: + """ + Verify remote peer identity after TLS handshake. + + Args: + connection: The established QUIC connection + expected_peer_id: Expected peer ID + + Raises: + QUICSecurityError: If peer verification fails + + """ + try: + # Get peer certificate from the connection + peer_certificate = await connection.get_peer_certificate() + + if not peer_certificate: + raise QUICSecurityError("No peer certificate available") + + # Verify peer identity using security manager + verified_peer_id = self._security_manager.verify_peer_identity( + peer_certificate, expected_peer_id + ) + + if verified_peer_id != expected_peer_id: + raise QUICSecurityError( + "Peer ID verification failed: expected " + f"{expected_peer_id}, got {verified_peer_id}" + ) + + logger.debug(f"Peer identity verified: {verified_peer_id}") + logger.debug(f"Peer identity verified: {verified_peer_id}") + + except Exception as e: + raise QUICSecurityError(f"Peer identity verification failed: {e}") from e + + def create_listener(self, handler_function: TQUICConnHandlerFn) -> QUICListener: + """ + Create a QUIC listener with integrated security. + + Args: + handler_function: Function to handle new connections + + Returns: + QUIC listener instance + + Raises: + QUICListenError: If transport is closed + + """ + if self._closed: + raise QUICListenError("Transport is closed") + + # Get server configurations for the listener + server_configs = { + version: config + for version, config in self._quic_configs.items() + if version.endswith("_server") + } + + listener = QUICListener( + transport=self, + handler_function=handler_function, + quic_configs=server_configs, + config=self._config, + security_manager=self._security_manager, + ) + + self._listeners.append(listener) + logger.debug("Created QUIC listener with security") + return listener + + def can_dial(self, maddr: multiaddr.Multiaddr) -> bool: + """ + Check if this transport can dial the given multiaddr. + + Args: + maddr: Multiaddr to check + + Returns: + True if this transport can dial the address + + """ + return is_quic_multiaddr(maddr) + + def protocols(self) -> list[TProtocol]: + """ + Get supported protocol identifiers. + + Returns: + List of supported protocol strings + + """ + protocols = [QUIC_V1_PROTOCOL] + if self._config.enable_draft29: + protocols.append(QUIC_DRAFT29_PROTOCOL) + return protocols + + def listen_order(self) -> int: + """ + Get the listen order priority for this transport. + Matches go-libp2p's ListenOrder = 1 for QUIC. + + Returns: + Priority order for listening (lower = higher priority) + + """ + return 1 + + async def close(self) -> None: + """Close the transport and cleanup resources.""" + if self._closed: + return + + self._closed = True + logger.debug("Closing QUIC transport") + + # Close all active connections and listeners concurrently using trio nursery + async with trio.open_nursery() as nursery: + # Close all connections + for connection in self._connections.values(): + nursery.start_soon(connection.close) + + # Close all listeners + for listener in self._listeners: + nursery.start_soon(listener.close) + + self._connections.clear() + self._listeners.clear() + + logger.debug("QUIC transport closed") + + async def _cleanup_terminated_connection(self, connection: QUICConnection) -> None: + """Clean up a terminated connection from all listeners.""" + try: + for listener in self._listeners: + await listener._remove_connection_by_object(connection) + logger.debug( + "āœ… TRANSPORT: Cleaned up terminated connection from all listeners" + ) + except Exception as e: + logger.error(f"āŒ TRANSPORT: Error cleaning up terminated connection: {e}") + + def get_stats(self) -> dict[str, int | list[str] | object]: + """Get transport statistics including security info.""" + return { + "active_connections": len(self._connections), + "active_listeners": len(self._listeners), + "supported_protocols": self.protocols(), + "local_peer_id": str(self._peer_id), + "security_enabled": True, + "tls_configured": True, + } + + def get_security_manager(self) -> QUICTLSConfigManager: + """ + Get the security manager for this transport. + + Returns: + The QUIC TLS configuration manager + + """ + return self._security_manager + + def get_listener_socket(self) -> trio.socket.SocketType | None: + """Get the socket from the first active listener.""" + for listener in self._listeners: + if listener.is_listening() and listener._socket: + return listener._socket + return None diff --git a/libp2p/transport/quic/utils.py b/libp2p/transport/quic/utils.py new file mode 100644 index 00000000..37b7880b --- /dev/null +++ b/libp2p/transport/quic/utils.py @@ -0,0 +1,466 @@ +""" +Multiaddr utilities for QUIC transport - Module 4. +Essential utilities required for QUIC transport implementation. +Based on go-libp2p and js-libp2p QUIC implementations. +""" + +import ipaddress +import logging +import ssl + +from aioquic.quic.configuration import QuicConfiguration +import multiaddr + +from libp2p.custom_types import TProtocol +from libp2p.transport.quic.security import QUICTLSConfigManager + +from .config import QUICTransportConfig +from .exceptions import QUICInvalidMultiaddrError, QUICUnsupportedVersionError + +logger = logging.getLogger(__name__) + +# Protocol constants +QUIC_V1_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_V1 +QUIC_DRAFT29_PROTOCOL = QUICTransportConfig.PROTOCOL_QUIC_DRAFT29 +UDP_PROTOCOL = "udp" +IP4_PROTOCOL = "ip4" +IP6_PROTOCOL = "ip6" + +SERVER_CONFIG_PROTOCOL_V1 = f"{QUIC_V1_PROTOCOL}_server" +CLIENT_CONFIG_PROTCOL_V1 = f"{QUIC_V1_PROTOCOL}_client" + +SERVER_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_server" +CLIENT_CONFIG_PROTOCOL_DRAFT_29 = f"{QUIC_DRAFT29_PROTOCOL}_client" + +CUSTOM_QUIC_VERSION_MAPPING: dict[str, int] = { + SERVER_CONFIG_PROTOCOL_V1: 0x00000001, # RFC 9000 + CLIENT_CONFIG_PROTCOL_V1: 0x00000001, # RFC 9000 + SERVER_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29 + CLIENT_CONFIG_PROTOCOL_DRAFT_29: 0xFF00001D, # draft-29 +} + +# QUIC version to wire format mappings (required for aioquic) +QUIC_VERSION_MAPPINGS: dict[TProtocol, int] = { + QUIC_V1_PROTOCOL: 0x00000001, # RFC 9000 + QUIC_DRAFT29_PROTOCOL: 0xFF00001D, # draft-29 +} + +# ALPN protocols for libp2p over QUIC +LIBP2P_ALPN_PROTOCOLS: list[str] = ["libp2p"] + + +def is_quic_multiaddr(maddr: multiaddr.Multiaddr) -> bool: + """ + Check if a multiaddr represents a QUIC address. + + Valid QUIC multiaddrs: + - /ip4/127.0.0.1/udp/4001/quic-v1 + - /ip4/127.0.0.1/udp/4001/quic + - /ip6/::1/udp/4001/quic-v1 + - /ip6/::1/udp/4001/quic + + Args: + maddr: Multiaddr to check + + Returns: + True if the multiaddr represents a QUIC address + + """ + try: + addr_str = str(maddr) + + # Check for required components + has_ip = f"/{IP4_PROTOCOL}/" in addr_str or f"/{IP6_PROTOCOL}/" in addr_str + has_udp = f"/{UDP_PROTOCOL}/" in addr_str + has_quic = ( + f"/{QUIC_V1_PROTOCOL}" in addr_str + or f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str + or "/quic" in addr_str + ) + + return has_ip and has_udp and has_quic + + except Exception: + return False + + +def quic_multiaddr_to_endpoint(maddr: multiaddr.Multiaddr) -> tuple[str, int]: + """ + Extract host and port from a QUIC multiaddr. + + Args: + maddr: QUIC multiaddr + + Returns: + Tuple of (host, port) + + Raises: + QUICInvalidMultiaddrError: If multiaddr is not a valid QUIC address + + """ + if not is_quic_multiaddr(maddr): + raise QUICInvalidMultiaddrError(f"Not a valid QUIC multiaddr: {maddr}") + + try: + host = None + port = None + + # Try to get IPv4 address + try: + host = maddr.value_for_protocol(multiaddr.protocols.P_IP4) # type: ignore + except Exception: + pass + + # Try to get IPv6 address if IPv4 not found + if host is None: + try: + host = maddr.value_for_protocol(multiaddr.protocols.P_IP6) # type: ignore + except Exception: + pass + + # Get UDP port + try: + port_str = maddr.value_for_protocol(multiaddr.protocols.P_UDP) # type: ignore + port = int(port_str) + except Exception: + pass + + if host is None or port is None: + raise QUICInvalidMultiaddrError(f"Could not extract host/port from {maddr}") + + return host, port + + except Exception as e: + raise QUICInvalidMultiaddrError( + f"Failed to parse QUIC multiaddr {maddr}: {e}" + ) from e + + +def multiaddr_to_quic_version(maddr: multiaddr.Multiaddr) -> TProtocol: + """ + Determine QUIC version from multiaddr. + + Args: + maddr: QUIC multiaddr + + Returns: + QUIC version identifier ("quic-v1" or "quic") + + Raises: + QUICInvalidMultiaddrError: If multiaddr doesn't contain QUIC protocol + + """ + try: + addr_str = str(maddr) + + if f"/{QUIC_V1_PROTOCOL}" in addr_str: + return QUIC_V1_PROTOCOL # RFC 9000 + elif f"/{QUIC_DRAFT29_PROTOCOL}" in addr_str: + return QUIC_DRAFT29_PROTOCOL # draft-29 + else: + raise QUICInvalidMultiaddrError(f"No QUIC protocol found in {maddr}") + + except Exception as e: + raise QUICInvalidMultiaddrError( + f"Failed to determine QUIC version from {maddr}: {e}" + ) from e + + +def create_quic_multiaddr( + host: str, port: int, version: str = "quic-v1" +) -> multiaddr.Multiaddr: + """ + Create a QUIC multiaddr from host, port, and version. + + Args: + host: IP address (IPv4 or IPv6) + port: UDP port number + version: QUIC version ("quic-v1" or "quic") + + Returns: + QUIC multiaddr + + Raises: + QUICInvalidMultiaddrError: If invalid parameters provided + + """ + try: + # Determine IP version + try: + ip = ipaddress.ip_address(host) + if isinstance(ip, ipaddress.IPv4Address): + ip_proto = IP4_PROTOCOL + else: + ip_proto = IP6_PROTOCOL + except ValueError: + raise QUICInvalidMultiaddrError(f"Invalid IP address: {host}") + + # Validate port + if not (0 <= port <= 65535): + raise QUICInvalidMultiaddrError(f"Invalid port: {port}") + + # Validate and normalize QUIC version + if version == "quic-v1" or version == "/quic-v1": + quic_proto = QUIC_V1_PROTOCOL + elif version == "quic" or version == "/quic": + quic_proto = QUIC_DRAFT29_PROTOCOL + else: + raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}") + + # Construct multiaddr + addr_str = f"/{ip_proto}/{host}/{UDP_PROTOCOL}/{port}/{quic_proto}" + return multiaddr.Multiaddr(addr_str) + + except Exception as e: + raise QUICInvalidMultiaddrError(f"Failed to create QUIC multiaddr: {e}") from e + + +def quic_version_to_wire_format(version: TProtocol) -> int: + """ + Convert QUIC version string to wire format integer for aioquic. + + Args: + version: QUIC version string ("quic-v1" or "quic") + + Returns: + Wire format version number + + Raises: + QUICUnsupportedVersionError: If version is not supported + + """ + wire_version = QUIC_VERSION_MAPPINGS.get(version) + if wire_version is None: + raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}") + + return wire_version + + +def custom_quic_version_to_wire_format(version: TProtocol) -> int: + """ + Convert QUIC version string to wire format integer for aioquic. + + Args: + version: QUIC version string ("quic-v1" or "quic") + + Returns: + Wire format version number + + Raises: + QUICUnsupportedVersionError: If version is not supported + + """ + wire_version = CUSTOM_QUIC_VERSION_MAPPING.get(version) + if wire_version is None: + raise QUICUnsupportedVersionError(f"Unsupported QUIC version: {version}") + + return wire_version + + +def get_alpn_protocols() -> list[str]: + """ + Get ALPN protocols for libp2p over QUIC. + + Returns: + List of ALPN protocol identifiers + + """ + return LIBP2P_ALPN_PROTOCOLS.copy() + + +def normalize_quic_multiaddr(maddr: multiaddr.Multiaddr) -> multiaddr.Multiaddr: + """ + Normalize a QUIC multiaddr to canonical form. + + Args: + maddr: Input QUIC multiaddr + + Returns: + Normalized multiaddr + + Raises: + QUICInvalidMultiaddrError: If not a valid QUIC multiaddr + + """ + if not is_quic_multiaddr(maddr): + raise QUICInvalidMultiaddrError(f"Not a QUIC multiaddr: {maddr}") + + host, port = quic_multiaddr_to_endpoint(maddr) + version = multiaddr_to_quic_version(maddr) + + return create_quic_multiaddr(host, port, version) + + +def create_server_config_from_base( + base_config: QuicConfiguration, + security_manager: QUICTLSConfigManager | None = None, + transport_config: QUICTransportConfig | None = None, +) -> QuicConfiguration: + """ + Create a server configuration without using deepcopy. + Manually copies attributes while handling cryptography objects properly. + """ + try: + # Create new server configuration from scratch + server_config = QuicConfiguration(is_client=False) + server_config.verify_mode = ssl.CERT_NONE + + # Copy basic configuration attributes (these are safe to copy) + copyable_attrs = [ + "alpn_protocols", + "verify_mode", + "max_datagram_frame_size", + "idle_timeout", + "max_concurrent_streams", + "supported_versions", + "max_data", + "max_stream_data", + "stateless_retry", + "quantum_readiness_test", + ] + + for attr in copyable_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(server_config, attr, value) + + # Handle cryptography objects - these need direct reference, not copying + crypto_attrs = [ + "certificate", + "private_key", + "certificate_chain", + "ca_certs", + ] + + for attr in crypto_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(server_config, attr, value) + + # Apply security manager configuration if available + if security_manager: + try: + server_tls_config = security_manager.create_server_config() + + # Override with security manager's TLS configuration + if server_tls_config.certificate: + server_config.certificate = server_tls_config.certificate + if server_tls_config.private_key: + server_config.private_key = server_tls_config.private_key + if server_tls_config.certificate_chain: + server_config.certificate_chain = ( + server_tls_config.certificate_chain + ) + if server_tls_config.alpn_protocols: + server_config.alpn_protocols = server_tls_config.alpn_protocols + server_tls_config.request_client_certificate = True + if getattr(server_tls_config, "request_client_certificate", False): + server_config._libp2p_request_client_cert = True # type: ignore + else: + logger.error( + "šŸ”§ Failed to set request_client_certificate in server config" + ) + + except Exception as e: + logger.warning(f"Failed to apply security manager config: {e}") + + # Set transport-specific defaults if provided + if transport_config: + if server_config.idle_timeout == 0: + server_config.idle_timeout = getattr( + transport_config, "idle_timeout", 30.0 + ) + if server_config.max_datagram_frame_size is None: + server_config.max_datagram_frame_size = getattr( + transport_config, "max_datagram_size", 1200 + ) + # Ensure we have ALPN protocols + if not server_config.alpn_protocols: + server_config.alpn_protocols = ["libp2p"] + + logger.debug("Successfully created server config without deepcopy") + return server_config + + except Exception as e: + logger.error(f"Failed to create server config: {e}") + raise + + +def create_client_config_from_base( + base_config: QuicConfiguration, + security_manager: QUICTLSConfigManager | None = None, + transport_config: QUICTransportConfig | None = None, +) -> QuicConfiguration: + """ + Create a client configuration without using deepcopy. + """ + try: + # Create new client configuration from scratch + client_config = QuicConfiguration(is_client=True) + client_config.verify_mode = ssl.CERT_NONE + + # Copy basic configuration attributes + copyable_attrs = [ + "alpn_protocols", + "verify_mode", + "max_datagram_frame_size", + "idle_timeout", + "max_concurrent_streams", + "supported_versions", + "max_data", + "max_stream_data", + "quantum_readiness_test", + ] + + for attr in copyable_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(client_config, attr, value) + + # Handle cryptography objects - these need direct reference, not copying + crypto_attrs = [ + "certificate", + "private_key", + "certificate_chain", + "ca_certs", + ] + + for attr in crypto_attrs: + if hasattr(base_config, attr): + value = getattr(base_config, attr) + if value is not None: + setattr(client_config, attr, value) + + # Apply security manager configuration if available + if security_manager: + try: + client_tls_config = security_manager.create_client_config() + + # Override with security manager's TLS configuration + if client_tls_config.certificate: + client_config.certificate = client_tls_config.certificate + if client_tls_config.private_key: + client_config.private_key = client_tls_config.private_key + if client_tls_config.certificate_chain: + client_config.certificate_chain = ( + client_tls_config.certificate_chain + ) + if client_tls_config.alpn_protocols: + client_config.alpn_protocols = client_tls_config.alpn_protocols + + except Exception as e: + logger.warning(f"Failed to apply security manager config: {e}") + + # Ensure we have ALPN protocols + if not client_config.alpn_protocols: + client_config.alpn_protocols = ["libp2p"] + + logger.debug("Successfully created client config without deepcopy") + return client_config + + except Exception as e: + logger.error(f"Failed to create client config: {e}") + raise diff --git a/libp2p/transport/upgrader.py b/libp2p/transport/upgrader.py index 8b47fff4..40ba5321 100644 --- a/libp2p/transport/upgrader.py +++ b/libp2p/transport/upgrader.py @@ -1,9 +1,7 @@ from libp2p.abc import ( - IListener, IMuxedConn, IRawConnection, ISecureConn, - ITransport, ) from libp2p.custom_types import ( TMuxerOptions, @@ -43,10 +41,6 @@ class TransportUpgrader: self.security_multistream = SecurityMultistream(secure_transports_by_protocol) self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol) - def upgrade_listener(self, transport: ITransport, listeners: IListener) -> None: - """Upgrade multiaddr listeners to libp2p-transport listeners.""" - # TODO: Figure out what to do with this function. - async def upgrade_security( self, raw_conn: IRawConnection, diff --git a/libp2p/utils/__init__.py b/libp2p/utils/__init__.py index 0f78bfcb..b881eb92 100644 --- a/libp2p/utils/__init__.py +++ b/libp2p/utils/__init__.py @@ -15,6 +15,13 @@ from libp2p.utils.version import ( get_agent_version, ) +from libp2p.utils.address_validation import ( + get_available_interfaces, + get_optimal_binding_address, + expand_wildcard_address, + find_free_port, +) + __all__ = [ "decode_uvarint_from_stream", "encode_delim", @@ -26,4 +33,8 @@ __all__ = [ "decode_varint_from_bytes", "decode_varint_with_size", "read_length_prefixed_protobuf", + "get_available_interfaces", + "get_optimal_binding_address", + "expand_wildcard_address", + "find_free_port", ] diff --git a/libp2p/utils/address_validation.py b/libp2p/utils/address_validation.py new file mode 100644 index 00000000..77b797a1 --- /dev/null +++ b/libp2p/utils/address_validation.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import socket + +from multiaddr import Multiaddr + +try: + from multiaddr.utils import ( # type: ignore + get_network_addrs, + get_thin_waist_addresses, + ) + + _HAS_THIN_WAIST = True +except ImportError: # pragma: no cover - only executed in older environments + _HAS_THIN_WAIST = False + get_thin_waist_addresses = None # type: ignore + get_network_addrs = None # type: ignore + + +def _safe_get_network_addrs(ip_version: int) -> list[str]: + """ + Internal safe wrapper. Returns a list of IP addresses for the requested IP version. + Falls back to minimal defaults when Thin Waist helpers are missing. + + :param ip_version: 4 or 6 + """ + if _HAS_THIN_WAIST and get_network_addrs: + try: + return get_network_addrs(ip_version) or [] + except Exception: # pragma: no cover - defensive + return [] + # Fallback behavior (very conservative) + if ip_version == 4: + return ["127.0.0.1"] + if ip_version == 6: + return ["::1"] + return [] + + +def find_free_port() -> int: + """Find a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to a free port provided by the OS + return s.getsockname()[1] + + +def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]: + """ + Internal safe expansion wrapper. Returns a list of Multiaddr objects. + If Thin Waist isn't available, returns [addr] (identity). + """ + if _HAS_THIN_WAIST and get_thin_waist_addresses: + try: + if port is not None: + return get_thin_waist_addresses(addr, port=port) or [] + return get_thin_waist_addresses(addr) or [] + except Exception: # pragma: no cover - defensive + return [addr] + return [addr] + + +def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]: + """ + Discover available network interfaces (IPv4 + IPv6 if supported) for binding. + + :param port: Port number to bind to. + :param protocol: Transport protocol (e.g., "tcp" or "udp"). + :return: List of Multiaddr objects representing candidate interface addresses. + """ + addrs: list[Multiaddr] = [] + + # IPv4 enumeration + seen_v4: set[str] = set() + + for ip in _safe_get_network_addrs(4): + seen_v4.add(ip) + addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}")) + + # Ensure IPv4 loopback is always included when IPv4 interfaces are discovered + if seen_v4 and "127.0.0.1" not in seen_v4: + addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")) + + # TODO: IPv6 support temporarily disabled due to libp2p handshake issues + # IPv6 connections fail during protocol negotiation (SecurityUpgradeFailure) + # Re-enable IPv6 support once the following issues are resolved: + # - libp2p security handshake over IPv6 + # - multiselect protocol over IPv6 + # - connection establishment over IPv6 + # + # seen_v6: set[str] = set() + # for ip in _safe_get_network_addrs(6): + # seen_v6.add(ip) + # addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}")) + # + # # Always include IPv6 loopback for testing purposes when IPv6 is available + # # This ensures IPv6 functionality can be tested even without global IPv6 addresses + # if "::1" not in seen_v6: + # addrs.append(Multiaddr(f"/ip6/::1/{protocol}/{port}")) + + # Fallback if nothing discovered + if not addrs: + addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")) + + return addrs + + +def expand_wildcard_address( + addr: Multiaddr, port: int | None = None +) -> list[Multiaddr]: + """ + Expand a wildcard (e.g. /ip4/0.0.0.0/tcp/0) into all concrete interfaces. + + :param addr: Multiaddr to expand. + :param port: Optional override for port selection. + :return: List of concrete Multiaddr instances. + """ + expanded = _safe_expand(addr, port=port) + if not expanded: # Safety fallback + return [addr] + return expanded + + +def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr: + """ + Choose an optimal address for an example to bind to: + - Prefer non-loopback IPv4 + - Then non-loopback IPv6 + - Fallback to loopback + - Fallback to wildcard + + :param port: Port number. + :param protocol: Transport protocol. + :return: A single Multiaddr chosen heuristically. + """ + candidates = get_available_interfaces(port, protocol) + + def is_non_loopback(ma: Multiaddr) -> bool: + s = str(ma) + return not ("/ip4/127." in s or "/ip6/::1" in s) + + for c in candidates: + if "/ip4/" in str(c) and is_non_loopback(c): + return c + for c in candidates: + if "/ip6/" in str(c) and is_non_loopback(c): + return c + for c in candidates: + if "/ip4/127." in str(c) or "/ip6/::1" in str(c): + return c + + # As a final fallback, produce a wildcard + return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}") + + +__all__ = [ + "get_available_interfaces", + "get_optimal_binding_address", + "expand_wildcard_address", + "find_free_port", +] diff --git a/libp2p/utils/logging.py b/libp2p/utils/logging.py index 3458a41e..b23136f5 100644 --- a/libp2p/utils/logging.py +++ b/libp2p/utils/logging.py @@ -1,7 +1,4 @@ import atexit -from datetime import ( - datetime, -) import logging import logging.handlers import os @@ -21,6 +18,9 @@ log_queue: "queue.Queue[Any]" = queue.Queue() # Store the current listener to stop it on exit _current_listener: logging.handlers.QueueListener | None = None +# Store the handlers for proper cleanup +_current_handlers: list[logging.Handler] = [] + # Event to track when the listener is ready _listener_ready = threading.Event() @@ -95,7 +95,7 @@ def setup_logging() -> None: - Child loggers inherit their parent's level unless explicitly set - The root libp2p logger controls the default level """ - global _current_listener, _listener_ready + global _current_listener, _listener_ready, _current_handlers # Reset the event _listener_ready.clear() @@ -105,6 +105,12 @@ def setup_logging() -> None: _current_listener.stop() _current_listener = None + # Close and clear existing handlers + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + _current_handlers.clear() + # Get the log level from environment variable debug_str = os.environ.get("LIBP2P_DEBUG", "") @@ -148,13 +154,10 @@ def setup_logging() -> None: log_path = Path(log_file) log_path.parent.mkdir(parents=True, exist_ok=True) else: - # Default log file with timestamp and unique identifier - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") - unique_id = os.urandom(4).hex() # Add a unique identifier to prevent collisions - if os.name == "nt": # Windows - log_file = f"C:\\Windows\\Temp\\py-libp2p_{timestamp}_{unique_id}.log" - else: # Unix-like - log_file = f"/tmp/py-libp2p_{timestamp}_{unique_id}.log" + # Use cross-platform temp file creation + from libp2p.utils.paths import create_temp_file + + log_file = str(create_temp_file(prefix="py-libp2p_", suffix=".log")) # Print the log file path so users know where to find it print(f"Logging to: {log_file}", file=sys.stderr) @@ -195,6 +198,9 @@ def setup_logging() -> None: logger.setLevel(level) logger.propagate = False # Prevent message duplication + # Store handlers globally for cleanup + _current_handlers.extend(handlers) + # Start the listener AFTER configuring all loggers _current_listener = logging.handlers.QueueListener( log_queue, *handlers, respect_handler_level=True @@ -209,7 +215,13 @@ def setup_logging() -> None: @atexit.register def cleanup_logging() -> None: """Clean up logging resources on exit.""" - global _current_listener + global _current_listener, _current_handlers if _current_listener is not None: _current_listener.stop() _current_listener = None + + # Close all file handlers to ensure proper cleanup on Windows + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + _current_handlers.clear() diff --git a/libp2p/utils/paths.py b/libp2p/utils/paths.py new file mode 100644 index 00000000..23f10dc6 --- /dev/null +++ b/libp2p/utils/paths.py @@ -0,0 +1,267 @@ +""" +Cross-platform path utilities for py-libp2p. + +This module provides standardized path operations to ensure consistent +behavior across Windows, macOS, and Linux platforms. +""" + +import os +from pathlib import Path +import sys +import tempfile +from typing import Union + +PathLike = Union[str, Path] + + +def get_temp_dir() -> Path: + """ + Get cross-platform temporary directory. + + Returns: + Path: Platform-specific temporary directory path + + """ + return Path(tempfile.gettempdir()) + + +def get_project_root() -> Path: + """ + Get the project root directory. + + Returns: + Path: Path to the py-libp2p project root + + """ + # Navigate from libp2p/utils/paths.py to project root + return Path(__file__).parent.parent.parent + + +def join_paths(*parts: PathLike) -> Path: + """ + Cross-platform path joining. + + Args: + *parts: Path components to join + + Returns: + Path: Joined path using platform-appropriate separator + + """ + return Path(*parts) + + +def ensure_dir_exists(path: PathLike) -> Path: + """ + Ensure directory exists, create if needed. + + Args: + path: Directory path to ensure exists + + Returns: + Path: Path object for the directory + + """ + path_obj = Path(path) + path_obj.mkdir(parents=True, exist_ok=True) + return path_obj + + +def get_config_dir() -> Path: + """ + Get user config directory (cross-platform). + + Returns: + Path: Platform-specific config directory + + """ + if os.name == "nt": # Windows + appdata = os.environ.get("APPDATA", "") + if appdata: + return Path(appdata) / "py-libp2p" + else: + # Fallback to user home directory + return Path.home() / "AppData" / "Roaming" / "py-libp2p" + else: # Unix-like (Linux, macOS) + return Path.home() / ".config" / "py-libp2p" + + +def get_script_dir(script_path: PathLike | None = None) -> Path: + """ + Get the directory containing a script file. + + Args: + script_path: Path to the script file. If None, uses __file__ + + Returns: + Path: Directory containing the script + + Raises: + RuntimeError: If script path cannot be determined + + """ + if script_path is None: + # This will be the directory of the calling script + import inspect + + frame = inspect.currentframe() + if frame and frame.f_back: + script_path = frame.f_back.f_globals.get("__file__") + else: + raise RuntimeError("Could not determine script path") + + if script_path is None: + raise RuntimeError("Script path is None") + + return Path(script_path).parent.absolute() + + +def create_temp_file(prefix: str = "py-libp2p_", suffix: str = ".log") -> Path: + """ + Create a temporary file with a unique name. + + Args: + prefix: File name prefix + suffix: File name suffix + + Returns: + Path: Path to the created temporary file + + """ + temp_dir = get_temp_dir() + # Create a unique filename using timestamp and random bytes + import secrets + import time + + timestamp = time.strftime("%Y%m%d_%H%M%S") + microseconds = f"{time.time() % 1:.6f}"[2:] # Get microseconds as string + unique_id = secrets.token_hex(4) + filename = f"{prefix}{timestamp}_{microseconds}_{unique_id}{suffix}" + + temp_file = temp_dir / filename + # Create the file by touching it + temp_file.touch() + return temp_file + + +def resolve_relative_path(base_path: PathLike, relative_path: PathLike) -> Path: + """ + Resolve a relative path from a base path. + + Args: + base_path: Base directory path + relative_path: Relative path to resolve + + Returns: + Path: Resolved absolute path + + """ + base = Path(base_path).resolve() + relative = Path(relative_path) + + if relative.is_absolute(): + return relative + else: + return (base / relative).resolve() + + +def normalize_path(path: PathLike) -> Path: + """ + Normalize a path, resolving any symbolic links and relative components. + + Args: + path: Path to normalize + + Returns: + Path: Normalized absolute path + + """ + return Path(path).resolve() + + +def get_venv_path() -> Path | None: + """ + Get virtual environment path if active. + + Returns: + Path: Virtual environment path if active, None otherwise + + """ + venv_path = os.environ.get("VIRTUAL_ENV") + if venv_path: + return Path(venv_path) + return None + + +def get_python_executable() -> Path: + """ + Get current Python executable path. + + Returns: + Path: Path to the current Python executable + + """ + return Path(sys.executable) + + +def find_executable(name: str) -> Path | None: + """ + Find executable in system PATH. + + Args: + name: Name of the executable to find + + Returns: + Path: Path to executable if found, None otherwise + + """ + # Check if name already contains path + if os.path.dirname(name): + path = Path(name) + if path.exists() and os.access(path, os.X_OK): + return path + return None + + # Search in PATH + for path_dir in os.environ.get("PATH", "").split(os.pathsep): + if not path_dir: + continue + path = Path(path_dir) / name + if path.exists() and os.access(path, os.X_OK): + return path + + return None + + +def get_script_binary_path() -> Path: + """ + Get path to script's binary directory. + + Returns: + Path: Directory containing the script's binary + + """ + return get_python_executable().parent + + +def get_binary_path(binary_name: str) -> Path | None: + """ + Find binary in PATH or virtual environment. + + Args: + binary_name: Name of the binary to find + + Returns: + Path: Path to binary if found, None otherwise + + """ + # First check in virtual environment if active + venv_path = get_venv_path() + if venv_path: + venv_bin = venv_path / "bin" if os.name != "nt" else venv_path / "Scripts" + binary_path = venv_bin / binary_name + if binary_path.exists() and os.access(binary_path, os.X_OK): + return binary_path + + # Fall back to system PATH + return find_executable(binary_name) diff --git a/newsfragments/763.feature.rst b/newsfragments/763.feature.rst new file mode 100644 index 00000000..838b0cae --- /dev/null +++ b/newsfragments/763.feature.rst @@ -0,0 +1 @@ +Add QUIC transport support for faster, more efficient peer-to-peer connections with native stream multiplexing. diff --git a/newsfragments/811.feature.rst b/newsfragments/811.feature.rst new file mode 100644 index 00000000..47a0aa68 --- /dev/null +++ b/newsfragments/811.feature.rst @@ -0,0 +1 @@ + Added Thin Waist address validation utilities (with support for interface enumeration, optimal binding, and wildcard expansion). diff --git a/newsfragments/811.internal.rst b/newsfragments/811.internal.rst new file mode 100644 index 00000000..59804430 --- /dev/null +++ b/newsfragments/811.internal.rst @@ -0,0 +1,7 @@ +Add Thin Waist address validation utilities and integrate into echo example + +- Add ``libp2p/utils/address_validation.py`` with dynamic interface discovery +- Implement ``get_available_interfaces()``, ``get_optimal_binding_address()``, and ``expand_wildcard_address()`` +- Update echo example to use dynamic address discovery instead of hardcoded wildcard +- Add safe fallbacks for environments lacking Thin Waist support +- Temporarily disable IPv6 support due to libp2p handshake issues (TODO: re-enable when resolved) diff --git a/newsfragments/815.feature.rst b/newsfragments/815.feature.rst new file mode 100644 index 00000000..8fcf6fea --- /dev/null +++ b/newsfragments/815.feature.rst @@ -0,0 +1 @@ +KAD-DHT now include signed-peer-records in its protobuf message schema, for more secure peer-discovery. diff --git a/newsfragments/837.bugfix.rst b/newsfragments/837.bugfix.rst new file mode 100644 index 00000000..47919c23 --- /dev/null +++ b/newsfragments/837.bugfix.rst @@ -0,0 +1 @@ +Added multiselect type consistency in negotiate method. Updates all the usages of the method. diff --git a/newsfragments/843.bugfix.rst b/newsfragments/843.bugfix.rst new file mode 100644 index 00000000..6160bbc7 --- /dev/null +++ b/newsfragments/843.bugfix.rst @@ -0,0 +1 @@ +Fixed message id type inconsistency in handle ihave and message id parsing improvement in handle iwant in pubsub module. diff --git a/newsfragments/849.feature.rst b/newsfragments/849.feature.rst new file mode 100644 index 00000000..73ad1453 --- /dev/null +++ b/newsfragments/849.feature.rst @@ -0,0 +1 @@ +Add automatic peer dialing in bootstrap module using trio.Nursery. diff --git a/newsfragments/855.internal.rst b/newsfragments/855.internal.rst new file mode 100644 index 00000000..2c425dde --- /dev/null +++ b/newsfragments/855.internal.rst @@ -0,0 +1 @@ +Improved PubsubNotifee integration tests and added failure scenario coverage. diff --git a/newsfragments/863.bugfix.rst b/newsfragments/863.bugfix.rst new file mode 100644 index 00000000..64de57b4 --- /dev/null +++ b/newsfragments/863.bugfix.rst @@ -0,0 +1,5 @@ +Fix multi-address listening bug in swarm.listen() + +- Fix early return in swarm.listen() that prevented listening on all addresses +- Add comprehensive tests for multi-address listening functionality +- Ensure all available interfaces are properly bound and connectable diff --git a/newsfragments/874.feature.rst b/newsfragments/874.feature.rst new file mode 100644 index 00000000..bef1d3bc --- /dev/null +++ b/newsfragments/874.feature.rst @@ -0,0 +1 @@ +Enhanced Swarm networking with retry logic, exponential backoff, and multi-connection support. Added configurable retry mechanisms that automatically recover from transient connection failures using exponential backoff with jitter to prevent thundering herd problems. Introduced connection pooling that allows multiple concurrent connections per peer for improved performance and fault tolerance. Added load balancing across connections and automatic connection health management. All enhancements are fully backward compatible and can be configured through new RetryConfig and ConnectionConfig classes. diff --git a/newsfragments/883.internal.rst b/newsfragments/883.internal.rst new file mode 100644 index 00000000..a9ca3a0e --- /dev/null +++ b/newsfragments/883.internal.rst @@ -0,0 +1,5 @@ +Remove unused upgrade_listener function from transport upgrader + +- Remove unused `upgrade_listener` function from `libp2p/transport/upgrader.py` (Issue 2 from #726) +- Clean up unused imports related to the removed function +- Improve code maintainability by removing dead code diff --git a/newsfragments/886.bugfix.rst b/newsfragments/886.bugfix.rst new file mode 100644 index 00000000..1ebf38d1 --- /dev/null +++ b/newsfragments/886.bugfix.rst @@ -0,0 +1,2 @@ +Fixed cross-platform path handling by replacing hardcoded OS-specific +paths with standardized utilities in core modules and examples. diff --git a/newsfragments/889.feature.rst b/newsfragments/889.feature.rst new file mode 100644 index 00000000..7e42f18e --- /dev/null +++ b/newsfragments/889.feature.rst @@ -0,0 +1 @@ +PubSub routers now include signed-peer-records in RPC messages for secure peer-info exchange. diff --git a/pyproject.toml b/pyproject.toml index 7f6ff7bb..6bec3e76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,15 @@ maintainers = [ { name = "Dave Grantham", email = "dwg@linuxprogrammer.org" }, ] dependencies = [ + "aioquic>=1.2.0", "base58>=1.0.3", - "coincurve>=10.0.0", + "coincurve==21.0.0", "exceptiongroup>=1.2.0; python_version < '3.11'", + "fastecdsa==2.3.2; sys_platform != 'win32'", "grpcio>=1.41.0", "lru-dict>=1.1.6", - # "multiaddr>=0.0.9", - "multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@db8124e2321f316d3b7d2733c7df11d6ad9c03e6", + # "multiaddr (>=0.0.9,<0.0.10)", + "multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@3ea7f866fda9268ee92506edf9d8e975274bf941", "mypy-protobuf>=3.0.0", "noiseprotocol>=0.3.0", "protobuf>=4.25.0,<5.0.0", @@ -53,6 +55,7 @@ Homepage = "https://github.com/libp2p/py-libp2p" [project.scripts] chat-demo = "examples.chat.chat:main" echo-demo = "examples.echo.echo:main" +echo-quic-demo="examples.echo.echo_quic:main" ping-demo = "examples.ping.ping:main" identify-demo = "examples.identify.identify:main" identify-push-demo = "examples.identify_push.identify_push_demo:run_main" @@ -78,6 +81,7 @@ dev = [ "pytest>=7.0.0", "pytest-xdist>=2.4.0", "pytest-trio>=0.5.2", + "pytest-timeout>=2.4.0", "factory-boy>=2.12.0,<3.0.0", "ruff>=0.11.10", "pyrefly (>=0.17.1,<0.18.0)", @@ -89,11 +93,12 @@ docs = [ "tomli; python_version < '3.11'", ] test = [ + "factory-boy>=2.12.0,<3.0.0", "p2pclient==0.2.0", "pytest>=7.0.0", - "pytest-xdist>=2.4.0", + "pytest-timeout>=2.4.0", "pytest-trio>=0.5.2", - "factory-boy>=2.12.0,<3.0.0", + "pytest-xdist>=2.4.0", ] [tool.setuptools] @@ -283,4 +288,5 @@ project_excludes = [ "**/*pb2.py", "**/*.pyi", ".venv/**", + "./tests/interop/nim_libp2p", ] diff --git a/scripts/audit_paths.py b/scripts/audit_paths.py new file mode 100644 index 00000000..80df11f8 --- /dev/null +++ b/scripts/audit_paths.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +""" +Audit script to identify path handling issues in the py-libp2p codebase. + +This script scans for patterns that should be migrated to use the new +cross-platform path utilities. +""" + +import argparse +from pathlib import Path +import re +from typing import Any + + +def scan_for_path_issues(directory: Path) -> dict[str, list[dict[str, Any]]]: + """ + Scan for path handling issues in the codebase. + + Args: + directory: Root directory to scan + + Returns: + Dictionary mapping issue types to lists of found issues + + """ + issues = { + "hard_coded_slash": [], + "os_path_join": [], + "temp_hardcode": [], + "os_path_dirname": [], + "os_path_abspath": [], + "direct_path_concat": [], + } + + # Patterns to search for + patterns = { + "hard_coded_slash": r'["\'][^"\']*\/[^"\']*["\']', + "os_path_join": r"os\.path\.join\(", + "temp_hardcode": r'["\']\/tmp\/|["\']C:\\\\', + "os_path_dirname": r"os\.path\.dirname\(", + "os_path_abspath": r"os\.path\.abspath\(", + "direct_path_concat": r'["\'][^"\']*["\']\s*\+\s*["\'][^"\']*["\']', + } + + # Files to exclude + exclude_patterns = [ + r"__pycache__", + r"\.git", + r"\.pytest_cache", + r"\.mypy_cache", + r"\.ruff_cache", + r"env/", + r"venv/", + r"\.venv/", + ] + + for py_file in directory.rglob("*.py"): + # Skip excluded files + if any(re.search(pattern, str(py_file)) for pattern in exclude_patterns): + continue + + try: + content = py_file.read_text(encoding="utf-8") + except UnicodeDecodeError: + print(f"Warning: Could not read {py_file} (encoding issue)") + continue + + for issue_type, pattern in patterns.items(): + matches = re.finditer(pattern, content, re.MULTILINE) + for match in matches: + line_num = content[: match.start()].count("\n") + 1 + line_content = content.split("\n")[line_num - 1].strip() + + issues[issue_type].append( + { + "file": py_file, + "line": line_num, + "content": match.group(), + "full_line": line_content, + "relative_path": py_file.relative_to(directory), + } + ) + + return issues + + +def generate_migration_suggestions(issues: dict[str, list[dict[str, Any]]]) -> str: + """ + Generate migration suggestions for found issues. + + Args: + issues: Dictionary of found issues + + Returns: + Formatted string with migration suggestions + + """ + suggestions = [] + + for issue_type, issue_list in issues.items(): + if not issue_list: + continue + + suggestions.append(f"\n## {issue_type.replace('_', ' ').title()}") + suggestions.append(f"Found {len(issue_list)} instances:") + + for issue in issue_list[:10]: # Show first 10 examples + suggestions.append(f"\n### {issue['relative_path']}:{issue['line']}") + suggestions.append("```python") + suggestions.append("# Current code:") + suggestions.append(f"{issue['full_line']}") + suggestions.append("```") + + # Add migration suggestion based on issue type + if issue_type == "os_path_join": + suggestions.append("```python") + suggestions.append("# Suggested fix:") + suggestions.append("from libp2p.utils.paths import join_paths") + suggestions.append( + "# Replace os.path.join(a, b, c) with join_paths(a, b, c)" + ) + suggestions.append("```") + elif issue_type == "temp_hardcode": + suggestions.append("```python") + suggestions.append("# Suggested fix:") + suggestions.append( + "from libp2p.utils.paths import get_temp_dir, create_temp_file" + ) + temp_fix_msg = ( + "# Replace hard-coded temp paths with get_temp_dir() or " + "create_temp_file()" + ) + suggestions.append(temp_fix_msg) + suggestions.append("```") + elif issue_type == "os_path_dirname": + suggestions.append("```python") + suggestions.append("# Suggested fix:") + suggestions.append("from libp2p.utils.paths import get_script_dir") + script_dir_fix_msg = ( + "# Replace os.path.dirname(os.path.abspath(__file__)) with " + "get_script_dir(__file__)" + ) + suggestions.append(script_dir_fix_msg) + suggestions.append("```") + + if len(issue_list) > 10: + suggestions.append(f"\n... and {len(issue_list) - 10} more instances") + + return "\n".join(suggestions) + + +def generate_summary_report(issues: dict[str, list[dict[str, Any]]]) -> str: + """ + Generate a summary report of all found issues. + + Args: + issues: Dictionary of found issues + + Returns: + Formatted summary report + + """ + total_issues = sum(len(issue_list) for issue_list in issues.values()) + + report = [ + "# Cross-Platform Path Handling Audit Report", + "", + "## Summary", + f"Total issues found: {total_issues}", + "", + "## Issue Breakdown:", + ] + + for issue_type, issue_list in issues.items(): + if issue_list: + issue_title = issue_type.replace("_", " ").title() + instances_count = len(issue_list) + report.append(f"- **{issue_title}**: {instances_count} instances") + + report.append("") + report.append("## Priority Matrix:") + report.append("") + report.append("| Priority | Issue Type | Risk Level | Impact |") + report.append("|----------|------------|------------|---------|") + + priority_map = { + "temp_hardcode": ( + "šŸ”“ P0", + "HIGH", + "Core functionality fails on different platforms", + ), + "os_path_join": ("🟔 P1", "MEDIUM", "Examples and utilities may break"), + "os_path_dirname": ("🟔 P1", "MEDIUM", "Script location detection issues"), + "hard_coded_slash": ("🟢 P2", "LOW", "Future-proofing and consistency"), + "os_path_abspath": ("🟢 P2", "LOW", "Path resolution consistency"), + "direct_path_concat": ("🟢 P2", "LOW", "String concatenation issues"), + } + + for issue_type, issue_list in issues.items(): + if issue_list: + priority, risk, impact = priority_map.get( + issue_type, ("🟢 P2", "LOW", "General improvement") + ) + issue_title = issue_type.replace("_", " ").title() + report.append(f"| {priority} | {issue_title} | {risk} | {impact} |") + + return "\n".join(report) + + +def main(): + """Main function to run the audit.""" + parser = argparse.ArgumentParser( + description="Audit py-libp2p codebase for path handling issues" + ) + parser.add_argument( + "--directory", + default=".", + help="Directory to scan (default: current directory)", + ) + parser.add_argument("--output", help="Output file for detailed report") + parser.add_argument( + "--summary-only", action="store_true", help="Only show summary report" + ) + + args = parser.parse_args() + + directory = Path(args.directory) + if not directory.exists(): + print(f"Error: Directory {directory} does not exist") + return 1 + + print("šŸ” Scanning for path handling issues...") + issues = scan_for_path_issues(directory) + + # Generate and display summary + summary = generate_summary_report(issues) + print(summary) + + if not args.summary_only: + # Generate detailed suggestions + suggestions = generate_migration_suggestions(issues) + + if args.output: + with open(args.output, "w", encoding="utf-8") as f: + f.write(summary) + f.write(suggestions) + print(f"\nšŸ“„ Detailed report saved to {args.output}") + else: + print(suggestions) + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/core/host/test_basic_host.py b/tests/core/host/test_basic_host.py index ed21ad80..635f2863 100644 --- a/tests/core/host/test_basic_host.py +++ b/tests/core/host/test_basic_host.py @@ -1,3 +1,10 @@ +from unittest.mock import ( + AsyncMock, + MagicMock, +) + +import pytest + from libp2p import ( new_swarm, ) @@ -10,6 +17,9 @@ from libp2p.host.basic_host import ( from libp2p.host.defaults import ( get_default_protocols, ) +from libp2p.host.exceptions import ( + StreamFailure, +) def test_default_protocols(): @@ -22,3 +32,30 @@ def test_default_protocols(): # NOTE: comparing keys for equality as handlers may be closures that do not compare # in the way this test is concerned with assert handlers.keys() == get_default_protocols(host).keys() + + +@pytest.mark.trio +async def test_swarm_stream_handler_no_protocol_selected(monkeypatch): + key_pair = create_new_key_pair() + swarm = new_swarm(key_pair) + host = BasicHost(swarm) + + # Create a mock net_stream + net_stream = MagicMock() + net_stream.reset = AsyncMock() + net_stream.muxed_conn.peer_id = "peer-test" + + # Monkeypatch negotiate to simulate "no protocol selected" + async def fake_negotiate(comm, timeout): + return None, None + + monkeypatch.setattr(host.multiselect, "negotiate", fake_negotiate) + + # Now run the handler and expect StreamFailure + with pytest.raises( + StreamFailure, match="Failed to negotiate protocol: no protocol selected" + ): + await host._swarm_stream_handler(net_stream) + + # Ensure reset was called since negotiation failed + net_stream.reset.assert_awaited() diff --git a/tests/core/host/test_live_peers.py b/tests/core/host/test_live_peers.py index 1d7948ad..e5af42ba 100644 --- a/tests/core/host/test_live_peers.py +++ b/tests/core/host/test_live_peers.py @@ -164,8 +164,8 @@ async def test_live_peers_unexpected_drop(security_protocol): assert peer_a_id in host_b.get_live_peers() # Simulate unexpected connection drop by directly closing the connection - conn = host_a.get_network().connections[peer_b_id] - await conn.muxed_conn.close() + conns = host_a.get_network().connections[peer_b_id] + await conns[0].muxed_conn.close() # Allow for connection cleanup await trio.sleep(0.1) diff --git a/tests/core/kad_dht/test_kad_dht.py b/tests/core/kad_dht/test_kad_dht.py index a6f73074..5bf4f3e8 100644 --- a/tests/core/kad_dht/test_kad_dht.py +++ b/tests/core/kad_dht/test_kad_dht.py @@ -9,11 +9,15 @@ This module tests core functionality of the Kademlia DHT including: import hashlib import logging +import os +from unittest.mock import patch import uuid import pytest +import multiaddr import trio +from libp2p.crypto.rsa import create_new_key_pair from libp2p.kad_dht.kad_dht import ( DHTMode, KadDHT, @@ -21,9 +25,13 @@ from libp2p.kad_dht.kad_dht import ( from libp2p.kad_dht.utils import ( create_key_from_binary, ) +from libp2p.peer.envelope import Envelope, seal_record +from libp2p.peer.id import ID +from libp2p.peer.peer_record import PeerRecord from libp2p.peer.peerinfo import ( PeerInfo, ) +from libp2p.peer.peerstore import create_signed_peer_record from libp2p.tools.async_service import ( background_trio_service, ) @@ -76,10 +84,52 @@ async def test_find_node(dht_pair: tuple[KadDHT, KadDHT]): """Test that nodes can find each other in the DHT.""" dht_a, dht_b = dht_pair + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before the next FIND_NODE + # req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Node A should be able to find Node B with trio.fail_after(TEST_TIMEOUT): found_info = await dht_a.find_peer(dht_b.host.get_id()) + # Verifies if the senderRecord in the FIND_NODE request is correctly processed + assert isinstance( + dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()), Envelope + ) + + # Verifies if the senderRecord in the FIND_NODE response is correctly processed + assert isinstance( + dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()), Envelope + ) + + # These are the records that were sent between the peers during the FIND_NODE req + envelope_a_find_peer = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_find_peer = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_find_peer, Envelope) + assert isinstance(envelope_b_find_peer, Envelope) + + record_a_find_peer = envelope_a_find_peer.record() + record_b_find_peer = envelope_b_find_peer.record() + + # This proves that both the records are same, and a latest cached signed record + # was passed between the peers during FIND_NODE execution, which proves the + # signed-record transfer/re-issuing works correctly in FIND_NODE executions. + assert record_a.seq == record_a_find_peer.seq + assert record_b.seq == record_b_find_peer.seq + # Verify that the found peer has the correct peer ID assert found_info is not None, "Failed to find the target peer" assert found_info.peer_id == dht_b.host.get_id(), "Found incorrect peer ID" @@ -104,14 +154,44 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): await dht_a.routing_table.add_peer(peer_b_info) print("Routing table of a has ", dht_a.routing_table.get_peer_ids()) + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before PUT_VALUE req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Store the value using the first node (this will also store locally) with trio.fail_after(TEST_TIMEOUT): await dht_a.put_value(key, value) + # These are the records that were sent between the peers during the PUT_VALUE req + envelope_a_put_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_put_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_put_value, Envelope) + assert isinstance(envelope_b_put_value, Envelope) + + record_a_put_value = envelope_a_put_value.record() + record_b_put_value = envelope_b_put_value.record() + + # This proves that both the records are same, and a latest cached signed record + # was passed between the peers during PUT_VALUE execution, which proves the + # signed-record transfer/re-issuing works correctly in PUT_VALUE executions. + assert record_a.seq == record_a_put_value.seq + assert record_b.seq == record_b_put_value.seq + # # Log debugging information logger.debug("Put value with key %s...", key.hex()[:10]) logger.debug("Node A value store: %s", dht_a.value_store.store) - print("hello test") # # Allow more time for the value to propagate await trio.sleep(0.5) @@ -126,6 +206,26 @@ async def test_put_and_get_value(dht_pair: tuple[KadDHT, KadDHT]): print("the value stored in node b is", dht_b.get_value_store_size()) logger.debug("Retrieved value: %s", retrieved_value) + # These are the records that were sent between the peers during the PUT_VALUE req + envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_get_value, Envelope) + assert isinstance(envelope_b_get_value, Envelope) + + record_a_get_value = envelope_a_get_value.record() + record_b_get_value = envelope_b_get_value.record() + + # This proves that there was no record exchange between the nodes during GET_VALUE + # execution, as dht_b already had the key/value pair stored locally after the + # PUT_VALUE execution. + assert record_a_get_value.seq == record_a_put_value.seq + assert record_b_get_value.seq == record_b_put_value.seq + # Verify that the retrieved value matches the original assert retrieved_value == value, "Retrieved value does not match the stored value" @@ -142,11 +242,44 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): # Store content on the first node dht_a.value_store.put(content_id, content) + # An extra FIND_NODE req is sent between the 2 nodes while dht creation, + # so both the nodes will have records of each other before PUT_VALUE req is sent + envelope_a = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + envelope_b = dht_b.host.get_peerstore().get_peer_record(dht_a.host.get_id()) + + assert isinstance(envelope_a, Envelope) + assert isinstance(envelope_b, Envelope) + + record_a = envelope_a.record() + record_b = envelope_b.record() + # Advertise the first node as a provider with trio.fail_after(TEST_TIMEOUT): success = await dht_a.provide(content_id) assert success, "Failed to advertise as provider" + # These are the records that were sent between the peers during + # the ADD_PROVIDER req + envelope_a_add_prov = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_add_prov = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_add_prov, Envelope) + assert isinstance(envelope_b_add_prov, Envelope) + + record_a_add_prov = envelope_a_add_prov.record() + record_b_add_prov = envelope_b_add_prov.record() + + # This proves that both the records are same, the latest cached signed record + # was passed between the peers during ADD_PROVIDER execution, which proves the + # signed-record transfer/re-issuing of the latest record works correctly in + # ADD_PROVIDER executions. + assert record_a.seq == record_a_add_prov.seq + assert record_b.seq == record_b_add_prov.seq + # Allow time for the provider record to propagate await trio.sleep(0.1) @@ -154,6 +287,26 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): with trio.fail_after(TEST_TIMEOUT): providers = await dht_b.find_providers(content_id) + # These are the records in each peer after the find_provider execution + envelope_a_find_prov = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_find_prov = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_find_prov, Envelope) + assert isinstance(envelope_b_find_prov, Envelope) + + record_a_find_prov = envelope_a_find_prov.record() + record_b_find_prov = envelope_b_find_prov.record() + + # This proves that both the records are same, as the dht_b already + # has the provider record for the content_id, after the ADD_PROVIDER + # advertisement by dht_a + assert record_a_find_prov.seq == record_a_add_prov.seq + assert record_b_find_prov.seq == record_b_add_prov.seq + # Verify that we found the first node as a provider assert providers, "No providers found" assert any(p.peer_id == dht_a.local_peer_id for p in providers), ( @@ -166,3 +319,143 @@ async def test_provide_and_find_providers(dht_pair: tuple[KadDHT, KadDHT]): assert retrieved_value == content, ( "Retrieved content does not match the original" ) + + # These are the record state of each peer aftet the GET_VALUE execution + envelope_a_get_value = dht_a.host.get_peerstore().get_peer_record( + dht_b.host.get_id() + ) + envelope_b_get_value = dht_b.host.get_peerstore().get_peer_record( + dht_a.host.get_id() + ) + + assert isinstance(envelope_a_get_value, Envelope) + assert isinstance(envelope_b_get_value, Envelope) + + record_a_get_value = envelope_a_get_value.record() + record_b_get_value = envelope_b_get_value.record() + + # This proves that both the records are same, meaning that the latest cached + # signed-record tranfer happened during the GET_VALUE execution by dht_b, + # which means the signed-record transfer/re-issuing works correctly + # in GET_VALUE executions. + assert record_a_find_prov.seq == record_a_get_value.seq + assert record_b_find_prov.seq == record_b_get_value.seq + + # Create a new provider record in dht_a + provider_key_pair = create_new_key_pair() + provider_peer_id = ID.from_pubkey(provider_key_pair.public_key) + provider_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + provider_peer_info = PeerInfo(peer_id=provider_peer_id, addrs=[provider_addr]) + + # Generate a random content ID + content_2 = f"random-content-{uuid.uuid4()}".encode() + content_id_2 = hashlib.sha256(content_2).digest() + + provider_signed_envelope = create_signed_peer_record( + provider_peer_id, [provider_addr], provider_key_pair.private_key + ) + assert ( + dht_a.host.get_peerstore().consume_peer_record(provider_signed_envelope, 7200) + is True + ) + + # Store this provider record in dht_a + dht_a.provider_store.add_provider(content_id_2, provider_peer_info) + + # Fetch the provider-record via peer-discovery at dht_b's end + peerinfo = await dht_b.provider_store.find_providers(content_id_2) + + assert len(peerinfo) == 1 + assert peerinfo[0].peer_id == provider_peer_id + provider_envelope = dht_b.host.get_peerstore().get_peer_record(provider_peer_id) + + # This proves that the signed-envelope of provider is consumed on dht_b's end + assert provider_envelope is not None + assert ( + provider_signed_envelope.marshal_envelope() + == provider_envelope.marshal_envelope() + ) + + +@pytest.mark.trio +async def test_reissue_when_listen_addrs_change(dht_pair: tuple[KadDHT, KadDHT]): + dht_a, dht_b = dht_pair + + # Warm-up: A stores B's current record + with trio.fail_after(10): + await dht_a.find_peer(dht_b.host.get_id()) + + env0 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env0, Envelope) + seq0 = env0.record().seq + + # Simulate B's listen addrs changing (different port) + new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + + # Patch just for the duration we force B to respond: + with patch.object(dht_b.host, "get_addrs", return_value=[new_addr]): + # Force B to send a response (which should include a fresh SPR) + with trio.fail_after(10): + await dht_a.peer_routing._query_peer_for_closest( + dht_b.host.get_id(), os.urandom(32) + ) + + # A should now hold B's new record with a bumped seq + env1 = dht_a.host.get_peerstore().get_peer_record(dht_b.host.get_id()) + assert isinstance(env1, Envelope) + seq1 = env1.record().seq + + # This proves that upon the change in listen_addrs, we issue new records + assert seq1 > seq0, f"Expected seq to bump after addr change, got {seq0} -> {seq1}" + + +@pytest.mark.trio +async def test_dht_req_fail_with_invalid_record_transfer( + dht_pair: tuple[KadDHT, KadDHT], +): + """ + Testing showing failure of storing and retrieving values in the DHT, + if invalid signed-records are sent. + """ + dht_a, dht_b = dht_pair + peer_b_info = PeerInfo(dht_b.host.get_id(), dht_b.host.get_addrs()) + + # Generate a random key and value + key = create_key_from_binary(b"test-key") + value = b"test-value" + + # First add the value directly to node A's store to verify storage works + dht_a.value_store.put(key, value) + local_value = dht_a.value_store.get(key) + assert local_value == value, "Local value storage failed" + await dht_a.routing_table.add_peer(peer_b_info) + + # Corrupt dht_a's local peer_record + envelope = dht_a.host.get_peerstore().get_local_record() + if envelope is not None: + true_record = envelope.record() + key_pair = create_new_key_pair() + + if envelope is not None: + envelope.public_key = key_pair.public_key + dht_a.host.get_peerstore().set_local_record(envelope) + + await dht_a.put_value(key, value) + retrieved_value = dht_b.value_store.get(key) + + # This proves that DHT_B rejected DHT_A PUT_RECORD req upon receiving + # the corrupted invalid record + assert retrieved_value is None + + # Create a corrupt envelope with correct signature but false peer_id + false_record = PeerRecord(ID.from_pubkey(key_pair.public_key), true_record.addrs) + false_envelope = seal_record(false_record, dht_a.host.get_private_key()) + + dht_a.host.get_peerstore().set_local_record(false_envelope) + + await dht_a.put_value(key, value) + retrieved_value = dht_b.value_store.get(key) + + # This proves that DHT_B rejected DHT_A PUT_RECORD req upon receving + # the record with a different peer_id regardless of a valid signature + assert retrieved_value is None diff --git a/tests/core/kad_dht/test_unit_peer_routing.py b/tests/core/kad_dht/test_unit_peer_routing.py index ffe20655..6e15ce7e 100644 --- a/tests/core/kad_dht/test_unit_peer_routing.py +++ b/tests/core/kad_dht/test_unit_peer_routing.py @@ -57,7 +57,10 @@ class TestPeerRouting: def mock_host(self): """Create a mock host for testing.""" host = Mock() - host.get_id.return_value = create_valid_peer_id("local") + key_pair = create_new_key_pair() + host.get_id.return_value = ID.from_pubkey(key_pair.public_key) + host.get_public_key.return_value = key_pair.public_key + host.get_private_key.return_value = key_pair.private_key host.get_addrs.return_value = [Multiaddr("/ip4/127.0.0.1/tcp/8000")] host.get_peerstore.return_value = Mock() host.new_stream = AsyncMock() diff --git a/tests/core/network/test_enhanced_swarm.py b/tests/core/network/test_enhanced_swarm.py new file mode 100644 index 00000000..e63de126 --- /dev/null +++ b/tests/core/network/test_enhanced_swarm.py @@ -0,0 +1,325 @@ +import time +from typing import cast +from unittest.mock import Mock + +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p.abc import INetConn, INetStream +from libp2p.network.exceptions import SwarmException +from libp2p.network.swarm import ( + ConnectionConfig, + RetryConfig, + Swarm, +) +from libp2p.peer.id import ID + + +class MockConnection(INetConn): + """Mock connection for testing.""" + + def __init__(self, peer_id: ID, is_closed: bool = False): + self.peer_id = peer_id + self._is_closed = is_closed + self.streams = set() # Track streams properly + # Mock the muxed_conn attribute that Swarm expects + self.muxed_conn = Mock() + self.muxed_conn.peer_id = peer_id + # Required by INetConn interface + self.event_started = trio.Event() + + async def close(self): + self._is_closed = True + + @property + def is_closed(self) -> bool: + return self._is_closed + + async def new_stream(self) -> INetStream: + # Create a mock stream and add it to the connection's stream set + mock_stream = Mock(spec=INetStream) + self.streams.add(mock_stream) + return mock_stream + + def get_streams(self) -> tuple[INetStream, ...]: + """Return all streams associated with this connection.""" + return tuple(self.streams) + + def get_transport_addresses(self) -> list[Multiaddr]: + """Mock implementation of get_transport_addresses.""" + return [] + + +class MockNetStream(INetStream): + """Mock network stream for testing.""" + + def __init__(self, peer_id: ID): + self.peer_id = peer_id + + +@pytest.mark.trio +async def test_retry_config_defaults(): + """Test RetryConfig default values.""" + config = RetryConfig() + assert config.max_retries == 3 + assert config.initial_delay == 0.1 + assert config.max_delay == 30.0 + assert config.backoff_multiplier == 2.0 + assert config.jitter_factor == 0.1 + + +@pytest.mark.trio +async def test_connection_config_defaults(): + """Test ConnectionConfig default values.""" + config = ConnectionConfig() + assert config.max_connections_per_peer == 3 + assert config.connection_timeout == 30.0 + assert config.load_balancing_strategy == "round_robin" + + +@pytest.mark.trio +async def test_enhanced_swarm_constructor(): + """Test enhanced Swarm constructor with new configuration.""" + # Create mock dependencies + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Test with default config + swarm = Swarm(peer_id, peerstore, upgrader, transport) + assert swarm.retry_config.max_retries == 3 + assert swarm.connection_config.max_connections_per_peer == 3 + assert isinstance(swarm.connections, dict) + + # Test with custom config + custom_retry = RetryConfig(max_retries=5, initial_delay=0.5) + custom_conn = ConnectionConfig(max_connections_per_peer=5) + + swarm = Swarm(peer_id, peerstore, upgrader, transport, custom_retry, custom_conn) + assert swarm.retry_config.max_retries == 5 + assert swarm.retry_config.initial_delay == 0.5 + assert swarm.connection_config.max_connections_per_peer == 5 + + +@pytest.mark.trio +async def test_swarm_backoff_calculation(): + """Test exponential backoff calculation with jitter.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + retry_config = RetryConfig( + initial_delay=0.1, max_delay=1.0, backoff_multiplier=2.0, jitter_factor=0.1 + ) + + swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) + + # Test backoff calculation + delay1 = swarm._calculate_backoff_delay(0) + delay2 = swarm._calculate_backoff_delay(1) + delay3 = swarm._calculate_backoff_delay(2) + + # Should increase exponentially + assert delay2 > delay1 + assert delay3 > delay2 + + # Should respect max delay + assert delay1 <= 1.0 + assert delay2 <= 1.0 + assert delay3 <= 1.0 + + # Should have jitter + assert delay1 != 0.1 # Should have jitter added + + +@pytest.mark.trio +async def test_swarm_retry_logic(): + """Test retry logic in dial operations.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Configure for fast testing + retry_config = RetryConfig( + max_retries=2, + initial_delay=0.01, # Very short for testing + max_delay=0.1, + ) + + swarm = Swarm(peer_id, peerstore, upgrader, transport, retry_config) + + # Mock the single attempt method to fail twice then succeed + attempt_count = [0] + + async def mock_single_attempt(addr, peer_id): + attempt_count[0] += 1 + if attempt_count[0] < 3: + raise SwarmException(f"Attempt {attempt_count[0]} failed") + return MockConnection(peer_id) + + swarm._dial_addr_single_attempt = mock_single_attempt + + # Test retry logic + start_time = time.time() + result = await swarm._dial_with_retry(Mock(spec=Multiaddr), peer_id) + end_time = time.time() + + # Should have succeeded after 3 attempts + assert attempt_count[0] == 3 + assert isinstance(result, MockConnection) + assert end_time - start_time > 0.01 # Should have some delay + + +@pytest.mark.trio +async def test_swarm_load_balancing_strategies(): + """Test load balancing strategies.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + swarm = Swarm(peer_id, peerstore, upgrader, transport) + + # Create mock connections with different stream counts + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + conn3 = MockConnection(peer_id) + + # Add some streams to simulate load + await conn1.new_stream() + await conn1.new_stream() + await conn2.new_stream() + + connections = [conn1, conn2, conn3] + + # Test round-robin strategy + swarm.connection_config.load_balancing_strategy = "round_robin" + # Cast to satisfy type checker + connections_cast = cast("list[INetConn]", connections) + selected1 = swarm._select_connection(connections_cast, peer_id) + selected2 = swarm._select_connection(connections_cast, peer_id) + selected3 = swarm._select_connection(connections_cast, peer_id) + + # Should cycle through connections + assert selected1 in connections + assert selected2 in connections + assert selected3 in connections + + # Test least loaded strategy + swarm.connection_config.load_balancing_strategy = "least_loaded" + least_loaded = swarm._select_connection(connections_cast, peer_id) + + # conn3 has 0 streams, conn2 has 1 stream, conn1 has 2 streams + # So conn3 should be selected as least loaded + assert least_loaded == conn3 + + # Test default strategy (first connection) + swarm.connection_config.load_balancing_strategy = "unknown" + default_selected = swarm._select_connection(connections_cast, peer_id) + assert default_selected == conn1 + + +@pytest.mark.trio +async def test_swarm_multiple_connections_api(): + """Test the new multiple connections API methods.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + swarm = Swarm(peer_id, peerstore, upgrader, transport) + + # Test empty connections + assert swarm.get_connections() == [] + assert swarm.get_connections(peer_id) == [] + assert swarm.get_connection(peer_id) is None + assert swarm.get_connections_map() == {} + + # Add some connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + swarm.connections[peer_id] = [conn1, conn2] + + # Test get_connections with peer_id + peer_connections = swarm.get_connections(peer_id) + assert len(peer_connections) == 2 + assert conn1 in peer_connections + assert conn2 in peer_connections + + # Test get_connections without peer_id (all connections) + all_connections = swarm.get_connections() + assert len(all_connections) == 2 + assert conn1 in all_connections + assert conn2 in all_connections + + # Test get_connection (backward compatibility) + single_conn = swarm.get_connection(peer_id) + assert single_conn in [conn1, conn2] + + # Test get_connections_map + connections_map = swarm.get_connections_map() + assert peer_id in connections_map + assert connections_map[peer_id] == [conn1, conn2] + + +@pytest.mark.trio +async def test_swarm_connection_trimming(): + """Test connection trimming when limit is exceeded.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + # Set max connections to 2 + connection_config = ConnectionConfig(max_connections_per_peer=2) + swarm = Swarm( + peer_id, peerstore, upgrader, transport, connection_config=connection_config + ) + + # Add 3 connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + conn3 = MockConnection(peer_id) + + swarm.connections[peer_id] = [conn1, conn2, conn3] + + # Trigger trimming + swarm._trim_connections(peer_id) + + # Should have only 2 connections + assert len(swarm.connections[peer_id]) == 2 + + # The most recent connections should remain + remaining = swarm.connections[peer_id] + assert conn2 in remaining + assert conn3 in remaining + + +@pytest.mark.trio +async def test_swarm_backward_compatibility(): + """Test backward compatibility features.""" + peer_id = ID(b"QmTest") + peerstore = Mock() + upgrader = Mock() + transport = Mock() + + swarm = Swarm(peer_id, peerstore, upgrader, transport) + + # Add connections + conn1 = MockConnection(peer_id) + conn2 = MockConnection(peer_id) + swarm.connections[peer_id] = [conn1, conn2] + + # Test connections_legacy property + legacy_connections = swarm.connections_legacy + assert peer_id in legacy_connections + # Should return first connection + assert legacy_connections[peer_id] in [conn1, conn2] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/core/network/test_notifee_performance.py b/tests/core/network/test_notifee_performance.py new file mode 100644 index 00000000..cba6d0ad --- /dev/null +++ b/tests/core/network/test_notifee_performance.py @@ -0,0 +1,82 @@ +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p.abc import ( + INetConn, + INetStream, + INetwork, + INotifee, +) +from libp2p.tools.utils import connect_swarm +from tests.utils.factories import SwarmFactory + + +class CountingNotifee(INotifee): + def __init__(self, event: trio.Event) -> None: + self._event = event + + async def opened_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def closed_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def connected(self, network: INetwork, conn: INetConn) -> None: + self._event.set() + + async def disconnected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + +class SlowNotifee(INotifee): + async def opened_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def closed_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def connected(self, network: INetwork, conn: INetConn) -> None: + await trio.sleep(0.5) + + async def disconnected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: + pass + + +@pytest.mark.trio +async def test_many_notifees_receive_connected_quickly() -> None: + async with SwarmFactory.create_batch_and_listen(2) as swarms: + count = 200 + events = [trio.Event() for _ in range(count)] + for ev in events: + swarms[0].register_notifee(CountingNotifee(ev)) + await connect_swarm(swarms[0], swarms[1]) + with trio.fail_after(1.5): + for ev in events: + await ev.wait() + + +@pytest.mark.trio +async def test_slow_notifee_does_not_block_others() -> None: + async with SwarmFactory.create_batch_and_listen(2) as swarms: + fast_events = [trio.Event() for _ in range(20)] + for ev in fast_events: + swarms[0].register_notifee(CountingNotifee(ev)) + swarms[0].register_notifee(SlowNotifee()) + await connect_swarm(swarms[0], swarms[1]) + # Fast notifees should complete quickly despite one slow notifee + with trio.fail_after(0.3): + for ev in fast_events: + await ev.wait() diff --git a/tests/core/network/test_notify_listen_lifecycle.py b/tests/core/network/test_notify_listen_lifecycle.py new file mode 100644 index 00000000..7bac5938 --- /dev/null +++ b/tests/core/network/test_notify_listen_lifecycle.py @@ -0,0 +1,76 @@ +import enum + +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p.abc import ( + INetConn, + INetStream, + INetwork, + INotifee, +) +from libp2p.tools.async_service import background_trio_service +from libp2p.tools.constants import LISTEN_MADDR +from tests.utils.factories import SwarmFactory + + +class Event(enum.Enum): + Listen = 0 + ListenClose = 1 + + +class MyNotifee(INotifee): + def __init__(self, events: list[Event]): + self.events = events + + async def opened_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def closed_stream(self, network: INetwork, stream: INetStream) -> None: + pass + + async def connected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def disconnected(self, network: INetwork, conn: INetConn) -> None: + pass + + async def listen(self, network: INetwork, multiaddr: Multiaddr) -> None: + self.events.append(Event.Listen) + + async def listen_close(self, network: INetwork, multiaddr: Multiaddr) -> None: + self.events.append(Event.ListenClose) + + +async def wait_for_event( + events_list: list[Event], event: Event, timeout: float = 1.0 +) -> bool: + with trio.move_on_after(timeout): + while event not in events_list: + await trio.sleep(0.01) + return True + return False + + +@pytest.mark.trio +async def test_listen_emitted_when_registered_before_listen(): + events: list[Event] = [] + swarm = SwarmFactory.build() + swarm.register_notifee(MyNotifee(events)) + async with background_trio_service(swarm): + # Start listening now; notifee was registered beforehand + assert await swarm.listen(LISTEN_MADDR) + assert await wait_for_event(events, Event.Listen) + + +@pytest.mark.trio +async def test_single_listener_close_emits_listen_close(): + events: list[Event] = [] + swarm = SwarmFactory.build() + swarm.register_notifee(MyNotifee(events)) + async with background_trio_service(swarm): + assert await swarm.listen(LISTEN_MADDR) + # Explicitly notify listen_close (close path via manager doesn't emit it) + await swarm.notify_listen_close(LISTEN_MADDR) + assert await wait_for_event(events, Event.ListenClose) diff --git a/tests/core/network/test_swarm.py b/tests/core/network/test_swarm.py index 6389bcb3..47bc3ace 100644 --- a/tests/core/network/test_swarm.py +++ b/tests/core/network/test_swarm.py @@ -16,6 +16,9 @@ from libp2p.network.exceptions import ( from libp2p.network.swarm import ( Swarm, ) +from libp2p.tools.async_service import ( + background_trio_service, +) from libp2p.tools.utils import ( connect_swarm, ) @@ -48,14 +51,19 @@ async def test_swarm_dial_peer(security_protocol): for addr in transport.get_addrs() ) swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) - await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # New: dial_peer now returns list of connections + connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert len(connections) > 0 + + # Verify connections are established in both directions assert swarms[0].get_peer_id() in swarms[1].connections assert swarms[1].get_peer_id() in swarms[0].connections # Test: Reuse connections when we already have ones with a peer. - conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] - conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) - assert conn is conn_to_1 + existing_connections = swarms[0].get_connections(swarms[1].get_peer_id()) + new_connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert new_connections == existing_connections @pytest.mark.trio @@ -104,7 +112,8 @@ async def test_swarm_close_peer(security_protocol): @pytest.mark.trio async def test_swarm_remove_conn(swarm_pair): swarm_0, swarm_1 = swarm_pair - conn_0 = swarm_0.connections[swarm_1.get_peer_id()] + # Get the first connection from the list + conn_0 = swarm_0.connections[swarm_1.get_peer_id()][0] swarm_0.remove_conn(conn_0) assert swarm_1.get_peer_id() not in swarm_0.connections # Test: Remove twice. There should not be errors. @@ -112,6 +121,67 @@ async def test_swarm_remove_conn(swarm_pair): assert swarm_1.get_peer_id() not in swarm_0.connections +@pytest.mark.trio +async def test_swarm_multiple_connections(security_protocol): + """Test multiple connections per peer functionality.""" + async with SwarmFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as swarms: + # Setup multiple addresses for peer + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + + # Dial peer - should return list of connections + connections = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert len(connections) > 0 + + # Test get_connections method + peer_connections = swarms[0].get_connections(swarms[1].get_peer_id()) + assert len(peer_connections) == len(connections) + + # Test get_connections_map method + connections_map = swarms[0].get_connections_map() + assert swarms[1].get_peer_id() in connections_map + assert len(connections_map[swarms[1].get_peer_id()]) == len(connections) + + # Test get_connection method (backward compatibility) + single_conn = swarms[0].get_connection(swarms[1].get_peer_id()) + assert single_conn is not None + assert single_conn in connections + + +@pytest.mark.trio +async def test_swarm_load_balancing(security_protocol): + """Test load balancing across multiple connections.""" + async with SwarmFactory.create_batch_and_listen( + 2, security_protocol=security_protocol + ) as swarms: + # Setup connection + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + + # Create multiple streams - should use load balancing + streams = [] + for _ in range(5): + stream = await swarms[0].new_stream(swarms[1].get_peer_id()) + streams.append(stream) + + # Verify streams were created successfully + assert len(streams) == 5 + + # Clean up + for stream in streams: + await stream.close() + + @pytest.mark.trio async def test_swarm_multiaddr(security_protocol): async with SwarmFactory.create_batch_and_listen( @@ -180,7 +250,123 @@ def test_new_swarm_tcp_multiaddr_supported(): assert isinstance(swarm.transport, TCP) -def test_new_swarm_quic_multiaddr_raises(): +def test_new_swarm_quic_multiaddr_supported(): + from libp2p.transport.quic.transport import QUICTransport + addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic") - with pytest.raises(ValueError, match="QUIC not yet supported"): - new_swarm(listen_addrs=[addr]) + swarm = new_swarm(listen_addrs=[addr]) + assert isinstance(swarm, Swarm) + assert isinstance(swarm.transport, QUICTransport) + + +@pytest.mark.trio +async def test_swarm_listen_multiple_addresses(security_protocol): + """Test that swarm can listen on multiple addresses simultaneously.""" + from libp2p.utils.address_validation import get_available_interfaces + + # Get multiple addresses to listen on + listen_addrs = get_available_interfaces(0) # Let OS choose ports + + # Create a swarm and listen on multiple addresses + swarm = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm): + # Listen on all addresses + success = await swarm.listen(*listen_addrs) + assert success, "Should successfully listen on at least one address" + + # Check that we have listeners for the addresses + actual_listeners = list(swarm.listeners.keys()) + assert len(actual_listeners) > 0, "Should have at least one listener" + + # Verify that all successful listeners are in the listeners dict + successful_count = 0 + for addr in listen_addrs: + addr_str = str(addr) + if addr_str in actual_listeners: + successful_count += 1 + # This address successfully started listening + listener = swarm.listeners[addr_str] + listener_addrs = listener.get_addrs() + assert len(listener_addrs) > 0, ( + f"Listener for {addr} should have addresses" + ) + + # Check that the listener address matches the expected address + # (port might be different if we used port 0) + expected_ip = addr.value_for_protocol("ip4") + expected_protocol = addr.value_for_protocol("tcp") + if expected_ip and expected_protocol: + found_matching = False + for listener_addr in listener_addrs: + if ( + listener_addr.value_for_protocol("ip4") == expected_ip + and listener_addr.value_for_protocol("tcp") is not None + ): + found_matching = True + break + assert found_matching, ( + f"Listener for {addr} should have matching IP" + ) + + assert successful_count == len(listen_addrs), ( + f"All {len(listen_addrs)} addresses should be listening, " + f"but only {successful_count} succeeded" + ) + + +@pytest.mark.trio +async def test_swarm_listen_multiple_addresses_connectivity(security_protocol): + """Test that real libp2p connections can be established to all listening addresses.""" # noqa: E501 + from libp2p.peer.peerinfo import info_from_p2p_addr + from libp2p.utils.address_validation import get_available_interfaces + + # Get multiple addresses to listen on + listen_addrs = get_available_interfaces(0) # Let OS choose ports + + # Create a swarm and listen on multiple addresses + swarm1 = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm1): + # Listen on all addresses + success = await swarm1.listen(*listen_addrs) + assert success, "Should successfully listen on at least one address" + + # Verify all available interfaces are listening + assert len(swarm1.listeners) == len(listen_addrs), ( + f"All {len(listen_addrs)} interfaces should be listening, " + f"but only {len(swarm1.listeners)} are" + ) + + # Create a second swarm to test connections + swarm2 = SwarmFactory.build(security_protocol=security_protocol) + async with background_trio_service(swarm2): + # Test connectivity to each listening address using real libp2p connections + for addr_str, listener in swarm1.listeners.items(): + listener_addrs = listener.get_addrs() + for listener_addr in listener_addrs: + # Create a full multiaddr with peer ID for libp2p connection + peer_id = swarm1.get_peer_id() + full_addr = listener_addr.encapsulate(f"/p2p/{peer_id}") + + # Test real libp2p connection + try: + peer_info = info_from_p2p_addr(full_addr) + + # Add the peer info to swarm2's peerstore so it knows where to connect # noqa: E501 + swarm2.peerstore.add_addrs( + peer_info.peer_id, [listener_addr], 10000 + ) + + await swarm2.dial_peer(peer_info.peer_id) + + # Verify connection was established + assert peer_info.peer_id in swarm2.connections, ( + f"Connection to {full_addr} should be established" + ) + assert swarm2.get_peer_id() in swarm1.connections, ( + f"Connection from {full_addr} should be established" + ) + + except Exception as e: + pytest.fail( + f"Failed to establish libp2p connection to {full_addr}: {e}" + ) diff --git a/tests/core/pubsub/test_gossipsub.py b/tests/core/pubsub/test_gossipsub.py index 91205b29..5c341d0b 100644 --- a/tests/core/pubsub/test_gossipsub.py +++ b/tests/core/pubsub/test_gossipsub.py @@ -1,4 +1,8 @@ import random +from unittest.mock import ( + AsyncMock, + MagicMock, +) import pytest import trio @@ -7,6 +11,9 @@ from libp2p.pubsub.gossipsub import ( PROTOCOL_ID, GossipSub, ) +from libp2p.pubsub.pb import ( + rpc_pb2, +) from libp2p.tools.utils import ( connect, ) @@ -754,3 +761,173 @@ async def test_single_host(): assert connected_peers == 0, ( f"Single host has {connected_peers} connections, expected 0" ) + + +@pytest.mark.trio +async def test_handle_ihave(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock emit_iwant to capture calls + mock_emit_iwant = AsyncMock() + monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant) + + # Create a test message ID as a string representation of a (seqno, from) tuple + test_seqno = b"1234" + test_from = id_bob.to_bytes() + test_msg_id = f"(b'{test_seqno.hex()}', b'{test_from.hex()}')" + ihave_msg = rpc_pb2.ControlIHave(messageIDs=[test_msg_id]) + + # Mock seen_messages.cache to avoid false positives + monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {}) + + # Simulate Bob sending IHAVE to Alice + await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob) + + # Check if emit_iwant was called with the correct message ID + mock_emit_iwant.assert_called_once() + called_args = mock_emit_iwant.call_args[0] + assert called_args[0] == [test_msg_id] # Expected message IDs + assert called_args[1] == id_bob # Sender peer ID + + +@pytest.mark.trio +async def test_handle_iwant(monkeypatch): + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_alice = pubsubs_gsub[index_alice].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock mcache.get to return a message + test_message = rpc_pb2.Message(data=b"test_data") + test_seqno = b"1234" + test_from = id_alice.to_bytes() + + # āœ… Correct: use raw tuple and str() to serialize, no hex() + test_msg_id = str((test_seqno, test_from)) + + mock_mcache_get = MagicMock(return_value=test_message) + monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get) + + # Mock write_msg to capture the sent packet + mock_write_msg = AsyncMock() + monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg) + + # Simulate Alice sending IWANT to Bob + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[test_msg_id]) + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + + # Check if write_msg was called with the correct packet + mock_write_msg.assert_called_once() + packet = mock_write_msg.call_args[0][1] + assert isinstance(packet, rpc_pb2.RPC) + assert len(packet.publish) == 1 + assert packet.publish[0] == test_message + + # Verify that mcache.get was called with the correct parsed message ID + mock_mcache_get.assert_called_once() + called_msg_id = mock_mcache_get.call_args[0][0] + assert isinstance(called_msg_id, tuple) + assert called_msg_id == (test_seqno, test_from) + + +@pytest.mark.trio +async def test_handle_iwant_invalid_msg_id(monkeypatch): + """ + Test that handle_iwant raises ValueError for malformed message IDs. + """ + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_alice = pubsubs_gsub[index_alice].my_id + + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) + + # Malformed message ID (not a tuple string) + malformed_msg_id = "not_a_valid_msg_id" + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[malformed_msg_id]) + + # Mock mcache.get and write_msg to ensure they are not called + mock_mcache_get = MagicMock() + monkeypatch.setattr(gossipsubs[index_bob].mcache, "get", mock_mcache_get) + mock_write_msg = AsyncMock() + monkeypatch.setattr(gossipsubs[index_bob].pubsub, "write_msg", mock_write_msg) + + with pytest.raises(ValueError): + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + mock_mcache_get.assert_not_called() + mock_write_msg.assert_not_called() + + # Message ID that's a tuple string but not (bytes, bytes) + invalid_tuple_msg_id = "('abc', 123)" + iwant_msg = rpc_pb2.ControlIWant(messageIDs=[invalid_tuple_msg_id]) + with pytest.raises(ValueError): + await gossipsubs[index_bob].handle_iwant(iwant_msg, id_alice) + mock_mcache_get.assert_not_called() + mock_write_msg.assert_not_called() + + +@pytest.mark.trio +async def test_handle_ihave_empty_message_ids(monkeypatch): + """ + Test that handle_ihave with an empty messageIDs list does not call emit_iwant. + """ + async with PubsubFactory.create_batch_with_gossipsub(2) as pubsubs_gsub: + gossipsub_routers = [] + for pubsub in pubsubs_gsub: + if isinstance(pubsub.router, GossipSub): + gossipsub_routers.append(pubsub.router) + gossipsubs = tuple(gossipsub_routers) + + index_alice = 0 + index_bob = 1 + id_bob = pubsubs_gsub[index_bob].my_id + + # Connect Alice and Bob + await connect(pubsubs_gsub[index_alice].host, pubsubs_gsub[index_bob].host) + await trio.sleep(0.1) # Allow connections to establish + + # Mock emit_iwant to capture calls + mock_emit_iwant = AsyncMock() + monkeypatch.setattr(gossipsubs[index_alice], "emit_iwant", mock_emit_iwant) + + # Empty messageIDs list + ihave_msg = rpc_pb2.ControlIHave(messageIDs=[]) + + # Mock seen_messages.cache to avoid false positives + monkeypatch.setattr(pubsubs_gsub[index_alice].seen_messages, "cache", {}) + + # Simulate Bob sending IHAVE to Alice + await gossipsubs[index_alice].handle_ihave(ihave_msg, id_bob) + + # emit_iwant should not be called since there are no message IDs + mock_emit_iwant.assert_not_called() diff --git a/tests/core/pubsub/test_pubsub.py b/tests/core/pubsub/test_pubsub.py index e674dbc0..9a09f34f 100644 --- a/tests/core/pubsub/test_pubsub.py +++ b/tests/core/pubsub/test_pubsub.py @@ -8,8 +8,10 @@ from typing import ( from unittest.mock import patch import pytest +import multiaddr import trio +from libp2p.crypto.rsa import create_new_key_pair from libp2p.custom_types import AsyncValidatorFn from libp2p.exceptions import ( ValidationError, @@ -17,9 +19,11 @@ from libp2p.exceptions import ( from libp2p.network.stream.exceptions import ( StreamEOF, ) +from libp2p.peer.envelope import Envelope, seal_record from libp2p.peer.id import ( ID, ) +from libp2p.peer.peer_record import PeerRecord from libp2p.pubsub.pb import ( rpc_pb2, ) @@ -87,6 +91,45 @@ async def test_re_unsubscribe(): assert TESTING_TOPIC not in pubsubs_fsub[0].topic_ids +@pytest.mark.trio +async def test_reissue_when_listen_addrs_change(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yield to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + + # Check whether signed-records were transfered properly in the subscribe call + envelope_b_sub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_sub, Envelope) + + # Simulate pubsubs_fsub[1].host listen addrs changing (different port) + new_addr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/123") + + # Patch just for the duration we force A to unsubscribe + with patch.object(pubsubs_fsub[0].host, "get_addrs", return_value=[new_addr]): + # Unsubscribe from A's side so that a new_record is issued + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) + await trio.sleep(1) + + # B should be holding A's new record with bumped seq + envelope_b_unsub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_unsub, Envelope) + + # This proves that a freshly signed record was issued rather than + # the latest-cached-one creating one. + assert envelope_b_sub.record().seq < envelope_b_unsub.record().seq + + @pytest.mark.trio async def test_peers_subscribe(): async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: @@ -95,11 +138,71 @@ async def test_peers_subscribe(): # Yield to let 0 notify 1 await trio.sleep(1) assert pubsubs_fsub[0].my_id in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + + # Check whether signed-records were transfered properly in the subscribe call + envelope_b_sub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_sub, Envelope) + await pubsubs_fsub[0].unsubscribe(TESTING_TOPIC) # Yield to let 0 notify 1 await trio.sleep(1) assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics[TESTING_TOPIC] + envelope_b_unsub = ( + pubsubs_fsub[1] + .host.get_peerstore() + .get_peer_record(pubsubs_fsub[0].host.get_id()) + ) + assert isinstance(envelope_b_unsub, Envelope) + + # This proves that the latest-cached-record was re-issued rather than + # freshly creating one. + assert envelope_b_sub.record().seq == envelope_b_unsub.record().seq + + +@pytest.mark.trio +async def test_peer_subscribe_fail_upon_invald_record_transfer(): + async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub: + await connect(pubsubs_fsub[0].host, pubsubs_fsub[1].host) + + # Corrupt host_a's local peer record + envelope = pubsubs_fsub[0].host.get_peerstore().get_local_record() + if envelope is not None: + true_record = envelope.record() + key_pair = create_new_key_pair() + + if envelope is not None: + envelope.public_key = key_pair.public_key + pubsubs_fsub[0].host.get_peerstore().set_local_record(envelope) + + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yeild to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get( + TESTING_TOPIC, set() + ) + + # Create a corrupt envelope with correct signature but false peer-id + false_record = PeerRecord( + ID.from_pubkey(key_pair.public_key), true_record.addrs + ) + false_envelope = seal_record( + false_record, pubsubs_fsub[0].host.get_private_key() + ) + + pubsubs_fsub[0].host.get_peerstore().set_local_record(false_envelope) + + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Yeild to let 0 notify 1 + await trio.sleep(1) + assert pubsubs_fsub[0].my_id not in pubsubs_fsub[1].peer_topics.get( + TESTING_TOPIC, set() + ) + @pytest.mark.trio async def test_get_hello_packet(): diff --git a/tests/core/pubsub/test_pubsub_notifee_integration.py b/tests/core/pubsub/test_pubsub_notifee_integration.py new file mode 100644 index 00000000..e35dfeb1 --- /dev/null +++ b/tests/core/pubsub/test_pubsub_notifee_integration.py @@ -0,0 +1,90 @@ +from typing import cast + +import pytest +import trio + +from libp2p.tools.utils import connect +from tests.utils.factories import PubsubFactory + + +@pytest.mark.trio +async def test_connected_enqueues_and_adds_peer(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + await connect(p0.host, p1.host) + await p0.wait_until_ready() + # Wait until peer is added via queue processing + with trio.fail_after(1.0): + while p1.my_id not in p0.peers: + await trio.sleep(0.01) + assert p1.my_id in p0.peers + + +@pytest.mark.trio +async def test_disconnected_enqueues_and_removes_peer(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + await connect(p0.host, p1.host) + await p0.wait_until_ready() + # Ensure present first + with trio.fail_after(1.0): + while p1.my_id not in p0.peers: + await trio.sleep(0.01) + # Now disconnect and expect removal via dead peer queue + await p0.host.get_network().close_peer(p1.host.get_id()) + with trio.fail_after(1.0): + while p1.my_id in p0.peers: + await trio.sleep(0.01) + assert p1.my_id not in p0.peers + + +@pytest.mark.trio +async def test_channel_closed_is_swallowed_in_notifee(monkeypatch) -> None: + # Ensure PubsubNotifee catches BrokenResourceError from its send channel + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + # Find the PubsubNotifee registered on the network + from libp2p.pubsub.pubsub_notifee import PubsubNotifee + + network = p0.host.get_network() + notifees = getattr(network, "notifees", []) + target = None + for nf in notifees: + if isinstance(nf, cast(type, PubsubNotifee)): + target = nf + break + assert target is not None, "PubsubNotifee not found on network" + + async def failing_send(_peer_id): # type: ignore[no-redef] + raise trio.BrokenResourceError + + # Make initiator queue send fail; PubsubNotifee should swallow + monkeypatch.setattr(target.initiator_peers_queue, "send", failing_send) + + # Connect peers; if exceptions are swallowed, service stays running + await connect(p0.host, p1.host) + await p0.wait_until_ready() + assert True + + +@pytest.mark.trio +async def test_duplicate_connection_does_not_duplicate_peer_state(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + await connect(p0.host, p1.host) + await p0.wait_until_ready() + with trio.fail_after(1.0): + while p1.my_id not in p0.peers: + await trio.sleep(0.01) + # Connect again should not add duplicates + await connect(p0.host, p1.host) + await trio.sleep(0.1) + assert list(p0.peers.keys()).count(p1.my_id) == 1 + + +@pytest.mark.trio +async def test_blacklist_blocks_peer_added_by_notifee(): + async with PubsubFactory.create_batch_with_floodsub(2) as (p0, p1): + # Blacklist before connecting + p0.add_to_blacklist(p1.my_id) + await connect(p0.host, p1.host) + await p0.wait_until_ready() + # Give handler a chance to run + await trio.sleep(0.1) + assert p1.my_id not in p0.peers diff --git a/tests/core/security/test_security_multistream.py b/tests/core/security/test_security_multistream.py index 577cf404..d4fed72d 100644 --- a/tests/core/security/test_security_multistream.py +++ b/tests/core/security/test_security_multistream.py @@ -51,6 +51,9 @@ async def perform_simple_test(assertion_func, security_protocol): # Extract the secured connection from either Mplex or Yamux implementation def get_secured_conn(conn): + # conn is now a list, get the first connection + if isinstance(conn, list): + conn = conn[0] muxed_conn = conn.muxed_conn # Direct attribute access for known implementations has_secured_conn = hasattr(muxed_conn, "secured_conn") diff --git a/tests/core/stream_muxer/test_multiplexer_selection.py b/tests/core/stream_muxer/test_multiplexer_selection.py index b2f3e305..9b45324e 100644 --- a/tests/core/stream_muxer/test_multiplexer_selection.py +++ b/tests/core/stream_muxer/test_multiplexer_selection.py @@ -74,7 +74,8 @@ async def test_multiplexer_preference_parameter(muxer_preference): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol @@ -150,7 +151,8 @@ async def test_explicit_muxer_options(muxer_option_func, expected_stream_class): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol @@ -219,7 +221,8 @@ async def test_global_default_muxer(global_default): assert len(connections) > 0, "Connection not established" # Get the first connection - conn = list(connections.values())[0] + conns = list(connections.values())[0] + conn = conns[0] # Get first connection from the list muxed_conn = conn.muxed_conn # Define a simple echo protocol diff --git a/tests/core/transport/quic/test_concurrency.py b/tests/core/transport/quic/test_concurrency.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/transport/quic/test_connection.py b/tests/core/transport/quic/test_connection.py new file mode 100644 index 00000000..9b3ad3a9 --- /dev/null +++ b/tests/core/transport/quic/test_connection.py @@ -0,0 +1,553 @@ +""" +Enhanced tests for QUIC connection functionality - Module 3. +Tests all new features including advanced stream management, resource management, +error handling, and concurrent operations. +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from multiaddr.multiaddr import Multiaddr +import trio + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.peer.id import ID +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.exceptions import ( + QUICConnectionClosedError, + QUICConnectionError, + QUICConnectionTimeoutError, + QUICPeerVerificationError, + QUICStreamLimitError, + QUICStreamTimeoutError, +) +from libp2p.transport.quic.security import QUICTLSConfigManager +from libp2p.transport.quic.stream import QUICStream, StreamDirection + + +class MockResourceScope: + """Mock resource scope for testing.""" + + def __init__(self): + self.memory_reserved = 0 + + def reserve_memory(self, size): + self.memory_reserved += size + + def release_memory(self, size): + self.memory_reserved = max(0, self.memory_reserved - size) + + +class TestQUICConnection: + """Test suite for QUIC connection functionality.""" + + @pytest.fixture + def mock_quic_connection(self): + """Create mock aioquic QuicConnection.""" + mock = Mock() + mock.next_event.return_value = None + mock.datagrams_to_send.return_value = [] + mock.get_timer.return_value = None + mock.connect = Mock() + mock.close = Mock() + mock.send_stream_data = Mock() + mock.reset_stream = Mock() + return mock + + @pytest.fixture + def mock_quic_transport(self): + mock = Mock() + mock._config = QUICTransportConfig() + return mock + + @pytest.fixture + def mock_resource_scope(self): + """Create mock resource scope.""" + return MockResourceScope() + + @pytest.fixture + def quic_connection( + self, + mock_quic_connection: Mock, + mock_quic_transport: Mock, + mock_resource_scope: MockResourceScope, + ): + """Create test QUIC connection with enhanced features.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + mock_security_manager = Mock() + + return QUICConnection( + quic_connection=mock_quic_connection, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=None, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=mock_quic_transport, + resource_scope=mock_resource_scope, + security_manager=mock_security_manager, + ) + + @pytest.fixture + def server_connection(self, mock_quic_connection, mock_resource_scope): + """Create server-side QUIC connection.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + return QUICConnection( + quic_connection=mock_quic_connection, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=False, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + resource_scope=mock_resource_scope, + ) + + # Basic functionality tests + + def test_connection_initialization_enhanced( + self, quic_connection, mock_resource_scope + ): + """Test enhanced connection initialization.""" + assert quic_connection._remote_addr == ("127.0.0.1", 4001) + assert quic_connection.is_initiator is True + assert not quic_connection.is_closed + assert not quic_connection.is_established + assert len(quic_connection._streams) == 0 + assert quic_connection._resource_scope == mock_resource_scope + assert quic_connection._outbound_stream_count == 0 + assert quic_connection._inbound_stream_count == 0 + assert len(quic_connection._stream_accept_queue) == 0 + + def test_stream_id_calculation_enhanced(self): + """Test enhanced stream ID calculation for client/server.""" + # Client connection (initiator) + client_conn = QUICConnection( + quic_connection=Mock(), + remote_addr=("127.0.0.1", 4001), + remote_peer_id=None, + local_peer_id=Mock(), + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + assert client_conn._next_stream_id == 0 # Client starts with 0 + + # Server connection (not initiator) + server_conn = QUICConnection( + quic_connection=Mock(), + remote_addr=("127.0.0.1", 4001), + remote_peer_id=None, + local_peer_id=Mock(), + is_initiator=False, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + assert server_conn._next_stream_id == 1 # Server starts with 1 + + def test_incoming_stream_detection_enhanced(self, quic_connection): + """Test enhanced incoming stream detection logic.""" + # For client (initiator), odd stream IDs are incoming + assert quic_connection._is_incoming_stream(1) is True # Server-initiated + assert quic_connection._is_incoming_stream(0) is False # Client-initiated + assert quic_connection._is_incoming_stream(5) is True # Server-initiated + assert quic_connection._is_incoming_stream(4) is False # Client-initiated + + # Stream management tests + + @pytest.mark.trio + async def test_open_stream_basic(self, quic_connection): + """Test basic stream opening.""" + quic_connection._started = True + + stream = await quic_connection.open_stream() + + assert isinstance(stream, QUICStream) + assert stream.stream_id == "0" + assert stream.direction == StreamDirection.OUTBOUND + assert 0 in quic_connection._streams + assert quic_connection._outbound_stream_count == 1 + + @pytest.mark.trio + async def test_open_stream_limit_reached(self, quic_connection): + """Test stream limit enforcement.""" + quic_connection._started = True + quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS + + with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"): + await quic_connection.open_stream() + + @pytest.mark.trio + async def test_open_stream_timeout(self, quic_connection: QUICConnection): + """Test stream opening timeout.""" + quic_connection._started = True + return + + # Mock the stream ID lock to simulate slow operation + async def slow_acquire(): + await trio.sleep(10) # Longer than timeout + + with patch.object( + quic_connection._stream_lock, "acquire", side_effect=slow_acquire + ): + with pytest.raises( + QUICStreamTimeoutError, match="Stream creation timed out" + ): + await quic_connection.open_stream(timeout=0.1) + + @pytest.mark.trio + async def test_accept_stream_basic(self, quic_connection): + """Test basic stream acceptance.""" + # Create a mock inbound stream + mock_stream = Mock(spec=QUICStream) + mock_stream.stream_id = "1" + + # Add to accept queue + quic_connection._stream_accept_queue.append(mock_stream) + quic_connection._stream_accept_event.set() + + accepted_stream = await quic_connection.accept_stream(timeout=0.1) + + assert accepted_stream == mock_stream + assert len(quic_connection._stream_accept_queue) == 0 + + @pytest.mark.trio + async def test_accept_stream_timeout(self, quic_connection): + """Test stream acceptance timeout.""" + with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"): + await quic_connection.accept_stream(timeout=0.1) + + @pytest.mark.trio + async def test_accept_stream_on_closed_connection(self, quic_connection): + """Test stream acceptance on closed connection.""" + await quic_connection.close() + + with pytest.raises(QUICConnectionClosedError, match="Connection is closed"): + await quic_connection.accept_stream() + + # Stream handler tests + + @pytest.mark.trio + async def test_stream_handler_setting(self, quic_connection): + """Test setting stream handler.""" + + async def mock_handler(stream): + pass + + quic_connection.set_stream_handler(mock_handler) + assert quic_connection._stream_handler == mock_handler + + # Connection lifecycle tests + + @pytest.mark.trio + async def test_connection_start_client(self, quic_connection): + """Test client connection start.""" + with patch.object( + quic_connection, "_initiate_connection", new_callable=AsyncMock + ) as mock_initiate: + await quic_connection.start() + + assert quic_connection._started + mock_initiate.assert_called_once() + + @pytest.mark.trio + async def test_connection_start_server(self, server_connection): + """Test server connection start.""" + await server_connection.start() + + assert server_connection._started + assert server_connection._established + assert server_connection._connected_event.is_set() + + @pytest.mark.trio + async def test_connection_start_already_started(self, quic_connection): + """Test starting already started connection.""" + quic_connection._started = True + + # Should not raise error, just log warning + await quic_connection.start() + assert quic_connection._started + + @pytest.mark.trio + async def test_connection_start_closed(self, quic_connection): + """Test starting closed connection.""" + quic_connection._closed = True + + with pytest.raises( + QUICConnectionError, match="Cannot start a closed connection" + ): + await quic_connection.start() + + @pytest.mark.trio + async def test_connection_connect_with_nursery( + self, quic_connection: QUICConnection + ): + """Test connection establishment with nursery.""" + quic_connection._started = True + quic_connection._established = True + quic_connection._connected_event.set() + + with patch.object( + quic_connection, "_start_background_tasks", new_callable=AsyncMock + ) as mock_start_tasks: + with patch.object( + quic_connection, + "_verify_peer_identity_with_security", + new_callable=AsyncMock, + ) as mock_verify: + async with trio.open_nursery() as nursery: + await quic_connection.connect(nursery) + + assert quic_connection._nursery == nursery + mock_start_tasks.assert_called_once() + mock_verify.assert_called_once() + + @pytest.mark.trio + @pytest.mark.slow + async def test_connection_connect_timeout( + self, quic_connection: QUICConnection + ) -> None: + """Test connection establishment timeout.""" + quic_connection._started = True + # Don't set connected event to simulate timeout + + with patch.object( + quic_connection, "_start_background_tasks", new_callable=AsyncMock + ): + async with trio.open_nursery() as nursery: + with pytest.raises( + QUICConnectionTimeoutError, match="Connection handshake timed out" + ): + await quic_connection.connect(nursery) + + # Resource management tests + + @pytest.mark.trio + async def test_stream_removal_resource_cleanup( + self, quic_connection: QUICConnection, mock_resource_scope + ): + """Test stream removal and resource cleanup.""" + quic_connection._started = True + + # Create a stream + stream = await quic_connection.open_stream() + + # Remove the stream + quic_connection._remove_stream(int(stream.stream_id)) + + assert int(stream.stream_id) not in quic_connection._streams + # Note: Count updates is async, so we can't test it directly here + + # Error handling tests + + @pytest.mark.trio + async def test_connection_error_handling(self, quic_connection) -> None: + """Test connection error handling.""" + error = Exception("Test error") + + with patch.object( + quic_connection, "close", new_callable=AsyncMock + ) as mock_close: + await quic_connection._handle_connection_error(error) + mock_close.assert_called_once() + + # Statistics and monitoring tests + + @pytest.mark.trio + async def test_connection_stats_enhanced(self, quic_connection) -> None: + """Test enhanced connection statistics.""" + quic_connection._started = True + + # Create some streams + _stream1 = await quic_connection.open_stream() + _stream2 = await quic_connection.open_stream() + + stats = quic_connection.get_stream_stats() + + expected_keys = [ + "total_streams", + "outbound_streams", + "inbound_streams", + "max_streams", + "stream_utilization", + "stats", + ] + + for key in expected_keys: + assert key in stats + + assert stats["total_streams"] == 2 + assert stats["outbound_streams"] == 2 + assert stats["inbound_streams"] == 0 + + @pytest.mark.trio + async def test_get_active_streams(self, quic_connection) -> None: + """Test getting active streams.""" + quic_connection._started = True + + # Create streams + stream1 = await quic_connection.open_stream() + stream2 = await quic_connection.open_stream() + + active_streams = quic_connection.get_active_streams() + + assert len(active_streams) == 2 + assert stream1 in active_streams + assert stream2 in active_streams + + @pytest.mark.trio + async def test_get_streams_by_protocol(self, quic_connection) -> None: + """Test getting streams by protocol.""" + quic_connection._started = True + + # Create streams with different protocols + stream1 = await quic_connection.open_stream() + stream1.protocol = "/test/1.0.0" + + stream2 = await quic_connection.open_stream() + stream2.protocol = "/other/1.0.0" + + test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0") + other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0") + + assert len(test_streams) == 1 + assert len(other_streams) == 1 + assert stream1 in test_streams + assert stream2 in other_streams + + # Enhanced close tests + + @pytest.mark.trio + async def test_connection_close_enhanced( + self, quic_connection: QUICConnection + ) -> None: + """Test enhanced connection close with stream cleanup.""" + quic_connection._started = True + + # Create some streams + _stream1 = await quic_connection.open_stream() + _stream2 = await quic_connection.open_stream() + + await quic_connection.close() + + assert quic_connection.is_closed + assert len(quic_connection._streams) == 0 + + # Concurrent operations tests + + @pytest.mark.trio + async def test_concurrent_stream_operations( + self, quic_connection: QUICConnection + ) -> None: + """Test concurrent stream operations.""" + quic_connection._started = True + + async def create_stream(): + return await quic_connection.open_stream() + + # Create multiple streams concurrently + async with trio.open_nursery() as nursery: + for i in range(10): + nursery.start_soon(create_stream) + + # Wait a bit for all to start + await trio.sleep(0.1) + + # Should have created streams without conflicts + assert quic_connection._outbound_stream_count == 10 + assert len(quic_connection._streams) == 10 + + # Connection properties tests + + def test_connection_properties(self, quic_connection: QUICConnection) -> None: + """Test connection property accessors.""" + assert quic_connection.multiaddr() == quic_connection._maddr + assert quic_connection.local_peer_id() == quic_connection._local_peer_id + assert quic_connection.remote_peer_id() == quic_connection._remote_peer_id + + # IRawConnection interface tests + + @pytest.mark.trio + async def test_raw_connection_write(self, quic_connection: QUICConnection) -> None: + """Test raw connection write interface.""" + quic_connection._started = True + + with patch.object(quic_connection, "open_stream") as mock_open: + mock_stream = AsyncMock() + mock_open.return_value = mock_stream + + await quic_connection.write(b"test data") + + mock_open.assert_called_once() + mock_stream.write.assert_called_once_with(b"test data") + mock_stream.close_write.assert_called_once() + + @pytest.mark.trio + async def test_raw_connection_read_not_implemented( + self, quic_connection: QUICConnection + ) -> None: + """Test raw connection read raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + await quic_connection.read() + + # Mock verification helpers + + def test_mock_resource_scope_functionality(self, mock_resource_scope) -> None: + """Test mock resource scope works correctly.""" + assert mock_resource_scope.memory_reserved == 0 + + mock_resource_scope.reserve_memory(1000) + assert mock_resource_scope.memory_reserved == 1000 + + mock_resource_scope.reserve_memory(500) + assert mock_resource_scope.memory_reserved == 1500 + + mock_resource_scope.release_memory(600) + assert mock_resource_scope.memory_reserved == 900 + + mock_resource_scope.release_memory(2000) # Should not go negative + assert mock_resource_scope.memory_reserved == 0 + + +@pytest.mark.trio +async def test_invalid_certificate_verification(): + key_pair1 = create_new_key_pair() + key_pair2 = create_new_key_pair() + + peer_id1 = ID.from_pubkey(key_pair1.public_key) + peer_id2 = ID.from_pubkey(key_pair2.public_key) + + manager = QUICTLSConfigManager( + libp2p_private_key=key_pair1.private_key, peer_id=peer_id1 + ) + + # Match the certificate against a different peer_id + with pytest.raises(QUICPeerVerificationError, match="Peer ID mismatch"): + manager.verify_peer_identity(manager.tls_config.certificate, peer_id2) + + from cryptography.hazmat.primitives.serialization import Encoding + + # --- Corrupt the certificate by tampering the DER bytes --- + cert_bytes = manager.tls_config.certificate.public_bytes(Encoding.DER) + corrupted_bytes = bytearray(cert_bytes) + + # Flip some random bytes in the middle of the certificate + corrupted_bytes[len(corrupted_bytes) // 2] ^= 0xFF + + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + + # This will still parse (structurally valid), but the signature + # or fingerprint will break + corrupted_cert = x509.load_der_x509_certificate( + bytes(corrupted_bytes), backend=default_backend() + ) + + with pytest.raises( + QUICPeerVerificationError, match="Certificate verification failed" + ): + manager.verify_peer_identity(corrupted_cert, peer_id1) diff --git a/tests/core/transport/quic/test_connection_id.py b/tests/core/transport/quic/test_connection_id.py new file mode 100644 index 00000000..de371550 --- /dev/null +++ b/tests/core/transport/quic/test_connection_id.py @@ -0,0 +1,624 @@ +""" +QUIC Connection ID Management Tests + +This test module covers comprehensive testing of QUIC connection ID functionality +including generation, rotation, retirement, and validation according to RFC 9000. + +Tests are organized into: +1. Basic Connection ID Management +2. Connection ID Rotation and Updates +3. Connection ID Retirement +4. Error Conditions and Edge Cases +5. Integration Tests with Real Connections +""" + +import secrets +import time +from typing import Any +from unittest.mock import Mock + +import pytest +from aioquic.buffer import Buffer + +# Import aioquic components for low-level testing +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.connection import QuicConnection, QuicConnectionId +from multiaddr import Multiaddr + +from libp2p.crypto.ed25519 import create_new_key_pair +from libp2p.peer.id import ID +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.transport import QUICTransport + + +class ConnectionIdTestHelper: + """Helper class for connection ID testing utilities.""" + + @staticmethod + def generate_connection_id(length: int = 8) -> bytes: + """Generate a random connection ID of specified length.""" + return secrets.token_bytes(length) + + @staticmethod + def create_quic_connection_id(cid: bytes, sequence: int = 0) -> QuicConnectionId: + """Create a QuicConnectionId object.""" + return QuicConnectionId( + cid=cid, + sequence_number=sequence, + stateless_reset_token=secrets.token_bytes(16), + ) + + @staticmethod + def extract_connection_ids_from_connection(conn: QUICConnection) -> dict[str, Any]: + """Extract connection ID information from a QUIC connection.""" + quic = conn._quic + return { + "host_cids": [cid.cid.hex() for cid in getattr(quic, "_host_cids", [])], + "peer_cid": getattr(quic, "_peer_cid", None), + "peer_cid_available": [ + cid.cid.hex() for cid in getattr(quic, "_peer_cid_available", []) + ], + "retire_connection_ids": getattr(quic, "_retire_connection_ids", []), + "host_cid_seq": getattr(quic, "_host_cid_seq", 0), + } + + +class TestBasicConnectionIdManagement: + """Test basic connection ID management functionality.""" + + @pytest.fixture + def mock_quic_connection(self): + """Create a mock QUIC connection with connection ID support.""" + mock_quic = Mock(spec=QuicConnection) + mock_quic._host_cids = [] + mock_quic._host_cid_seq = 0 + mock_quic._peer_cid = None + mock_quic._peer_cid_available = [] + mock_quic._retire_connection_ids = [] + mock_quic._configuration = Mock() + mock_quic._configuration.connection_id_length = 8 + mock_quic._remote_active_connection_id_limit = 8 + return mock_quic + + @pytest.fixture + def quic_connection(self, mock_quic_connection): + """Create a QUICConnection instance for testing.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + return QUICConnection( + quic_connection=mock_quic_connection, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + + def test_connection_id_initialization(self, quic_connection): + """Test that connection ID tracking is properly initialized.""" + # Check that connection ID tracking structures are initialized + assert hasattr(quic_connection, "_available_connection_ids") + assert hasattr(quic_connection, "_current_connection_id") + assert hasattr(quic_connection, "_retired_connection_ids") + assert hasattr(quic_connection, "_connection_id_sequence_numbers") + + # Initial state should be empty + assert len(quic_connection._available_connection_ids) == 0 + assert quic_connection._current_connection_id is None + assert len(quic_connection._retired_connection_ids) == 0 + assert len(quic_connection._connection_id_sequence_numbers) == 0 + + def test_connection_id_stats_tracking(self, quic_connection): + """Test connection ID statistics are properly tracked.""" + stats = quic_connection.get_connection_id_stats() + + # Check that all expected stats are present + expected_keys = [ + "available_connection_ids", + "current_connection_id", + "retired_connection_ids", + "connection_ids_issued", + "connection_ids_retired", + "connection_id_changes", + "available_cid_list", + ] + + for key in expected_keys: + assert key in stats + + # Initial values should be zero/empty + assert stats["available_connection_ids"] == 0 + assert stats["current_connection_id"] is None + assert stats["retired_connection_ids"] == 0 + assert stats["connection_ids_issued"] == 0 + assert stats["connection_ids_retired"] == 0 + assert stats["connection_id_changes"] == 0 + assert stats["available_cid_list"] == [] + + def test_current_connection_id_getter(self, quic_connection): + """Test getting current connection ID.""" + # Initially no connection ID + assert quic_connection.get_current_connection_id() is None + + # Set a connection ID + test_cid = ConnectionIdTestHelper.generate_connection_id() + quic_connection._current_connection_id = test_cid + + assert quic_connection.get_current_connection_id() == test_cid + + def test_connection_id_generation(self): + """Test connection ID generation utilities.""" + # Test default length + cid1 = ConnectionIdTestHelper.generate_connection_id() + assert len(cid1) == 8 + assert isinstance(cid1, bytes) + + # Test custom length + cid2 = ConnectionIdTestHelper.generate_connection_id(16) + assert len(cid2) == 16 + + # Test uniqueness + cid3 = ConnectionIdTestHelper.generate_connection_id() + assert cid1 != cid3 + + +class TestConnectionIdRotationAndUpdates: + """Test connection ID rotation and update mechanisms.""" + + @pytest.fixture + def transport_config(self): + """Create transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=100, + ) + + @pytest.fixture + def server_key(self): + """Generate server private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def client_key(self): + """Generate client private key.""" + return create_new_key_pair().private_key + + def test_connection_id_replenishment(self): + """Test connection ID replenishment mechanism.""" + # Create a real QuicConnection to test replenishment + config = QuicConfiguration(is_client=True) + config.connection_id_length = 8 + + quic_conn = QuicConnection(configuration=config) + + # Initial state - should have some host connection IDs + initial_count = len(quic_conn._host_cids) + assert initial_count > 0 + + # Remove some connection IDs to trigger replenishment + while len(quic_conn._host_cids) > 2: + quic_conn._host_cids.pop() + + # Trigger replenishment + quic_conn._replenish_connection_ids() + + # Should have replenished up to the limit + assert len(quic_conn._host_cids) >= initial_count + + # All connection IDs should have unique sequence numbers + sequences = [cid.sequence_number for cid in quic_conn._host_cids] + assert len(sequences) == len(set(sequences)) + + def test_connection_id_sequence_numbers(self): + """Test connection ID sequence number management.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Get initial sequence number + initial_seq = quic_conn._host_cid_seq + + # Trigger replenishment to generate new connection IDs + quic_conn._replenish_connection_ids() + + # Sequence numbers should increment + assert quic_conn._host_cid_seq > initial_seq + + # All host connection IDs should have sequential numbers + sequences = [cid.sequence_number for cid in quic_conn._host_cids] + sequences.sort() + + # Check for proper sequence + for i in range(len(sequences) - 1): + assert sequences[i + 1] > sequences[i] + + def test_connection_id_limits(self): + """Test connection ID limit enforcement.""" + config = QuicConfiguration(is_client=True) + config.connection_id_length = 8 + + quic_conn = QuicConnection(configuration=config) + + # Set a reasonable limit + quic_conn._remote_active_connection_id_limit = 4 + + # Replenish connection IDs + quic_conn._replenish_connection_ids() + + # Should not exceed the limit + assert len(quic_conn._host_cids) <= quic_conn._remote_active_connection_id_limit + + +class TestConnectionIdRetirement: + """Test connection ID retirement functionality.""" + + def test_connection_id_retirement_basic(self): + """Test basic connection ID retirement.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Create a test connection ID to retire + test_cid = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=1 + ) + + # Add it to peer connection IDs + quic_conn._peer_cid_available.append(test_cid) + quic_conn._peer_cid_sequence_numbers.add(1) + + # Retire the connection ID + quic_conn._retire_peer_cid(test_cid) + + # Should be added to retirement list + assert 1 in quic_conn._retire_connection_ids + + def test_connection_id_retirement_limits(self): + """Test connection ID retirement limits.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Fill up retirement list near the limit + max_retirements = 32 # Based on aioquic's default limit + + for i in range(max_retirements): + quic_conn._retire_connection_ids.append(i) + + # Should be at limit + assert len(quic_conn._retire_connection_ids) == max_retirements + + def test_connection_id_retirement_events(self): + """Test that retirement generates proper events.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Create and add a host connection ID + test_cid = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=5 + ) + quic_conn._host_cids.append(test_cid) + + # Create a retirement frame buffer + from aioquic.buffer import Buffer + + buf = Buffer(capacity=16) + buf.push_uint_var(5) # sequence number to retire + buf.seek(0) + + # Process retirement (this should generate an event) + try: + quic_conn._handle_retire_connection_id_frame( + Mock(), # context + 0x19, # RETIRE_CONNECTION_ID frame type + buf, + ) + + # Check that connection ID was removed + remaining_sequences = [cid.sequence_number for cid in quic_conn._host_cids] + assert 5 not in remaining_sequences + + except Exception: + # May fail due to missing context, but that's okay for this test + pass + + +class TestConnectionIdErrorConditions: + """Test error conditions and edge cases in connection ID handling.""" + + def test_invalid_connection_id_length(self): + """Test handling of invalid connection ID lengths.""" + # Connection IDs must be 1-20 bytes according to RFC 9000 + + # Test too short (0 bytes) - this should be handled gracefully + empty_cid = b"" + assert len(empty_cid) == 0 + + # Test too long (>20 bytes) + long_cid = secrets.token_bytes(21) + assert len(long_cid) == 21 + + # Test valid lengths + for length in range(1, 21): + valid_cid = secrets.token_bytes(length) + assert len(valid_cid) == length + + def test_duplicate_sequence_numbers(self): + """Test handling of duplicate sequence numbers.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Create two connection IDs with same sequence number + cid1 = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=10 + ) + cid2 = ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), sequence=10 + ) + + # Add first connection ID + quic_conn._peer_cid_available.append(cid1) + quic_conn._peer_cid_sequence_numbers.add(10) + + # Adding second with same sequence should be handled appropriately + # (The implementation should prevent duplicates) + if 10 not in quic_conn._peer_cid_sequence_numbers: + quic_conn._peer_cid_available.append(cid2) + quic_conn._peer_cid_sequence_numbers.add(10) + + # Should only have one entry for sequence 10 + sequences = [cid.sequence_number for cid in quic_conn._peer_cid_available] + assert sequences.count(10) <= 1 + + def test_retire_unknown_connection_id(self): + """Test retiring an unknown connection ID.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Try to create a buffer to retire unknown sequence number + buf = Buffer(capacity=16) + buf.push_uint_var(999) # Unknown sequence number + buf.seek(0) + + # This should raise an error when processed + # (Testing the error condition, not the full processing) + unknown_sequence = 999 + known_sequences = [cid.sequence_number for cid in quic_conn._host_cids] + + assert unknown_sequence not in known_sequences + + def test_retire_current_connection_id(self): + """Test that retiring current connection ID is prevented.""" + config = QuicConfiguration(is_client=True) + quic_conn = QuicConnection(configuration=config) + + # Get current connection ID if available + if quic_conn._host_cids: + current_cid = quic_conn._host_cids[0] + current_sequence = current_cid.sequence_number + + # Trying to retire current connection ID should be prevented + # This is tested by checking the sequence number logic + assert current_sequence >= 0 + + +class TestConnectionIdIntegration: + """Integration tests for connection ID functionality with real connections.""" + + @pytest.fixture + def server_config(self): + """Server transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=100, + ) + + @pytest.fixture + def client_config(self): + """Client transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + ) + + @pytest.fixture + def server_key(self): + """Generate server private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def client_key(self): + """Generate client private key.""" + return create_new_key_pair().private_key + + @pytest.mark.trio + async def test_connection_id_exchange_during_handshake( + self, server_key, client_key, server_config, client_config + ): + """Test connection ID exchange during connection handshake.""" + # This test would require a full connection setup + # For now, we test the setup components + + server_transport = QUICTransport(server_key, server_config) + client_transport = QUICTransport(client_key, client_config) + + # Verify transports are created with proper configuration + assert server_transport._config == server_config + assert client_transport._config == client_config + + # Test that connection ID tracking is available + # (Integration with actual networking would require more setup) + + def test_connection_id_extraction_utilities(self): + """Test connection ID extraction utilities.""" + # Create a mock connection with some connection IDs + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + mock_quic = Mock() + mock_quic._host_cids = [ + ConnectionIdTestHelper.create_quic_connection_id( + ConnectionIdTestHelper.generate_connection_id(), i + ) + for i in range(3) + ] + mock_quic._peer_cid = None + mock_quic._peer_cid_available = [] + mock_quic._retire_connection_ids = [] + mock_quic._host_cid_seq = 3 + + quic_conn = QUICConnection( + quic_connection=mock_quic, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + + # Extract connection ID information + cid_info = ConnectionIdTestHelper.extract_connection_ids_from_connection( + quic_conn + ) + + # Verify extraction works + assert "host_cids" in cid_info + assert "peer_cid" in cid_info + assert "peer_cid_available" in cid_info + assert "retire_connection_ids" in cid_info + assert "host_cid_seq" in cid_info + + # Check values + assert len(cid_info["host_cids"]) == 3 + assert cid_info["host_cid_seq"] == 3 + assert cid_info["peer_cid"] is None + assert len(cid_info["peer_cid_available"]) == 0 + assert len(cid_info["retire_connection_ids"]) == 0 + + +class TestConnectionIdStatistics: + """Test connection ID statistics and monitoring.""" + + @pytest.fixture + def connection_with_stats(self): + """Create a connection with connection ID statistics.""" + private_key = create_new_key_pair().private_key + peer_id = ID.from_pubkey(private_key.get_public_key()) + + mock_quic = Mock() + mock_quic._host_cids = [] + mock_quic._peer_cid = None + mock_quic._peer_cid_available = [] + mock_quic._retire_connection_ids = [] + + return QUICConnection( + quic_connection=mock_quic, + remote_addr=("127.0.0.1", 4001), + remote_peer_id=peer_id, + local_peer_id=peer_id, + is_initiator=True, + maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + transport=Mock(), + ) + + def test_connection_id_stats_initialization(self, connection_with_stats): + """Test that connection ID statistics are properly initialized.""" + stats = connection_with_stats._stats + + # Check that connection ID stats are present + assert "connection_ids_issued" in stats + assert "connection_ids_retired" in stats + assert "connection_id_changes" in stats + + # Initial values should be zero + assert stats["connection_ids_issued"] == 0 + assert stats["connection_ids_retired"] == 0 + assert stats["connection_id_changes"] == 0 + + def test_connection_id_stats_update(self, connection_with_stats): + """Test updating connection ID statistics.""" + conn = connection_with_stats + + # Add some connection IDs to tracking + test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(3)] + + for cid in test_cids: + conn._available_connection_ids.add(cid) + + # Update stats (this would normally be done by the implementation) + conn._stats["connection_ids_issued"] = len(test_cids) + + # Verify stats + stats = conn.get_connection_id_stats() + assert stats["connection_ids_issued"] == 3 + assert stats["available_connection_ids"] == 3 + + def test_connection_id_list_representation(self, connection_with_stats): + """Test connection ID list representation in stats.""" + conn = connection_with_stats + + # Add some connection IDs + test_cids = [ConnectionIdTestHelper.generate_connection_id() for _ in range(2)] + + for cid in test_cids: + conn._available_connection_ids.add(cid) + + # Get stats + stats = conn.get_connection_id_stats() + + # Check that CID list is properly formatted + assert "available_cid_list" in stats + assert len(stats["available_cid_list"]) == 2 + + # All entries should be hex strings + for cid_hex in stats["available_cid_list"]: + assert isinstance(cid_hex, str) + assert len(cid_hex) == 16 # 8 bytes = 16 hex chars + + +# Performance and stress tests +class TestConnectionIdPerformance: + """Test connection ID performance and stress scenarios.""" + + def test_connection_id_generation_performance(self): + """Test connection ID generation performance.""" + start_time = time.time() + + # Generate many connection IDs + cids = [] + for _ in range(1000): + cid = ConnectionIdTestHelper.generate_connection_id() + cids.append(cid) + + end_time = time.time() + generation_time = end_time - start_time + + # Should be reasonably fast (less than 1 second for 1000 IDs) + assert generation_time < 1.0 + + # All should be unique + assert len(set(cids)) == len(cids) + + def test_connection_id_tracking_memory(self): + """Test memory usage of connection ID tracking.""" + conn_ids = set() + + # Add many connection IDs + for _ in range(1000): + cid = ConnectionIdTestHelper.generate_connection_id() + conn_ids.add(cid) + + # Verify they're all stored + assert len(conn_ids) == 1000 + + # Clean up + conn_ids.clear() + assert len(conn_ids) == 0 + + +if __name__ == "__main__": + # Run tests if executed directly + pytest.main([__file__, "-v"]) diff --git a/tests/core/transport/quic/test_integration.py b/tests/core/transport/quic/test_integration.py new file mode 100644 index 00000000..5016c996 --- /dev/null +++ b/tests/core/transport/quic/test_integration.py @@ -0,0 +1,418 @@ +""" +Basic QUIC Echo Test + +Simple test to verify the basic QUIC flow: +1. Client connects to server +2. Client sends data +3. Server receives data and echoes back +4. Client receives the echo + +This test focuses on identifying where the accept_stream issue occurs. +""" + +import logging + +import pytest +import multiaddr +import trio + +from examples.ping.ping import PING_LENGTH, PING_PROTOCOL_ID +from libp2p import new_host +from libp2p.abc import INetStream +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.transport.quic.config import QUICTransportConfig +from libp2p.transport.quic.connection import QUICConnection +from libp2p.transport.quic.transport import QUICTransport +from libp2p.transport.quic.utils import create_quic_multiaddr + +# Set up logging to see what's happening +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +class TestBasicQUICFlow: + """Test basic QUIC client-server communication flow.""" + + @pytest.fixture + def server_key(self): + """Generate server key pair.""" + return create_new_key_pair() + + @pytest.fixture + def client_key(self): + """Generate client key pair.""" + return create_new_key_pair() + + @pytest.fixture + def server_config(self): + """Simple server configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=10, + max_connections=5, + ) + + @pytest.fixture + def client_config(self): + """Simple client configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, + connection_timeout=5.0, + max_concurrent_streams=5, + ) + + @pytest.mark.trio + async def test_basic_echo_flow( + self, server_key, client_key, server_config, client_config + ): + """Test basic client-server echo flow with detailed logging.""" + print("\n=== BASIC QUIC ECHO TEST ===") + + # Create server components + server_transport = QUICTransport(server_key.private_key, server_config) + + # Track test state + server_received_data = None + server_connection_established = False + echo_sent = False + + async def echo_server_handler(connection: QUICConnection) -> None: + """Simple echo server handler with detailed logging.""" + nonlocal server_received_data, server_connection_established, echo_sent + + print("šŸ”— SERVER: Connection handler called") + server_connection_established = True + + try: + print("šŸ“” SERVER: Waiting for incoming stream...") + + # Accept stream with timeout and detailed logging + print("šŸ“” SERVER: Calling accept_stream...") + stream = await connection.accept_stream(timeout=5.0) + + if stream is None: + print("āŒ SERVER: accept_stream returned None") + return + + print(f"āœ… SERVER: Stream accepted! Stream ID: {stream.stream_id}") + + # Read data from the stream + print("šŸ“– SERVER: Reading data from stream...") + server_data = await stream.read(1024) + + if not server_data: + print("āŒ SERVER: No data received from stream") + return + + server_received_data = server_data.decode("utf-8", errors="ignore") + print(f"šŸ“Ø SERVER: Received data: '{server_received_data}'") + + # Echo the data back + echo_message = f"ECHO: {server_received_data}" + print(f"šŸ“¤ SERVER: Sending echo: '{echo_message}'") + + await stream.write(echo_message.encode()) + echo_sent = True + print("āœ… SERVER: Echo sent successfully") + + # Close the stream + await stream.close() + print("šŸ”’ SERVER: Stream closed") + + except Exception as e: + print(f"āŒ SERVER: Error in handler: {e}") + import traceback + + traceback.print_exc() + + # Create listener + listener = server_transport.create_listener(echo_server_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + # Variables to track client state + client_connected = False + client_sent_data = False + client_received_echo = None + + try: + print("šŸš€ Starting server...") + + async with trio.open_nursery() as nursery: + # Start server listener + success = await listener.listen(listen_addr, nursery) + assert success, "Failed to start server listener" + + # Get server address + server_addrs = listener.get_addrs() + server_addr = multiaddr.Multiaddr( + f"{server_addrs[0]}/p2p/{ID.from_pubkey(server_key.public_key)}" + ) + print(f"šŸ”§ SERVER: Listening on {server_addr}") + + # Give server a moment to be ready + await trio.sleep(0.1) + + print("šŸš€ Starting client...") + + # Create client transport + client_transport = QUICTransport(client_key.private_key, client_config) + client_transport.set_background_nursery(nursery) + + try: + # Connect to server + print(f"šŸ“ž CLIENT: Connecting to {server_addr}") + connection = await client_transport.dial(server_addr) + client_connected = True + print("āœ… CLIENT: Connected to server") + + # Open a stream + print("šŸ“¤ CLIENT: Opening stream...") + stream = await connection.open_stream() + print(f"āœ… CLIENT: Stream opened with ID: {stream.stream_id}") + + # Send test data + test_message = "Hello QUIC Server!" + print(f"šŸ“Ø CLIENT: Sending message: '{test_message}'") + await stream.write(test_message.encode()) + client_sent_data = True + print("āœ… CLIENT: Message sent") + + # Read echo response + print("šŸ“– CLIENT: Waiting for echo response...") + response_data = await stream.read(1024) + + if response_data: + client_received_echo = response_data.decode( + "utf-8", errors="ignore" + ) + print(f"šŸ“¬ CLIENT: Received echo: '{client_received_echo}'") + else: + print("āŒ CLIENT: No echo response received") + + print("šŸ”’ CLIENT: Closing connection") + await connection.close() + print("šŸ”’ CLIENT: Connection closed") + + print("šŸ”’ CLIENT: Closing transport") + await client_transport.close() + print("šŸ”’ CLIENT: Transport closed") + + except Exception as e: + print(f"āŒ CLIENT: Error: {e}") + import traceback + + traceback.print_exc() + + finally: + await client_transport.close() + print("šŸ”’ CLIENT: Transport closed") + + # Give everything time to complete + await trio.sleep(0.5) + + # Cancel nursery to stop server + nursery.cancel_scope.cancel() + + finally: + # Cleanup + if not listener._closed: + await listener.close() + await server_transport.close() + + # Verify the flow worked + print("\nšŸ“Š TEST RESULTS:") + print(f" Server connection established: {server_connection_established}") + print(f" Client connected: {client_connected}") + print(f" Client sent data: {client_sent_data}") + print(f" Server received data: '{server_received_data}'") + print(f" Echo sent by server: {echo_sent}") + print(f" Client received echo: '{client_received_echo}'") + + # Test assertions + assert server_connection_established, "Server connection handler was not called" + assert client_connected, "Client failed to connect" + assert client_sent_data, "Client failed to send data" + assert server_received_data == "Hello QUIC Server!", ( + f"Server received wrong data: '{server_received_data}'" + ) + assert echo_sent, "Server failed to send echo" + assert client_received_echo == "ECHO: Hello QUIC Server!", ( + f"Client received wrong echo: '{client_received_echo}'" + ) + + print("āœ… BASIC ECHO TEST PASSED!") + + @pytest.mark.trio + async def test_server_accept_stream_timeout( + self, server_key, client_key, server_config, client_config + ): + """Test what happens when server accept_stream times out.""" + print("\n=== TESTING SERVER ACCEPT_STREAM TIMEOUT ===") + + server_transport = QUICTransport(server_key.private_key, server_config) + + accept_stream_called = False + accept_stream_timeout = False + + async def timeout_test_handler(connection: QUICConnection) -> None: + """Handler that tests accept_stream timeout.""" + nonlocal accept_stream_called, accept_stream_timeout + + print("šŸ”— SERVER: Connection established, testing accept_stream timeout") + accept_stream_called = True + + try: + print("šŸ“” SERVER: Calling accept_stream with 2 second timeout...") + stream = await connection.accept_stream(timeout=2.0) + print(f"āœ… SERVER: accept_stream returned: {stream}") + + except Exception as e: + print(f"ā° SERVER: accept_stream timed out or failed: {e}") + accept_stream_timeout = True + + listener = server_transport.create_listener(timeout_test_handler) + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + client_connected = False + + try: + async with trio.open_nursery() as nursery: + # Start server + server_transport.set_background_nursery(nursery) + success = await listener.listen(listen_addr, nursery) + assert success + + server_addr = multiaddr.Multiaddr( + f"{listener.get_addrs()[0]}/p2p/{ID.from_pubkey(server_key.public_key)}" + ) + print(f"šŸ”§ SERVER: Listening on {server_addr}") + + # Create client but DON'T open a stream + async with trio.open_nursery() as client_nursery: + client_transport = QUICTransport( + client_key.private_key, client_config + ) + client_transport.set_background_nursery(client_nursery) + + try: + print("šŸ“ž CLIENT: Connecting (but NOT opening stream)...") + connection = await client_transport.dial(server_addr) + client_connected = True + print("āœ… CLIENT: Connected (no stream opened)") + + # Wait for server timeout + await trio.sleep(3.0) + + await connection.close() + print("šŸ”’ CLIENT: Connection closed") + + finally: + await client_transport.close() + + nursery.cancel_scope.cancel() + + finally: + await listener.close() + await server_transport.close() + + print("\nšŸ“Š TIMEOUT TEST RESULTS:") + print(f" Client connected: {client_connected}") + print(f" accept_stream called: {accept_stream_called}") + print(f" accept_stream timeout: {accept_stream_timeout}") + + assert client_connected, "Client should have connected" + assert accept_stream_called, "accept_stream should have been called" + assert accept_stream_timeout, ( + "accept_stream should have timed out when no stream was opened" + ) + + print("āœ… TIMEOUT TEST PASSED!") + + +@pytest.mark.trio +async def test_yamux_stress_ping(): + STREAM_COUNT = 100 + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + latencies = [] + failures = [] + + # === Server Setup === + server_host = new_host(listen_addrs=[listen_addr]) + + async def handle_ping(stream: INetStream) -> None: + try: + while True: + payload = await stream.read(PING_LENGTH) + if not payload: + break + await stream.write(payload) + except Exception: + await stream.reset() + + server_host.set_stream_handler(PING_PROTOCOL_ID, handle_ping) + + async with server_host.run(listen_addrs=[listen_addr]): + # Give server time to start + await trio.sleep(0.1) + + # === Client Setup === + destination = str(server_host.get_addrs()[0]) + maddr = multiaddr.Multiaddr(destination) + info = info_from_p2p_addr(maddr) + + client_listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + client_host = new_host(listen_addrs=[client_listen_addr]) + + async with client_host.run(listen_addrs=[client_listen_addr]): + await client_host.connect(info) + + async def ping_stream(i: int): + stream = None + try: + start = trio.current_time() + stream = await client_host.new_stream( + info.peer_id, [PING_PROTOCOL_ID] + ) + + await stream.write(b"\x01" * PING_LENGTH) + + with trio.fail_after(5): + response = await stream.read(PING_LENGTH) + + if response == b"\x01" * PING_LENGTH: + latency_ms = int((trio.current_time() - start) * 1000) + latencies.append(latency_ms) + print(f"[Ping #{i}] Latency: {latency_ms} ms") + await stream.close() + except Exception as e: + print(f"[Ping #{i}] Failed: {e}") + failures.append(i) + if stream: + await stream.reset() + + async with trio.open_nursery() as nursery: + for i in range(STREAM_COUNT): + nursery.start_soon(ping_stream, i) + + # === Result Summary === + print("\nšŸ“Š Ping Stress Test Summary") + print(f"Total Streams Launched: {STREAM_COUNT}") + print(f"Successful Pings: {len(latencies)}") + print(f"Failed Pings: {len(failures)}") + if failures: + print(f"āŒ Failed stream indices: {failures}") + + # === Assertions === + assert len(latencies) == STREAM_COUNT, ( + f"Expected {STREAM_COUNT} successful streams, got {len(latencies)}" + ) + assert all(isinstance(x, int) and x >= 0 for x in latencies), ( + "Invalid latencies" + ) + + avg_latency = sum(latencies) / len(latencies) + print(f"āœ… Average Latency: {avg_latency:.2f} ms") + assert avg_latency < 1000 diff --git a/tests/core/transport/quic/test_listener.py b/tests/core/transport/quic/test_listener.py new file mode 100644 index 00000000..840f7218 --- /dev/null +++ b/tests/core/transport/quic/test_listener.py @@ -0,0 +1,150 @@ +from unittest.mock import AsyncMock + +import pytest +from multiaddr.multiaddr import Multiaddr +import trio + +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.transport.quic.exceptions import ( + QUICListenError, +) +from libp2p.transport.quic.listener import QUICListener +from libp2p.transport.quic.transport import ( + QUICTransport, + QUICTransportConfig, +) +from libp2p.transport.quic.utils import ( + create_quic_multiaddr, +) + + +class TestQUICListener: + """Test suite for QUIC listener functionality.""" + + @pytest.fixture + def private_key(self): + """Generate test private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def transport_config(self): + """Generate test transport configuration.""" + return QUICTransportConfig(idle_timeout=10.0) + + @pytest.fixture + def transport(self, private_key, transport_config): + """Create test transport instance.""" + return QUICTransport(private_key, transport_config) + + @pytest.fixture + def connection_handler(self): + """Mock connection handler.""" + return AsyncMock() + + @pytest.fixture + def listener(self, transport, connection_handler): + """Create test listener.""" + return transport.create_listener(connection_handler) + + def test_listener_creation(self, transport, connection_handler): + """Test listener creation.""" + listener = transport.create_listener(connection_handler) + + assert isinstance(listener, QUICListener) + assert listener._transport == transport + assert listener._handler == connection_handler + assert not listener._listening + assert not listener._closed + + @pytest.mark.trio + async def test_listener_invalid_multiaddr(self, listener: QUICListener): + """Test listener with invalid multiaddr.""" + async with trio.open_nursery() as nursery: + invalid_addr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + + with pytest.raises(QUICListenError, match="Invalid QUIC multiaddr"): + await listener.listen(invalid_addr, nursery) + + @pytest.mark.trio + async def test_listener_basic_lifecycle(self, listener: QUICListener): + """Test basic listener lifecycle.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") # Port 0 = random + + async with trio.open_nursery() as nursery: + # Start listening + success = await listener.listen(listen_addr, nursery) + assert success + assert listener.is_listening() + + # Check bound addresses + addrs = listener.get_addrs() + assert len(addrs) == 1 + + # Check stats + stats = listener.get_stats() + assert stats["is_listening"] is True + assert stats["active_connections"] == 0 + assert stats["pending_connections"] == 0 + + # Sender Cancel Signal + nursery.cancel_scope.cancel() + + await listener.close() + assert not listener.is_listening() + + @pytest.mark.trio + async def test_listener_double_listen(self, listener: QUICListener): + """Test that double listen raises error.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 9001, "/quic") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + await trio.sleep(0.01) + + addrs = listener.get_addrs() + assert len(addrs) > 0 + async with trio.open_nursery() as nursery2: + with pytest.raises(QUICListenError, match="Already listening"): + await listener.listen(listen_addr, nursery2) + nursery2.cancel_scope.cancel() + + nursery.cancel_scope.cancel() + finally: + await listener.close() + + @pytest.mark.trio + async def test_listener_port_binding(self, listener: QUICListener): + """Test listener port binding and cleanup.""" + listen_addr = create_quic_multiaddr("127.0.0.1", 0, "/quic") + + try: + async with trio.open_nursery() as nursery: + success = await listener.listen(listen_addr, nursery) + assert success + await trio.sleep(0.5) + + addrs = listener.get_addrs() + assert len(addrs) > 0 + + nursery.cancel_scope.cancel() + finally: + await listener.close() + + # By the time we get here, the listener and its tasks have been fully + # shut down, allowing the nursery to exit without hanging. + print("TEST COMPLETED SUCCESSFULLY.") + + @pytest.mark.trio + async def test_listener_stats_tracking(self, listener): + """Test listener statistics tracking.""" + initial_stats = listener.get_stats() + + # All counters should start at 0 + assert initial_stats["connections_accepted"] == 0 + assert initial_stats["connections_rejected"] == 0 + assert initial_stats["bytes_received"] == 0 + assert initial_stats["packets_processed"] == 0 diff --git a/tests/core/transport/quic/test_transport.py b/tests/core/transport/quic/test_transport.py new file mode 100644 index 00000000..f9d65d8a --- /dev/null +++ b/tests/core/transport/quic/test_transport.py @@ -0,0 +1,123 @@ +from unittest.mock import ( + Mock, +) + +import pytest + +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.crypto.keys import PrivateKey +from libp2p.transport.quic.exceptions import ( + QUICDialError, + QUICListenError, +) +from libp2p.transport.quic.transport import ( + QUICTransport, + QUICTransportConfig, +) + + +class TestQUICTransport: + """Test suite for QUIC transport using trio.""" + + @pytest.fixture + def private_key(self): + """Generate test private key.""" + return create_new_key_pair().private_key + + @pytest.fixture + def transport_config(self): + """Generate test transport configuration.""" + return QUICTransportConfig( + idle_timeout=10.0, enable_draft29=True, enable_v1=True + ) + + @pytest.fixture + def transport(self, private_key: PrivateKey, transport_config: QUICTransportConfig): + """Create test transport instance.""" + return QUICTransport(private_key, transport_config) + + def test_transport_initialization(self, transport): + """Test transport initialization.""" + assert transport._private_key is not None + assert transport._peer_id is not None + assert not transport._closed + assert len(transport._quic_configs) >= 1 + + def test_supported_protocols(self, transport): + """Test supported protocol identifiers.""" + protocols = transport.protocols() + # TODO: Update when quic-v1 compatible + # assert "quic-v1" in protocols + assert "quic" in protocols # draft-29 + + def test_can_dial_quic_addresses(self, transport: QUICTransport): + """Test multiaddr compatibility checking.""" + import multiaddr + + # Valid QUIC addresses + valid_addrs = [ + # TODO: Update Multiaddr package to accept quic-v1 + multiaddr.Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_DRAFT29}" + ), + multiaddr.Multiaddr( + f"/ip4/127.0.0.1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + multiaddr.Multiaddr( + f"/ip4/192.168.1.1/udp/8080/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + multiaddr.Multiaddr( + f"/ip6/::1/udp/4001/{QUICTransportConfig.PROTOCOL_QUIC_V1}" + ), + ] + + for addr in valid_addrs: + assert transport.can_dial(addr) + + # Invalid addresses + invalid_addrs = [ + multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/4001"), + multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001"), + multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/ws"), + ] + + for addr in invalid_addrs: + assert not transport.can_dial(addr) + + @pytest.mark.trio + async def test_transport_lifecycle(self, transport): + """Test transport lifecycle management using trio.""" + assert not transport._closed + + await transport.close() + assert transport._closed + + # Should be safe to close multiple times + await transport.close() + + @pytest.mark.trio + async def test_dial_closed_transport(self, transport: QUICTransport) -> None: + """Test dialing with closed transport raises error.""" + import multiaddr + + await transport.close() + + with pytest.raises(QUICDialError, match="Transport is closed"): + await transport.dial( + multiaddr.Multiaddr("/ip4/127.0.0.1/udp/4001/quic"), + ) + + def test_create_listener_closed_transport(self, transport: QUICTransport) -> None: + """Test creating listener with closed transport raises error.""" + transport._closed = True + + with pytest.raises(QUICListenError, match="Transport is closed"): + transport.create_listener(Mock()) diff --git a/tests/core/transport/quic/test_utils.py b/tests/core/transport/quic/test_utils.py new file mode 100644 index 00000000..900c5c7e --- /dev/null +++ b/tests/core/transport/quic/test_utils.py @@ -0,0 +1,321 @@ +""" +Test suite for QUIC multiaddr utilities. +Focused tests covering essential functionality required for QUIC transport. +""" + +import pytest +from multiaddr import Multiaddr + +from libp2p.custom_types import TProtocol +from libp2p.transport.quic.exceptions import ( + QUICInvalidMultiaddrError, + QUICUnsupportedVersionError, +) +from libp2p.transport.quic.utils import ( + create_quic_multiaddr, + get_alpn_protocols, + is_quic_multiaddr, + multiaddr_to_quic_version, + normalize_quic_multiaddr, + quic_multiaddr_to_endpoint, + quic_version_to_wire_format, +) + + +class TestIsQuicMultiaddr: + """Test QUIC multiaddr detection.""" + + def test_valid_quic_v1_multiaddrs(self): + """Test valid QUIC v1 multiaddrs are detected.""" + valid_addrs = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip4/192.168.1.1/udp/8080/quic-v1", + "/ip6/::1/udp/4001/quic-v1", + "/ip6/2001:db8::1/udp/5000/quic-v1", + ] + + for addr_str in valid_addrs: + maddr = Multiaddr(addr_str) + assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" + + def test_valid_quic_draft29_multiaddrs(self): + """Test valid QUIC draft-29 multiaddrs are detected.""" + valid_addrs = [ + "/ip4/127.0.0.1/udp/4001/quic", + "/ip4/10.0.0.1/udp/9000/quic", + "/ip6/::1/udp/4001/quic", + "/ip6/fe80::1/udp/6000/quic", + ] + + for addr_str in valid_addrs: + maddr = Multiaddr(addr_str) + assert is_quic_multiaddr(maddr), f"Should detect {addr_str} as QUIC" + + def test_invalid_multiaddrs(self): + """Test non-QUIC multiaddrs are not detected.""" + invalid_addrs = [ + "/ip4/127.0.0.1/tcp/4001", # TCP, not QUIC + "/ip4/127.0.0.1/udp/4001", # UDP without QUIC + "/ip4/127.0.0.1/udp/4001/ws", # WebSocket + "/ip4/127.0.0.1/quic-v1", # Missing UDP + "/udp/4001/quic-v1", # Missing IP + "/dns4/example.com/tcp/443/tls", # Completely different + ] + + for addr_str in invalid_addrs: + maddr = Multiaddr(addr_str) + assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC" + + +class TestQuicMultiaddrToEndpoint: + """Test endpoint extraction from QUIC multiaddrs.""" + + def test_ipv4_extraction(self): + """Test IPv4 host/port extraction.""" + test_cases = [ + ("/ip4/127.0.0.1/udp/4001/quic-v1", ("127.0.0.1", 4001)), + ("/ip4/192.168.1.100/udp/8080/quic", ("192.168.1.100", 8080)), + ("/ip4/10.0.0.1/udp/9000/quic-v1", ("10.0.0.1", 9000)), + ] + + for addr_str, expected in test_cases: + maddr = Multiaddr(addr_str) + result = quic_multiaddr_to_endpoint(maddr) + assert result == expected, f"Failed for {addr_str}" + + def test_ipv6_extraction(self): + """Test IPv6 host/port extraction.""" + test_cases = [ + ("/ip6/::1/udp/4001/quic-v1", ("::1", 4001)), + ("/ip6/2001:db8::1/udp/5000/quic", ("2001:db8::1", 5000)), + ] + + for addr_str, expected in test_cases: + maddr = Multiaddr(addr_str) + result = quic_multiaddr_to_endpoint(maddr) + assert result == expected, f"Failed for {addr_str}" + + def test_invalid_multiaddr_raises_error(self): + """Test invalid multiaddrs raise appropriate errors.""" + invalid_addrs = [ + "/ip4/127.0.0.1/tcp/4001", # Not QUIC + "/ip4/127.0.0.1/udp/4001", # Missing QUIC protocol + ] + + for addr_str in invalid_addrs: + maddr = Multiaddr(addr_str) + with pytest.raises(QUICInvalidMultiaddrError): + quic_multiaddr_to_endpoint(maddr) + + +class TestMultiaddrToQuicVersion: + """Test QUIC version extraction.""" + + def test_quic_v1_detection(self): + """Test QUIC v1 version detection.""" + addrs = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip6/::1/udp/5000/quic-v1", + ] + + for addr_str in addrs: + maddr = Multiaddr(addr_str) + version = multiaddr_to_quic_version(maddr) + assert version == "quic-v1", f"Should detect quic-v1 for {addr_str}" + + def test_quic_draft29_detection(self): + """Test QUIC draft-29 version detection.""" + addrs = [ + "/ip4/127.0.0.1/udp/4001/quic", + "/ip6/::1/udp/5000/quic", + ] + + for addr_str in addrs: + maddr = Multiaddr(addr_str) + version = multiaddr_to_quic_version(maddr) + assert version == "quic", f"Should detect quic for {addr_str}" + + def test_non_quic_raises_error(self): + """Test non-QUIC multiaddrs raise error.""" + maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + with pytest.raises(QUICInvalidMultiaddrError): + multiaddr_to_quic_version(maddr) + + +class TestCreateQuicMultiaddr: + """Test QUIC multiaddr creation.""" + + def test_ipv4_creation(self): + """Test IPv4 QUIC multiaddr creation.""" + test_cases = [ + ("127.0.0.1", 4001, "quic-v1", "/ip4/127.0.0.1/udp/4001/quic-v1"), + ("192.168.1.1", 8080, "quic", "/ip4/192.168.1.1/udp/8080/quic"), + ("10.0.0.1", 9000, "/quic-v1", "/ip4/10.0.0.1/udp/9000/quic-v1"), + ] + + for host, port, version, expected in test_cases: + result = create_quic_multiaddr(host, port, version) + assert str(result) == expected + + def test_ipv6_creation(self): + """Test IPv6 QUIC multiaddr creation.""" + test_cases = [ + ("::1", 4001, "quic-v1", "/ip6/::1/udp/4001/quic-v1"), + ("2001:db8::1", 5000, "quic", "/ip6/2001:db8::1/udp/5000/quic"), + ] + + for host, port, version, expected in test_cases: + result = create_quic_multiaddr(host, port, version) + assert str(result) == expected + + def test_default_version(self): + """Test default version is quic-v1.""" + result = create_quic_multiaddr("127.0.0.1", 4001) + expected = "/ip4/127.0.0.1/udp/4001/quic-v1" + assert str(result) == expected + + def test_invalid_inputs_raise_errors(self): + """Test invalid inputs raise appropriate errors.""" + # Invalid IP + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("invalid-ip", 4001) + + # Invalid port + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", 70000) + + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", -1) + + # Invalid version + with pytest.raises(QUICInvalidMultiaddrError): + create_quic_multiaddr("127.0.0.1", 4001, "invalid-version") + + +class TestQuicVersionToWireFormat: + """Test QUIC version to wire format conversion.""" + + def test_supported_versions(self): + """Test supported version conversions.""" + test_cases = [ + ("quic-v1", 0x00000001), # RFC 9000 + ("quic", 0xFF00001D), # draft-29 + ] + + for version, expected_wire in test_cases: + result = quic_version_to_wire_format(TProtocol(version)) + assert result == expected_wire, f"Failed for version {version}" + + def test_unsupported_version_raises_error(self): + """Test unsupported versions raise error.""" + with pytest.raises(QUICUnsupportedVersionError): + quic_version_to_wire_format(TProtocol("unsupported-version")) + + +class TestGetAlpnProtocols: + """Test ALPN protocol retrieval.""" + + def test_returns_libp2p_protocols(self): + """Test returns expected libp2p ALPN protocols.""" + protocols = get_alpn_protocols() + assert protocols == ["libp2p"] + assert isinstance(protocols, list) + + def test_returns_copy(self): + """Test returns a copy, not the original list.""" + protocols1 = get_alpn_protocols() + protocols2 = get_alpn_protocols() + + # Modify one list + protocols1.append("test") + + # Other list should be unchanged + assert protocols2 == ["libp2p"] + + +class TestNormalizeQuicMultiaddr: + """Test QUIC multiaddr normalization.""" + + def test_already_normalized(self): + """Test already normalized multiaddrs pass through.""" + addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" + maddr = Multiaddr(addr_str) + + result = normalize_quic_multiaddr(maddr) + assert str(result) == addr_str + + def test_normalize_different_versions(self): + """Test normalization works for different QUIC versions.""" + test_cases = [ + "/ip4/127.0.0.1/udp/4001/quic-v1", + "/ip4/127.0.0.1/udp/4001/quic", + "/ip6/::1/udp/5000/quic-v1", + ] + + for addr_str in test_cases: + maddr = Multiaddr(addr_str) + result = normalize_quic_multiaddr(maddr) + + # Should be valid QUIC multiaddr + assert is_quic_multiaddr(result) + + # Should be parseable + host, port = quic_multiaddr_to_endpoint(result) + version = multiaddr_to_quic_version(result) + + # Should match original + orig_host, orig_port = quic_multiaddr_to_endpoint(maddr) + orig_version = multiaddr_to_quic_version(maddr) + + assert host == orig_host + assert port == orig_port + assert version == orig_version + + def test_non_quic_raises_error(self): + """Test non-QUIC multiaddrs raise error.""" + maddr = Multiaddr("/ip4/127.0.0.1/tcp/4001") + with pytest.raises(QUICInvalidMultiaddrError): + normalize_quic_multiaddr(maddr) + + +class TestIntegration: + """Integration tests for utility functions working together.""" + + def test_round_trip_conversion(self): + """Test creating and parsing multiaddrs works correctly.""" + test_cases = [ + ("127.0.0.1", 4001, "quic-v1"), + ("::1", 5000, "quic"), + ("192.168.1.100", 8080, "quic-v1"), + ] + + for host, port, version in test_cases: + # Create multiaddr + maddr = create_quic_multiaddr(host, port, version) + + # Should be detected as QUIC + assert is_quic_multiaddr(maddr) + + # Should extract original values + extracted_host, extracted_port = quic_multiaddr_to_endpoint(maddr) + extracted_version = multiaddr_to_quic_version(maddr) + + assert extracted_host == host + assert extracted_port == port + assert extracted_version == version + + # Should normalize to same value + normalized = normalize_quic_multiaddr(maddr) + assert str(normalized) == str(maddr) + + def test_wire_format_integration(self): + """Test wire format conversion works with version detection.""" + addr_str = "/ip4/127.0.0.1/udp/4001/quic-v1" + maddr = Multiaddr(addr_str) + + # Extract version and convert to wire format + version = multiaddr_to_quic_version(maddr) + wire_format = quic_version_to_wire_format(version) + + # Should be QUIC v1 wire format + assert wire_format == 0x00000001 diff --git a/tests/examples/test_echo_thin_waist.py b/tests/examples/test_echo_thin_waist.py new file mode 100644 index 00000000..2bcb52b1 --- /dev/null +++ b/tests/examples/test_echo_thin_waist.py @@ -0,0 +1,109 @@ +import contextlib +import os +from pathlib import Path +import subprocess +import sys +import time + +from multiaddr import Multiaddr +from multiaddr.protocols import P_IP4, P_IP6, P_P2P, P_TCP + +# pytestmark = pytest.mark.timeout(20) # Temporarily disabled for debugging + +# This test is intentionally lightweight and can be marked as 'integration'. +# It ensures the echo example runs and prints the new Thin Waist lines using +# Trio primitives. + +current_file = Path(__file__) +project_root = current_file.parent.parent.parent +EXAMPLES_DIR: Path = project_root / "examples" / "echo" + + +def test_echo_example_starts_and_prints_thin_waist(monkeypatch, tmp_path): + """Run echo server and validate printed multiaddr and peer id.""" + # Run echo example as server + cmd = [sys.executable, "-u", str(EXAMPLES_DIR / "echo.py"), "-p", "0"] + env = {**os.environ, "PYTHONUNBUFFERED": "1"} + proc: subprocess.Popen[str] = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + env=env, + ) + + if proc.stdout is None: + proc.terminate() + raise RuntimeError("Process stdout is None") + out_stream = proc.stdout + + peer_id: str | None = None + printed_multiaddr: str | None = None + saw_waiting = False + + start = time.time() + timeout_s = 8.0 + try: + while time.time() - start < timeout_s: + line = out_stream.readline() + if not line: + time.sleep(0.05) + continue + s = line.strip() + if s.startswith("I am "): + peer_id = s.partition("I am ")[2] + if s.startswith("echo-demo -d "): + printed_multiaddr = s.partition("echo-demo -d ")[2] + if "Waiting for incoming connections..." in s: + saw_waiting = True + break + finally: + with contextlib.suppress(ProcessLookupError): + proc.terminate() + with contextlib.suppress(ProcessLookupError): + proc.kill() + + assert peer_id, "Did not capture peer ID line" + assert printed_multiaddr, "Did not capture multiaddr line" + assert saw_waiting, "Did not capture waiting-for-connections line" + + # Validate multiaddr structure using py-multiaddr protocol methods + ma = Multiaddr(printed_multiaddr) # should parse without error + + # Check that the multiaddr contains the p2p protocol + try: + peer_id_from_multiaddr = ma.value_for_protocol("p2p") + assert peer_id_from_multiaddr is not None, ( + "Multiaddr missing p2p protocol value" + ) + assert peer_id_from_multiaddr == peer_id, ( + f"Peer ID mismatch: {peer_id_from_multiaddr} != {peer_id}" + ) + except Exception as e: + raise AssertionError(f"Failed to extract p2p protocol value: {e}") + + # Validate the multiaddr structure by checking protocols + protocols = ma.protocols() + + # Should have at least IP, TCP, and P2P protocols + assert any(p.code == P_IP4 or p.code == P_IP6 for p in protocols), ( + "Missing IP protocol" + ) + assert any(p.code == P_TCP for p in protocols), "Missing TCP protocol" + assert any(p.code == P_P2P for p in protocols), "Missing P2P protocol" + + # Extract the p2p part and validate it matches the captured peer ID + p2p_part = Multiaddr(f"/p2p/{peer_id}") + try: + # Decapsulate the p2p part to get the transport address + transport_addr = ma.decapsulate(p2p_part) + # Verify the decapsulated address doesn't contain p2p + transport_protocols = transport_addr.protocols() + assert not any(p.code == P_P2P for p in transport_protocols), ( + "Decapsulation failed - still contains p2p" + ) + # Verify the original multiaddr can be reconstructed + reconstructed = transport_addr.encapsulate(p2p_part) + assert str(reconstructed) == str(ma), "Reconstruction failed" + except Exception as e: + raise AssertionError(f"Multiaddr decapsulation failed: {e}") diff --git a/tests/examples/test_quic_echo_example.py b/tests/examples/test_quic_echo_example.py new file mode 100644 index 00000000..fc843f4b --- /dev/null +++ b/tests/examples/test_quic_echo_example.py @@ -0,0 +1,6 @@ +def test_echo_quic_example(): + """Test that the QUIC echo example can be imported and has required functions.""" + from examples.echo import echo_quic + + assert hasattr(echo_quic, "main") + assert hasattr(echo_quic, "run") diff --git a/tests/interop/nim_libp2p/.gitignore b/tests/interop/nim_libp2p/.gitignore new file mode 100644 index 00000000..7bcc01ea --- /dev/null +++ b/tests/interop/nim_libp2p/.gitignore @@ -0,0 +1,8 @@ +nimble.develop +nimble.paths + +*.nimble +nim-libp2p/ + +nim_echo_server +config.nims diff --git a/tests/interop/nim_libp2p/conftest.py b/tests/interop/nim_libp2p/conftest.py new file mode 100644 index 00000000..5765a09d --- /dev/null +++ b/tests/interop/nim_libp2p/conftest.py @@ -0,0 +1,119 @@ +import fcntl +import logging +from pathlib import Path +import shutil +import subprocess +import time + +import pytest + +logger = logging.getLogger(__name__) + + +def check_nim_available(): + """Check if nim compiler is available.""" + return shutil.which("nim") is not None and shutil.which("nimble") is not None + + +def check_nim_binary_built(): + """Check if nim echo server binary is built.""" + current_dir = Path(__file__).parent + binary_path = current_dir / "nim_echo_server" + return binary_path.exists() and binary_path.stat().st_size > 0 + + +def run_nim_setup_with_lock(): + """Run nim setup with file locking to prevent parallel execution.""" + current_dir = Path(__file__).parent + lock_file = current_dir / ".setup_lock" + setup_script = current_dir / "scripts" / "setup_nim_echo.sh" + + if not setup_script.exists(): + raise RuntimeError(f"Setup script not found: {setup_script}") + + # Try to acquire lock + try: + with open(lock_file, "w") as f: + # Non-blocking lock attempt + fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + + # Double-check binary doesn't exist (another worker might have built it) + if check_nim_binary_built(): + logger.info("Binary already exists, skipping setup") + return + + logger.info("Acquired setup lock, running nim-libp2p setup...") + + # Make setup script executable and run it + setup_script.chmod(0o755) + result = subprocess.run( + [str(setup_script)], + cwd=current_dir, + capture_output=True, + text=True, + timeout=300, # 5 minute timeout + ) + + if result.returncode != 0: + raise RuntimeError( + f"Setup failed (exit {result.returncode}):\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}" + ) + + # Verify binary was built + if not check_nim_binary_built(): + raise RuntimeError("nim_echo_server binary not found after setup") + + logger.info("nim-libp2p setup completed successfully") + + except BlockingIOError: + # Another worker is running setup, wait for it to complete + logger.info("Another worker is running setup, waiting...") + + # Wait for setup to complete (check every 2 seconds, max 5 minutes) + for _ in range(150): # 150 * 2 = 300 seconds = 5 minutes + if check_nim_binary_built(): + logger.info("Setup completed by another worker") + return + time.sleep(2) + + raise TimeoutError("Timed out waiting for setup to complete") + + finally: + # Clean up lock file + try: + lock_file.unlink(missing_ok=True) + except Exception: + pass + + +@pytest.fixture(scope="function") # Changed to function scope +def nim_echo_binary(): + """Get nim echo server binary path.""" + current_dir = Path(__file__).parent + binary_path = current_dir / "nim_echo_server" + + if not binary_path.exists(): + pytest.skip( + "nim_echo_server binary not found. " + "Run setup script: ./scripts/setup_nim_echo.sh" + ) + + return binary_path + + +@pytest.fixture +async def nim_server(nim_echo_binary): + """Start and stop nim echo server for tests.""" + # Import here to avoid circular imports + # pyrefly: ignore + from test_echo_interop import NimEchoServer + + server = NimEchoServer(nim_echo_binary) + + try: + peer_id, listen_addr = await server.start() + yield server, peer_id, listen_addr + finally: + await server.stop() diff --git a/tests/interop/nim_libp2p/nim_echo_server.nim b/tests/interop/nim_libp2p/nim_echo_server.nim new file mode 100644 index 00000000..a4f581d9 --- /dev/null +++ b/tests/interop/nim_libp2p/nim_echo_server.nim @@ -0,0 +1,108 @@ +{.used.} + +import chronos +import stew/byteutils +import libp2p + +## +# Simple Echo Protocol Implementation for py-libp2p Interop Testing +## +const EchoCodec = "/echo/1.0.0" + +type EchoProto = ref object of LPProtocol + +proc new(T: typedesc[EchoProto]): T = + proc handle(conn: Connection, proto: string) {.async: (raises: [CancelledError]).} = + try: + echo "Echo server: Received connection from ", conn.peerId + + # Read and echo messages in a loop + while not conn.atEof: + try: + # Read length-prefixed message using nim-libp2p's readLp + let message = await conn.readLp(1024 * 1024) # Max 1MB + if message.len == 0: + echo "Echo server: Empty message, closing connection" + break + + let messageStr = string.fromBytes(message) + echo "Echo server: Received (", message.len, " bytes): ", messageStr + + # Echo back using writeLp + await conn.writeLp(message) + echo "Echo server: Echoed message back" + + except CatchableError as e: + echo "Echo server: Error processing message: ", e.msg + break + + except CancelledError as e: + echo "Echo server: Connection cancelled" + raise e + except CatchableError as e: + echo "Echo server: Exception in handler: ", e.msg + finally: + echo "Echo server: Connection closed" + await conn.close() + + return T.new(codecs = @[EchoCodec], handler = handle) + +## +# Create QUIC-enabled switch +## +proc createSwitch(ma: MultiAddress, rng: ref HmacDrbgContext): Switch = + var switch = SwitchBuilder + .new() + .withRng(rng) + .withAddress(ma) + .withQuicTransport() + .build() + result = switch + +## +# Main server +## +proc main() {.async.} = + let + rng = newRng() + localAddr = MultiAddress.init("/ip4/0.0.0.0/udp/0/quic-v1").tryGet() + echoProto = EchoProto.new() + + echo "=== Nim Echo Server for py-libp2p Interop ===" + + # Create switch + let switch = createSwitch(localAddr, rng) + switch.mount(echoProto) + + # Start server + await switch.start() + + # Print connection info + echo "Peer ID: ", $switch.peerInfo.peerId + echo "Listening on:" + for addr in switch.peerInfo.addrs: + echo " ", $addr, "/p2p/", $switch.peerInfo.peerId + echo "Protocol: ", EchoCodec + echo "Ready for py-libp2p connections!" + echo "" + + # Keep running + try: + await sleepAsync(100.hours) + except CancelledError: + echo "Shutting down..." + finally: + await switch.stop() + +# Graceful shutdown handler +proc signalHandler() {.noconv.} = + echo "\nShutdown signal received" + quit(0) + +when isMainModule: + setControlCHook(signalHandler) + try: + waitFor(main()) + except CatchableError as e: + echo "Error: ", e.msg + quit(1) diff --git a/tests/interop/nim_libp2p/scripts/setup_nim_echo.sh b/tests/interop/nim_libp2p/scripts/setup_nim_echo.sh new file mode 100755 index 00000000..f80b2d27 --- /dev/null +++ b/tests/interop/nim_libp2p/scripts/setup_nim_echo.sh @@ -0,0 +1,74 @@ +#!/usr/bin/env bash +# tests/interop/nim_libp2p/scripts/setup_nim_echo.sh +# Cache-aware setup that skips installation if packages exist + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="${SCRIPT_DIR}/.." + +# Colors +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +main() { + log_info "Setting up nim echo server for interop testing..." + + # Check if nim is available + if ! command -v nim &> /dev/null || ! command -v nimble &> /dev/null; then + log_error "Nim not found. Please install nim first." + exit 1 + fi + + cd "${PROJECT_DIR}" + + # Create logs directory + mkdir -p logs + + # Check if binary already exists + if [[ -f "nim_echo_server" ]]; then + log_info "nim_echo_server already exists, skipping build" + return 0 + fi + + # Check if libp2p is already installed (cache-aware) + if nimble list -i | grep -q "libp2p"; then + log_info "libp2p already installed, skipping installation" + else + log_info "Installing nim-libp2p globally..." + nimble install -y libp2p + fi + + log_info "Building nim echo server..." + # Compile the echo server + nim c \ + -d:release \ + -d:chronicles_log_level=INFO \ + -d:libp2p_quic_support \ + -d:chronos_event_loop=iocp \ + -d:ssl \ + --opt:speed \ + --mm:orc \ + --verbosity:1 \ + -o:nim_echo_server \ + nim_echo_server.nim + + # Verify binary was created + if [[ -f "nim_echo_server" ]]; then + log_info "āœ… nim_echo_server built successfully" + log_info "Binary size: $(ls -lh nim_echo_server | awk '{print $5}')" + else + log_error "āŒ Failed to build nim_echo_server" + exit 1 + fi + + log_info "šŸŽ‰ Setup complete!" +} + +main "$@" diff --git a/tests/interop/nim_libp2p/test_echo_interop.py b/tests/interop/nim_libp2p/test_echo_interop.py new file mode 100644 index 00000000..8e2b3e33 --- /dev/null +++ b/tests/interop/nim_libp2p/test_echo_interop.py @@ -0,0 +1,195 @@ +import logging +from pathlib import Path +import subprocess +import time + +import pytest +import multiaddr +import trio + +from libp2p import new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.utils.varint import encode_varint_prefixed, read_varint_prefixed_bytes + +# Configuration +PROTOCOL_ID = TProtocol("/echo/1.0.0") +TEST_TIMEOUT = 30 +SERVER_START_TIMEOUT = 10.0 + +# Setup logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class NimEchoServer: + """Simple nim echo server manager.""" + + def __init__(self, binary_path: Path): + self.binary_path = binary_path + self.process: None | subprocess.Popen = None + self.peer_id = None + self.listen_addr = None + + async def start(self): + """Start nim echo server and get connection info.""" + logger.info(f"Starting nim echo server: {self.binary_path}") + + self.process = subprocess.Popen( + [str(self.binary_path)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + ) + + # Parse output for connection info + start_time = time.time() + while time.time() - start_time < SERVER_START_TIMEOUT: + if self.process and self.process.poll() and self.process.stdout: + output = self.process.stdout.read() + raise RuntimeError(f"Server exited early: {output}") + + reader = self.process.stdout if self.process else None + if reader: + line = reader.readline().strip() + if not line: + continue + + logger.info(f"Server: {line}") + + if line.startswith("Peer ID:"): + self.peer_id = line.split(":", 1)[1].strip() + + elif "/quic-v1/p2p/" in line and self.peer_id: + if line.strip().startswith("/"): + self.listen_addr = line.strip() + logger.info(f"Server ready: {self.listen_addr}") + return self.peer_id, self.listen_addr + + await self.stop() + raise TimeoutError(f"Server failed to start within {SERVER_START_TIMEOUT}s") + + async def stop(self): + """Stop the server.""" + if self.process: + logger.info("Stopping nim echo server...") + try: + self.process.terminate() + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + self.process = None + + +async def run_echo_test(server_addr: str, messages: list[str]): + """Test echo protocol against nim server with proper timeout handling.""" + # Create py-libp2p QUIC client with shorter timeouts + + host = new_host( + enable_quic=True, + key_pair=create_new_key_pair(), + ) + + listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/udp/0/quic-v1") + responses = [] + + try: + async with host.run(listen_addrs=[listen_addr]): + logger.info(f"Connecting to nim server: {server_addr}") + + # Connect to nim server + maddr = multiaddr.Multiaddr(server_addr) + info = info_from_p2p_addr(maddr) + await host.connect(info) + + # Create stream + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + logger.info("Stream created") + + # Test each message + for i, message in enumerate(messages, 1): + logger.info(f"Testing message {i}: {message}") + + # Send with varint length prefix + data = message.encode("utf-8") + prefixed_data = encode_varint_prefixed(data) + await stream.write(prefixed_data) + + # Read response + response_data = await read_varint_prefixed_bytes(stream) + response = response_data.decode("utf-8") + + logger.info(f"Got echo: {response}") + responses.append(response) + + # Verify echo + assert message == response, ( + f"Echo failed: sent {message!r}, got {response!r}" + ) + + await stream.close() + logger.info("āœ… All messages echoed correctly") + + finally: + await host.close() + + return responses + + +@pytest.mark.trio +@pytest.mark.timeout(TEST_TIMEOUT) +async def test_basic_echo_interop(nim_server): + """Test basic echo functionality between py-libp2p and nim-libp2p.""" + server, peer_id, listen_addr = nim_server + + test_messages = [ + "Hello from py-libp2p!", + "QUIC transport working", + "Echo test successful!", + "Unicode: ƑoĆ«l, 测试, Ψυχή", + ] + + logger.info(f"Testing against nim server: {peer_id}") + + # Run test with timeout + with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup + responses = await run_echo_test(listen_addr, test_messages) + + # Verify all messages echoed correctly + assert len(responses) == len(test_messages) + for sent, received in zip(test_messages, responses): + assert sent == received + + logger.info("āœ… Basic echo interop test passed!") + + +@pytest.mark.trio +@pytest.mark.timeout(TEST_TIMEOUT) +async def test_large_message_echo(nim_server): + """Test echo with larger messages.""" + server, peer_id, listen_addr = nim_server + + large_messages = [ + "x" * 1024, + "y" * 5000, + ] + + logger.info("Testing large message echo...") + + # Run test with timeout + with trio.move_on_after(TEST_TIMEOUT - 2): # Leave 2s buffer for cleanup + responses = await run_echo_test(listen_addr, large_messages) + + assert len(responses) == len(large_messages) + for sent, received in zip(large_messages, responses): + assert sent == received + + logger.info("āœ… Large message echo test passed!") + + +if __name__ == "__main__": + # Run tests directly + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 75639e36..c006200f 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -669,8 +669,8 @@ async def swarm_conn_pair_factory( async with swarm_pair_factory( security_protocol=security_protocol, muxer_opt=muxer_opt ) as swarms: - conn_0 = swarms[0].connections[swarms[1].get_peer_id()] - conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + conn_0 = swarms[0].connections[swarms[1].get_peer_id()][0] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()][0] yield cast(SwarmConn, conn_0), cast(SwarmConn, conn_1) diff --git a/tests/utils/test_address_validation.py b/tests/utils/test_address_validation.py new file mode 100644 index 00000000..5b108d09 --- /dev/null +++ b/tests/utils/test_address_validation.py @@ -0,0 +1,56 @@ +import os + +import pytest +from multiaddr import Multiaddr + +from libp2p.utils.address_validation import ( + expand_wildcard_address, + get_available_interfaces, + get_optimal_binding_address, +) + + +@pytest.mark.parametrize("proto", ["tcp"]) +def test_get_available_interfaces(proto: str) -> None: + interfaces = get_available_interfaces(0, protocol=proto) + assert len(interfaces) > 0 + for addr in interfaces: + assert isinstance(addr, Multiaddr) + assert f"/{proto}/" in str(addr) + + +def test_get_optimal_binding_address() -> None: + addr = get_optimal_binding_address(0) + assert isinstance(addr, Multiaddr) + # At least IPv4 or IPv6 prefix present + s = str(addr) + assert ("/ip4/" in s) or ("/ip6/" in s) + + +def test_expand_wildcard_address_ipv4() -> None: + wildcard = Multiaddr("/ip4/0.0.0.0/tcp/0") + expanded = expand_wildcard_address(wildcard) + assert len(expanded) > 0 + for e in expanded: + assert isinstance(e, Multiaddr) + assert "/tcp/" in str(e) + + +def test_expand_wildcard_address_port_override() -> None: + wildcard = Multiaddr("/ip4/0.0.0.0/tcp/7000") + overridden = expand_wildcard_address(wildcard, port=9001) + assert len(overridden) > 0 + for e in overridden: + assert str(e).endswith("/tcp/9001") + + +@pytest.mark.skipif( + os.environ.get("NO_IPV6") == "1", + reason="Environment disallows IPv6", +) +def test_expand_wildcard_address_ipv6() -> None: + wildcard = Multiaddr("/ip6/::/tcp/0") + expanded = expand_wildcard_address(wildcard) + assert len(expanded) > 0 + for e in expanded: + assert "/ip6/" in str(e) diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py index 603af5e1..06be05c7 100644 --- a/tests/utils/test_logging.py +++ b/tests/utils/test_logging.py @@ -15,6 +15,7 @@ import pytest import trio from libp2p.utils.logging import ( + _current_handlers, _current_listener, _listener_ready, log_queue, @@ -24,13 +25,19 @@ from libp2p.utils.logging import ( def _reset_logging(): """Reset all logging state.""" - global _current_listener, _listener_ready + global _current_listener, _listener_ready, _current_handlers # Stop existing listener if any if _current_listener is not None: _current_listener.stop() _current_listener = None + # Close all file handlers to ensure proper cleanup on Windows + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + _current_handlers.clear() + # Reset the event _listener_ready = threading.Event() @@ -174,6 +181,15 @@ async def test_custom_log_file(clean_env): if _current_listener is not None: _current_listener.stop() + # Give a moment for the listener to fully stop + await trio.sleep(0.05) + + # Close all file handlers to release the file + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.flush() # Ensure all writes are flushed + handler.close() + # Check if the file exists and contains our message assert log_file.exists() content = log_file.read_text() @@ -185,16 +201,15 @@ async def test_default_log_file(clean_env): """Test logging to the default file path.""" os.environ["LIBP2P_DEBUG"] = "INFO" - with patch("libp2p.utils.logging.datetime") as mock_datetime: - # Mock the timestamp to have a predictable filename - mock_datetime.now.return_value.strftime.return_value = "20240101_120000" + with patch("libp2p.utils.paths.create_temp_file") as mock_create_temp: + # Mock the temp file creation to return a predictable path + mock_temp_file = ( + Path(tempfile.gettempdir()) / "test_py-libp2p_20240101_120000.log" + ) + mock_create_temp.return_value = mock_temp_file # Remove the log file if it exists - if os.name == "nt": # Windows - log_file = Path("C:/Windows/Temp/20240101_120000_py-libp2p.log") - else: # Unix-like - log_file = Path("/tmp/20240101_120000_py-libp2p.log") - log_file.unlink(missing_ok=True) + mock_temp_file.unlink(missing_ok=True) setup_logging() @@ -211,9 +226,18 @@ async def test_default_log_file(clean_env): if _current_listener is not None: _current_listener.stop() - # Check the default log file - if log_file.exists(): # Only check content if we have write permission - content = log_file.read_text() + # Give a moment for the listener to fully stop + await trio.sleep(0.05) + + # Close all file handlers to release the file + for handler in _current_handlers: + if isinstance(handler, logging.FileHandler): + handler.flush() # Ensure all writes are flushed + handler.close() + + # Check the mocked temp file + if mock_temp_file.exists(): + content = mock_temp_file.read_text() assert "Test message" in content diff --git a/tests/utils/test_paths.py b/tests/utils/test_paths.py new file mode 100644 index 00000000..421fc557 --- /dev/null +++ b/tests/utils/test_paths.py @@ -0,0 +1,290 @@ +""" +Tests for cross-platform path utilities. +""" + +import os +from pathlib import Path +import tempfile + +import pytest + +from libp2p.utils.paths import ( + create_temp_file, + ensure_dir_exists, + find_executable, + get_binary_path, + get_config_dir, + get_project_root, + get_python_executable, + get_script_binary_path, + get_script_dir, + get_temp_dir, + get_venv_path, + join_paths, + normalize_path, + resolve_relative_path, +) + + +class TestPathUtilities: + """Test cross-platform path utilities.""" + + def test_get_temp_dir(self): + """Test that temp directory is accessible and exists.""" + temp_dir = get_temp_dir() + assert isinstance(temp_dir, Path) + assert temp_dir.exists() + assert temp_dir.is_dir() + # Should match system temp directory + assert temp_dir == Path(tempfile.gettempdir()) + + def test_get_project_root(self): + """Test that project root is correctly determined.""" + project_root = get_project_root() + assert isinstance(project_root, Path) + assert project_root.exists() + # Should contain pyproject.toml + assert (project_root / "pyproject.toml").exists() + # Should contain libp2p directory + assert (project_root / "libp2p").exists() + + def test_join_paths(self): + """Test cross-platform path joining.""" + # Test with strings + result = join_paths("a", "b", "c") + expected = Path("a") / "b" / "c" + assert result == expected + + # Test with mixed types + result = join_paths("a", Path("b"), "c") + expected = Path("a") / "b" / "c" + assert result == expected + + # Test with absolute path + result = join_paths("/absolute", "path") + expected = Path("/absolute") / "path" + assert result == expected + + def test_ensure_dir_exists(self, tmp_path): + """Test directory creation and existence checking.""" + # Test creating new directory + new_dir = tmp_path / "new_dir" + result = ensure_dir_exists(new_dir) + assert result == new_dir + assert new_dir.exists() + assert new_dir.is_dir() + + # Test creating nested directory + nested_dir = tmp_path / "parent" / "child" / "grandchild" + result = ensure_dir_exists(nested_dir) + assert result == nested_dir + assert nested_dir.exists() + assert nested_dir.is_dir() + + # Test with existing directory + result = ensure_dir_exists(new_dir) + assert result == new_dir + assert new_dir.exists() + + def test_get_config_dir(self): + """Test platform-specific config directory.""" + config_dir = get_config_dir() + assert isinstance(config_dir, Path) + + if os.name == "nt": # Windows + # Should be in AppData/Roaming or user home + assert "AppData" in str(config_dir) or "py-libp2p" in str(config_dir) + else: # Unix-like + # Should be in ~/.config + assert ".config" in str(config_dir) + assert "py-libp2p" in str(config_dir) + + def test_get_script_dir(self): + """Test script directory detection.""" + # Test with current file + script_dir = get_script_dir(__file__) + assert isinstance(script_dir, Path) + assert script_dir.exists() + assert script_dir.is_dir() + # Should contain this test file + assert (script_dir / "test_paths.py").exists() + + def test_create_temp_file(self): + """Test temporary file creation.""" + temp_file = create_temp_file() + assert isinstance(temp_file, Path) + assert temp_file.parent == get_temp_dir() + assert temp_file.name.startswith("py-libp2p_") + assert temp_file.name.endswith(".log") + + # Test with custom prefix and suffix + temp_file = create_temp_file(prefix="test_", suffix=".txt") + assert temp_file.name.startswith("test_") + assert temp_file.name.endswith(".txt") + + def test_resolve_relative_path(self, tmp_path): + """Test relative path resolution.""" + base_path = tmp_path / "base" + base_path.mkdir() + + # Test relative path + relative_path = "subdir/file.txt" + result = resolve_relative_path(base_path, relative_path) + expected = (base_path / "subdir" / "file.txt").resolve() + assert result == expected + + # Test absolute path (platform-agnostic) + if os.name == "nt": # Windows + absolute_path = "C:\\absolute\\path" + else: # Unix-like + absolute_path = "/absolute/path" + result = resolve_relative_path(base_path, absolute_path) + assert result == Path(absolute_path) + + def test_normalize_path(self, tmp_path): + """Test path normalization.""" + # Test with relative path + relative_path = tmp_path / ".." / "normalize_test" + result = normalize_path(relative_path) + assert result.is_absolute() + assert "normalize_test" in str(result) + + # Test with absolute path + absolute_path = tmp_path / "test_file" + result = normalize_path(absolute_path) + assert result.is_absolute() + assert result == absolute_path.resolve() + + def test_get_venv_path(self, monkeypatch): + """Test virtual environment path detection.""" + # Test when no virtual environment is active + # Temporarily clear VIRTUAL_ENV to test the "no venv" case + monkeypatch.delenv("VIRTUAL_ENV", raising=False) + result = get_venv_path() + assert result is None + + # Test when virtual environment is active + test_venv_path = "/path/to/venv" + monkeypatch.setenv("VIRTUAL_ENV", test_venv_path) + result = get_venv_path() + assert result == Path(test_venv_path) + + def test_get_python_executable(self): + """Test Python executable path detection.""" + result = get_python_executable() + assert isinstance(result, Path) + assert result.exists() + assert result.name.startswith("python") + + def test_find_executable(self): + """Test executable finding in PATH.""" + # Test with non-existent executable + result = find_executable("nonexistent_executable") + assert result is None + + # Test with existing executable (python should be available) + result = find_executable("python") + if result: + assert isinstance(result, Path) + assert result.exists() + + def test_get_script_binary_path(self): + """Test script binary path detection.""" + result = get_script_binary_path() + assert isinstance(result, Path) + assert result.exists() + assert result.is_dir() + + def test_get_binary_path(self, monkeypatch): + """Test binary path resolution with virtual environment.""" + # Test when no virtual environment is active + result = get_binary_path("python") + if result: + assert isinstance(result, Path) + assert result.exists() + + # Test when virtual environment is active + test_venv_path = "/path/to/venv" + monkeypatch.setenv("VIRTUAL_ENV", test_venv_path) + # This test is more complex as it depends on the actual venv structure + # We'll just verify the function doesn't crash + result = get_binary_path("python") + # Result can be None if binary not found in venv + if result: + assert isinstance(result, Path) + + +class TestCrossPlatformCompatibility: + """Test cross-platform compatibility.""" + + def test_config_dir_platform_specific_windows(self, monkeypatch): + """Test config directory respects Windows conventions.""" + import platform + + # Only run this test on Windows systems + if platform.system() != "Windows": + pytest.skip("This test only runs on Windows systems") + + monkeypatch.setattr("os.name", "nt") + monkeypatch.setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming") + config_dir = get_config_dir() + assert "AppData" in str(config_dir) + assert "py-libp2p" in str(config_dir) + + def test_path_separators_consistent(self): + """Test that path separators are handled consistently.""" + # Test that join_paths uses platform-appropriate separators + result = join_paths("dir1", "dir2", "file.txt") + expected = Path("dir1") / "dir2" / "file.txt" + assert result == expected + + # Test that the result uses correct separators for the platform + if os.name == "nt": # Windows + assert "\\" in str(result) or "/" in str(result) + else: # Unix-like + assert "/" in str(result) + + def test_temp_file_uniqueness(self): + """Test that temporary files have unique names.""" + files = set() + for _ in range(10): + temp_file = create_temp_file() + assert temp_file not in files + files.add(temp_file) + + +class TestBackwardCompatibility: + """Test backward compatibility with existing code patterns.""" + + def test_path_operations_equivalent(self): + """Test that new path operations are equivalent to old os.path operations.""" + # Test join_paths vs os.path.join + parts = ["a", "b", "c"] + new_result = join_paths(*parts) + old_result = Path(os.path.join(*parts)) + assert new_result == old_result + + # Test get_script_dir vs os.path.dirname(os.path.abspath(__file__)) + new_script_dir = get_script_dir(__file__) + old_script_dir = Path(os.path.dirname(os.path.abspath(__file__))) + assert new_script_dir == old_script_dir + + def test_existing_functionality_preserved(self): + """Ensure no existing functionality is broken.""" + # Test that all functions return Path objects + assert isinstance(get_temp_dir(), Path) + assert isinstance(get_project_root(), Path) + assert isinstance(join_paths("a", "b"), Path) + assert isinstance(ensure_dir_exists(tempfile.gettempdir()), Path) + assert isinstance(get_config_dir(), Path) + assert isinstance(get_script_dir(__file__), Path) + assert isinstance(create_temp_file(), Path) + assert isinstance(resolve_relative_path(".", "test"), Path) + assert isinstance(normalize_path("."), Path) + assert isinstance(get_python_executable(), Path) + assert isinstance(get_script_binary_path(), Path) + + # Test optional return types + venv_path = get_venv_path() + if venv_path is not None: + assert isinstance(venv_path, Path)