mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Feat/587-circuit-relay (#611)
* feat: implemented setup of circuit relay and test cases * chore: remove test files to be rewritten * added 1 test suite for protocol * added 1 test suite for discovery * fixed protocol timeouts and message types to handle reservations and stream operations. * Resolved merge conflict in libp2p/tools/utils.py by combining timeout approach with retry mechanism * fix: linting issues * docs: updated documentation with circuit-relay * chore: added enums, improved typing, security and examples * fix: created proper __init__ file to ensure importability * fix: replace transport_opt with listen_addrs in examples, fixed typing and improved code * fix type checking issues across relay module and test suite * regenerated circuit_pb2 file protobuf version 3 * fixed circuit relay example and moved imports to top in test_security_multistream * chore: moved imports to the top * chore: fixed linting of test_circuit_v2_transport.py --------- Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
This commit is contained in:
537
libp2p/relay/circuit_v2/discovery.py
Normal file
537
libp2p/relay/circuit_v2/discovery.py
Normal file
@ -0,0 +1,537 @@
|
||||
"""
|
||||
Discovery module for Circuit Relay v2.
|
||||
|
||||
This module handles discovering and tracking relay nodes in the network.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .pb.circuit_pb2 import (
|
||||
HopMessage,
|
||||
)
|
||||
from .protocol import (
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from .protocol_buffer import (
|
||||
StatusCode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
|
||||
|
||||
# Constants
|
||||
MAX_RELAYS_TO_TRACK = 10
|
||||
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
|
||||
STREAM_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
# Extended interfaces for type checking
|
||||
@runtime_checkable
|
||||
class IHostWithMultiselect(TypingProtocol):
|
||||
"""Extended host interface with multiselect attribute."""
|
||||
|
||||
@property
|
||||
def multiselect(self) -> Any:
|
||||
"""Get the multiselect component."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayInfo:
|
||||
"""Information about a discovered relay."""
|
||||
|
||||
peer_id: ID
|
||||
discovered_at: float
|
||||
last_seen: float
|
||||
has_reservation: bool = False
|
||||
reservation_expires_at: float | None = None
|
||||
reservation_data_limit: int | None = None
|
||||
|
||||
|
||||
class RelayDiscovery(Service):
|
||||
"""
|
||||
Discovery service for Circuit Relay v2 nodes.
|
||||
|
||||
This service discovers and keeps track of available relay nodes, and optionally
|
||||
makes reservations with them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
auto_reserve: bool = False,
|
||||
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
|
||||
max_relays: int = MAX_RELAYS_TO_TRACK,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the discovery service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this discovery service is running on
|
||||
auto_reserve : bool
|
||||
Whether to automatically make reservations with discovered relays
|
||||
discovery_interval : int
|
||||
How often to run discovery, in seconds
|
||||
max_relays : int
|
||||
Maximum number of relays to track
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.auto_reserve = auto_reserve
|
||||
self.discovery_interval = discovery_interval
|
||||
self.max_relays = max_relays
|
||||
self._discovered_relays: dict[ID, RelayInfo] = {}
|
||||
self._protocol_cache: dict[
|
||||
ID, set[str]
|
||||
] = {} # Cache protocol info to reduce queries
|
||||
self.event_started = trio.Event()
|
||||
self.is_running = False
|
||||
|
||||
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
|
||||
"""Run the discovery service."""
|
||||
try:
|
||||
self.is_running = True
|
||||
self.event_started.set()
|
||||
task_status.started()
|
||||
|
||||
# Main discovery loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Run initial discovery
|
||||
nursery.start_soon(self.discover_relays)
|
||||
|
||||
# Set up periodic discovery
|
||||
while True:
|
||||
await trio.sleep(self.discovery_interval)
|
||||
if not self.manager.is_running:
|
||||
break
|
||||
nursery.start_soon(self.discover_relays)
|
||||
|
||||
# Cleanup expired relays and reservations
|
||||
await self._cleanup_expired()
|
||||
|
||||
finally:
|
||||
self.is_running = False
|
||||
|
||||
async def discover_relays(self) -> None:
|
||||
r"""
|
||||
Discover relay nodes in the network.
|
||||
|
||||
This method queries the network for peers that support the
|
||||
Circuit Relay v2 protocol.
|
||||
"""
|
||||
logger.debug("Starting relay discovery")
|
||||
|
||||
try:
|
||||
# Get connected peers
|
||||
connected_peers = self.host.get_connected_peers()
|
||||
logger.debug(
|
||||
"Checking %d connected peers for relay support", len(connected_peers)
|
||||
)
|
||||
|
||||
# Check each peer if they support the relay protocol
|
||||
for peer_id in connected_peers:
|
||||
if peer_id == self.host.get_id():
|
||||
continue # Skip ourselves
|
||||
|
||||
if peer_id in self._discovered_relays:
|
||||
# Update last seen time for existing relay
|
||||
self._discovered_relays[peer_id].last_seen = time.time()
|
||||
continue
|
||||
|
||||
# Check if peer supports the relay protocol
|
||||
with trio.move_on_after(5): # Don't wait too long for protocol info
|
||||
if await self._supports_relay_protocol(peer_id):
|
||||
await self._add_relay(peer_id)
|
||||
|
||||
# Limit number of relays we track
|
||||
if len(self._discovered_relays) > self.max_relays:
|
||||
# Sort by last seen time and keep only the most recent ones
|
||||
sorted_relays = sorted(
|
||||
self._discovered_relays.items(),
|
||||
key=lambda x: x[1].last_seen,
|
||||
reverse=True,
|
||||
)
|
||||
to_remove = sorted_relays[self.max_relays :]
|
||||
for peer_id, _ in to_remove:
|
||||
del self._discovered_relays[peer_id]
|
||||
|
||||
logger.debug(
|
||||
"Discovery completed, tracking %d relays", len(self._discovered_relays)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during relay discovery: %s", str(e))
|
||||
|
||||
async def _supports_relay_protocol(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer supports the relay protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the peer to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the peer supports the relay protocol, False otherwise
|
||||
|
||||
"""
|
||||
# Check cache first
|
||||
if peer_id in self._protocol_cache:
|
||||
return PROTOCOL_ID in self._protocol_cache[peer_id]
|
||||
|
||||
# Method 1: Try peerstore
|
||||
result = await self._check_via_peerstore(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Method 2: Try direct stream connection
|
||||
result = await self._check_via_direct_connection(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Method 3: Try protocols from mux
|
||||
result = await self._check_via_mux(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Default: Cannot determine, assume false
|
||||
return False
|
||||
|
||||
async def _check_via_peerstore(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via peerstore."""
|
||||
try:
|
||||
peerstore = self.host.get_peerstore()
|
||||
proto_getter = peerstore.get_protocols
|
||||
|
||||
if not callable(proto_getter):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Try to get protocols
|
||||
proto_result = proto_getter(peer_id)
|
||||
|
||||
# Get protocols list
|
||||
protocols_list = []
|
||||
if hasattr(proto_result, "__await__"):
|
||||
protocols_list = await cast(Any, proto_result)
|
||||
else:
|
||||
protocols_list = proto_result
|
||||
|
||||
# Check result
|
||||
if protocols_list is not None:
|
||||
protocols = set(protocols_list)
|
||||
self._protocol_cache[peer_id] = protocols
|
||||
return PROTOCOL_ID in protocols
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error getting protocols: %s", str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug("Error accessing peerstore: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via direct connection."""
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if stream:
|
||||
await stream.close()
|
||||
self._protocol_cache[peer_id] = {PROTOCOL_ID}
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to open relay protocol stream to %s: %s", peer_id, str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def _check_via_mux(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via mux protocols."""
|
||||
try:
|
||||
if not (hasattr(self.host, "get_mux") and self.host.get_mux() is not None):
|
||||
return None
|
||||
|
||||
mux = self.host.get_mux()
|
||||
if not hasattr(mux, "protocols"):
|
||||
return None
|
||||
|
||||
peer_protocols = set()
|
||||
# Get protocols from mux with proper type safety
|
||||
available_protocols = []
|
||||
if hasattr(mux, "get_protocols"):
|
||||
# Get protocols with proper typing
|
||||
mux_protocols = mux.get_protocols()
|
||||
if isinstance(mux_protocols, (list, tuple)):
|
||||
available_protocols = list(mux_protocols)
|
||||
|
||||
for protocol in available_protocols:
|
||||
try:
|
||||
with trio.fail_after(2): # Quick check
|
||||
# Ensure we have a proper protocol object
|
||||
# Use string representation since we can't use isinstance
|
||||
is_tprotocol = str(type(protocol)) == str(type(TProtocol))
|
||||
protocol_obj = (
|
||||
protocol if is_tprotocol else TProtocol(str(protocol))
|
||||
)
|
||||
stream = await self.host.new_stream(peer_id, [protocol_obj])
|
||||
if stream:
|
||||
peer_protocols.add(str(protocol_obj))
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing the stream
|
||||
|
||||
self._protocol_cache[peer_id] = peer_protocols
|
||||
protocol_str = str(PROTOCOL_ID)
|
||||
for protocol in peer_protocols:
|
||||
if protocol == protocol_str:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error checking protocols via mux: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _add_relay(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Add a peer as a relay and optionally make a reservation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the peer to add as a relay
|
||||
|
||||
"""
|
||||
now = time.time()
|
||||
relay_info = RelayInfo(
|
||||
peer_id=peer_id,
|
||||
discovered_at=now,
|
||||
last_seen=now,
|
||||
)
|
||||
self._discovered_relays[peer_id] = relay_info
|
||||
logger.debug("Added relay %s to discovered relays", peer_id)
|
||||
|
||||
# If auto-reserve is enabled, make a reservation with this relay
|
||||
if self.auto_reserve:
|
||||
await self.make_reservation(peer_id)
|
||||
|
||||
async def make_reservation(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Make a reservation with a relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the relay to make a reservation with
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if reservation succeeded, False otherwise
|
||||
|
||||
"""
|
||||
if peer_id not in self._discovered_relays:
|
||||
logger.error("Cannot make reservation with unknown relay %s", peer_id)
|
||||
return False
|
||||
|
||||
stream = None
|
||||
try:
|
||||
logger.debug("Making reservation with relay %s", peer_id)
|
||||
|
||||
# Open a stream to the relay with timeout
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if not stream:
|
||||
logger.error("Failed to open stream to relay %s", peer_id)
|
||||
return False
|
||||
except trio.TooSlowError:
|
||||
logger.error("Timeout opening stream to relay %s", peer_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create and send reservation request
|
||||
request = HopMessage(
|
||||
type=HopMessage.RESERVE,
|
||||
peer=self.host.get_id().to_bytes(),
|
||||
)
|
||||
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
await stream.write(request.SerializeToString())
|
||||
|
||||
# Wait for response
|
||||
response_bytes = await stream.read()
|
||||
if not response_bytes:
|
||||
logger.error("No response received from relay %s", peer_id)
|
||||
return False
|
||||
|
||||
# Parse response
|
||||
response = HopMessage()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check if reservation was successful
|
||||
if response.type == HopMessage.RESERVE and response.HasField(
|
||||
"status"
|
||||
):
|
||||
# Access status code directly from protobuf object
|
||||
status_code = getattr(response.status, "code", StatusCode.OK)
|
||||
|
||||
if status_code == StatusCode.OK:
|
||||
# Update relay info with reservation details
|
||||
relay_info = self._discovered_relays[peer_id]
|
||||
relay_info.has_reservation = True
|
||||
|
||||
if response.HasField("reservation") and response.HasField(
|
||||
"limit"
|
||||
):
|
||||
relay_info.reservation_expires_at = (
|
||||
response.reservation.expire
|
||||
)
|
||||
relay_info.reservation_data_limit = response.limit.data
|
||||
|
||||
logger.debug(
|
||||
"Successfully made reservation with relay %s", peer_id
|
||||
)
|
||||
return True
|
||||
|
||||
# Reservation failed
|
||||
error_message = "Unknown error"
|
||||
if response.HasField("status"):
|
||||
# Access message directly from protobuf object
|
||||
error_message = getattr(response.status, "message", "")
|
||||
|
||||
logger.warning(
|
||||
"Reservation request rejected by relay %s: %s",
|
||||
peer_id,
|
||||
error_message,
|
||||
)
|
||||
return False
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.error(
|
||||
"Timeout during reservation process with relay %s", peer_id
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error making reservation with relay %s: %s", peer_id, str(e))
|
||||
return False
|
||||
finally:
|
||||
# Always close the stream
|
||||
if stream:
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing the stream
|
||||
|
||||
return False
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Clean up expired relays and reservations."""
|
||||
now = time.time()
|
||||
to_remove = []
|
||||
|
||||
for peer_id, relay_info in self._discovered_relays.items():
|
||||
# Check if relay hasn't been seen in a while (3x discovery interval)
|
||||
if now - relay_info.last_seen > self.discovery_interval * 3:
|
||||
to_remove.append(peer_id)
|
||||
continue
|
||||
|
||||
# Check if reservation has expired
|
||||
if (
|
||||
relay_info.has_reservation
|
||||
and relay_info.reservation_expires_at
|
||||
and now > relay_info.reservation_expires_at
|
||||
):
|
||||
relay_info.has_reservation = False
|
||||
relay_info.reservation_expires_at = None
|
||||
relay_info.reservation_data_limit = None
|
||||
|
||||
# If auto-reserve is enabled, try to renew
|
||||
if self.auto_reserve:
|
||||
await self.make_reservation(peer_id)
|
||||
|
||||
# Remove expired relays
|
||||
for peer_id in to_remove:
|
||||
del self._discovered_relays[peer_id]
|
||||
if peer_id in self._protocol_cache:
|
||||
del self._protocol_cache[peer_id]
|
||||
|
||||
def get_relays(self) -> list[ID]:
|
||||
"""
|
||||
Get a list of discovered relay peer IDs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
List of discovered relay peer IDs
|
||||
|
||||
"""
|
||||
return list(self._discovered_relays.keys())
|
||||
|
||||
def get_relay_info(self, peer_id: ID) -> RelayInfo | None:
|
||||
"""
|
||||
Get information about a specific relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the relay to get information about
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[RelayInfo]
|
||||
Information about the relay, or None if not found
|
||||
|
||||
"""
|
||||
return self._discovered_relays.get(peer_id)
|
||||
|
||||
def get_relay(self) -> ID | None:
|
||||
"""
|
||||
Get a single relay peer ID for connection purposes.
|
||||
Prioritizes relays with active reservations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[ID]
|
||||
ID of a discovered relay, or None if no relays found
|
||||
|
||||
"""
|
||||
if not self._discovered_relays:
|
||||
return None
|
||||
|
||||
# First try to find a relay with an active reservation
|
||||
for peer_id, relay_info in self._discovered_relays.items():
|
||||
if relay_info and relay_info.has_reservation:
|
||||
return peer_id
|
||||
|
||||
return next(iter(self._discovered_relays.keys()), None)
|
||||
Reference in New Issue
Block a user