Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore

This commit is contained in:
Soham Bhoir
2025-06-22 15:26:51 +05:30
committed by GitHub
28 changed files with 4554 additions and 9 deletions

View File

@ -59,6 +59,7 @@ PB = libp2p/crypto/pb/crypto.proto \
libp2p/security/noise/pb/noise.proto \
libp2p/identity/identify/pb/identify.proto \
libp2p/host/autonat/pb/autonat.proto \
libp2p/relay/circuit_v2/pb/circuit.proto \
libp2p/kad_dht/pb/kademlia.proto
PY = $(PB:.proto=_pb2.py)

View File

@ -0,0 +1,499 @@
Circuit Relay v2 Example
========================
This example demonstrates how to use Circuit Relay v2 in py-libp2p. It includes three components:
1. A relay node that provides relay services
2. A destination node that accepts relayed connections
3. A source node that connects to the destination through the relay
Prerequisites
-------------
First, ensure you have py-libp2p installed:
.. code-block:: console
$ python -m pip install libp2p
Collecting libp2p
...
Successfully installed libp2p-x.x.x
Relay Node
----------
Create a file named ``relay_node.py`` with the following content:
.. code-block:: python
import trio
import logging
import multiaddr
import traceback
from libp2p import new_host
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
from libp2p.relay.circuit_v2.config import RelayConfig
from libp2p.tools.async_service import background_trio_service
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("relay_node")
async def run_relay():
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9000")
host = new_host()
config = RelayConfig(
enable_hop=True, # Act as a relay
enable_stop=True, # Accept relayed connections
enable_client=False, # Don't use other relays
max_circuit_duration=3600, # 1 hour
max_circuit_bytes=1024 * 1024 * 10, # 10MB
)
# Initialize the relay protocol with allow_hop=True to act as a relay
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=True)
print(f"Created relay protocol with hop enabled: {protocol.allow_hop}")
# Start the protocol service
async with host.run(listen_addrs=[listen_addr]):
peer_id = host.get_id()
print("\n" + "="*50)
print(f"Relay node started with ID: {peer_id}")
print(f"Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/{peer_id}")
print("="*50 + "\n")
print(f"Listening on: {host.get_addrs()}")
try:
async with background_trio_service(protocol):
print("Protocol service started")
transport = CircuitV2Transport(host, protocol, config)
print("Relay service started successfully")
print(f"Relay limits: {protocol.limits}")
while True:
await trio.sleep(10)
print("Relay node still running...")
print(f"Active connections: {len(host.get_network().connections)}")
except Exception as e:
print(f"Error in relay service: {e}")
traceback.print_exc()
if __name__ == "__main__":
try:
trio.run(run_relay)
except Exception as e:
print(f"Error running relay: {e}")
traceback.print_exc()
Destination Node
----------------
Create a file named ``destination_node.py`` with the following content:
.. code-block:: python
import trio
import logging
import multiaddr
import traceback
import sys
from libp2p import new_host
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
from libp2p.relay.circuit_v2.config import RelayConfig
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.async_service import background_trio_service
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("destination_node")
async def handle_echo_stream(stream):
"""Handle incoming stream by echoing received data."""
try:
print(f"New echo stream from: {stream.get_protocol()}")
while True:
data = await stream.read(1024)
if not data:
print("Stream closed by remote")
break
message = data.decode('utf-8')
print(f"Received: {message}")
response = f"Echo: {message}".encode('utf-8')
await stream.write(response)
print(f"Sent response: Echo: {message}")
except Exception as e:
print(f"Error handling stream: {e}")
traceback.print_exc()
finally:
await stream.close()
print("Stream closed")
async def run_destination(relay_peer_id=None):
"""
Run a simple destination node that accepts connections.
This is a simplified version that doesn't use the relay functionality.
"""
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/9001")
host = new_host()
# Configure as a relay receiver (stop)
config = RelayConfig(
enable_stop=True, # Accept relayed connections
enable_client=True, # Use relays for outbound connections
max_circuit_duration=3600, # 1 hour
max_circuit_bytes=1024 * 1024 * 10, # 10MB
)
# Initialize the relay protocol
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=False)
async with host.run(listen_addrs=[listen_addr]):
# Print host information
dest_peer_id = host.get_id()
print("\n" + "="*50)
print(f"Destination node started with ID: {dest_peer_id}")
print(f"Use this ID in the source node: {dest_peer_id}")
print("="*50 + "\n")
print(f"Listening on: {host.get_addrs()}")
# Set stream handler for the echo protocol
host.set_stream_handler("/echo/1.0.0", handle_echo_stream)
print("Registered echo protocol handler")
# Start the protocol service in the background
async with background_trio_service(protocol):
print("Protocol service started")
# Create and register the transport
transport = CircuitV2Transport(host, protocol, config)
print("Transport created")
# Create a listener for relayed connections
listener = transport.create_listener(handle_echo_stream)
print("Created relay listener")
# Start listening for relayed connections
async with trio.open_nursery() as nursery:
await listener.listen("/p2p-circuit", nursery)
print("Destination node ready to accept relayed connections")
if not relay_peer_id:
print("No relay peer ID provided. Please enter the relay's peer ID:")
print("Waiting for relay peer ID input...")
while True:
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
try:
relay_peer_id = input("Enter relay peer ID: ").strip()
if relay_peer_id:
break
except EOFError:
await trio.sleep(5)
else:
print("No terminal detected. Waiting for relay peer ID as command line argument.")
await trio.sleep(10)
continue
# Connect to the relay node with the provided relay peer ID
relay_addr_str = f"/ip4/127.0.0.1/tcp/9000/p2p/{relay_peer_id}"
print(f"Connecting to relay at {relay_addr_str}")
try:
# Convert string address to multiaddr, then to peer info
relay_maddr = multiaddr.Multiaddr(relay_addr_str)
relay_peer_info = info_from_p2p_addr(relay_maddr)
await host.connect(relay_peer_info)
print("Connected to relay successfully")
# Add the relay to the transport's discovery
transport.discovery._add_relay(relay_peer_info.peer_id)
print(f"Added relay {relay_peer_info.peer_id} to discovery")
# Keep the node running
while True:
await trio.sleep(10)
print("Destination node still running...")
except Exception as e:
print(f"Failed to connect to relay: {e}")
traceback.print_exc()
if __name__ == "__main__":
print("Starting destination node...")
relay_id = None
if len(sys.argv) > 1:
relay_id = sys.argv[1]
print(f"Using provided relay ID: {relay_id}")
trio.run(run_destination, relay_id)
Source Node
-----------
Create a file named ``source_node.py`` with the following content:
.. code-block:: python
import trio
import logging
import multiaddr
import traceback
import sys
from libp2p import new_host
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.id import ID
from libp2p.relay.circuit_v2.protocol import CircuitV2Protocol
from libp2p.relay.circuit_v2.transport import CircuitV2Transport
from libp2p.relay.circuit_v2.config import RelayConfig
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.tools.async_service import background_trio_service
from libp2p.relay.circuit_v2.discovery import RelayInfo
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("source_node")
async def run_source(relay_peer_id=None, destination_peer_id=None):
# Create a libp2p host
listen_addr = multiaddr.Multiaddr("/ip4/0.0.0.0/tcp/9002")
host = new_host()
# Configure as a relay client
config = RelayConfig(
enable_client=True, # Use relays for outbound connections
max_circuit_duration=3600, # 1 hour
max_circuit_bytes=1024 * 1024 * 10, # 10MB
)
# Initialize the relay protocol
protocol = CircuitV2Protocol(host, limits=config.limits, allow_hop=False)
# Start the protocol service
async with host.run(listen_addrs=[listen_addr]):
# Print host information
print(f"Source node started with ID: {host.get_id()}")
print(f"Listening on: {host.get_addrs()}")
# Start the protocol service in the background
async with background_trio_service(protocol):
print("Protocol service started")
# Create and register the transport
transport = CircuitV2Transport(host, protocol, config)
# Get relay peer ID if not provided
if not relay_peer_id:
print("No relay peer ID provided. Please enter the relay's peer ID:")
while True:
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
try:
relay_peer_id = input("Enter relay peer ID: ").strip()
if relay_peer_id:
break
except EOFError:
await trio.sleep(5)
else:
print("No terminal detected. Waiting for relay peer ID as command line argument.")
await trio.sleep(10)
continue
# Connect to the relay node with the provided relay peer ID
relay_addr_str = f"/ip4/127.0.0.1/tcp/9000/p2p/{relay_peer_id}"
print(f"Connecting to relay at {relay_addr_str}")
try:
# Convert string address to multiaddr, then to peer info
relay_maddr = multiaddr.Multiaddr(relay_addr_str)
relay_peer_info = info_from_p2p_addr(relay_maddr)
await host.connect(relay_peer_info)
print("Connected to relay successfully")
# Manually add the relay to the discovery service
relay_id = relay_peer_info.peer_id
now = trio.current_time()
# Create relay info and add it to discovery
relay_info = RelayInfo(
peer_id=relay_id,
discovered_at=now,
last_seen=now
)
transport.discovery._discovered_relays[relay_id] = relay_info
print(f"Added relay {relay_id} to discovery")
# Start relay discovery in the background
async with background_trio_service(transport.discovery):
print("Relay discovery started")
# Wait for relay discovery
await trio.sleep(5)
print("Relay discovery completed")
# Get destination peer ID if not provided
if not destination_peer_id:
print("No destination peer ID provided. Please enter the destination's peer ID:")
while True:
if sys.stdin.isatty(): # Only try to read from stdin if it's a terminal
try:
destination_peer_id = input("Enter destination peer ID: ").strip()
if destination_peer_id:
break
except EOFError:
await trio.sleep(5)
else:
print("No terminal detected. Waiting for destination peer ID as command line argument.")
await trio.sleep(10)
continue
print(f"Attempting to connect to {destination_peer_id} via relay")
# Check if we have any discovered relays
discovered_relays = list(transport.discovery._discovered_relays.keys())
print(f"Discovered relays: {discovered_relays}")
try:
# Create a circuit relay multiaddr for the destination
dest_id = ID.from_base58(destination_peer_id)
# Create a circuit multiaddr that includes the relay
# Format: /ip4/127.0.0.1/tcp/9000/p2p/RELAY_ID/p2p-circuit/p2p/DEST_ID
circuit_addr = multiaddr.Multiaddr(f"{relay_addr_str}/p2p-circuit/p2p/{destination_peer_id}")
print(f"Created circuit address: {circuit_addr}")
# Dial using the circuit address
connection = await transport.dial(circuit_addr)
print("Connection established through relay!")
# Open a stream using the echo protocol
stream = await connection.new_stream("/echo/1.0.0")
# Send messages periodically
for i in range(5):
message = f"Hello from source, message {i+1}"
print(f"Sending: {message}")
await stream.write(message.encode('utf-8'))
response = await stream.read(1024)
print(f"Received: {response.decode('utf-8')}")
await trio.sleep(1)
# Close the stream
await stream.close()
print("Stream closed")
except Exception as e:
print(f"Error connecting through relay: {e}")
print("Detailed error:")
traceback.print_exc()
# Keep the node running for a while
await trio.sleep(30)
print("Source node shutting down")
except Exception as e:
print(f"Error: {e}")
traceback.print_exc()
if __name__ == "__main__":
relay_id = None
dest_id = None
# Parse command line arguments if provided
if len(sys.argv) > 1:
relay_id = sys.argv[1]
print(f"Using provided relay ID: {relay_id}")
if len(sys.argv) > 2:
dest_id = sys.argv[2]
print(f"Using provided destination ID: {dest_id}")
trio.run(run_source, relay_id, dest_id)
Running the Example
-------------------
1. First, start the relay node:
.. code-block:: console
$ python relay_node.py
Created relay protocol with hop enabled: True
==================================================
Relay node started with ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Relay node multiaddr: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
==================================================
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx>]
Protocol service started
Relay service started successfully
Relay limits: RelayLimits(duration=3600, data=10485760, max_circuit_conns=8, max_reservations=4)
Note the relay node\'s peer ID (in this example: `QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx`). You\'ll need this for the other nodes.
2. Next, start the destination node:
.. code-block:: console
$ python destination_node.py
Starting destination node...
==================================================
Destination node started with ID: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
Use this ID in the source node: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
==================================================
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9001/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s>]
Registered echo protocol handler
Protocol service started
Transport created
Created relay listener
Destination node ready to accept relayed connections
No relay peer ID provided. Please enter the relay\'s peer ID:
Waiting for relay peer ID input...
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connecting to relay at /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connected to relay successfully
Added relay QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx to discovery
Destination node still running...
Note the destination node's peer ID (in this example: `QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s`). You'll need this for the source node.
3. Finally, start the source node:
.. code-block:: console
$ python source_node.py
Source node started with ID: QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3
Listening on: [<Multiaddr /ip4/0.0.0.0/tcp/9002/p2p/QmPyM56cgmFoHTgvMgGfDWRdVRQznmxCDDDg2dJ8ygVXj3>]
Protocol service started
No relay peer ID provided. Please enter the relay\'s peer ID:
Enter relay peer ID: QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connecting to relay at /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
Connected to relay successfully
Added relay QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx to discovery
Relay discovery started
Relay discovery completed
No destination peer ID provided. Please enter the destination\'s peer ID:
Enter destination peer ID: QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
Attempting to connect to QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s via relay
Discovered relays: [<libp2p.peer.id.ID (QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx)>]
Created circuit address: /ip4/127.0.0.1/tcp/9000/p2p/QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx/p2p-circuit/p2p/QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
At this point, the source node will establish a connection through the relay to the destination node and start sending messages.
4. Alternatively, you can provide the peer IDs as command-line arguments:
.. code-block:: console
# For the destination node (provide relay ID)
$ python destination_node.py QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx
# For the source node (provide both relay and destination IDs)
$ python source_node.py QmaUigQJ9nJERa6GaZuyfaiX91QjYwoQJ46JS3k7ys7SLx QmPBr38KeQG2ibyL4fxq6yJWpfoVNCqJMHBdNyn1Qe4h5s
This example demonstrates how to use Circuit Relay v2 to establish connections between peers that cannot connect directly. The peer IDs are dynamically generated for each node, and the relay facilitates communication between the source and destination nodes.

View File

@ -11,4 +11,5 @@ Examples
examples.echo
examples.ping
examples.pubsub
examples.circuit_relay
examples.kademlia

View File

@ -12,10 +12,6 @@ The Python implementation of the libp2p networking stack
getting_started
release_notes
.. toctree::
:maxdepth: 1
:caption: Community
.. toctree::
:maxdepth: 1
:caption: py-libp2p

View File

@ -0,0 +1,22 @@
libp2p.relay.circuit_v2.pb package
==================================
Submodules
----------
libp2p.relay.circuit_v2.pb.circuit_pb2 module
---------------------------------------------
.. automodule:: libp2p.relay.circuit_v2.pb.circuit_pb2
:members:
:show-inheritance:
:undoc-members:
Module contents
---------------
.. automodule:: libp2p.relay.circuit_v2.pb
:members:
:show-inheritance:
:undoc-members:
:no-index:

View File

@ -0,0 +1,70 @@
libp2p.relay.circuit_v2 package
===============================
Subpackages
-----------
.. toctree::
:maxdepth: 4
libp2p.relay.circuit_v2.pb
Submodules
----------
libp2p.relay.circuit_v2.protocol module
---------------------------------------
.. automodule:: libp2p.relay.circuit_v2.protocol
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.transport module
----------------------------------------
.. automodule:: libp2p.relay.circuit_v2.transport
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.discovery module
----------------------------------------
.. automodule:: libp2p.relay.circuit_v2.discovery
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.resources module
----------------------------------------
.. automodule:: libp2p.relay.circuit_v2.resources
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.config module
-------------------------------------
.. automodule:: libp2p.relay.circuit_v2.config
:members:
:show-inheritance:
:undoc-members:
libp2p.relay.circuit_v2.protocol_buffer module
----------------------------------------------
.. automodule:: libp2p.relay.circuit_v2.protocol_buffer
:members:
:show-inheritance:
:undoc-members:
Module contents
---------------
.. automodule:: libp2p.relay.circuit_v2
:members:
:show-inheritance:
:undoc-members:
:no-index:

19
docs/libp2p.relay.rst Normal file
View File

@ -0,0 +1,19 @@
libp2p.relay package
====================
Subpackages
-----------
.. toctree::
:maxdepth: 4
libp2p.relay.circuit_v2
Module contents
---------------
.. automodule:: libp2p.relay
:members:
:show-inheritance:
:undoc-members:
:no-index:

View File

@ -16,6 +16,7 @@ Subpackages
libp2p.peer
libp2p.protocol_muxer
libp2p.pubsub
libp2p.relay
libp2p.security
libp2p.stream_muxer
libp2p.tools

28
libp2p/relay/__init__.py Normal file
View File

@ -0,0 +1,28 @@
"""
Relay module for libp2p.
This package includes implementations of circuit relay protocols
for enabling connectivity between peers behind NATs or firewalls.
"""
# Import the circuit_v2 module to make it accessible
# through the relay package
from libp2p.relay.circuit_v2 import (
PROTOCOL_ID,
CircuitV2Protocol,
CircuitV2Transport,
RelayDiscovery,
RelayLimits,
RelayResourceManager,
Reservation,
)
__all__ = [
"CircuitV2Protocol",
"CircuitV2Transport",
"PROTOCOL_ID",
"RelayDiscovery",
"RelayLimits",
"RelayResourceManager",
"Reservation",
]

View File

@ -0,0 +1,32 @@
"""
Circuit Relay v2 implementation for libp2p.
This package implements the Circuit Relay v2 protocol as specified in:
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
"""
from .discovery import (
RelayDiscovery,
)
from .protocol import (
PROTOCOL_ID,
CircuitV2Protocol,
)
from .resources import (
RelayLimits,
RelayResourceManager,
Reservation,
)
from .transport import (
CircuitV2Transport,
)
__all__ = [
"CircuitV2Protocol",
"PROTOCOL_ID",
"RelayLimits",
"Reservation",
"RelayResourceManager",
"CircuitV2Transport",
"RelayDiscovery",
]

View File

@ -0,0 +1,92 @@
"""
Configuration management for Circuit Relay v2.
This module handles configuration for relay roles, resource limits,
and discovery settings.
"""
from dataclasses import (
dataclass,
field,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from .resources import (
RelayLimits,
)
@dataclass
class RelayConfig:
"""Configuration for Circuit Relay v2."""
# Role configuration
enable_hop: bool = False # Whether to act as a relay (hop)
enable_stop: bool = True # Whether to accept relayed connections (stop)
enable_client: bool = True # Whether to use relays for dialing
# Resource limits
limits: RelayLimits | None = None
# Discovery configuration
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
min_relays: int = 3
max_relays: int = 20
discovery_interval: int = 300 # seconds
# Connection configuration
reservation_ttl: int = 3600 # seconds
max_circuit_duration: int = 3600 # seconds
max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB
def __post_init__(self) -> None:
"""Initialize default values."""
if self.limits is None:
self.limits = RelayLimits(
duration=self.max_circuit_duration,
data=self.max_circuit_bytes,
max_circuit_conns=8,
max_reservations=4,
)
@dataclass
class HopConfig:
"""Configuration specific to relay (hop) nodes."""
# Resource limits per IP
max_reservations_per_ip: int = 8
max_circuits_per_ip: int = 16
# Rate limiting
reservation_rate_per_ip: int = 4 # per minute
circuit_rate_per_ip: int = 8 # per minute
# Resource quotas
max_circuits_total: int = 64
max_reservations_total: int = 32
# Bandwidth limits
max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s
max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s
@dataclass
class ClientConfig:
"""Configuration specific to relay clients."""
# Relay selection
min_relay_score: float = 0.5
max_relay_latency: float = 1.0 # seconds
# Auto-relay settings
enable_auto_relay: bool = True
auto_relay_timeout: int = 30 # seconds
max_auto_relay_attempts: int = 3
# Reservation management
reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL
max_concurrent_reservations: int = 2

View File

@ -0,0 +1,537 @@
"""
Discovery module for Circuit Relay v2.
This module handles discovering and tracking relay nodes in the network.
"""
from dataclasses import (
dataclass,
)
import logging
import time
from typing import (
Any,
Protocol as TypingProtocol,
cast,
runtime_checkable,
)
import trio
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.peer.id import (
ID,
)
from libp2p.tools.async_service import (
Service,
)
from .pb.circuit_pb2 import (
HopMessage,
)
from .protocol import (
PROTOCOL_ID,
)
from .protocol_buffer import (
StatusCode,
)
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
# Constants
MAX_RELAYS_TO_TRACK = 10
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
STREAM_TIMEOUT = 10 # seconds
# Extended interfaces for type checking
@runtime_checkable
class IHostWithMultiselect(TypingProtocol):
"""Extended host interface with multiselect attribute."""
@property
def multiselect(self) -> Any:
"""Get the multiselect component."""
...
@dataclass
class RelayInfo:
"""Information about a discovered relay."""
peer_id: ID
discovered_at: float
last_seen: float
has_reservation: bool = False
reservation_expires_at: float | None = None
reservation_data_limit: int | None = None
class RelayDiscovery(Service):
"""
Discovery service for Circuit Relay v2 nodes.
This service discovers and keeps track of available relay nodes, and optionally
makes reservations with them.
"""
def __init__(
self,
host: IHost,
auto_reserve: bool = False,
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
max_relays: int = MAX_RELAYS_TO_TRACK,
) -> None:
"""
Initialize the discovery service.
Parameters
----------
host : IHost
The libp2p host this discovery service is running on
auto_reserve : bool
Whether to automatically make reservations with discovered relays
discovery_interval : int
How often to run discovery, in seconds
max_relays : int
Maximum number of relays to track
"""
super().__init__()
self.host = host
self.auto_reserve = auto_reserve
self.discovery_interval = discovery_interval
self.max_relays = max_relays
self._discovered_relays: dict[ID, RelayInfo] = {}
self._protocol_cache: dict[
ID, set[str]
] = {} # Cache protocol info to reduce queries
self.event_started = trio.Event()
self.is_running = False
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
"""Run the discovery service."""
try:
self.is_running = True
self.event_started.set()
task_status.started()
# Main discovery loop
async with trio.open_nursery() as nursery:
# Run initial discovery
nursery.start_soon(self.discover_relays)
# Set up periodic discovery
while True:
await trio.sleep(self.discovery_interval)
if not self.manager.is_running:
break
nursery.start_soon(self.discover_relays)
# Cleanup expired relays and reservations
await self._cleanup_expired()
finally:
self.is_running = False
async def discover_relays(self) -> None:
r"""
Discover relay nodes in the network.
This method queries the network for peers that support the
Circuit Relay v2 protocol.
"""
logger.debug("Starting relay discovery")
try:
# Get connected peers
connected_peers = self.host.get_connected_peers()
logger.debug(
"Checking %d connected peers for relay support", len(connected_peers)
)
# Check each peer if they support the relay protocol
for peer_id in connected_peers:
if peer_id == self.host.get_id():
continue # Skip ourselves
if peer_id in self._discovered_relays:
# Update last seen time for existing relay
self._discovered_relays[peer_id].last_seen = time.time()
continue
# Check if peer supports the relay protocol
with trio.move_on_after(5): # Don't wait too long for protocol info
if await self._supports_relay_protocol(peer_id):
await self._add_relay(peer_id)
# Limit number of relays we track
if len(self._discovered_relays) > self.max_relays:
# Sort by last seen time and keep only the most recent ones
sorted_relays = sorted(
self._discovered_relays.items(),
key=lambda x: x[1].last_seen,
reverse=True,
)
to_remove = sorted_relays[self.max_relays :]
for peer_id, _ in to_remove:
del self._discovered_relays[peer_id]
logger.debug(
"Discovery completed, tracking %d relays", len(self._discovered_relays)
)
except Exception as e:
logger.error("Error during relay discovery: %s", str(e))
async def _supports_relay_protocol(self, peer_id: ID) -> bool:
"""
Check if a peer supports the relay protocol.
Parameters
----------
peer_id : ID
The ID of the peer to check
Returns
-------
bool
True if the peer supports the relay protocol, False otherwise
"""
# Check cache first
if peer_id in self._protocol_cache:
return PROTOCOL_ID in self._protocol_cache[peer_id]
# Method 1: Try peerstore
result = await self._check_via_peerstore(peer_id)
if result is not None:
return result
# Method 2: Try direct stream connection
result = await self._check_via_direct_connection(peer_id)
if result is not None:
return result
# Method 3: Try protocols from mux
result = await self._check_via_mux(peer_id)
if result is not None:
return result
# Default: Cannot determine, assume false
return False
async def _check_via_peerstore(self, peer_id: ID) -> bool | None:
"""Check protocol support via peerstore."""
try:
peerstore = self.host.get_peerstore()
proto_getter = peerstore.get_protocols
if not callable(proto_getter):
return None
try:
# Try to get protocols
proto_result = proto_getter(peer_id)
# Get protocols list
protocols_list = []
if hasattr(proto_result, "__await__"):
protocols_list = await cast(Any, proto_result)
else:
protocols_list = proto_result
# Check result
if protocols_list is not None:
protocols = set(protocols_list)
self._protocol_cache[peer_id] = protocols
return PROTOCOL_ID in protocols
return False
except Exception as e:
logger.debug("Error getting protocols: %s", str(e))
return None
except Exception as e:
logger.debug("Error accessing peerstore: %s", str(e))
return None
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
"""Check protocol support via direct connection."""
try:
with trio.fail_after(STREAM_TIMEOUT):
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
if stream:
await stream.close()
self._protocol_cache[peer_id] = {PROTOCOL_ID}
return True
return False
except Exception as e:
logger.debug(
"Failed to open relay protocol stream to %s: %s", peer_id, str(e)
)
return None
async def _check_via_mux(self, peer_id: ID) -> bool | None:
"""Check protocol support via mux protocols."""
try:
if not (hasattr(self.host, "get_mux") and self.host.get_mux() is not None):
return None
mux = self.host.get_mux()
if not hasattr(mux, "protocols"):
return None
peer_protocols = set()
# Get protocols from mux with proper type safety
available_protocols = []
if hasattr(mux, "get_protocols"):
# Get protocols with proper typing
mux_protocols = mux.get_protocols()
if isinstance(mux_protocols, (list, tuple)):
available_protocols = list(mux_protocols)
for protocol in available_protocols:
try:
with trio.fail_after(2): # Quick check
# Ensure we have a proper protocol object
# Use string representation since we can't use isinstance
is_tprotocol = str(type(protocol)) == str(type(TProtocol))
protocol_obj = (
protocol if is_tprotocol else TProtocol(str(protocol))
)
stream = await self.host.new_stream(peer_id, [protocol_obj])
if stream:
peer_protocols.add(str(protocol_obj))
await stream.close()
except Exception:
pass # Ignore errors when closing the stream
self._protocol_cache[peer_id] = peer_protocols
protocol_str = str(PROTOCOL_ID)
for protocol in peer_protocols:
if protocol == protocol_str:
return True
return False
except Exception as e:
logger.debug("Error checking protocols via mux: %s", str(e))
return None
async def _add_relay(self, peer_id: ID) -> None:
"""
Add a peer as a relay and optionally make a reservation.
Parameters
----------
peer_id : ID
The ID of the peer to add as a relay
"""
now = time.time()
relay_info = RelayInfo(
peer_id=peer_id,
discovered_at=now,
last_seen=now,
)
self._discovered_relays[peer_id] = relay_info
logger.debug("Added relay %s to discovered relays", peer_id)
# If auto-reserve is enabled, make a reservation with this relay
if self.auto_reserve:
await self.make_reservation(peer_id)
async def make_reservation(self, peer_id: ID) -> bool:
"""
Make a reservation with a relay.
Parameters
----------
peer_id : ID
The ID of the relay to make a reservation with
Returns
-------
bool
True if reservation succeeded, False otherwise
"""
if peer_id not in self._discovered_relays:
logger.error("Cannot make reservation with unknown relay %s", peer_id)
return False
stream = None
try:
logger.debug("Making reservation with relay %s", peer_id)
# Open a stream to the relay with timeout
try:
with trio.fail_after(STREAM_TIMEOUT):
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
if not stream:
logger.error("Failed to open stream to relay %s", peer_id)
return False
except trio.TooSlowError:
logger.error("Timeout opening stream to relay %s", peer_id)
return False
try:
# Create and send reservation request
request = HopMessage(
type=HopMessage.RESERVE,
peer=self.host.get_id().to_bytes(),
)
with trio.fail_after(STREAM_TIMEOUT):
await stream.write(request.SerializeToString())
# Wait for response
response_bytes = await stream.read()
if not response_bytes:
logger.error("No response received from relay %s", peer_id)
return False
# Parse response
response = HopMessage()
response.ParseFromString(response_bytes)
# Check if reservation was successful
if response.type == HopMessage.RESERVE and response.HasField(
"status"
):
# Access status code directly from protobuf object
status_code = getattr(response.status, "code", StatusCode.OK)
if status_code == StatusCode.OK:
# Update relay info with reservation details
relay_info = self._discovered_relays[peer_id]
relay_info.has_reservation = True
if response.HasField("reservation") and response.HasField(
"limit"
):
relay_info.reservation_expires_at = (
response.reservation.expire
)
relay_info.reservation_data_limit = response.limit.data
logger.debug(
"Successfully made reservation with relay %s", peer_id
)
return True
# Reservation failed
error_message = "Unknown error"
if response.HasField("status"):
# Access message directly from protobuf object
error_message = getattr(response.status, "message", "")
logger.warning(
"Reservation request rejected by relay %s: %s",
peer_id,
error_message,
)
return False
except trio.TooSlowError:
logger.error(
"Timeout during reservation process with relay %s", peer_id
)
return False
except Exception as e:
logger.error("Error making reservation with relay %s: %s", peer_id, str(e))
return False
finally:
# Always close the stream
if stream:
try:
await stream.close()
except Exception:
pass # Ignore errors when closing the stream
return False
async def _cleanup_expired(self) -> None:
"""Clean up expired relays and reservations."""
now = time.time()
to_remove = []
for peer_id, relay_info in self._discovered_relays.items():
# Check if relay hasn't been seen in a while (3x discovery interval)
if now - relay_info.last_seen > self.discovery_interval * 3:
to_remove.append(peer_id)
continue
# Check if reservation has expired
if (
relay_info.has_reservation
and relay_info.reservation_expires_at
and now > relay_info.reservation_expires_at
):
relay_info.has_reservation = False
relay_info.reservation_expires_at = None
relay_info.reservation_data_limit = None
# If auto-reserve is enabled, try to renew
if self.auto_reserve:
await self.make_reservation(peer_id)
# Remove expired relays
for peer_id in to_remove:
del self._discovered_relays[peer_id]
if peer_id in self._protocol_cache:
del self._protocol_cache[peer_id]
def get_relays(self) -> list[ID]:
"""
Get a list of discovered relay peer IDs.
Returns
-------
list[ID]
List of discovered relay peer IDs
"""
return list(self._discovered_relays.keys())
def get_relay_info(self, peer_id: ID) -> RelayInfo | None:
"""
Get information about a specific relay.
Parameters
----------
peer_id : ID
The ID of the relay to get information about
Returns
-------
Optional[RelayInfo]
Information about the relay, or None if not found
"""
return self._discovered_relays.get(peer_id)
def get_relay(self) -> ID | None:
"""
Get a single relay peer ID for connection purposes.
Prioritizes relays with active reservations.
Returns
-------
Optional[ID]
ID of a discovered relay, or None if no relays found
"""
if not self._discovered_relays:
return None
# First try to find a relay with an active reservation
for peer_id, relay_info in self._discovered_relays.items():
if relay_info and relay_info.has_reservation:
return peer_id
return next(iter(self._discovered_relays.keys()), None)

View File

@ -0,0 +1,16 @@
"""
Protocol buffer package for circuit_v2.
Contains generated protobuf code for circuit_v2 relay protocol.
"""
# Import the classes to be accessible directly from the package
from .circuit_pb2 import (
HopMessage,
Limit,
Reservation,
Status,
StopMessage,
)
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage"]

View File

@ -0,0 +1,55 @@
syntax = "proto3";
package circuit.pb.v2;
// Circuit v2 message types
message HopMessage {
enum Type {
RESERVE = 0;
CONNECT = 1;
STATUS = 2;
}
Type type = 1;
bytes peer = 2;
Reservation reservation = 3;
Limit limit = 4;
Status status = 5;
}
message StopMessage {
enum Type {
CONNECT = 0;
STATUS = 1;
}
Type type = 1;
bytes peer = 2;
Status status = 3;
}
message Reservation {
bytes voucher = 1;
bytes signature = 2;
int64 expire = 3;
}
message Limit {
int64 duration = 1;
int64 data = 2;
}
message Status {
enum Code {
OK = 0;
RESERVATION_REFUSED = 100;
RESOURCE_LIMIT_EXCEEDED = 101;
PERMISSION_DENIED = 102;
CONNECTION_FAILED = 200;
DIAL_REFUSED = 201;
STOP_FAILED = 300;
MALFORMED_MESSAGE = 400;
}
Code code = 1;
string message = 2;
}

View File

@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: libp2p/relay/circuit_v2/pb/circuit.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\xf3\x01\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\"\x92\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.circuit_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_HOPMESSAGE._serialized_start=60
_HOPMESSAGE._serialized_end=303
_HOPMESSAGE_TYPE._serialized_start=259
_HOPMESSAGE_TYPE._serialized_end=303
_STOPMESSAGE._serialized_start=306
_STOPMESSAGE._serialized_end=452
_STOPMESSAGE_TYPE._serialized_start=421
_STOPMESSAGE_TYPE._serialized_end=452
_RESERVATION._serialized_start=454
_RESERVATION._serialized_end=519
_LIMIT._serialized_start=521
_LIMIT._serialized_end=560
_STATUS._serialized_start=563
_STATUS._serialized_end=809
_STATUS_CODE._serialized_start=633
_STATUS_CODE._serialized_end=809
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,184 @@
"""
@generated by mypy-protobuf. Do not edit manually!
isort:skip_file
"""
import builtins
import google.protobuf.descriptor
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import sys
import typing
if sys.version_info >= (3, 10):
import typing as typing_extensions
else:
import typing_extensions
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
@typing.final
class HopMessage(google.protobuf.message.Message):
"""Circuit v2 message types"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _Type:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HopMessage._Type.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
RESERVE: HopMessage._Type.ValueType # 0
CONNECT: HopMessage._Type.ValueType # 1
STATUS: HopMessage._Type.ValueType # 2
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
RESERVE: HopMessage.Type.ValueType # 0
CONNECT: HopMessage.Type.ValueType # 1
STATUS: HopMessage.Type.ValueType # 2
TYPE_FIELD_NUMBER: builtins.int
PEER_FIELD_NUMBER: builtins.int
RESERVATION_FIELD_NUMBER: builtins.int
LIMIT_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
type: global___HopMessage.Type.ValueType
peer: builtins.bytes
@property
def reservation(self) -> global___Reservation: ...
@property
def limit(self) -> global___Limit: ...
@property
def status(self) -> global___Status: ...
def __init__(
self,
*,
type: global___HopMessage.Type.ValueType = ...,
peer: builtins.bytes = ...,
reservation: global___Reservation | None = ...,
limit: global___Limit | None = ...,
status: global___Status | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["limit", b"limit", "reservation", b"reservation", "status", b"status"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["limit", b"limit", "peer", b"peer", "reservation", b"reservation", "status", b"status", "type", b"type"]) -> None: ...
global___HopMessage = HopMessage
@typing.final
class StopMessage(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _Type:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[StopMessage._Type.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
CONNECT: StopMessage._Type.ValueType # 0
STATUS: StopMessage._Type.ValueType # 1
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
CONNECT: StopMessage.Type.ValueType # 0
STATUS: StopMessage.Type.ValueType # 1
TYPE_FIELD_NUMBER: builtins.int
PEER_FIELD_NUMBER: builtins.int
STATUS_FIELD_NUMBER: builtins.int
type: global___StopMessage.Type.ValueType
peer: builtins.bytes
@property
def status(self) -> global___Status: ...
def __init__(
self,
*,
type: global___StopMessage.Type.ValueType = ...,
peer: builtins.bytes = ...,
status: global___Status | None = ...,
) -> None: ...
def HasField(self, field_name: typing.Literal["status", b"status"]) -> builtins.bool: ...
def ClearField(self, field_name: typing.Literal["peer", b"peer", "status", b"status", "type", b"type"]) -> None: ...
global___StopMessage = StopMessage
@typing.final
class Reservation(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
VOUCHER_FIELD_NUMBER: builtins.int
SIGNATURE_FIELD_NUMBER: builtins.int
EXPIRE_FIELD_NUMBER: builtins.int
voucher: builtins.bytes
signature: builtins.bytes
expire: builtins.int
def __init__(
self,
*,
voucher: builtins.bytes = ...,
signature: builtins.bytes = ...,
expire: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["expire", b"expire", "signature", b"signature", "voucher", b"voucher"]) -> None: ...
global___Reservation = Reservation
@typing.final
class Limit(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
DURATION_FIELD_NUMBER: builtins.int
DATA_FIELD_NUMBER: builtins.int
duration: builtins.int
data: builtins.int
def __init__(
self,
*,
duration: builtins.int = ...,
data: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["data", b"data", "duration", b"duration"]) -> None: ...
global___Limit = Limit
@typing.final
class Status(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class _Code:
ValueType = typing.NewType("ValueType", builtins.int)
V: typing_extensions.TypeAlias = ValueType
class _CodeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Status._Code.ValueType], builtins.type):
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
OK: Status._Code.ValueType # 0
RESERVATION_REFUSED: Status._Code.ValueType # 100
RESOURCE_LIMIT_EXCEEDED: Status._Code.ValueType # 101
PERMISSION_DENIED: Status._Code.ValueType # 102
CONNECTION_FAILED: Status._Code.ValueType # 200
DIAL_REFUSED: Status._Code.ValueType # 201
STOP_FAILED: Status._Code.ValueType # 300
MALFORMED_MESSAGE: Status._Code.ValueType # 400
class Code(_Code, metaclass=_CodeEnumTypeWrapper): ...
OK: Status.Code.ValueType # 0
RESERVATION_REFUSED: Status.Code.ValueType # 100
RESOURCE_LIMIT_EXCEEDED: Status.Code.ValueType # 101
PERMISSION_DENIED: Status.Code.ValueType # 102
CONNECTION_FAILED: Status.Code.ValueType # 200
DIAL_REFUSED: Status.Code.ValueType # 201
STOP_FAILED: Status.Code.ValueType # 300
MALFORMED_MESSAGE: Status.Code.ValueType # 400
CODE_FIELD_NUMBER: builtins.int
MESSAGE_FIELD_NUMBER: builtins.int
code: global___Status.Code.ValueType
message: builtins.str
def __init__(
self,
*,
code: global___Status.Code.ValueType = ...,
message: builtins.str = ...,
) -> None: ...
def ClearField(self, field_name: typing.Literal["code", b"code", "message", b"message"]) -> None: ...
global___Status = Status

View File

@ -0,0 +1,800 @@
"""
Circuit Relay v2 protocol implementation.
This module implements the Circuit Relay v2 protocol as specified in:
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
"""
import logging
import time
from typing import (
Any,
Protocol as TypingProtocol,
cast,
runtime_checkable,
)
import trio
from libp2p.abc import (
IHost,
INetStream,
)
from libp2p.custom_types import (
TProtocol,
)
from libp2p.io.abc import (
ReadWriteCloser,
)
from libp2p.peer.id import (
ID,
)
from libp2p.stream_muxer.mplex.exceptions import (
MplexStreamEOF,
MplexStreamReset,
)
from libp2p.tools.async_service import (
Service,
)
from .pb.circuit_pb2 import (
HopMessage,
Limit,
Reservation,
Status as PbStatus,
StopMessage,
)
from .protocol_buffer import (
StatusCode,
create_status,
)
from .resources import (
RelayLimits,
RelayResourceManager,
)
logger = logging.getLogger("libp2p.relay.circuit_v2")
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
# Default limits for relay resources
DEFAULT_RELAY_LIMITS = RelayLimits(
duration=60 * 60, # 1 hour
data=1024 * 1024 * 1024, # 1GB
max_circuit_conns=8,
max_reservations=4,
)
# Stream operation timeouts
STREAM_READ_TIMEOUT = 15 # seconds
STREAM_WRITE_TIMEOUT = 15 # seconds
STREAM_CLOSE_TIMEOUT = 10 # seconds
MAX_READ_RETRIES = 5 # Maximum number of read retries
# Extended interfaces for type checking
@runtime_checkable
class IHostWithStreamHandlers(TypingProtocol):
"""Extended host interface with stream handler methods."""
def remove_stream_handler(self, protocol_id: TProtocol) -> None:
"""Remove a stream handler for a protocol."""
...
@runtime_checkable
class INetStreamWithExtras(TypingProtocol):
"""Extended net stream interface with additional methods."""
def get_remote_peer_id(self) -> ID:
"""Get the remote peer ID."""
...
def is_open(self) -> bool:
"""Check if the stream is open."""
...
def is_closed(self) -> bool:
"""Check if the stream is closed."""
...
class CircuitV2Protocol(Service):
"""
CircuitV2Protocol implements the Circuit Relay v2 protocol.
This protocol allows peers to establish connections through relay nodes
when direct connections are not possible (e.g., due to NAT).
"""
def __init__(
self,
host: IHost,
limits: RelayLimits | None = None,
allow_hop: bool = False,
) -> None:
"""
Initialize a Circuit Relay v2 protocol instance.
Parameters
----------
host : IHost
The libp2p host instance
limits : RelayLimits | None
Resource limits for the relay
allow_hop : bool
Whether to allow this node to act as a relay
"""
self.host = host
self.limits = limits or DEFAULT_RELAY_LIMITS
self.allow_hop = allow_hop
self.resource_manager = RelayResourceManager(self.limits)
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
self.event_started = trio.Event()
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
"""Run the protocol service."""
try:
# Register protocol handlers
if self.allow_hop:
logger.debug("Registering stream handlers for relay protocol")
self.host.set_stream_handler(PROTOCOL_ID, self._handle_hop_stream)
self.host.set_stream_handler(STOP_PROTOCOL_ID, self._handle_stop_stream)
logger.debug("Stream handlers registered successfully")
# Signal that we're ready
self.event_started.set()
task_status.started()
logger.debug("Protocol service started")
# Wait for service to be stopped
await self.manager.wait_finished()
finally:
# Clean up any active relay connections
for src_stream, dst_stream in self._active_relays.values():
await self._close_stream(src_stream)
await self._close_stream(dst_stream)
self._active_relays.clear()
# Unregister protocol handlers
if self.allow_hop:
try:
# Cast host to extended interface with remove_stream_handler
host_with_handlers = cast(IHostWithStreamHandlers, self.host)
host_with_handlers.remove_stream_handler(PROTOCOL_ID)
host_with_handlers.remove_stream_handler(STOP_PROTOCOL_ID)
except Exception as e:
logger.error("Error unregistering stream handlers: %s", str(e))
async def _close_stream(self, stream: INetStream | None) -> None:
"""Helper function to safely close a stream."""
if stream is None:
return
try:
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
await stream.close()
except Exception:
try:
await stream.reset()
except Exception:
pass
async def _read_stream_with_retry(
self,
stream: INetStream,
max_retries: int = MAX_READ_RETRIES,
) -> bytes | None:
"""
Helper function to read from a stream with retries.
Parameters
----------
stream : INetStream
The stream to read from
max_retries : int
Maximum number of read retries
Returns
-------
Optional[bytes]
The data read from the stream, or None if the stream is closed/reset
Raises
------
trio.TooSlowError
If read timeout occurs after all retries
Exception
For other unexpected errors
"""
retries = 0
last_error: Any = None
backoff_time = 0.2 # Base backoff time in seconds
while retries < max_retries:
try:
with trio.fail_after(STREAM_READ_TIMEOUT):
# Try reading with timeout
logger.debug(
"Attempting to read from stream (attempt %d/%d)",
retries + 1,
max_retries,
)
data = await stream.read()
if not data: # EOF
logger.debug("Stream EOF detected")
return None
logger.debug("Successfully read %d bytes from stream", len(data))
return data
except trio.WouldBlock:
# Just retry immediately if we would block
retries += 1
logger.debug(
"Stream would block (attempt %d/%d), retrying...",
retries,
max_retries,
)
await trio.sleep(backoff_time * retries) # Increased backoff time
continue
except (MplexStreamEOF, MplexStreamReset):
# Stream closed/reset - no point retrying
logger.debug("Stream closed/reset during read")
return None
except trio.TooSlowError as e:
last_error = e
retries += 1
logger.debug(
"Read timeout (attempt %d/%d), retrying...", retries, max_retries
)
if retries < max_retries:
# Wait longer before retry with increasing backoff
await trio.sleep(backoff_time * retries) # Increased backoff
continue
except Exception as e:
logger.error("Unexpected error reading from stream: %s", str(e))
last_error = e
retries += 1
if retries < max_retries:
await trio.sleep(backoff_time * retries) # Increased backoff
continue
raise
if last_error:
if isinstance(last_error, trio.TooSlowError):
logger.error("Read timed out after %d retries", max_retries)
raise last_error
return None
async def _handle_hop_stream(self, stream: INetStream) -> None:
"""
Handle incoming HOP streams.
This handler processes relay requests from other peers.
"""
try:
# Try to get peer ID first
try:
# Cast to extended interface with get_remote_peer_id
stream_with_peer_id = cast(INetStreamWithExtras, stream)
remote_peer_id = stream_with_peer_id.get_remote_peer_id()
remote_id = str(remote_peer_id)
except Exception:
# Fall back to address if peer ID not available
remote_addr = stream.get_remote_address()
remote_id = f"peer at {remote_addr}" if remote_addr else "unknown peer"
logger.debug("Handling hop stream from %s", remote_id)
# First, handle the read timeout gracefully
try:
with trio.fail_after(
STREAM_READ_TIMEOUT * 2
): # Double the timeout for reading
msg_bytes = await stream.read()
if not msg_bytes:
logger.error(
"Empty read from stream from %s",
remote_id,
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
pb_status.message = "Empty message received"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure message is sent
return
except trio.TooSlowError:
logger.error(
"Timeout reading from hop stream from %s",
remote_id,
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.CONNECTION_FAILED))
pb_status.message = "Stream read timeout"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure the message is sent
return
except Exception as e:
logger.error(
"Error reading from hop stream from %s: %s",
remote_id,
str(e),
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
pb_status.message = f"Read error: {str(e)}"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure the message is sent
return
# Parse the message
try:
hop_msg = HopMessage()
hop_msg.ParseFromString(msg_bytes)
except Exception as e:
logger.error(
"Error parsing hop message from %s: %s",
remote_id,
str(e),
)
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
pb_status.message = f"Parse error: {str(e)}"
response = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
await stream.write(response.SerializeToString())
await trio.sleep(0.5) # Longer wait to ensure the message is sent
return
# Process based on message type
if hop_msg.type == HopMessage.RESERVE:
logger.debug("Handling RESERVE message from %s", remote_id)
await self._handle_reserve(stream, hop_msg)
# For RESERVE requests, let the client close the stream
return
elif hop_msg.type == HopMessage.CONNECT:
logger.debug("Handling CONNECT message from %s", remote_id)
await self._handle_connect(stream, hop_msg)
else:
logger.error("Invalid message type %d from %s", hop_msg.type, remote_id)
# Send a nice error response using _send_status method
await self._send_status(
stream,
StatusCode.MALFORMED_MESSAGE,
f"Invalid message type: {hop_msg.type}",
)
except Exception as e:
logger.error(
"Unexpected error handling hop stream from %s: %s", remote_id, str(e)
)
try:
# Send a nice error response using _send_status method
await self._send_status(
stream,
StatusCode.MALFORMED_MESSAGE,
f"Internal error: {str(e)}",
)
except Exception as e2:
logger.error(
"Failed to send error response to %s: %s", remote_id, str(e2)
)
async def _handle_stop_stream(self, stream: INetStream) -> None:
"""
Handle incoming STOP streams.
This handler processes incoming relay connections from the destination side.
"""
try:
# Read the incoming message with timeout
with trio.fail_after(STREAM_READ_TIMEOUT):
msg_bytes = await stream.read()
stop_msg = StopMessage()
stop_msg.ParseFromString(msg_bytes)
if stop_msg.type != StopMessage.CONNECT:
# Use direct attribute access to create status object for error response
await self._send_stop_status(
stream,
StatusCode.MALFORMED_MESSAGE,
"Invalid message type",
)
await self._close_stream(stream)
return
# Get the source stream from active relays
peer_id = ID(stop_msg.peer)
if peer_id not in self._active_relays:
# Use direct attribute access to create status object for error response
await self._send_stop_status(
stream,
StatusCode.CONNECTION_FAILED,
"No pending relay connection",
)
await self._close_stream(stream)
return
src_stream, _ = self._active_relays[peer_id]
self._active_relays[peer_id] = (src_stream, stream)
# Send success status to both sides
await self._send_status(
src_stream,
StatusCode.OK,
"Connection established",
)
await self._send_stop_status(
stream,
StatusCode.OK,
"Connection established",
)
# Start relaying data
async with trio.open_nursery() as nursery:
nursery.start_soon(self._relay_data, src_stream, stream, peer_id)
nursery.start_soon(self._relay_data, stream, src_stream, peer_id)
except trio.TooSlowError:
logger.error("Timeout reading from stop stream")
await self._send_stop_status(
stream,
StatusCode.CONNECTION_FAILED,
"Stream read timeout",
)
await self._close_stream(stream)
except Exception as e:
logger.error("Error handling stop stream: %s", str(e))
try:
await self._send_stop_status(
stream,
StatusCode.MALFORMED_MESSAGE,
str(e),
)
await self._close_stream(stream)
except Exception:
pass
async def _handle_reserve(self, stream: INetStream, msg: Any) -> None:
"""Handle a reservation request."""
peer_id = None
try:
peer_id = ID(msg.peer)
logger.debug("Handling reservation request from peer %s", peer_id)
# Check if we can accept more reservations
if not self.resource_manager.can_accept_reservation(peer_id):
logger.debug("Reservation limit exceeded for peer %s", peer_id)
# Send status message with STATUS type
status = create_status(
code=StatusCode.RESOURCE_LIMIT_EXCEEDED,
message="Reservation limit exceeded",
)
status_msg = HopMessage(
type=HopMessage.STATUS,
status=status.to_pb(),
)
await stream.write(status_msg.SerializeToString())
return
# Accept reservation
logger.debug("Accepting reservation from peer %s", peer_id)
ttl = self.resource_manager.reserve(peer_id)
# Send reservation success response
with trio.fail_after(STREAM_WRITE_TIMEOUT):
status = create_status(
code=StatusCode.OK, message="Reservation accepted"
)
response = HopMessage(
type=HopMessage.STATUS,
status=status.to_pb(),
reservation=Reservation(
expire=int(time.time() + ttl),
voucher=b"", # We don't use vouchers yet
signature=b"", # We don't use signatures yet
),
limit=Limit(
duration=self.limits.duration,
data=self.limits.data,
),
)
# Log the response message details for debugging
logger.debug(
"Sending reservation response: type=%s, status=%s, ttl=%d",
response.type,
getattr(response.status, "code", "unknown"),
ttl,
)
# Send the response with increased timeout
await stream.write(response.SerializeToString())
# Add a small wait to ensure the message is fully sent
await trio.sleep(0.1)
logger.debug("Reservation response sent successfully")
except Exception as e:
logger.error("Error handling reservation request: %s", str(e))
if cast(INetStreamWithExtras, stream).is_open():
try:
# Send error response
await self._send_status(
stream,
StatusCode.INTERNAL_ERROR,
f"Failed to process reservation: {str(e)}",
)
except Exception as send_err:
logger.error("Failed to send error response: %s", str(send_err))
finally:
# Always close the stream when done with reservation
if cast(INetStreamWithExtras, stream).is_open():
try:
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
await stream.close()
except Exception as close_err:
logger.error("Error closing stream: %s", str(close_err))
async def _handle_connect(self, stream: INetStream, msg: Any) -> None:
"""Handle a connect request."""
peer_id = ID(msg.peer)
dst_stream: INetStream | None = None
# Verify reservation if provided
if msg.HasField("reservation"):
if not self.resource_manager.verify_reservation(peer_id, msg.reservation):
await self._send_status(
stream,
StatusCode.PERMISSION_DENIED,
"Invalid reservation",
)
await stream.reset()
return
# Check resource limits
if not self.resource_manager.can_accept_connection(peer_id):
await self._send_status(
stream,
StatusCode.RESOURCE_LIMIT_EXCEEDED,
"Connection limit exceeded",
)
await stream.reset()
return
try:
# Store the source stream with properly typed None
self._active_relays[peer_id] = (stream, None)
# Try to connect to the destination with timeout
with trio.fail_after(STREAM_READ_TIMEOUT):
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
if not dst_stream:
raise ConnectionError("Could not connect to destination")
# Send STOP CONNECT message
stop_msg = StopMessage(
type=StopMessage.CONNECT,
# Cast to extended interface with get_remote_peer_id
peer=cast(INetStreamWithExtras, stream)
.get_remote_peer_id()
.to_bytes(),
)
await dst_stream.write(stop_msg.SerializeToString())
# Wait for response from destination
resp_bytes = await dst_stream.read()
resp = StopMessage()
resp.ParseFromString(resp_bytes)
# Handle status attributes from the response
if resp.HasField("status"):
# Get code and message attributes with defaults
status_code = getattr(resp.status, "code", StatusCode.OK)
# Get message with default
status_msg = getattr(resp.status, "message", "Unknown error")
else:
status_code = StatusCode.OK
status_msg = "No status provided"
if status_code != StatusCode.OK:
raise ConnectionError(
f"Destination rejected connection: {status_msg}"
)
# Update active relays with destination stream
self._active_relays[peer_id] = (stream, dst_stream)
# Update reservation connection count
reservation = self.resource_manager._reservations.get(peer_id)
if reservation:
reservation.active_connections += 1
# Send success status
await self._send_status(
stream,
StatusCode.OK,
"Connection established",
)
# Start relaying data
async with trio.open_nursery() as nursery:
nursery.start_soon(self._relay_data, stream, dst_stream, peer_id)
nursery.start_soon(self._relay_data, dst_stream, stream, peer_id)
except (trio.TooSlowError, ConnectionError) as e:
logger.error("Error establishing relay connection: %s", str(e))
await self._send_status(
stream,
StatusCode.CONNECTION_FAILED,
str(e),
)
if peer_id in self._active_relays:
del self._active_relays[peer_id]
# Clean up reservation connection count on failure
reservation = self.resource_manager._reservations.get(peer_id)
if reservation:
reservation.active_connections -= 1
await stream.reset()
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
await dst_stream.reset()
except Exception as e:
logger.error("Unexpected error in connect handler: %s", str(e))
await self._send_status(
stream,
StatusCode.CONNECTION_FAILED,
"Internal error",
)
if peer_id in self._active_relays:
del self._active_relays[peer_id]
await stream.reset()
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
await dst_stream.reset()
async def _relay_data(
self,
src_stream: INetStream,
dst_stream: INetStream,
peer_id: ID,
) -> None:
"""
Relay data between two streams.
Parameters
----------
src_stream : INetStream
Source stream to read from
dst_stream : INetStream
Destination stream to write to
peer_id : ID
ID of the peer being relayed
"""
try:
while True:
# Read data with retries
data = await self._read_stream_with_retry(src_stream)
if not data:
logger.info("Source stream closed/reset")
break
# Write data with timeout
try:
with trio.fail_after(STREAM_WRITE_TIMEOUT):
await dst_stream.write(data)
except trio.TooSlowError:
logger.error("Timeout writing to destination stream")
break
except Exception as e:
logger.error("Error writing to destination stream: %s", str(e))
break
# Update resource usage
reservation = self.resource_manager._reservations.get(peer_id)
if reservation:
reservation.data_used += len(data)
if reservation.data_used >= reservation.limits.data:
logger.warning("Data limit exceeded for peer %s", peer_id)
break
except Exception as e:
logger.error("Error relaying data: %s", str(e))
finally:
# Clean up streams and remove from active relays
await src_stream.reset()
await dst_stream.reset()
if peer_id in self._active_relays:
del self._active_relays[peer_id]
async def _send_status(
self,
stream: ReadWriteCloser,
code: int,
message: str,
) -> None:
"""Send a status message."""
try:
logger.debug("Sending status message with code %s: %s", code, message)
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(
Any, int(code)
) # Cast to Any to avoid type errors
pb_status.message = message
status_msg = HopMessage(
type=HopMessage.STATUS,
status=pb_status,
)
msg_bytes = status_msg.SerializeToString()
logger.debug("Status message serialized (%d bytes)", len(msg_bytes))
await stream.write(msg_bytes)
logger.debug("Status message sent, waiting for processing")
# Wait longer to ensure the message is sent
await trio.sleep(1.5)
logger.debug("Status message sending completed")
except trio.TooSlowError:
logger.error(
"Timeout sending status message: code=%s, message=%s", code, message
)
except Exception as e:
logger.error("Error sending status message: %s", str(e))
async def _send_stop_status(
self,
stream: ReadWriteCloser,
code: int,
message: str,
) -> None:
"""Send a status message on a STOP stream."""
try:
logger.debug("Sending stop status message with code %s: %s", code, message)
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
# Create a proto Status directly
pb_status = PbStatus()
pb_status.code = cast(
Any, int(code)
) # Cast to Any to avoid type errors
pb_status.message = message
status_msg = StopMessage(
type=StopMessage.STATUS,
status=pb_status,
)
await stream.write(status_msg.SerializeToString())
await trio.sleep(0.5) # Ensure message is sent
except Exception as e:
logger.error("Error sending stop status message: %s", str(e))

View File

@ -0,0 +1,55 @@
"""
Protocol buffer wrapper classes for Circuit Relay v2.
This module provides wrapper classes for protocol buffer generated objects
to make them easier to work with in type-checked code.
"""
from enum import (
IntEnum,
)
from typing import (
Any,
)
from .pb.circuit_pb2 import Status as PbStatus
# Define Status codes as an Enum for better type safety and organization
class StatusCode(IntEnum):
OK = 0
RESERVATION_REFUSED = 100
RESOURCE_LIMIT_EXCEEDED = 101
PERMISSION_DENIED = 102
CONNECTION_FAILED = 200
DIAL_REFUSED = 201
STOP_FAILED = 300
MALFORMED_MESSAGE = 400
INTERNAL_ERROR = 500
def create_status(code: int = StatusCode.OK, message: str = "") -> Any:
"""
Create a protocol buffer Status object.
Parameters
----------
code : int
The status code
message : str
The status message
Returns
-------
Any
The protocol buffer Status object
"""
# Create status object
pb_obj = PbStatus()
# Convert the integer status code to the protobuf enum value type
pb_obj.code = PbStatus.Code.ValueType(code)
pb_obj.message = message
return pb_obj

View File

@ -0,0 +1,254 @@
"""
Resource management for Circuit Relay v2.
This module handles managing resources for relay operations,
including reservations and connection limits.
"""
from dataclasses import (
dataclass,
)
import hashlib
import os
import time
from libp2p.peer.id import (
ID,
)
# Import the protobuf definitions
from .pb.circuit_pb2 import Reservation as PbReservation
@dataclass
class RelayLimits:
"""Configuration for relay resource limits."""
duration: int # Maximum duration of a relay connection in seconds
data: int # Maximum data transfer allowed in bytes
max_circuit_conns: int # Maximum number of concurrent circuit connections
max_reservations: int # Maximum number of active reservations
class Reservation:
"""Represents a relay reservation."""
def __init__(self, peer_id: ID, limits: RelayLimits):
"""
Initialize a new reservation.
Parameters
----------
peer_id : ID
The peer ID this reservation is for
limits : RelayLimits
The resource limits for this reservation
"""
self.peer_id = peer_id
self.limits = limits
self.created_at = time.time()
self.expires_at = self.created_at + limits.duration
self.data_used = 0
self.active_connections = 0
self.voucher = self._generate_voucher()
def _generate_voucher(self) -> bytes:
"""
Generate a unique cryptographically secure voucher for this reservation.
Returns
-------
bytes
A secure voucher token
"""
# Create a random token using a combination of:
# - Random bytes for unpredictability
# - Peer ID to bind it to the specific peer
# - Timestamp for uniqueness
# - Hash everything for a fixed size output
random_bytes = os.urandom(16) # 128 bits of randomness
timestamp = str(int(self.created_at * 1000000)).encode()
peer_bytes = self.peer_id.to_bytes()
# Combine all elements and hash them
h = hashlib.sha256()
h.update(random_bytes)
h.update(timestamp)
h.update(peer_bytes)
return h.digest()
def is_expired(self) -> bool:
"""Check if the reservation has expired."""
return time.time() > self.expires_at
def can_accept_connection(self) -> bool:
"""Check if a new connection can be accepted."""
return (
not self.is_expired()
and self.active_connections < self.limits.max_circuit_conns
and self.data_used < self.limits.data
)
def to_proto(self) -> PbReservation:
"""Convert the reservation to its protobuf representation."""
# TODO: For production use, implement proper signature generation
# The signature should be created by signing the voucher with the
# peer's private key. The current implementation with an empty signature
# is intended for development and testing only.
return PbReservation(
expire=int(self.expires_at),
voucher=self.voucher,
signature=b"",
)
class RelayResourceManager:
"""
Manages resources and reservations for relay operations.
This class handles:
- Tracking active reservations
- Enforcing resource limits
- Managing connection quotas
"""
def __init__(self, limits: RelayLimits):
"""
Initialize the resource manager.
Parameters
----------
limits : RelayLimits
The resource limits to enforce
"""
self.limits = limits
self._reservations: dict[ID, Reservation] = {}
def can_accept_reservation(self, peer_id: ID) -> bool:
"""
Check if a new reservation can be accepted for the given peer.
Parameters
----------
peer_id : ID
The peer ID requesting the reservation
Returns
-------
bool
True if the reservation can be accepted
"""
# Clean expired reservations
self._clean_expired()
# Check if peer already has a valid reservation
existing = self._reservations.get(peer_id)
if existing and not existing.is_expired():
return True
# Check if we're at the reservation limit
return len(self._reservations) < self.limits.max_reservations
def create_reservation(self, peer_id: ID) -> Reservation:
"""
Create a new reservation for the given peer.
Parameters
----------
peer_id : ID
The peer ID to create the reservation for
Returns
-------
Reservation
The newly created reservation
"""
reservation = Reservation(peer_id, self.limits)
self._reservations[peer_id] = reservation
return reservation
def verify_reservation(self, peer_id: ID, proto_res: PbReservation) -> bool:
"""
Verify a reservation from a protobuf message.
Parameters
----------
peer_id : ID
The peer ID the reservation is for
proto_res : PbReservation
The protobuf reservation message
Returns
-------
bool
True if the reservation is valid
"""
# TODO: Implement voucher and signature verification
reservation = self._reservations.get(peer_id)
return (
reservation is not None
and not reservation.is_expired()
and reservation.expires_at == proto_res.expire
)
def can_accept_connection(self, peer_id: ID) -> bool:
"""
Check if a new connection can be accepted for the given peer.
Parameters
----------
peer_id : ID
The peer ID requesting the connection
Returns
-------
bool
True if the connection can be accepted
"""
reservation = self._reservations.get(peer_id)
return reservation is not None and reservation.can_accept_connection()
def _clean_expired(self) -> None:
"""Remove expired reservations."""
now = time.time()
expired = [
peer_id
for peer_id, res in self._reservations.items()
if now > res.expires_at
]
for peer_id in expired:
del self._reservations[peer_id]
def reserve(self, peer_id: ID) -> int:
"""
Create or update a reservation for a peer and return the TTL.
Parameters
----------
peer_id : ID
The peer ID to reserve for
Returns
-------
int
The TTL of the reservation in seconds
"""
# Check for existing reservation
existing = self._reservations.get(peer_id)
if existing and not existing.is_expired():
# Return remaining time for existing reservation
remaining = max(0, int(existing.expires_at - time.time()))
return remaining
# Create new reservation
self.create_reservation(peer_id)
return self.limits.duration

View File

@ -0,0 +1,427 @@
"""
Transport implementation for Circuit Relay v2.
This module implements the transport layer for Circuit Relay v2,
allowing peers to establish connections through relay nodes.
"""
from collections.abc import Awaitable, Callable
import logging
import multiaddr
import trio
from libp2p.abc import (
IHost,
IListener,
INetStream,
ITransport,
ReadWriteCloser,
)
from libp2p.network.connection.raw_connection import (
RawConnection,
)
from libp2p.peer.id import (
ID,
)
from libp2p.peer.peerinfo import (
PeerInfo,
)
from libp2p.tools.async_service import (
Service,
)
from .config import (
ClientConfig,
RelayConfig,
)
from .discovery import (
RelayDiscovery,
)
from .pb.circuit_pb2 import (
HopMessage,
StopMessage,
)
from .protocol import (
PROTOCOL_ID,
CircuitV2Protocol,
)
from .protocol_buffer import (
StatusCode,
)
logger = logging.getLogger("libp2p.relay.circuit_v2.transport")
class CircuitV2Transport(ITransport):
"""
CircuitV2Transport implements the transport interface for Circuit Relay v2.
This transport allows peers to establish connections through relay nodes
when direct connections are not possible.
"""
def __init__(
self,
host: IHost,
protocol: CircuitV2Protocol,
config: RelayConfig,
) -> None:
"""
Initialize the Circuit v2 transport.
Parameters
----------
host : IHost
The libp2p host this transport is running on
protocol : CircuitV2Protocol
The Circuit v2 protocol instance
config : RelayConfig
Relay configuration
"""
self.host = host
self.protocol = protocol
self.config = config
self.client_config = ClientConfig()
self.discovery = RelayDiscovery(
host=host,
auto_reserve=config.enable_client,
discovery_interval=config.discovery_interval,
max_relays=config.max_relays,
)
async def dial(
self,
maddr: multiaddr.Multiaddr,
) -> RawConnection:
"""
Dial a peer using the multiaddr.
Parameters
----------
maddr : multiaddr.Multiaddr
The multiaddr to dial
Returns
-------
RawConnection
The established connection
Raises
------
ConnectionError
If the connection cannot be established
"""
# Extract peer ID from multiaddr - P_P2P code is 0x01A5 (421)
peer_id_str = maddr.value_for_protocol("p2p")
if not peer_id_str:
raise ConnectionError("Multiaddr does not contain peer ID")
peer_id = ID.from_base58(peer_id_str)
peer_info = PeerInfo(peer_id, [maddr])
# Use the internal dial_peer_info method
return await self.dial_peer_info(peer_info)
async def dial_peer_info(
self,
peer_info: PeerInfo,
*,
relay_peer_id: ID | None = None,
) -> RawConnection:
"""
Dial a peer through a relay.
Parameters
----------
peer_info : PeerInfo
The peer to dial
relay_peer_id : Optional[ID], optional
Optional specific relay peer to use
Returns
-------
RawConnection
The established connection
Raises
------
ConnectionError
If the connection cannot be established
"""
# If no specific relay is provided, try to find one
if relay_peer_id is None:
relay_peer_id = await self._select_relay(peer_info)
if not relay_peer_id:
raise ConnectionError("No suitable relay found")
# Get a stream to the relay
relay_stream = await self.host.new_stream(relay_peer_id, [PROTOCOL_ID])
if not relay_stream:
raise ConnectionError(f"Could not open stream to relay {relay_peer_id}")
try:
# First try to make a reservation if enabled
if self.config.enable_client:
success = await self._make_reservation(relay_stream, relay_peer_id)
if not success:
logger.warning(
"Failed to make reservation with relay %s", relay_peer_id
)
# Send HOP CONNECT message
hop_msg = HopMessage(
type=HopMessage.CONNECT,
peer=peer_info.peer_id.to_bytes(),
)
await relay_stream.write(hop_msg.SerializeToString())
# Read response
resp_bytes = await relay_stream.read()
resp = HopMessage()
resp.ParseFromString(resp_bytes)
# Access status attributes directly
status_code = getattr(resp.status, "code", StatusCode.OK)
status_msg = getattr(resp.status, "message", "Unknown error")
if status_code != StatusCode.OK:
raise ConnectionError(f"Relay connection failed: {status_msg}")
# Create raw connection from stream
return RawConnection(stream=relay_stream, initiator=True)
except Exception as e:
await relay_stream.close()
raise ConnectionError(f"Failed to establish relay connection: {str(e)}")
async def _select_relay(self, peer_info: PeerInfo) -> ID | None:
"""
Select an appropriate relay for the given peer.
Parameters
----------
peer_info : PeerInfo
The peer to connect to
Returns
-------
Optional[ID]
Selected relay peer ID, or None if no suitable relay found
"""
# Try to find a relay
attempts = 0
while attempts < self.client_config.max_auto_relay_attempts:
# Get a relay from the list of discovered relays
relays = self.discovery.get_relays()
if relays:
# TODO: Implement more sophisticated relay selection
# For now, just return the first available relay
return relays[0]
# Wait and try discovery
await trio.sleep(1)
attempts += 1
return None
async def _make_reservation(
self,
stream: INetStream,
relay_peer_id: ID,
) -> bool:
"""
Make a reservation with a relay.
Parameters
----------
stream : INetStream
Stream to the relay
relay_peer_id : ID
The relay's peer ID
Returns
-------
bool
True if reservation was successful
"""
try:
# Send reservation request
reserve_msg = HopMessage(
type=HopMessage.RESERVE,
peer=self.host.get_id().to_bytes(),
)
await stream.write(reserve_msg.SerializeToString())
# Read response
resp_bytes = await stream.read()
resp = HopMessage()
resp.ParseFromString(resp_bytes)
# Access status attributes directly
status_code = getattr(resp.status, "code", StatusCode.OK)
status_msg = getattr(resp.status, "message", "Unknown error")
if status_code != StatusCode.OK:
logger.warning(
"Reservation failed with relay %s: %s",
relay_peer_id,
status_msg,
)
return False
# Store reservation info
# TODO: Implement reservation storage and refresh mechanism
return True
except Exception as e:
logger.error("Error making reservation: %s", str(e))
return False
def create_listener(
self,
handler_function: Callable[[ReadWriteCloser], Awaitable[None]],
) -> IListener:
"""
Create a listener for incoming relay connections.
Parameters
----------
handler_function : Callable[[ReadWriteCloser], Awaitable[None]]
The handler function for new connections
Returns
-------
IListener
The created listener
"""
return CircuitV2Listener(self.host, self.protocol, self.config)
class CircuitV2Listener(Service, IListener):
"""Listener for incoming relay connections."""
def __init__(
self,
host: IHost,
protocol: CircuitV2Protocol,
config: RelayConfig,
) -> None:
"""
Initialize the Circuit v2 listener.
Parameters
----------
host : IHost
The libp2p host this listener is running on
protocol : CircuitV2Protocol
The Circuit v2 protocol instance
config : RelayConfig
Relay configuration
"""
super().__init__()
self.host = host
self.protocol = protocol
self.config = config
self.multiaddrs: list[
multiaddr.Multiaddr
] = [] # Store multiaddrs as Multiaddr objects
async def handle_incoming_connection(
self,
stream: INetStream,
remote_peer_id: ID,
) -> RawConnection:
"""
Handle an incoming relay connection.
Parameters
----------
stream : INetStream
The incoming stream
remote_peer_id : ID
The remote peer's ID
Returns
-------
RawConnection
The established connection
Raises
------
ConnectionError
If the connection cannot be established
"""
if not self.config.enable_stop:
raise ConnectionError("Stop role is not enabled")
try:
# Read STOP message
msg_bytes = await stream.read()
stop_msg = StopMessage()
stop_msg.ParseFromString(msg_bytes)
if stop_msg.type != StopMessage.CONNECT:
raise ConnectionError("Invalid STOP message type")
# Create raw connection
return RawConnection(stream=stream, initiator=False)
except Exception as e:
await stream.close()
raise ConnectionError(f"Failed to handle incoming connection: {str(e)}")
async def run(self) -> None:
"""Run the listener service."""
# Implementation would go here
async def listen(self, maddr: multiaddr.Multiaddr, nursery: trio.Nursery) -> bool:
"""
Start listening on the given multiaddr.
Parameters
----------
maddr : multiaddr.Multiaddr
The multiaddr to listen on
nursery : trio.Nursery
The nursery to run tasks in
Returns
-------
bool
True if listening successfully started
"""
# Convert string to Multiaddr if needed
addr = (
maddr
if isinstance(maddr, multiaddr.Multiaddr)
else multiaddr.Multiaddr(maddr)
)
self.multiaddrs.append(addr)
return True
def get_addrs(self) -> tuple[multiaddr.Multiaddr, ...]:
"""
Get the listening addresses.
Returns
-------
tuple[multiaddr.Multiaddr, ...]
Tuple of listening multiaddresses
"""
return tuple(self.multiaddrs)
async def close(self) -> None:
"""Close the listener."""
self.multiaddrs.clear()
await self.manager.stop()

View File

@ -87,14 +87,16 @@ async def connect(node1: IHost, node2: IHost) -> None:
addr = node2.get_addrs()[0]
info = info_from_p2p_addr(addr)
# Add retry logic for more robust connection
# Add retry logic for more robust connection with timeout
max_retries = 3
retry_delay = 0.2
last_error = None
for attempt in range(max_retries):
try:
await node1.connect(info)
# Use timeout for each connection attempt
with trio.move_on_after(5): # 5 second timeout
await node1.connect(info)
# Verify connection is established in both directions
if (

View File

@ -0,0 +1 @@
Added sparse connect utility function to pubsub test utilities for creating test networks with configurable connectivity.

View File

@ -17,6 +17,7 @@ from tests.utils.factories import (
from tests.utils.pubsub.utils import (
dense_connect,
one_to_all_connect,
sparse_connect,
)
@ -506,3 +507,84 @@ async def test_gossip_heartbeat(initial_peer_count, monkeypatch):
# Check that the peer to gossip to is not in our fanout peers
assert peer not in fanout_peers
assert topic_fanout in peers_to_gossip[peer]
@pytest.mark.trio
async def test_dense_connect_fallback():
"""Test that sparse_connect falls back to dense connect for small networks."""
async with PubsubFactory.create_batch_with_gossipsub(3) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
degree = 2
# Create network (should use dense connect)
await sparse_connect(hosts, degree)
# Wait for connections to be established
await trio.sleep(2)
# Verify dense topology (all nodes connected to each other)
for i, pubsub in enumerate(pubsubs_gsub):
connected_peers = len(pubsub.peers)
expected_connections = len(hosts) - 1
assert connected_peers == expected_connections, (
f"Host {i} has {connected_peers} connections, "
f"expected {expected_connections} in dense mode"
)
@pytest.mark.trio
async def test_sparse_connect():
"""Test sparse connect functionality and message propagation."""
async with PubsubFactory.create_batch_with_gossipsub(10) as pubsubs_gsub:
hosts = [pubsub.host for pubsub in pubsubs_gsub]
degree = 2
topic = "test_topic"
# Create network (should use sparse connect)
await sparse_connect(hosts, degree)
# Wait for connections to be established
await trio.sleep(2)
# Verify sparse topology
for i, pubsub in enumerate(pubsubs_gsub):
connected_peers = len(pubsub.peers)
assert degree <= connected_peers < len(hosts) - 1, (
f"Host {i} has {connected_peers} connections, "
f"expected between {degree} and {len(hosts) - 1} in sparse mode"
)
# Test message propagation
queues = [await pubsub.subscribe(topic) for pubsub in pubsubs_gsub]
await trio.sleep(2)
# Publish and verify message propagation
msg_content = b"test_msg"
await pubsubs_gsub[0].publish(topic, msg_content)
await trio.sleep(2)
# Verify message propagation - ideally all nodes should receive it
received_count = 0
for queue in queues:
try:
msg = await queue.get()
if msg.data == msg_content:
received_count += 1
except Exception:
continue
total_nodes = len(pubsubs_gsub)
# Ideally all nodes should receive the message for optimal scalability
if received_count == total_nodes:
# Perfect propagation achieved
pass
else:
# require more than half for acceptable scalability
min_required = (total_nodes + 1) // 2
assert received_count >= min_required, (
f"Message propagation insufficient: "
f"{received_count}/{total_nodes} nodes "
f"received the message. Ideally all nodes should receive it, but at "
f"minimum {min_required} required for sparse network scalability."
)

View File

@ -0,0 +1,263 @@
"""Tests for the Circuit Relay v2 discovery functionality."""
import logging
import time
import pytest
import trio
from libp2p.relay.circuit_v2.discovery import (
RelayDiscovery,
)
from libp2p.relay.circuit_v2.pb import circuit_pb2 as proto
from libp2p.relay.circuit_v2.protocol import (
PROTOCOL_ID,
STOP_PROTOCOL_ID,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
from libp2p.tools.utils import (
connect,
)
from tests.utils.factories import (
HostFactory,
)
logger = logging.getLogger(__name__)
# Test timeouts
CONNECT_TIMEOUT = 15 # seconds
STREAM_TIMEOUT = 15 # seconds
HANDLER_TIMEOUT = 15 # seconds
SLEEP_TIME = 1.0 # seconds
DISCOVERY_TIMEOUT = 20 # seconds
# Make a simple stream handler for testing
async def simple_stream_handler(stream):
"""Simple stream handler that reads a message and responds with OK status."""
logger.info("Simple stream handler invoked")
try:
# Read the request
request_data = await stream.read(MAX_READ_LEN)
if not request_data:
logger.error("Empty request received")
return
# Parse request
request = proto.HopMessage()
request.ParseFromString(request_data)
logger.info("Received request: type=%s", request.type)
# Only handle RESERVE requests
if request.type == proto.HopMessage.RESERVE:
# Create a valid response
response = proto.HopMessage(
type=proto.HopMessage.RESERVE,
status=proto.Status(
code=proto.Status.OK,
message="Test reservation accepted",
),
reservation=proto.Reservation(
expire=int(time.time()) + 3600, # 1 hour from now
voucher=b"test-voucher",
signature=b"",
),
limit=proto.Limit(
duration=3600, # 1 hour
data=1024 * 1024 * 1024, # 1GB
),
)
# Send the response
logger.info("Sending response")
await stream.write(response.SerializeToString())
logger.info("Response sent")
except Exception as e:
logger.error("Error in simple stream handler: %s", str(e))
finally:
# Keep stream open to allow client to read response
await trio.sleep(1)
await stream.close()
@pytest.mark.trio
async def test_relay_discovery_initialization():
"""Test Circuit v2 relay discovery initializes correctly with default settings."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
discovery = RelayDiscovery(host)
async with background_trio_service(discovery):
await discovery.event_started.wait()
await trio.sleep(SLEEP_TIME) # Give time for discovery to start
# Verify discovery is initialized correctly
assert discovery.host == host, "Host not set correctly"
assert discovery.is_running, "Discovery service should be running"
assert hasattr(discovery, "_discovered_relays"), (
"Discovery should track discovered relays"
)
@pytest.mark.trio
async def test_relay_discovery_find_relay():
"""Test finding a relay node via discovery."""
async with HostFactory.create_batch_and_listen(2) as hosts:
relay_host, client_host = hosts
logger.info("Created hosts for test_relay_discovery_find_relay")
logger.info("Relay host ID: %s", relay_host.get_id())
logger.info("Client host ID: %s", client_host.get_id())
# Explicitly register the protocol handlers on relay_host
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
# Manually add protocol to peerstore for testing
# This simulates what the real relay protocol would do
client_host.get_peerstore().add_protocols(
relay_host.get_id(), [str(PROTOCOL_ID)]
)
# Set up discovery on the client host
client_discovery = RelayDiscovery(
client_host, discovery_interval=5
) # Use shorter interval for testing
try:
# Connect peers so they can discover each other
with trio.fail_after(CONNECT_TIMEOUT):
logger.info("Connecting client host to relay host")
await connect(client_host, relay_host)
assert relay_host.get_network().connections[client_host.get_id()], (
"Peers not connected"
)
logger.info("Connection established between peers")
except Exception as e:
logger.error("Failed to connect peers: %s", str(e))
raise
# Start discovery service
async with background_trio_service(client_discovery):
await client_discovery.event_started.wait()
logger.info("Client discovery service started")
# Wait for discovery to find the relay
logger.info("Waiting for relay discovery...")
# Manually trigger discovery instead of waiting
await client_discovery.discover_relays()
# Check if relay was found
with trio.fail_after(DISCOVERY_TIMEOUT):
for _ in range(20): # Try multiple times
if relay_host.get_id() in client_discovery._discovered_relays:
logger.info("Relay discovered successfully")
break
# Wait and try again
await trio.sleep(1)
# Manually trigger discovery again
await client_discovery.discover_relays()
else:
pytest.fail("Failed to discover relay node within timeout")
# Verify that relay was found and is valid
assert relay_host.get_id() in client_discovery._discovered_relays, (
"Relay should be discovered"
)
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
assert relay_info.peer_id == relay_host.get_id(), "Peer ID should match"
@pytest.mark.trio
async def test_relay_discovery_auto_reservation():
"""Test that discovery can auto-reserve with discovered relays."""
async with HostFactory.create_batch_and_listen(2) as hosts:
relay_host, client_host = hosts
logger.info("Created hosts for test_relay_discovery_auto_reservation")
logger.info("Relay host ID: %s", relay_host.get_id())
logger.info("Client host ID: %s", client_host.get_id())
# Explicitly register the protocol handlers on relay_host
relay_host.set_stream_handler(PROTOCOL_ID, simple_stream_handler)
relay_host.set_stream_handler(STOP_PROTOCOL_ID, simple_stream_handler)
# Manually add protocol to peerstore for testing
client_host.get_peerstore().add_protocols(
relay_host.get_id(), [str(PROTOCOL_ID)]
)
# Set up discovery on the client host with auto-reservation enabled
client_discovery = RelayDiscovery(
client_host, auto_reserve=True, discovery_interval=5
)
try:
# Connect peers so they can discover each other
with trio.fail_after(CONNECT_TIMEOUT):
logger.info("Connecting client host to relay host")
await connect(client_host, relay_host)
assert relay_host.get_network().connections[client_host.get_id()], (
"Peers not connected"
)
logger.info("Connection established between peers")
except Exception as e:
logger.error("Failed to connect peers: %s", str(e))
raise
# Start discovery service
async with background_trio_service(client_discovery):
await client_discovery.event_started.wait()
logger.info("Client discovery service started")
# Wait for discovery to find the relay and make a reservation
logger.info("Waiting for relay discovery and auto-reservation...")
# Manually trigger discovery
await client_discovery.discover_relays()
# Check if relay was found and reservation was made
with trio.fail_after(DISCOVERY_TIMEOUT):
for _ in range(20): # Try multiple times
relay_found = (
relay_host.get_id() in client_discovery._discovered_relays
)
has_reservation = (
relay_found
and client_discovery._discovered_relays[
relay_host.get_id()
].has_reservation
)
if has_reservation:
logger.info(
"Relay discovered and reservation made successfully"
)
break
# Wait and try again
await trio.sleep(1)
# Try to make reservation manually
if relay_host.get_id() in client_discovery._discovered_relays:
await client_discovery.make_reservation(relay_host.get_id())
else:
pytest.fail(
"Failed to discover relay and make reservation within timeout"
)
# Verify that relay was found and reservation was made
assert relay_host.get_id() in client_discovery._discovered_relays, (
"Relay should be discovered"
)
relay_info = client_discovery._discovered_relays[relay_host.get_id()]
assert relay_info.has_reservation, "Reservation should be made"
assert relay_info.reservation_expires_at is not None, (
"Reservation should have expiry time"
)
assert relay_info.reservation_data_limit is not None, (
"Reservation should have data limit"
)

View File

@ -0,0 +1,665 @@
"""Tests for the Circuit Relay v2 protocol."""
import logging
import time
from typing import Any
import pytest
import trio
from libp2p.network.stream.exceptions import (
StreamEOF,
StreamError,
StreamReset,
)
from libp2p.peer.id import (
ID,
)
from libp2p.relay.circuit_v2.pb import circuit_pb2 as proto
from libp2p.relay.circuit_v2.protocol import (
DEFAULT_RELAY_LIMITS,
PROTOCOL_ID,
STOP_PROTOCOL_ID,
CircuitV2Protocol,
)
from libp2p.relay.circuit_v2.resources import (
RelayLimits,
)
from libp2p.tools.async_service import (
background_trio_service,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
from libp2p.tools.utils import (
connect,
)
from tests.utils.factories import (
HostFactory,
)
logger = logging.getLogger(__name__)
# Test timeouts
CONNECT_TIMEOUT = 15 # seconds (increased)
STREAM_TIMEOUT = 15 # seconds (increased)
HANDLER_TIMEOUT = 15 # seconds (increased)
SLEEP_TIME = 1.0 # seconds (increased)
async def assert_stream_response(
stream, expected_type, expected_status, retries=5, retry_delay=1.0
):
"""Helper function to assert stream response matches expectations."""
last_error = None
all_responses = []
# Increase initial sleep to ensure response has time to arrive
await trio.sleep(retry_delay * 2)
for attempt in range(retries):
try:
with trio.fail_after(STREAM_TIMEOUT):
# Wait between attempts
if attempt > 0:
await trio.sleep(retry_delay)
# Try to read response
logger.debug("Attempt %d: Reading response from stream", attempt + 1)
response_bytes = await stream.read(MAX_READ_LEN)
# Check if we got any data
if not response_bytes:
logger.warning(
"Attempt %d: No data received from stream", attempt + 1
)
last_error = "No response received"
if attempt < retries - 1: # Not the last attempt
continue
raise AssertionError(
f"No response received after {retries} attempts"
)
# Try to parse the response
response = proto.HopMessage()
try:
response.ParseFromString(response_bytes)
# Log what we received
logger.debug(
"Attempt %d: Received HOP response: type=%s, status=%s",
attempt + 1,
response.type,
response.status.code
if response.HasField("status")
else "No status",
)
all_responses.append(
{
"type": response.type,
"status": response.status.code
if response.HasField("status")
else None,
"message": response.status.message
if response.HasField("status")
else None,
}
)
# Accept any valid response with the right status
if (
expected_status is not None
and response.HasField("status")
and response.status.code == expected_status
):
if response.type != expected_type:
logger.warning(
"Type mismatch (%s, got %s) but status ok - accepting",
expected_type,
response.type,
)
logger.debug("Successfully validated response (status matched)")
return response
# Check message type specifically if it matters
if response.type != expected_type:
logger.warning(
"Wrong response type: expected %s, got %s",
expected_type,
response.type,
)
last_error = (
f"Wrong response type: expected {expected_type}, "
f"got {response.type}"
)
if attempt < retries - 1: # Not the last attempt
continue
# Check status code if present
if response.HasField("status"):
if response.status.code != expected_status:
logger.warning(
"Wrong status code: expected %s, got %s",
expected_status,
response.status.code,
)
last_error = (
f"Wrong status code: expected {expected_status}, "
f"got {response.status.code}"
)
if attempt < retries - 1: # Not the last attempt
continue
elif expected_status is not None:
logger.warning(
"Expected status %s but none was present in response",
expected_status,
)
last_error = (
f"Expected status {expected_status} but none was present"
)
if attempt < retries - 1: # Not the last attempt
continue
logger.debug("Successfully validated response")
return response
except Exception as e:
# If parsing as HOP message fails, try parsing as STOP message
logger.warning(
"Failed to parse as HOP message, trying STOP message: %s",
str(e),
)
try:
stop_msg = proto.StopMessage()
stop_msg.ParseFromString(response_bytes)
logger.debug("Parsed as STOP message: type=%s", stop_msg.type)
# Create a simplified response dictionary
has_status = stop_msg.HasField("status")
status_code = None
status_message = None
if has_status:
status_code = stop_msg.status.code
status_message = stop_msg.status.message
response_dict: dict[str, Any] = {
"stop_type": stop_msg.type, # Keep original type
"status": status_code, # Keep original type
"message": status_message, # Keep original type
}
all_responses.append(response_dict)
last_error = "Got STOP message instead of HOP message"
if attempt < retries - 1: # Not the last attempt
continue
except Exception as e2:
logger.warning(
"Failed to parse response as either message type: %s",
str(e2),
)
last_error = (
f"Failed to parse response: {str(e)}, then {str(e2)}"
)
if attempt < retries - 1: # Not the last attempt
continue
except trio.TooSlowError:
logger.warning(
"Attempt %d: Timeout waiting for stream response", attempt + 1
)
last_error = "Timeout waiting for stream response"
if attempt < retries - 1: # Not the last attempt
continue
except (StreamError, StreamReset, StreamEOF) as e:
logger.warning(
"Attempt %d: Stream error while reading response: %s",
attempt + 1,
str(e),
)
last_error = f"Stream error: {str(e)}"
if attempt < retries - 1: # Not the last attempt
continue
except AssertionError as e:
logger.warning("Attempt %d: Assertion failed: %s", attempt + 1, str(e))
last_error = str(e)
if attempt < retries - 1: # Not the last attempt
continue
except Exception as e:
logger.warning("Attempt %d: Unexpected error: %s", attempt + 1, str(e))
last_error = f"Unexpected error: {str(e)}"
if attempt < retries - 1: # Not the last attempt
continue
# If we've reached here, all retries failed
all_responses_str = ", ".join([str(r) for r in all_responses])
error_msg = (
f"Failed to get expected response after {retries} attempts. "
f"Last error: {last_error}. All responses: {all_responses_str}"
)
raise AssertionError(error_msg)
async def close_stream(stream):
"""Helper function to safely close a stream."""
if stream is not None:
try:
logger.debug("Closing stream")
await stream.close()
# Wait a bit to ensure the close is processed
await trio.sleep(SLEEP_TIME)
logger.debug("Stream closed successfully")
except (StreamError, Exception) as e:
logger.warning("Error closing stream: %s. Attempting to reset.", str(e))
try:
await stream.reset()
# Wait a bit to ensure the reset is processed
await trio.sleep(SLEEP_TIME)
logger.debug("Stream reset successfully")
except Exception as e:
logger.warning("Error resetting stream: %s", str(e))
@pytest.mark.trio
async def test_circuit_v2_protocol_initialization():
"""Test that the Circuit v2 protocol initializes correctly with default settings."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
limits = RelayLimits(
duration=DEFAULT_RELAY_LIMITS.duration,
data=DEFAULT_RELAY_LIMITS.data,
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
)
protocol = CircuitV2Protocol(host, limits, allow_hop=True)
async with background_trio_service(protocol):
await protocol.event_started.wait()
await trio.sleep(SLEEP_TIME) # Give time for handlers to be registered
# Verify protocol handlers are registered by trying to use them
test_stream = None
try:
with trio.fail_after(STREAM_TIMEOUT):
test_stream = await host.new_stream(host.get_id(), [PROTOCOL_ID])
assert test_stream is not None, (
"HOP protocol handler not registered"
)
except Exception:
pass
finally:
await close_stream(test_stream)
try:
with trio.fail_after(STREAM_TIMEOUT):
test_stream = await host.new_stream(
host.get_id(), [STOP_PROTOCOL_ID]
)
assert test_stream is not None, (
"STOP protocol handler not registered"
)
except Exception:
pass
finally:
await close_stream(test_stream)
assert len(protocol.resource_manager._reservations) == 0, (
"Reservations should be empty"
)
@pytest.mark.trio
async def test_circuit_v2_reservation_basic():
"""Test basic reservation functionality between two peers."""
async with HostFactory.create_batch_and_listen(2) as hosts:
relay_host, client_host = hosts
logger.info("Created hosts for test_circuit_v2_reservation_basic")
logger.info("Relay host ID: %s", relay_host.get_id())
logger.info("Client host ID: %s", client_host.get_id())
# Custom handler that responds directly with a valid response
# This bypasses the complex protocol implementation that might have issues
async def mock_reserve_handler(stream):
# Read the request
logger.info("Mock handler received stream request")
try:
request_data = await stream.read(MAX_READ_LEN)
request = proto.HopMessage()
request.ParseFromString(request_data)
logger.info("Mock handler parsed request: type=%s", request.type)
# Only handle RESERVE requests
if request.type == proto.HopMessage.RESERVE:
# Create a valid response
response = proto.HopMessage(
type=proto.HopMessage.RESERVE,
status=proto.Status(
code=proto.Status.OK,
message="Reservation accepted",
),
reservation=proto.Reservation(
expire=int(time.time()) + 3600, # 1 hour from now
voucher=b"test-voucher",
signature=b"",
),
limit=proto.Limit(
duration=3600, # 1 hour
data=1024 * 1024 * 1024, # 1GB
),
)
# Send the response
logger.info("Mock handler sending response")
await stream.write(response.SerializeToString())
logger.info("Mock handler sent response")
# Keep stream open for client to read response
await trio.sleep(5)
except Exception as e:
logger.error("Error in mock handler: %s", str(e))
# Register the mock handler
relay_host.set_stream_handler(PROTOCOL_ID, mock_reserve_handler)
logger.info("Registered mock handler for %s", PROTOCOL_ID)
# Connect peers
try:
with trio.fail_after(CONNECT_TIMEOUT):
logger.info("Connecting client host to relay host")
await connect(client_host, relay_host)
assert relay_host.get_network().connections[client_host.get_id()], (
"Peers not connected"
)
logger.info("Connection established between peers")
except Exception as e:
logger.error("Failed to connect peers: %s", str(e))
raise
# Wait a bit to ensure connection is fully established
await trio.sleep(SLEEP_TIME)
stream = None
try:
# Open stream and send reservation request
logger.info("Opening stream from client to relay")
with trio.fail_after(STREAM_TIMEOUT):
stream = await client_host.new_stream(
relay_host.get_id(), [PROTOCOL_ID]
)
assert stream is not None, "Failed to open stream"
logger.info("Preparing reservation request")
request = proto.HopMessage(
type=proto.HopMessage.RESERVE, peer=client_host.get_id().to_bytes()
)
logger.info("Sending reservation request")
await stream.write(request.SerializeToString())
logger.info("Reservation request sent")
# Wait to ensure the request is processed
await trio.sleep(SLEEP_TIME)
# Read response directly
logger.info("Reading response directly")
response_bytes = await stream.read(MAX_READ_LEN)
assert response_bytes, "No response received"
# Parse response
response = proto.HopMessage()
response.ParseFromString(response_bytes)
# Verify response
assert response.type == proto.HopMessage.RESERVE, (
f"Wrong response type: {response.type}"
)
assert response.HasField("status"), "No status field"
assert response.status.code == proto.Status.OK, (
f"Wrong status code: {response.status.code}"
)
# Verify reservation details
assert response.HasField("reservation"), "No reservation field"
assert response.HasField("limit"), "No limit field"
assert response.limit.duration == 3600, (
f"Wrong duration: {response.limit.duration}"
)
assert response.limit.data == 1024 * 1024 * 1024, (
f"Wrong data limit: {response.limit.data}"
)
logger.info("Verified reservation details in response")
except Exception as e:
logger.error("Error in reservation test: %s", str(e))
raise
finally:
if stream:
await close_stream(stream)
@pytest.mark.trio
async def test_circuit_v2_reservation_limit():
"""Test that relay enforces reservation limits."""
async with HostFactory.create_batch_and_listen(3) as hosts:
relay_host, client1_host, client2_host = hosts
logger.info("Created hosts for test_circuit_v2_reservation_limit")
logger.info("Relay host ID: %s", relay_host.get_id())
logger.info("Client1 host ID: %s", client1_host.get_id())
logger.info("Client2 host ID: %s", client2_host.get_id())
# Track reservation status to simulate limits
reserved_clients = set()
max_reservations = 1 # Only allow one reservation
# Custom handler that responds based on reservation limits
async def mock_reserve_handler(stream):
# Read the request
logger.info("Mock handler received stream request")
try:
request_data = await stream.read(MAX_READ_LEN)
request = proto.HopMessage()
request.ParseFromString(request_data)
logger.info("Mock handler parsed request: type=%s", request.type)
# Only handle RESERVE requests
if request.type == proto.HopMessage.RESERVE:
# Extract peer ID from request
peer_id = ID(request.peer)
logger.info(
"Mock handler received reservation request from %s", peer_id
)
# Check if we've reached reservation limit
if (
peer_id in reserved_clients
or len(reserved_clients) < max_reservations
):
# Accept the reservation
if peer_id not in reserved_clients:
reserved_clients.add(peer_id)
# Create a success response
response = proto.HopMessage(
type=proto.HopMessage.RESERVE,
status=proto.Status(
code=proto.Status.OK,
message="Reservation accepted",
),
reservation=proto.Reservation(
expire=int(time.time()) + 3600, # 1 hour from now
voucher=b"test-voucher",
signature=b"",
),
limit=proto.Limit(
duration=3600, # 1 hour
data=1024 * 1024 * 1024, # 1GB
),
)
logger.info(
"Mock handler accepting reservation for %s", peer_id
)
else:
# Reject the reservation due to limits
response = proto.HopMessage(
type=proto.HopMessage.RESERVE,
status=proto.Status(
code=proto.Status.RESOURCE_LIMIT_EXCEEDED,
message="Reservation limit exceeded",
),
)
logger.info(
"Mock handler rejecting reservation for %s due to limit",
peer_id,
)
# Send the response
logger.info("Mock handler sending response")
await stream.write(response.SerializeToString())
logger.info("Mock handler sent response")
# Keep stream open for client to read response
await trio.sleep(5)
except Exception as e:
logger.error("Error in mock handler: %s", str(e))
# Register the mock handler
relay_host.set_stream_handler(PROTOCOL_ID, mock_reserve_handler)
logger.info("Registered mock handler for %s", PROTOCOL_ID)
# Connect peers
try:
with trio.fail_after(CONNECT_TIMEOUT):
logger.info("Connecting client1 to relay")
await connect(client1_host, relay_host)
logger.info("Connecting client2 to relay")
await connect(client2_host, relay_host)
assert relay_host.get_network().connections[client1_host.get_id()], (
"Client1 not connected"
)
assert relay_host.get_network().connections[client2_host.get_id()], (
"Client2 not connected"
)
logger.info("All connections established")
except Exception as e:
logger.error("Failed to connect peers: %s", str(e))
raise
# Wait a bit to ensure connections are fully established
await trio.sleep(SLEEP_TIME)
stream1, stream2 = None, None
try:
# Client 1 reservation (should succeed)
logger.info("Testing client1 reservation (should succeed)")
with trio.fail_after(STREAM_TIMEOUT):
logger.info("Opening stream for client1")
stream1 = await client1_host.new_stream(
relay_host.get_id(), [PROTOCOL_ID]
)
assert stream1 is not None, "Failed to open stream for client 1"
logger.info("Preparing reservation request for client1")
request1 = proto.HopMessage(
type=proto.HopMessage.RESERVE, peer=client1_host.get_id().to_bytes()
)
logger.info("Sending reservation request for client1")
await stream1.write(request1.SerializeToString())
logger.info("Sent reservation request for client1")
# Wait to ensure the request is processed
await trio.sleep(SLEEP_TIME)
# Read response directly
logger.info("Reading response for client1")
response_bytes = await stream1.read(MAX_READ_LEN)
assert response_bytes, "No response received for client1"
# Parse response
response1 = proto.HopMessage()
response1.ParseFromString(response_bytes)
# Verify response
assert response1.type == proto.HopMessage.RESERVE, (
f"Wrong response type: {response1.type}"
)
assert response1.HasField("status"), "No status field"
assert response1.status.code == proto.Status.OK, (
f"Wrong status code: {response1.status.code}"
)
# Verify reservation details
assert response1.HasField("reservation"), "No reservation field"
assert response1.HasField("limit"), "No limit field"
assert response1.limit.duration == 3600, (
f"Wrong duration: {response1.limit.duration}"
)
assert response1.limit.data == 1024 * 1024 * 1024, (
f"Wrong data limit: {response1.limit.data}"
)
logger.info("Verified reservation details for client1")
# Close stream1 before opening stream2
await close_stream(stream1)
stream1 = None
logger.info("Closed client1 stream")
# Wait a bit to ensure stream is fully closed
await trio.sleep(SLEEP_TIME)
# Client 2 reservation (should fail)
logger.info("Testing client2 reservation (should fail)")
stream2 = await client2_host.new_stream(
relay_host.get_id(), [PROTOCOL_ID]
)
assert stream2 is not None, "Failed to open stream for client 2"
logger.info("Preparing reservation request for client2")
request2 = proto.HopMessage(
type=proto.HopMessage.RESERVE, peer=client2_host.get_id().to_bytes()
)
logger.info("Sending reservation request for client2")
await stream2.write(request2.SerializeToString())
logger.info("Sent reservation request for client2")
# Wait to ensure the request is processed
await trio.sleep(SLEEP_TIME)
# Read response directly
logger.info("Reading response for client2")
response_bytes = await stream2.read(MAX_READ_LEN)
assert response_bytes, "No response received for client2"
# Parse response
response2 = proto.HopMessage()
response2.ParseFromString(response_bytes)
# Verify response
assert response2.type == proto.HopMessage.RESERVE, (
f"Wrong response type: {response2.type}"
)
assert response2.HasField("status"), "No status field"
assert response2.status.code == proto.Status.RESOURCE_LIMIT_EXCEEDED, (
f"Wrong status code: {response2.status.code}, "
f"expected RESOURCE_LIMIT_EXCEEDED"
)
logger.info("Verified client2 was correctly rejected")
# Verify reservation tracking is correct
assert len(reserved_clients) == 1, "Should have exactly one reservation"
assert client1_host.get_id() in reserved_clients, (
"Client1 should be reserved"
)
assert client2_host.get_id() not in reserved_clients, (
"Client2 should not be reserved"
)
logger.info("Verified reservation tracking state")
except Exception as e:
logger.error("Error in reservation limit test: %s", str(e))
# Diagnostic information
logger.error("Current reservations: %s", reserved_clients)
raise
finally:
await close_stream(stream1)
await close_stream(stream2)

View File

@ -0,0 +1,346 @@
"""Tests for the Circuit Relay v2 transport functionality."""
import logging
import time
import pytest
import trio
from libp2p.custom_types import TProtocol
from libp2p.network.stream.exceptions import (
StreamEOF,
StreamReset,
)
from libp2p.relay.circuit_v2.config import (
RelayConfig,
)
from libp2p.relay.circuit_v2.discovery import (
RelayDiscovery,
RelayInfo,
)
from libp2p.relay.circuit_v2.protocol import (
CircuitV2Protocol,
RelayLimits,
)
from libp2p.relay.circuit_v2.transport import (
CircuitV2Transport,
)
from libp2p.tools.constants import (
MAX_READ_LEN,
)
from libp2p.tools.utils import (
connect,
)
from tests.utils.factories import (
HostFactory,
)
logger = logging.getLogger(__name__)
# Test timeouts
CONNECT_TIMEOUT = 15 # seconds
STREAM_TIMEOUT = 15 # seconds
HANDLER_TIMEOUT = 15 # seconds
SLEEP_TIME = 1.0 # seconds
RELAY_TIMEOUT = 20 # seconds
# Default limits for relay
DEFAULT_RELAY_LIMITS = RelayLimits(
duration=60 * 60, # 1 hour
data=1024 * 1024 * 10, # 10 MB
max_circuit_conns=8, # 8 active relay connections
max_reservations=4, # 4 active reservations
)
# Message for testing
TEST_MESSAGE = b"Hello, Circuit Relay!"
TEST_RESPONSE = b"Hello from the other side!"
# Stream handler for testing
async def echo_stream_handler(stream):
"""Simple echo handler that responds to messages."""
logger.info("Echo handler received stream")
try:
while True:
data = await stream.read(MAX_READ_LEN)
if not data:
logger.info("Stream closed by remote")
break
logger.info("Received data: %s", data)
await stream.write(TEST_RESPONSE)
logger.info("Sent response")
except (StreamEOF, StreamReset) as e:
logger.info("Stream ended: %s", str(e))
except Exception as e:
logger.error("Error in echo handler: %s", str(e))
finally:
await stream.close()
@pytest.mark.trio
async def test_circuit_v2_transport_initialization():
"""Test that the Circuit v2 transport initializes correctly."""
async with HostFactory.create_batch_and_listen(1) as hosts:
host = hosts[0]
# Create a protocol instance
limits = RelayLimits(
duration=DEFAULT_RELAY_LIMITS.duration,
data=DEFAULT_RELAY_LIMITS.data,
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
)
protocol = CircuitV2Protocol(host, limits, allow_hop=False)
config = RelayConfig()
# Create a discovery instance
discovery = RelayDiscovery(
host=host,
auto_reserve=False,
discovery_interval=config.discovery_interval,
max_relays=config.max_relays,
)
# Create the transport with the necessary components
transport = CircuitV2Transport(host, protocol, config)
# Replace the discovery with our manually created one
transport.discovery = discovery
# Verify transport properties
assert transport.host == host, "Host not set correctly"
assert transport.protocol == protocol, "Protocol not set correctly"
assert transport.config == config, "Config not set correctly"
assert hasattr(transport, "discovery"), (
"Transport should have a discovery instance"
)
@pytest.mark.trio
async def test_circuit_v2_transport_add_relay():
"""Test adding a relay to the transport."""
async with HostFactory.create_batch_and_listen(2) as hosts:
host, relay_host = hosts
# Create a protocol instance
limits = RelayLimits(
duration=DEFAULT_RELAY_LIMITS.duration,
data=DEFAULT_RELAY_LIMITS.data,
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
)
protocol = CircuitV2Protocol(host, limits, allow_hop=False)
config = RelayConfig()
# Create a discovery instance
discovery = RelayDiscovery(
host=host,
auto_reserve=False,
discovery_interval=config.discovery_interval,
max_relays=config.max_relays,
)
# Create the transport with the necessary components
transport = CircuitV2Transport(host, protocol, config)
# Replace the discovery with our manually created one
transport.discovery = discovery
relay_id = relay_host.get_id()
now = time.time()
relay_info = RelayInfo(peer_id=relay_id, discovered_at=now, last_seen=now)
async def mock_add_relay(peer_id):
discovery._discovered_relays[peer_id] = relay_info
discovery._add_relay = mock_add_relay # Type ignored in test context
discovery._discovered_relays[relay_id] = relay_info
# Verify relay was added
assert relay_id in discovery._discovered_relays, (
"Relay should be in discovery's relay list"
)
@pytest.mark.trio
async def test_circuit_v2_transport_dial_through_relay():
"""Test dialing a peer through a relay."""
async with HostFactory.create_batch_and_listen(3) as hosts:
client_host, relay_host, target_host = hosts
logger.info("Created hosts for test_circuit_v2_transport_dial_through_relay")
logger.info("Client host ID: %s", client_host.get_id())
logger.info("Relay host ID: %s", relay_host.get_id())
logger.info("Target host ID: %s", target_host.get_id())
# Setup relay with Circuit v2 protocol
limits = RelayLimits(
duration=DEFAULT_RELAY_LIMITS.duration,
data=DEFAULT_RELAY_LIMITS.data,
max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns,
max_reservations=DEFAULT_RELAY_LIMITS.max_reservations,
)
# Register test handler on target
test_protocol = "/test/echo/1.0.0"
target_host.set_stream_handler(TProtocol(test_protocol), echo_stream_handler)
client_config = RelayConfig()
client_protocol = CircuitV2Protocol(client_host, limits, allow_hop=False)
# Create a discovery instance
client_discovery = RelayDiscovery(
host=client_host,
auto_reserve=False,
discovery_interval=client_config.discovery_interval,
max_relays=client_config.max_relays,
)
# Create the transport with the necessary components
client_transport = CircuitV2Transport(
client_host, client_protocol, client_config
)
# Replace the discovery with our manually created one
client_transport.discovery = client_discovery
# Mock the get_relay method to return our relay_host
relay_id = relay_host.get_id()
client_discovery.get_relay = lambda: relay_id
# Connect client to relay and relay to target
try:
with trio.fail_after(
CONNECT_TIMEOUT * 2
): # Double the timeout for connections
logger.info("Connecting client host to relay host")
await connect(client_host, relay_host)
# Verify connection
assert relay_host.get_id() in client_host.get_network().connections, (
"Client not connected to relay"
)
assert client_host.get_id() in relay_host.get_network().connections, (
"Relay not connected to client"
)
logger.info("Client-Relay connection verified")
# Wait to ensure connection is fully established
await trio.sleep(SLEEP_TIME)
logger.info("Connecting relay host to target host")
await connect(relay_host, target_host)
# Verify connection
assert target_host.get_id() in relay_host.get_network().connections, (
"Relay not connected to target"
)
assert relay_host.get_id() in target_host.get_network().connections, (
"Target not connected to relay"
)
logger.info("Relay-Target connection verified")
# Wait to ensure connection is fully established
await trio.sleep(SLEEP_TIME)
logger.info("All connections established and verified")
except Exception as e:
logger.error("Failed to connect peers: %s", str(e))
raise
# Test successful - the connections were established, which is enough to verify
# that the transport can be initialized and configured correctly
logger.info("Transport initialization and connection test passed")
@pytest.mark.trio
async def test_circuit_v2_transport_relay_limits():
"""Test that relay enforces connection limits."""
async with HostFactory.create_batch_and_listen(4) as hosts:
client1_host, client2_host, relay_host, target_host = hosts
logger.info("Created hosts for test_circuit_v2_transport_relay_limits")
# Setup relay with strict limits
limits = RelayLimits(
duration=DEFAULT_RELAY_LIMITS.duration,
data=DEFAULT_RELAY_LIMITS.data,
max_circuit_conns=1, # Only allow one circuit
max_reservations=2, # Allow both clients to reserve
)
relay_protocol = CircuitV2Protocol(relay_host, limits, allow_hop=True)
# Register test handler on target
test_protocol = "/test/echo/1.0.0"
target_host.set_stream_handler(TProtocol(test_protocol), echo_stream_handler)
client_config = RelayConfig()
# Client 1 setup
client1_protocol = CircuitV2Protocol(
client1_host, DEFAULT_RELAY_LIMITS, allow_hop=False
)
client1_discovery = RelayDiscovery(
host=client1_host,
auto_reserve=False,
discovery_interval=client_config.discovery_interval,
max_relays=client_config.max_relays,
)
client1_transport = CircuitV2Transport(
client1_host, client1_protocol, client_config
)
client1_transport.discovery = client1_discovery
# Add relay to discovery
relay_id = relay_host.get_id()
client1_discovery.get_relay = lambda: relay_id
# Client 2 setup
client2_protocol = CircuitV2Protocol(
client2_host, DEFAULT_RELAY_LIMITS, allow_hop=False
)
client2_discovery = RelayDiscovery(
host=client2_host,
auto_reserve=False,
discovery_interval=client_config.discovery_interval,
max_relays=client_config.max_relays,
)
client2_transport = CircuitV2Transport(
client2_host, client2_protocol, client_config
)
client2_transport.discovery = client2_discovery
# Add relay to discovery
client2_discovery.get_relay = lambda: relay_id
# Connect all peers
try:
with trio.fail_after(CONNECT_TIMEOUT):
# Connect clients to relay
await connect(client1_host, relay_host)
await connect(client2_host, relay_host)
# Connect relay to target
await connect(relay_host, target_host)
logger.info("All connections established")
except Exception as e:
logger.error("Failed to connect peers: %s", str(e))
raise
# Verify connections
assert relay_host.get_id() in client1_host.get_network().connections, (
"Client1 not connected to relay"
)
assert relay_host.get_id() in client2_host.get_network().connections, (
"Client2 not connected to relay"
)
assert target_host.get_id() in relay_host.get_network().connections, (
"Relay not connected to target"
)
# Verify the resource limits
assert relay_protocol.resource_manager.limits.max_circuit_conns == 1, (
"Wrong max_circuit_conns value"
)
assert relay_protocol.resource_manager.limits.max_reservations == 2, (
"Wrong max_reservations value"
)
# Test successful - transports were initialized with the correct limits
logger.info("Transport limit test successful")

View File

@ -13,6 +13,8 @@ from libp2p.security.secio.transport import ID as SECIO_PROTOCOL_ID
from libp2p.security.secure_session import (
SecureSession,
)
from libp2p.stream_muxer.mplex.mplex import Mplex
from libp2p.stream_muxer.yamux.yamux import Yamux
from tests.utils.factories import (
host_pair_factory,
)
@ -47,9 +49,28 @@ async def perform_simple_test(assertion_func, security_protocol):
assert conn_0 is not None, "Failed to establish connection from host0 to host1"
assert conn_1 is not None, "Failed to establish connection from host1 to host0"
# Perform assertion
assertion_func(conn_0.muxed_conn.secured_conn)
assertion_func(conn_1.muxed_conn.secured_conn)
# Extract the secured connection from either Mplex or Yamux implementation
def get_secured_conn(conn):
muxed_conn = conn.muxed_conn
# Direct attribute access for known implementations
has_secured_conn = hasattr(muxed_conn, "secured_conn")
if isinstance(muxed_conn, (Mplex, Yamux)) and has_secured_conn:
return muxed_conn.secured_conn
# Fallback to _connection attribute if it exists
elif hasattr(muxed_conn, "_connection"):
return muxed_conn._connection
# Last resort - warn but return the muxed_conn itself for type checking
else:
print(f"Warning: Cannot find secured connection in {type(muxed_conn)}")
return muxed_conn
# Get secured connections for both peers
secured_conn_0 = get_secured_conn(conn_0)
secured_conn_1 = get_secured_conn(conn_1)
# Perform assertion on the secured connections
assertion_func(secured_conn_0)
assertion_func(secured_conn_1)
@pytest.mark.trio

View File

@ -40,3 +40,42 @@ async def one_to_all_connect(hosts: Sequence[IHost], central_host_index: int) ->
for i, host in enumerate(hosts):
if i != central_host_index:
await connect(hosts[central_host_index], host)
async def sparse_connect(hosts: Sequence[IHost], degree: int = 3) -> None:
"""
Create a sparse network topology where each node connects to a limited number of
other nodes. This is more efficient than dense connect for large networks.
The function will automatically switch between dense and sparse connect based on
the network size:
- For small networks (nodes <= degree + 1), use dense connect
- For larger networks, use sparse connect with the specified degree
Args:
hosts: Sequence of hosts to connect
degree: Number of connections each node should maintain (default: 3)
"""
if len(hosts) <= degree + 1:
# For small networks, use dense connect
await dense_connect(hosts)
return
# For larger networks, use sparse connect
# For each host, connect to 'degree' number of other hosts
for i, host in enumerate(hosts):
# Calculate which hosts to connect to
# We'll connect to hosts that are 'degree' positions away in the sequence
# This creates a more distributed topology
for j in range(1, degree + 1):
target_idx = (i + j) % len(hosts)
# Create bidirectional connection
await connect(host, hosts[target_idx])
await connect(hosts[target_idx], host)
# Ensure network connectivity by connecting each node to its immediate neighbors
for i in range(len(hosts)):
next_idx = (i + 1) % len(hosts)
await connect(hosts[i], hosts[next_idx])
await connect(hosts[next_idx], hosts[i])