mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 15:40:54 +00:00
Merge branch 'main' into feat/619-store-pubkey-peerid-peerstore
This commit is contained in:
28
libp2p/relay/__init__.py
Normal file
28
libp2p/relay/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""
|
||||
Relay module for libp2p.
|
||||
|
||||
This package includes implementations of circuit relay protocols
|
||||
for enabling connectivity between peers behind NATs or firewalls.
|
||||
"""
|
||||
|
||||
# Import the circuit_v2 module to make it accessible
|
||||
# through the relay package
|
||||
from libp2p.relay.circuit_v2 import (
|
||||
PROTOCOL_ID,
|
||||
CircuitV2Protocol,
|
||||
CircuitV2Transport,
|
||||
RelayDiscovery,
|
||||
RelayLimits,
|
||||
RelayResourceManager,
|
||||
Reservation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CircuitV2Protocol",
|
||||
"CircuitV2Transport",
|
||||
"PROTOCOL_ID",
|
||||
"RelayDiscovery",
|
||||
"RelayLimits",
|
||||
"RelayResourceManager",
|
||||
"Reservation",
|
||||
]
|
||||
32
libp2p/relay/circuit_v2/__init__.py
Normal file
32
libp2p/relay/circuit_v2/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Circuit Relay v2 implementation for libp2p.
|
||||
|
||||
This package implements the Circuit Relay v2 protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
||||
"""
|
||||
|
||||
from .discovery import (
|
||||
RelayDiscovery,
|
||||
)
|
||||
from .protocol import (
|
||||
PROTOCOL_ID,
|
||||
CircuitV2Protocol,
|
||||
)
|
||||
from .resources import (
|
||||
RelayLimits,
|
||||
RelayResourceManager,
|
||||
Reservation,
|
||||
)
|
||||
from .transport import (
|
||||
CircuitV2Transport,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CircuitV2Protocol",
|
||||
"PROTOCOL_ID",
|
||||
"RelayLimits",
|
||||
"Reservation",
|
||||
"RelayResourceManager",
|
||||
"CircuitV2Transport",
|
||||
"RelayDiscovery",
|
||||
]
|
||||
92
libp2p/relay/circuit_v2/config.py
Normal file
92
libp2p/relay/circuit_v2/config.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""
|
||||
Configuration management for Circuit Relay v2.
|
||||
|
||||
This module handles configuration for relay roles, resource limits,
|
||||
and discovery settings.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
|
||||
from .resources import (
|
||||
RelayLimits,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayConfig:
|
||||
"""Configuration for Circuit Relay v2."""
|
||||
|
||||
# Role configuration
|
||||
enable_hop: bool = False # Whether to act as a relay (hop)
|
||||
enable_stop: bool = True # Whether to accept relayed connections (stop)
|
||||
enable_client: bool = True # Whether to use relays for dialing
|
||||
|
||||
# Resource limits
|
||||
limits: RelayLimits | None = None
|
||||
|
||||
# Discovery configuration
|
||||
bootstrap_relays: list[PeerInfo] = field(default_factory=list)
|
||||
min_relays: int = 3
|
||||
max_relays: int = 20
|
||||
discovery_interval: int = 300 # seconds
|
||||
|
||||
# Connection configuration
|
||||
reservation_ttl: int = 3600 # seconds
|
||||
max_circuit_duration: int = 3600 # seconds
|
||||
max_circuit_bytes: int = 1024 * 1024 * 1024 # 1GB
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize default values."""
|
||||
if self.limits is None:
|
||||
self.limits = RelayLimits(
|
||||
duration=self.max_circuit_duration,
|
||||
data=self.max_circuit_bytes,
|
||||
max_circuit_conns=8,
|
||||
max_reservations=4,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HopConfig:
|
||||
"""Configuration specific to relay (hop) nodes."""
|
||||
|
||||
# Resource limits per IP
|
||||
max_reservations_per_ip: int = 8
|
||||
max_circuits_per_ip: int = 16
|
||||
|
||||
# Rate limiting
|
||||
reservation_rate_per_ip: int = 4 # per minute
|
||||
circuit_rate_per_ip: int = 8 # per minute
|
||||
|
||||
# Resource quotas
|
||||
max_circuits_total: int = 64
|
||||
max_reservations_total: int = 32
|
||||
|
||||
# Bandwidth limits
|
||||
max_bandwidth_per_circuit: int = 1024 * 1024 # 1MB/s
|
||||
max_bandwidth_total: int = 10 * 1024 * 1024 # 10MB/s
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
"""Configuration specific to relay clients."""
|
||||
|
||||
# Relay selection
|
||||
min_relay_score: float = 0.5
|
||||
max_relay_latency: float = 1.0 # seconds
|
||||
|
||||
# Auto-relay settings
|
||||
enable_auto_relay: bool = True
|
||||
auto_relay_timeout: int = 30 # seconds
|
||||
max_auto_relay_attempts: int = 3
|
||||
|
||||
# Reservation management
|
||||
reservation_refresh_threshold: float = 0.8 # Refresh at 80% of TTL
|
||||
max_concurrent_reservations: int = 2
|
||||
537
libp2p/relay/circuit_v2/discovery.py
Normal file
537
libp2p/relay/circuit_v2/discovery.py
Normal file
@ -0,0 +1,537 @@
|
||||
"""
|
||||
Discovery module for Circuit Relay v2.
|
||||
|
||||
This module handles discovering and tracking relay nodes in the network.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .pb.circuit_pb2 import (
|
||||
HopMessage,
|
||||
)
|
||||
from .protocol import (
|
||||
PROTOCOL_ID,
|
||||
)
|
||||
from .protocol_buffer import (
|
||||
StatusCode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2.discovery")
|
||||
|
||||
# Constants
|
||||
MAX_RELAYS_TO_TRACK = 10
|
||||
DEFAULT_DISCOVERY_INTERVAL = 60 # seconds
|
||||
STREAM_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
# Extended interfaces for type checking
|
||||
@runtime_checkable
|
||||
class IHostWithMultiselect(TypingProtocol):
|
||||
"""Extended host interface with multiselect attribute."""
|
||||
|
||||
@property
|
||||
def multiselect(self) -> Any:
|
||||
"""Get the multiselect component."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayInfo:
|
||||
"""Information about a discovered relay."""
|
||||
|
||||
peer_id: ID
|
||||
discovered_at: float
|
||||
last_seen: float
|
||||
has_reservation: bool = False
|
||||
reservation_expires_at: float | None = None
|
||||
reservation_data_limit: int | None = None
|
||||
|
||||
|
||||
class RelayDiscovery(Service):
|
||||
"""
|
||||
Discovery service for Circuit Relay v2 nodes.
|
||||
|
||||
This service discovers and keeps track of available relay nodes, and optionally
|
||||
makes reservations with them.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
auto_reserve: bool = False,
|
||||
discovery_interval: int = DEFAULT_DISCOVERY_INTERVAL,
|
||||
max_relays: int = MAX_RELAYS_TO_TRACK,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the discovery service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this discovery service is running on
|
||||
auto_reserve : bool
|
||||
Whether to automatically make reservations with discovered relays
|
||||
discovery_interval : int
|
||||
How often to run discovery, in seconds
|
||||
max_relays : int
|
||||
Maximum number of relays to track
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.auto_reserve = auto_reserve
|
||||
self.discovery_interval = discovery_interval
|
||||
self.max_relays = max_relays
|
||||
self._discovered_relays: dict[ID, RelayInfo] = {}
|
||||
self._protocol_cache: dict[
|
||||
ID, set[str]
|
||||
] = {} # Cache protocol info to reduce queries
|
||||
self.event_started = trio.Event()
|
||||
self.is_running = False
|
||||
|
||||
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
|
||||
"""Run the discovery service."""
|
||||
try:
|
||||
self.is_running = True
|
||||
self.event_started.set()
|
||||
task_status.started()
|
||||
|
||||
# Main discovery loop
|
||||
async with trio.open_nursery() as nursery:
|
||||
# Run initial discovery
|
||||
nursery.start_soon(self.discover_relays)
|
||||
|
||||
# Set up periodic discovery
|
||||
while True:
|
||||
await trio.sleep(self.discovery_interval)
|
||||
if not self.manager.is_running:
|
||||
break
|
||||
nursery.start_soon(self.discover_relays)
|
||||
|
||||
# Cleanup expired relays and reservations
|
||||
await self._cleanup_expired()
|
||||
|
||||
finally:
|
||||
self.is_running = False
|
||||
|
||||
async def discover_relays(self) -> None:
|
||||
r"""
|
||||
Discover relay nodes in the network.
|
||||
|
||||
This method queries the network for peers that support the
|
||||
Circuit Relay v2 protocol.
|
||||
"""
|
||||
logger.debug("Starting relay discovery")
|
||||
|
||||
try:
|
||||
# Get connected peers
|
||||
connected_peers = self.host.get_connected_peers()
|
||||
logger.debug(
|
||||
"Checking %d connected peers for relay support", len(connected_peers)
|
||||
)
|
||||
|
||||
# Check each peer if they support the relay protocol
|
||||
for peer_id in connected_peers:
|
||||
if peer_id == self.host.get_id():
|
||||
continue # Skip ourselves
|
||||
|
||||
if peer_id in self._discovered_relays:
|
||||
# Update last seen time for existing relay
|
||||
self._discovered_relays[peer_id].last_seen = time.time()
|
||||
continue
|
||||
|
||||
# Check if peer supports the relay protocol
|
||||
with trio.move_on_after(5): # Don't wait too long for protocol info
|
||||
if await self._supports_relay_protocol(peer_id):
|
||||
await self._add_relay(peer_id)
|
||||
|
||||
# Limit number of relays we track
|
||||
if len(self._discovered_relays) > self.max_relays:
|
||||
# Sort by last seen time and keep only the most recent ones
|
||||
sorted_relays = sorted(
|
||||
self._discovered_relays.items(),
|
||||
key=lambda x: x[1].last_seen,
|
||||
reverse=True,
|
||||
)
|
||||
to_remove = sorted_relays[self.max_relays :]
|
||||
for peer_id, _ in to_remove:
|
||||
del self._discovered_relays[peer_id]
|
||||
|
||||
logger.debug(
|
||||
"Discovery completed, tracking %d relays", len(self._discovered_relays)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error during relay discovery: %s", str(e))
|
||||
|
||||
async def _supports_relay_protocol(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer supports the relay protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the peer to check
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the peer supports the relay protocol, False otherwise
|
||||
|
||||
"""
|
||||
# Check cache first
|
||||
if peer_id in self._protocol_cache:
|
||||
return PROTOCOL_ID in self._protocol_cache[peer_id]
|
||||
|
||||
# Method 1: Try peerstore
|
||||
result = await self._check_via_peerstore(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Method 2: Try direct stream connection
|
||||
result = await self._check_via_direct_connection(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Method 3: Try protocols from mux
|
||||
result = await self._check_via_mux(peer_id)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Default: Cannot determine, assume false
|
||||
return False
|
||||
|
||||
async def _check_via_peerstore(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via peerstore."""
|
||||
try:
|
||||
peerstore = self.host.get_peerstore()
|
||||
proto_getter = peerstore.get_protocols
|
||||
|
||||
if not callable(proto_getter):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Try to get protocols
|
||||
proto_result = proto_getter(peer_id)
|
||||
|
||||
# Get protocols list
|
||||
protocols_list = []
|
||||
if hasattr(proto_result, "__await__"):
|
||||
protocols_list = await cast(Any, proto_result)
|
||||
else:
|
||||
protocols_list = proto_result
|
||||
|
||||
# Check result
|
||||
if protocols_list is not None:
|
||||
protocols = set(protocols_list)
|
||||
self._protocol_cache[peer_id] = protocols
|
||||
return PROTOCOL_ID in protocols
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error getting protocols: %s", str(e))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug("Error accessing peerstore: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _check_via_direct_connection(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via direct connection."""
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if stream:
|
||||
await stream.close()
|
||||
self._protocol_cache[peer_id] = {PROTOCOL_ID}
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to open relay protocol stream to %s: %s", peer_id, str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
async def _check_via_mux(self, peer_id: ID) -> bool | None:
|
||||
"""Check protocol support via mux protocols."""
|
||||
try:
|
||||
if not (hasattr(self.host, "get_mux") and self.host.get_mux() is not None):
|
||||
return None
|
||||
|
||||
mux = self.host.get_mux()
|
||||
if not hasattr(mux, "protocols"):
|
||||
return None
|
||||
|
||||
peer_protocols = set()
|
||||
# Get protocols from mux with proper type safety
|
||||
available_protocols = []
|
||||
if hasattr(mux, "get_protocols"):
|
||||
# Get protocols with proper typing
|
||||
mux_protocols = mux.get_protocols()
|
||||
if isinstance(mux_protocols, (list, tuple)):
|
||||
available_protocols = list(mux_protocols)
|
||||
|
||||
for protocol in available_protocols:
|
||||
try:
|
||||
with trio.fail_after(2): # Quick check
|
||||
# Ensure we have a proper protocol object
|
||||
# Use string representation since we can't use isinstance
|
||||
is_tprotocol = str(type(protocol)) == str(type(TProtocol))
|
||||
protocol_obj = (
|
||||
protocol if is_tprotocol else TProtocol(str(protocol))
|
||||
)
|
||||
stream = await self.host.new_stream(peer_id, [protocol_obj])
|
||||
if stream:
|
||||
peer_protocols.add(str(protocol_obj))
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing the stream
|
||||
|
||||
self._protocol_cache[peer_id] = peer_protocols
|
||||
protocol_str = str(PROTOCOL_ID)
|
||||
for protocol in peer_protocols:
|
||||
if protocol == protocol_str:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Error checking protocols via mux: %s", str(e))
|
||||
return None
|
||||
|
||||
async def _add_relay(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Add a peer as a relay and optionally make a reservation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the peer to add as a relay
|
||||
|
||||
"""
|
||||
now = time.time()
|
||||
relay_info = RelayInfo(
|
||||
peer_id=peer_id,
|
||||
discovered_at=now,
|
||||
last_seen=now,
|
||||
)
|
||||
self._discovered_relays[peer_id] = relay_info
|
||||
logger.debug("Added relay %s to discovered relays", peer_id)
|
||||
|
||||
# If auto-reserve is enabled, make a reservation with this relay
|
||||
if self.auto_reserve:
|
||||
await self.make_reservation(peer_id)
|
||||
|
||||
async def make_reservation(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Make a reservation with a relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the relay to make a reservation with
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if reservation succeeded, False otherwise
|
||||
|
||||
"""
|
||||
if peer_id not in self._discovered_relays:
|
||||
logger.error("Cannot make reservation with unknown relay %s", peer_id)
|
||||
return False
|
||||
|
||||
stream = None
|
||||
try:
|
||||
logger.debug("Making reservation with relay %s", peer_id)
|
||||
|
||||
# Open a stream to the relay with timeout
|
||||
try:
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
stream = await self.host.new_stream(peer_id, [PROTOCOL_ID])
|
||||
if not stream:
|
||||
logger.error("Failed to open stream to relay %s", peer_id)
|
||||
return False
|
||||
except trio.TooSlowError:
|
||||
logger.error("Timeout opening stream to relay %s", peer_id)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create and send reservation request
|
||||
request = HopMessage(
|
||||
type=HopMessage.RESERVE,
|
||||
peer=self.host.get_id().to_bytes(),
|
||||
)
|
||||
|
||||
with trio.fail_after(STREAM_TIMEOUT):
|
||||
await stream.write(request.SerializeToString())
|
||||
|
||||
# Wait for response
|
||||
response_bytes = await stream.read()
|
||||
if not response_bytes:
|
||||
logger.error("No response received from relay %s", peer_id)
|
||||
return False
|
||||
|
||||
# Parse response
|
||||
response = HopMessage()
|
||||
response.ParseFromString(response_bytes)
|
||||
|
||||
# Check if reservation was successful
|
||||
if response.type == HopMessage.RESERVE and response.HasField(
|
||||
"status"
|
||||
):
|
||||
# Access status code directly from protobuf object
|
||||
status_code = getattr(response.status, "code", StatusCode.OK)
|
||||
|
||||
if status_code == StatusCode.OK:
|
||||
# Update relay info with reservation details
|
||||
relay_info = self._discovered_relays[peer_id]
|
||||
relay_info.has_reservation = True
|
||||
|
||||
if response.HasField("reservation") and response.HasField(
|
||||
"limit"
|
||||
):
|
||||
relay_info.reservation_expires_at = (
|
||||
response.reservation.expire
|
||||
)
|
||||
relay_info.reservation_data_limit = response.limit.data
|
||||
|
||||
logger.debug(
|
||||
"Successfully made reservation with relay %s", peer_id
|
||||
)
|
||||
return True
|
||||
|
||||
# Reservation failed
|
||||
error_message = "Unknown error"
|
||||
if response.HasField("status"):
|
||||
# Access message directly from protobuf object
|
||||
error_message = getattr(response.status, "message", "")
|
||||
|
||||
logger.warning(
|
||||
"Reservation request rejected by relay %s: %s",
|
||||
peer_id,
|
||||
error_message,
|
||||
)
|
||||
return False
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.error(
|
||||
"Timeout during reservation process with relay %s", peer_id
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error making reservation with relay %s: %s", peer_id, str(e))
|
||||
return False
|
||||
finally:
|
||||
# Always close the stream
|
||||
if stream:
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing the stream
|
||||
|
||||
return False
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
"""Clean up expired relays and reservations."""
|
||||
now = time.time()
|
||||
to_remove = []
|
||||
|
||||
for peer_id, relay_info in self._discovered_relays.items():
|
||||
# Check if relay hasn't been seen in a while (3x discovery interval)
|
||||
if now - relay_info.last_seen > self.discovery_interval * 3:
|
||||
to_remove.append(peer_id)
|
||||
continue
|
||||
|
||||
# Check if reservation has expired
|
||||
if (
|
||||
relay_info.has_reservation
|
||||
and relay_info.reservation_expires_at
|
||||
and now > relay_info.reservation_expires_at
|
||||
):
|
||||
relay_info.has_reservation = False
|
||||
relay_info.reservation_expires_at = None
|
||||
relay_info.reservation_data_limit = None
|
||||
|
||||
# If auto-reserve is enabled, try to renew
|
||||
if self.auto_reserve:
|
||||
await self.make_reservation(peer_id)
|
||||
|
||||
# Remove expired relays
|
||||
for peer_id in to_remove:
|
||||
del self._discovered_relays[peer_id]
|
||||
if peer_id in self._protocol_cache:
|
||||
del self._protocol_cache[peer_id]
|
||||
|
||||
def get_relays(self) -> list[ID]:
|
||||
"""
|
||||
Get a list of discovered relay peer IDs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ID]
|
||||
List of discovered relay peer IDs
|
||||
|
||||
"""
|
||||
return list(self._discovered_relays.keys())
|
||||
|
||||
def get_relay_info(self, peer_id: ID) -> RelayInfo | None:
|
||||
"""
|
||||
Get information about a specific relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The ID of the relay to get information about
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[RelayInfo]
|
||||
Information about the relay, or None if not found
|
||||
|
||||
"""
|
||||
return self._discovered_relays.get(peer_id)
|
||||
|
||||
def get_relay(self) -> ID | None:
|
||||
"""
|
||||
Get a single relay peer ID for connection purposes.
|
||||
Prioritizes relays with active reservations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[ID]
|
||||
ID of a discovered relay, or None if no relays found
|
||||
|
||||
"""
|
||||
if not self._discovered_relays:
|
||||
return None
|
||||
|
||||
# First try to find a relay with an active reservation
|
||||
for peer_id, relay_info in self._discovered_relays.items():
|
||||
if relay_info and relay_info.has_reservation:
|
||||
return peer_id
|
||||
|
||||
return next(iter(self._discovered_relays.keys()), None)
|
||||
16
libp2p/relay/circuit_v2/pb/__init__.py
Normal file
16
libp2p/relay/circuit_v2/pb/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Protocol buffer package for circuit_v2.
|
||||
|
||||
Contains generated protobuf code for circuit_v2 relay protocol.
|
||||
"""
|
||||
|
||||
# Import the classes to be accessible directly from the package
|
||||
from .circuit_pb2 import (
|
||||
HopMessage,
|
||||
Limit,
|
||||
Reservation,
|
||||
Status,
|
||||
StopMessage,
|
||||
)
|
||||
|
||||
__all__ = ["HopMessage", "Limit", "Reservation", "Status", "StopMessage"]
|
||||
55
libp2p/relay/circuit_v2/pb/circuit.proto
Normal file
55
libp2p/relay/circuit_v2/pb/circuit.proto
Normal file
@ -0,0 +1,55 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package circuit.pb.v2;
|
||||
|
||||
// Circuit v2 message types
|
||||
message HopMessage {
|
||||
enum Type {
|
||||
RESERVE = 0;
|
||||
CONNECT = 1;
|
||||
STATUS = 2;
|
||||
}
|
||||
|
||||
Type type = 1;
|
||||
bytes peer = 2;
|
||||
Reservation reservation = 3;
|
||||
Limit limit = 4;
|
||||
Status status = 5;
|
||||
}
|
||||
|
||||
message StopMessage {
|
||||
enum Type {
|
||||
CONNECT = 0;
|
||||
STATUS = 1;
|
||||
}
|
||||
|
||||
Type type = 1;
|
||||
bytes peer = 2;
|
||||
Status status = 3;
|
||||
}
|
||||
|
||||
message Reservation {
|
||||
bytes voucher = 1;
|
||||
bytes signature = 2;
|
||||
int64 expire = 3;
|
||||
}
|
||||
|
||||
message Limit {
|
||||
int64 duration = 1;
|
||||
int64 data = 2;
|
||||
}
|
||||
|
||||
message Status {
|
||||
enum Code {
|
||||
OK = 0;
|
||||
RESERVATION_REFUSED = 100;
|
||||
RESOURCE_LIMIT_EXCEEDED = 101;
|
||||
PERMISSION_DENIED = 102;
|
||||
CONNECTION_FAILED = 200;
|
||||
DIAL_REFUSED = 201;
|
||||
STOP_FAILED = 300;
|
||||
MALFORMED_MESSAGE = 400;
|
||||
}
|
||||
Code code = 1;
|
||||
string message = 2;
|
||||
}
|
||||
37
libp2p/relay/circuit_v2/pb/circuit_pb2.py
Normal file
37
libp2p/relay/circuit_v2/pb/circuit_pb2.py
Normal file
@ -0,0 +1,37 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: libp2p/relay/circuit_v2/pb/circuit.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(libp2p/relay/circuit_v2/pb/circuit.proto\x12\rcircuit.pb.v2\"\xf3\x01\n\nHopMessage\x12,\n\x04type\x18\x01 \x01(\x0e\x32\x1e.circuit.pb.v2.HopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12/\n\x0breservation\x18\x03 \x01(\x0b\x32\x1a.circuit.pb.v2.Reservation\x12#\n\x05limit\x18\x04 \x01(\x0b\x32\x14.circuit.pb.v2.Limit\x12%\n\x06status\x18\x05 \x01(\x0b\x32\x15.circuit.pb.v2.Status\",\n\x04Type\x12\x0b\n\x07RESERVE\x10\x00\x12\x0b\n\x07\x43ONNECT\x10\x01\x12\n\n\x06STATUS\x10\x02\"\x92\x01\n\x0bStopMessage\x12-\n\x04type\x18\x01 \x01(\x0e\x32\x1f.circuit.pb.v2.StopMessage.Type\x12\x0c\n\x04peer\x18\x02 \x01(\x0c\x12%\n\x06status\x18\x03 \x01(\x0b\x32\x15.circuit.pb.v2.Status\"\x1f\n\x04Type\x12\x0b\n\x07\x43ONNECT\x10\x00\x12\n\n\x06STATUS\x10\x01\"A\n\x0bReservation\x12\x0f\n\x07voucher\x18\x01 \x01(\x0c\x12\x11\n\tsignature\x18\x02 \x01(\x0c\x12\x0e\n\x06\x65xpire\x18\x03 \x01(\x03\"\'\n\x05Limit\x12\x10\n\x08\x64uration\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x03\"\xf6\x01\n\x06Status\x12(\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1a.circuit.pb.v2.Status.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\"\xb0\x01\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\x17\n\x13RESERVATION_REFUSED\x10\x64\x12\x1b\n\x17RESOURCE_LIMIT_EXCEEDED\x10\x65\x12\x15\n\x11PERMISSION_DENIED\x10\x66\x12\x16\n\x11\x43ONNECTION_FAILED\x10\xc8\x01\x12\x11\n\x0c\x44IAL_REFUSED\x10\xc9\x01\x12\x10\n\x0bSTOP_FAILED\x10\xac\x02\x12\x16\n\x11MALFORMED_MESSAGE\x10\x90\x03\x62\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'libp2p.relay.circuit_v2.pb.circuit_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
DESCRIPTOR._options = None
|
||||
_HOPMESSAGE._serialized_start=60
|
||||
_HOPMESSAGE._serialized_end=303
|
||||
_HOPMESSAGE_TYPE._serialized_start=259
|
||||
_HOPMESSAGE_TYPE._serialized_end=303
|
||||
_STOPMESSAGE._serialized_start=306
|
||||
_STOPMESSAGE._serialized_end=452
|
||||
_STOPMESSAGE_TYPE._serialized_start=421
|
||||
_STOPMESSAGE_TYPE._serialized_end=452
|
||||
_RESERVATION._serialized_start=454
|
||||
_RESERVATION._serialized_end=519
|
||||
_LIMIT._serialized_start=521
|
||||
_LIMIT._serialized_end=560
|
||||
_STATUS._serialized_start=563
|
||||
_STATUS._serialized_end=809
|
||||
_STATUS_CODE._serialized_start=633
|
||||
_STATUS_CODE._serialized_end=809
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
184
libp2p/relay/circuit_v2/pb/circuit_pb2.pyi
Normal file
184
libp2p/relay/circuit_v2/pb/circuit_pb2.pyi
Normal file
@ -0,0 +1,184 @@
|
||||
"""
|
||||
@generated by mypy-protobuf. Do not edit manually!
|
||||
isort:skip_file
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import google.protobuf.descriptor
|
||||
import google.protobuf.internal.enum_type_wrapper
|
||||
import google.protobuf.message
|
||||
import sys
|
||||
import typing
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import typing as typing_extensions
|
||||
else:
|
||||
import typing_extensions
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
|
||||
|
||||
@typing.final
|
||||
class HopMessage(google.protobuf.message.Message):
|
||||
"""Circuit v2 message types"""
|
||||
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _Type:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HopMessage._Type.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
RESERVE: HopMessage._Type.ValueType # 0
|
||||
CONNECT: HopMessage._Type.ValueType # 1
|
||||
STATUS: HopMessage._Type.ValueType # 2
|
||||
|
||||
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
|
||||
RESERVE: HopMessage.Type.ValueType # 0
|
||||
CONNECT: HopMessage.Type.ValueType # 1
|
||||
STATUS: HopMessage.Type.ValueType # 2
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
PEER_FIELD_NUMBER: builtins.int
|
||||
RESERVATION_FIELD_NUMBER: builtins.int
|
||||
LIMIT_FIELD_NUMBER: builtins.int
|
||||
STATUS_FIELD_NUMBER: builtins.int
|
||||
type: global___HopMessage.Type.ValueType
|
||||
peer: builtins.bytes
|
||||
@property
|
||||
def reservation(self) -> global___Reservation: ...
|
||||
@property
|
||||
def limit(self) -> global___Limit: ...
|
||||
@property
|
||||
def status(self) -> global___Status: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___HopMessage.Type.ValueType = ...,
|
||||
peer: builtins.bytes = ...,
|
||||
reservation: global___Reservation | None = ...,
|
||||
limit: global___Limit | None = ...,
|
||||
status: global___Status | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["limit", b"limit", "reservation", b"reservation", "status", b"status"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["limit", b"limit", "peer", b"peer", "reservation", b"reservation", "status", b"status", "type", b"type"]) -> None: ...
|
||||
|
||||
global___HopMessage = HopMessage
|
||||
|
||||
@typing.final
|
||||
class StopMessage(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _Type:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[StopMessage._Type.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
CONNECT: StopMessage._Type.ValueType # 0
|
||||
STATUS: StopMessage._Type.ValueType # 1
|
||||
|
||||
class Type(_Type, metaclass=_TypeEnumTypeWrapper): ...
|
||||
CONNECT: StopMessage.Type.ValueType # 0
|
||||
STATUS: StopMessage.Type.ValueType # 1
|
||||
|
||||
TYPE_FIELD_NUMBER: builtins.int
|
||||
PEER_FIELD_NUMBER: builtins.int
|
||||
STATUS_FIELD_NUMBER: builtins.int
|
||||
type: global___StopMessage.Type.ValueType
|
||||
peer: builtins.bytes
|
||||
@property
|
||||
def status(self) -> global___Status: ...
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
type: global___StopMessage.Type.ValueType = ...,
|
||||
peer: builtins.bytes = ...,
|
||||
status: global___Status | None = ...,
|
||||
) -> None: ...
|
||||
def HasField(self, field_name: typing.Literal["status", b"status"]) -> builtins.bool: ...
|
||||
def ClearField(self, field_name: typing.Literal["peer", b"peer", "status", b"status", "type", b"type"]) -> None: ...
|
||||
|
||||
global___StopMessage = StopMessage
|
||||
|
||||
@typing.final
|
||||
class Reservation(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
VOUCHER_FIELD_NUMBER: builtins.int
|
||||
SIGNATURE_FIELD_NUMBER: builtins.int
|
||||
EXPIRE_FIELD_NUMBER: builtins.int
|
||||
voucher: builtins.bytes
|
||||
signature: builtins.bytes
|
||||
expire: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
voucher: builtins.bytes = ...,
|
||||
signature: builtins.bytes = ...,
|
||||
expire: builtins.int = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["expire", b"expire", "signature", b"signature", "voucher", b"voucher"]) -> None: ...
|
||||
|
||||
global___Reservation = Reservation
|
||||
|
||||
@typing.final
|
||||
class Limit(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
DURATION_FIELD_NUMBER: builtins.int
|
||||
DATA_FIELD_NUMBER: builtins.int
|
||||
duration: builtins.int
|
||||
data: builtins.int
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
duration: builtins.int = ...,
|
||||
data: builtins.int = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["data", b"data", "duration", b"duration"]) -> None: ...
|
||||
|
||||
global___Limit = Limit
|
||||
|
||||
@typing.final
|
||||
class Status(google.protobuf.message.Message):
|
||||
DESCRIPTOR: google.protobuf.descriptor.Descriptor
|
||||
|
||||
class _Code:
|
||||
ValueType = typing.NewType("ValueType", builtins.int)
|
||||
V: typing_extensions.TypeAlias = ValueType
|
||||
|
||||
class _CodeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[Status._Code.ValueType], builtins.type):
|
||||
DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
|
||||
OK: Status._Code.ValueType # 0
|
||||
RESERVATION_REFUSED: Status._Code.ValueType # 100
|
||||
RESOURCE_LIMIT_EXCEEDED: Status._Code.ValueType # 101
|
||||
PERMISSION_DENIED: Status._Code.ValueType # 102
|
||||
CONNECTION_FAILED: Status._Code.ValueType # 200
|
||||
DIAL_REFUSED: Status._Code.ValueType # 201
|
||||
STOP_FAILED: Status._Code.ValueType # 300
|
||||
MALFORMED_MESSAGE: Status._Code.ValueType # 400
|
||||
|
||||
class Code(_Code, metaclass=_CodeEnumTypeWrapper): ...
|
||||
OK: Status.Code.ValueType # 0
|
||||
RESERVATION_REFUSED: Status.Code.ValueType # 100
|
||||
RESOURCE_LIMIT_EXCEEDED: Status.Code.ValueType # 101
|
||||
PERMISSION_DENIED: Status.Code.ValueType # 102
|
||||
CONNECTION_FAILED: Status.Code.ValueType # 200
|
||||
DIAL_REFUSED: Status.Code.ValueType # 201
|
||||
STOP_FAILED: Status.Code.ValueType # 300
|
||||
MALFORMED_MESSAGE: Status.Code.ValueType # 400
|
||||
|
||||
CODE_FIELD_NUMBER: builtins.int
|
||||
MESSAGE_FIELD_NUMBER: builtins.int
|
||||
code: global___Status.Code.ValueType
|
||||
message: builtins.str
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
code: global___Status.Code.ValueType = ...,
|
||||
message: builtins.str = ...,
|
||||
) -> None: ...
|
||||
def ClearField(self, field_name: typing.Literal["code", b"code", "message", b"message"]) -> None: ...
|
||||
|
||||
global___Status = Status
|
||||
800
libp2p/relay/circuit_v2/protocol.py
Normal file
800
libp2p/relay/circuit_v2/protocol.py
Normal file
@ -0,0 +1,800 @@
|
||||
"""
|
||||
Circuit Relay v2 protocol implementation.
|
||||
|
||||
This module implements the Circuit Relay v2 protocol as specified in:
|
||||
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Protocol as TypingProtocol,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.io.abc import (
|
||||
ReadWriteCloser,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.stream_muxer.mplex.exceptions import (
|
||||
MplexStreamEOF,
|
||||
MplexStreamReset,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .pb.circuit_pb2 import (
|
||||
HopMessage,
|
||||
Limit,
|
||||
Reservation,
|
||||
Status as PbStatus,
|
||||
StopMessage,
|
||||
)
|
||||
from .protocol_buffer import (
|
||||
StatusCode,
|
||||
create_status,
|
||||
)
|
||||
from .resources import (
|
||||
RelayLimits,
|
||||
RelayResourceManager,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2")
|
||||
|
||||
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
|
||||
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
|
||||
|
||||
# Default limits for relay resources
|
||||
DEFAULT_RELAY_LIMITS = RelayLimits(
|
||||
duration=60 * 60, # 1 hour
|
||||
data=1024 * 1024 * 1024, # 1GB
|
||||
max_circuit_conns=8,
|
||||
max_reservations=4,
|
||||
)
|
||||
|
||||
# Stream operation timeouts
|
||||
STREAM_READ_TIMEOUT = 15 # seconds
|
||||
STREAM_WRITE_TIMEOUT = 15 # seconds
|
||||
STREAM_CLOSE_TIMEOUT = 10 # seconds
|
||||
MAX_READ_RETRIES = 5 # Maximum number of read retries
|
||||
|
||||
|
||||
# Extended interfaces for type checking
|
||||
@runtime_checkable
|
||||
class IHostWithStreamHandlers(TypingProtocol):
|
||||
"""Extended host interface with stream handler methods."""
|
||||
|
||||
def remove_stream_handler(self, protocol_id: TProtocol) -> None:
|
||||
"""Remove a stream handler for a protocol."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class INetStreamWithExtras(TypingProtocol):
|
||||
"""Extended net stream interface with additional methods."""
|
||||
|
||||
def get_remote_peer_id(self) -> ID:
|
||||
"""Get the remote peer ID."""
|
||||
...
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if the stream is open."""
|
||||
...
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the stream is closed."""
|
||||
...
|
||||
|
||||
|
||||
class CircuitV2Protocol(Service):
|
||||
"""
|
||||
CircuitV2Protocol implements the Circuit Relay v2 protocol.
|
||||
|
||||
This protocol allows peers to establish connections through relay nodes
|
||||
when direct connections are not possible (e.g., due to NAT).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
limits: RelayLimits | None = None,
|
||||
allow_hop: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a Circuit Relay v2 protocol instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host instance
|
||||
limits : RelayLimits | None
|
||||
Resource limits for the relay
|
||||
allow_hop : bool
|
||||
Whether to allow this node to act as a relay
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.limits = limits or DEFAULT_RELAY_LIMITS
|
||||
self.allow_hop = allow_hop
|
||||
self.resource_manager = RelayResourceManager(self.limits)
|
||||
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
|
||||
self.event_started = trio.Event()
|
||||
|
||||
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
|
||||
"""Run the protocol service."""
|
||||
try:
|
||||
# Register protocol handlers
|
||||
if self.allow_hop:
|
||||
logger.debug("Registering stream handlers for relay protocol")
|
||||
self.host.set_stream_handler(PROTOCOL_ID, self._handle_hop_stream)
|
||||
self.host.set_stream_handler(STOP_PROTOCOL_ID, self._handle_stop_stream)
|
||||
logger.debug("Stream handlers registered successfully")
|
||||
|
||||
# Signal that we're ready
|
||||
self.event_started.set()
|
||||
task_status.started()
|
||||
logger.debug("Protocol service started")
|
||||
|
||||
# Wait for service to be stopped
|
||||
await self.manager.wait_finished()
|
||||
finally:
|
||||
# Clean up any active relay connections
|
||||
for src_stream, dst_stream in self._active_relays.values():
|
||||
await self._close_stream(src_stream)
|
||||
await self._close_stream(dst_stream)
|
||||
self._active_relays.clear()
|
||||
|
||||
# Unregister protocol handlers
|
||||
if self.allow_hop:
|
||||
try:
|
||||
# Cast host to extended interface with remove_stream_handler
|
||||
host_with_handlers = cast(IHostWithStreamHandlers, self.host)
|
||||
host_with_handlers.remove_stream_handler(PROTOCOL_ID)
|
||||
host_with_handlers.remove_stream_handler(STOP_PROTOCOL_ID)
|
||||
except Exception as e:
|
||||
logger.error("Error unregistering stream handlers: %s", str(e))
|
||||
|
||||
async def _close_stream(self, stream: INetStream | None) -> None:
|
||||
"""Helper function to safely close a stream."""
|
||||
if stream is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
||||
await stream.close()
|
||||
except Exception:
|
||||
try:
|
||||
await stream.reset()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _read_stream_with_retry(
|
||||
self,
|
||||
stream: INetStream,
|
||||
max_retries: int = MAX_READ_RETRIES,
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Helper function to read from a stream with retries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stream : INetStream
|
||||
The stream to read from
|
||||
max_retries : int
|
||||
Maximum number of read retries
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[bytes]
|
||||
The data read from the stream, or None if the stream is closed/reset
|
||||
|
||||
Raises
|
||||
------
|
||||
trio.TooSlowError
|
||||
If read timeout occurs after all retries
|
||||
Exception
|
||||
For other unexpected errors
|
||||
|
||||
"""
|
||||
retries = 0
|
||||
last_error: Any = None
|
||||
backoff_time = 0.2 # Base backoff time in seconds
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
# Try reading with timeout
|
||||
logger.debug(
|
||||
"Attempting to read from stream (attempt %d/%d)",
|
||||
retries + 1,
|
||||
max_retries,
|
||||
)
|
||||
data = await stream.read()
|
||||
if not data: # EOF
|
||||
logger.debug("Stream EOF detected")
|
||||
return None
|
||||
|
||||
logger.debug("Successfully read %d bytes from stream", len(data))
|
||||
return data
|
||||
except trio.WouldBlock:
|
||||
# Just retry immediately if we would block
|
||||
retries += 1
|
||||
logger.debug(
|
||||
"Stream would block (attempt %d/%d), retrying...",
|
||||
retries,
|
||||
max_retries,
|
||||
)
|
||||
await trio.sleep(backoff_time * retries) # Increased backoff time
|
||||
continue
|
||||
except (MplexStreamEOF, MplexStreamReset):
|
||||
# Stream closed/reset - no point retrying
|
||||
logger.debug("Stream closed/reset during read")
|
||||
return None
|
||||
except trio.TooSlowError as e:
|
||||
last_error = e
|
||||
retries += 1
|
||||
logger.debug(
|
||||
"Read timeout (attempt %d/%d), retrying...", retries, max_retries
|
||||
)
|
||||
if retries < max_retries:
|
||||
# Wait longer before retry with increasing backoff
|
||||
await trio.sleep(backoff_time * retries) # Increased backoff
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error reading from stream: %s", str(e))
|
||||
last_error = e
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
await trio.sleep(backoff_time * retries) # Increased backoff
|
||||
continue
|
||||
raise
|
||||
|
||||
if last_error:
|
||||
if isinstance(last_error, trio.TooSlowError):
|
||||
logger.error("Read timed out after %d retries", max_retries)
|
||||
raise last_error
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_hop_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle incoming HOP streams.
|
||||
|
||||
This handler processes relay requests from other peers.
|
||||
"""
|
||||
try:
|
||||
# Try to get peer ID first
|
||||
try:
|
||||
# Cast to extended interface with get_remote_peer_id
|
||||
stream_with_peer_id = cast(INetStreamWithExtras, stream)
|
||||
remote_peer_id = stream_with_peer_id.get_remote_peer_id()
|
||||
remote_id = str(remote_peer_id)
|
||||
except Exception:
|
||||
# Fall back to address if peer ID not available
|
||||
remote_addr = stream.get_remote_address()
|
||||
remote_id = f"peer at {remote_addr}" if remote_addr else "unknown peer"
|
||||
|
||||
logger.debug("Handling hop stream from %s", remote_id)
|
||||
|
||||
# First, handle the read timeout gracefully
|
||||
try:
|
||||
with trio.fail_after(
|
||||
STREAM_READ_TIMEOUT * 2
|
||||
): # Double the timeout for reading
|
||||
msg_bytes = await stream.read()
|
||||
if not msg_bytes:
|
||||
logger.error(
|
||||
"Empty read from stream from %s",
|
||||
remote_id,
|
||||
)
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
|
||||
pb_status.message = "Empty message received"
|
||||
|
||||
response = HopMessage(
|
||||
type=HopMessage.STATUS,
|
||||
status=pb_status,
|
||||
)
|
||||
await stream.write(response.SerializeToString())
|
||||
await trio.sleep(0.5) # Longer wait to ensure message is sent
|
||||
return
|
||||
except trio.TooSlowError:
|
||||
logger.error(
|
||||
"Timeout reading from hop stream from %s",
|
||||
remote_id,
|
||||
)
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(Any, int(StatusCode.CONNECTION_FAILED))
|
||||
pb_status.message = "Stream read timeout"
|
||||
|
||||
response = HopMessage(
|
||||
type=HopMessage.STATUS,
|
||||
status=pb_status,
|
||||
)
|
||||
await stream.write(response.SerializeToString())
|
||||
await trio.sleep(0.5) # Longer wait to ensure the message is sent
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error reading from hop stream from %s: %s",
|
||||
remote_id,
|
||||
str(e),
|
||||
)
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
|
||||
pb_status.message = f"Read error: {str(e)}"
|
||||
|
||||
response = HopMessage(
|
||||
type=HopMessage.STATUS,
|
||||
status=pb_status,
|
||||
)
|
||||
await stream.write(response.SerializeToString())
|
||||
await trio.sleep(0.5) # Longer wait to ensure the message is sent
|
||||
return
|
||||
|
||||
# Parse the message
|
||||
try:
|
||||
hop_msg = HopMessage()
|
||||
hop_msg.ParseFromString(msg_bytes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error parsing hop message from %s: %s",
|
||||
remote_id,
|
||||
str(e),
|
||||
)
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
|
||||
pb_status.message = f"Parse error: {str(e)}"
|
||||
|
||||
response = HopMessage(
|
||||
type=HopMessage.STATUS,
|
||||
status=pb_status,
|
||||
)
|
||||
await stream.write(response.SerializeToString())
|
||||
await trio.sleep(0.5) # Longer wait to ensure the message is sent
|
||||
return
|
||||
|
||||
# Process based on message type
|
||||
if hop_msg.type == HopMessage.RESERVE:
|
||||
logger.debug("Handling RESERVE message from %s", remote_id)
|
||||
await self._handle_reserve(stream, hop_msg)
|
||||
# For RESERVE requests, let the client close the stream
|
||||
return
|
||||
elif hop_msg.type == HopMessage.CONNECT:
|
||||
logger.debug("Handling CONNECT message from %s", remote_id)
|
||||
await self._handle_connect(stream, hop_msg)
|
||||
else:
|
||||
logger.error("Invalid message type %d from %s", hop_msg.type, remote_id)
|
||||
# Send a nice error response using _send_status method
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.MALFORMED_MESSAGE,
|
||||
f"Invalid message type: {hop_msg.type}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error handling hop stream from %s: %s", remote_id, str(e)
|
||||
)
|
||||
try:
|
||||
# Send a nice error response using _send_status method
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.MALFORMED_MESSAGE,
|
||||
f"Internal error: {str(e)}",
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error(
|
||||
"Failed to send error response to %s: %s", remote_id, str(e2)
|
||||
)
|
||||
|
||||
async def _handle_stop_stream(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Handle incoming STOP streams.
|
||||
|
||||
This handler processes incoming relay connections from the destination side.
|
||||
"""
|
||||
try:
|
||||
# Read the incoming message with timeout
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
msg_bytes = await stream.read()
|
||||
stop_msg = StopMessage()
|
||||
stop_msg.ParseFromString(msg_bytes)
|
||||
|
||||
if stop_msg.type != StopMessage.CONNECT:
|
||||
# Use direct attribute access to create status object for error response
|
||||
await self._send_stop_status(
|
||||
stream,
|
||||
StatusCode.MALFORMED_MESSAGE,
|
||||
"Invalid message type",
|
||||
)
|
||||
await self._close_stream(stream)
|
||||
return
|
||||
|
||||
# Get the source stream from active relays
|
||||
peer_id = ID(stop_msg.peer)
|
||||
if peer_id not in self._active_relays:
|
||||
# Use direct attribute access to create status object for error response
|
||||
await self._send_stop_status(
|
||||
stream,
|
||||
StatusCode.CONNECTION_FAILED,
|
||||
"No pending relay connection",
|
||||
)
|
||||
await self._close_stream(stream)
|
||||
return
|
||||
|
||||
src_stream, _ = self._active_relays[peer_id]
|
||||
self._active_relays[peer_id] = (src_stream, stream)
|
||||
|
||||
# Send success status to both sides
|
||||
await self._send_status(
|
||||
src_stream,
|
||||
StatusCode.OK,
|
||||
"Connection established",
|
||||
)
|
||||
await self._send_stop_status(
|
||||
stream,
|
||||
StatusCode.OK,
|
||||
"Connection established",
|
||||
)
|
||||
|
||||
# Start relaying data
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._relay_data, src_stream, stream, peer_id)
|
||||
nursery.start_soon(self._relay_data, stream, src_stream, peer_id)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.error("Timeout reading from stop stream")
|
||||
await self._send_stop_status(
|
||||
stream,
|
||||
StatusCode.CONNECTION_FAILED,
|
||||
"Stream read timeout",
|
||||
)
|
||||
await self._close_stream(stream)
|
||||
except Exception as e:
|
||||
logger.error("Error handling stop stream: %s", str(e))
|
||||
try:
|
||||
await self._send_stop_status(
|
||||
stream,
|
||||
StatusCode.MALFORMED_MESSAGE,
|
||||
str(e),
|
||||
)
|
||||
await self._close_stream(stream)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _handle_reserve(self, stream: INetStream, msg: Any) -> None:
|
||||
"""Handle a reservation request."""
|
||||
peer_id = None
|
||||
try:
|
||||
peer_id = ID(msg.peer)
|
||||
logger.debug("Handling reservation request from peer %s", peer_id)
|
||||
|
||||
# Check if we can accept more reservations
|
||||
if not self.resource_manager.can_accept_reservation(peer_id):
|
||||
logger.debug("Reservation limit exceeded for peer %s", peer_id)
|
||||
# Send status message with STATUS type
|
||||
status = create_status(
|
||||
code=StatusCode.RESOURCE_LIMIT_EXCEEDED,
|
||||
message="Reservation limit exceeded",
|
||||
)
|
||||
|
||||
status_msg = HopMessage(
|
||||
type=HopMessage.STATUS,
|
||||
status=status.to_pb(),
|
||||
)
|
||||
await stream.write(status_msg.SerializeToString())
|
||||
return
|
||||
|
||||
# Accept reservation
|
||||
logger.debug("Accepting reservation from peer %s", peer_id)
|
||||
ttl = self.resource_manager.reserve(peer_id)
|
||||
|
||||
# Send reservation success response
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
status = create_status(
|
||||
code=StatusCode.OK, message="Reservation accepted"
|
||||
)
|
||||
|
||||
response = HopMessage(
|
||||
type=HopMessage.STATUS,
|
||||
status=status.to_pb(),
|
||||
reservation=Reservation(
|
||||
expire=int(time.time() + ttl),
|
||||
voucher=b"", # We don't use vouchers yet
|
||||
signature=b"", # We don't use signatures yet
|
||||
),
|
||||
limit=Limit(
|
||||
duration=self.limits.duration,
|
||||
data=self.limits.data,
|
||||
),
|
||||
)
|
||||
|
||||
# Log the response message details for debugging
|
||||
logger.debug(
|
||||
"Sending reservation response: type=%s, status=%s, ttl=%d",
|
||||
response.type,
|
||||
getattr(response.status, "code", "unknown"),
|
||||
ttl,
|
||||
)
|
||||
|
||||
# Send the response with increased timeout
|
||||
await stream.write(response.SerializeToString())
|
||||
|
||||
# Add a small wait to ensure the message is fully sent
|
||||
await trio.sleep(0.1)
|
||||
|
||||
logger.debug("Reservation response sent successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling reservation request: %s", str(e))
|
||||
if cast(INetStreamWithExtras, stream).is_open():
|
||||
try:
|
||||
# Send error response
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.INTERNAL_ERROR,
|
||||
f"Failed to process reservation: {str(e)}",
|
||||
)
|
||||
except Exception as send_err:
|
||||
logger.error("Failed to send error response: %s", str(send_err))
|
||||
finally:
|
||||
# Always close the stream when done with reservation
|
||||
if cast(INetStreamWithExtras, stream).is_open():
|
||||
try:
|
||||
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
||||
await stream.close()
|
||||
except Exception as close_err:
|
||||
logger.error("Error closing stream: %s", str(close_err))
|
||||
|
||||
async def _handle_connect(self, stream: INetStream, msg: Any) -> None:
|
||||
"""Handle a connect request."""
|
||||
peer_id = ID(msg.peer)
|
||||
dst_stream: INetStream | None = None
|
||||
|
||||
# Verify reservation if provided
|
||||
if msg.HasField("reservation"):
|
||||
if not self.resource_manager.verify_reservation(peer_id, msg.reservation):
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.PERMISSION_DENIED,
|
||||
"Invalid reservation",
|
||||
)
|
||||
await stream.reset()
|
||||
return
|
||||
|
||||
# Check resource limits
|
||||
if not self.resource_manager.can_accept_connection(peer_id):
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.RESOURCE_LIMIT_EXCEEDED,
|
||||
"Connection limit exceeded",
|
||||
)
|
||||
await stream.reset()
|
||||
return
|
||||
|
||||
try:
|
||||
# Store the source stream with properly typed None
|
||||
self._active_relays[peer_id] = (stream, None)
|
||||
|
||||
# Try to connect to the destination with timeout
|
||||
with trio.fail_after(STREAM_READ_TIMEOUT):
|
||||
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
|
||||
if not dst_stream:
|
||||
raise ConnectionError("Could not connect to destination")
|
||||
|
||||
# Send STOP CONNECT message
|
||||
stop_msg = StopMessage(
|
||||
type=StopMessage.CONNECT,
|
||||
# Cast to extended interface with get_remote_peer_id
|
||||
peer=cast(INetStreamWithExtras, stream)
|
||||
.get_remote_peer_id()
|
||||
.to_bytes(),
|
||||
)
|
||||
await dst_stream.write(stop_msg.SerializeToString())
|
||||
|
||||
# Wait for response from destination
|
||||
resp_bytes = await dst_stream.read()
|
||||
resp = StopMessage()
|
||||
resp.ParseFromString(resp_bytes)
|
||||
|
||||
# Handle status attributes from the response
|
||||
if resp.HasField("status"):
|
||||
# Get code and message attributes with defaults
|
||||
status_code = getattr(resp.status, "code", StatusCode.OK)
|
||||
# Get message with default
|
||||
status_msg = getattr(resp.status, "message", "Unknown error")
|
||||
else:
|
||||
status_code = StatusCode.OK
|
||||
status_msg = "No status provided"
|
||||
|
||||
if status_code != StatusCode.OK:
|
||||
raise ConnectionError(
|
||||
f"Destination rejected connection: {status_msg}"
|
||||
)
|
||||
|
||||
# Update active relays with destination stream
|
||||
self._active_relays[peer_id] = (stream, dst_stream)
|
||||
|
||||
# Update reservation connection count
|
||||
reservation = self.resource_manager._reservations.get(peer_id)
|
||||
if reservation:
|
||||
reservation.active_connections += 1
|
||||
|
||||
# Send success status
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.OK,
|
||||
"Connection established",
|
||||
)
|
||||
|
||||
# Start relaying data
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(self._relay_data, stream, dst_stream, peer_id)
|
||||
nursery.start_soon(self._relay_data, dst_stream, stream, peer_id)
|
||||
|
||||
except (trio.TooSlowError, ConnectionError) as e:
|
||||
logger.error("Error establishing relay connection: %s", str(e))
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.CONNECTION_FAILED,
|
||||
str(e),
|
||||
)
|
||||
if peer_id in self._active_relays:
|
||||
del self._active_relays[peer_id]
|
||||
# Clean up reservation connection count on failure
|
||||
reservation = self.resource_manager._reservations.get(peer_id)
|
||||
if reservation:
|
||||
reservation.active_connections -= 1
|
||||
await stream.reset()
|
||||
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
|
||||
await dst_stream.reset()
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error in connect handler: %s", str(e))
|
||||
await self._send_status(
|
||||
stream,
|
||||
StatusCode.CONNECTION_FAILED,
|
||||
"Internal error",
|
||||
)
|
||||
if peer_id in self._active_relays:
|
||||
del self._active_relays[peer_id]
|
||||
await stream.reset()
|
||||
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
|
||||
await dst_stream.reset()
|
||||
|
||||
async def _relay_data(
|
||||
self,
|
||||
src_stream: INetStream,
|
||||
dst_stream: INetStream,
|
||||
peer_id: ID,
|
||||
) -> None:
|
||||
"""
|
||||
Relay data between two streams.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
src_stream : INetStream
|
||||
Source stream to read from
|
||||
dst_stream : INetStream
|
||||
Destination stream to write to
|
||||
peer_id : ID
|
||||
ID of the peer being relayed
|
||||
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Read data with retries
|
||||
data = await self._read_stream_with_retry(src_stream)
|
||||
if not data:
|
||||
logger.info("Source stream closed/reset")
|
||||
break
|
||||
|
||||
# Write data with timeout
|
||||
try:
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
||||
await dst_stream.write(data)
|
||||
except trio.TooSlowError:
|
||||
logger.error("Timeout writing to destination stream")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error writing to destination stream: %s", str(e))
|
||||
break
|
||||
|
||||
# Update resource usage
|
||||
reservation = self.resource_manager._reservations.get(peer_id)
|
||||
if reservation:
|
||||
reservation.data_used += len(data)
|
||||
if reservation.data_used >= reservation.limits.data:
|
||||
logger.warning("Data limit exceeded for peer %s", peer_id)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error relaying data: %s", str(e))
|
||||
finally:
|
||||
# Clean up streams and remove from active relays
|
||||
await src_stream.reset()
|
||||
await dst_stream.reset()
|
||||
if peer_id in self._active_relays:
|
||||
del self._active_relays[peer_id]
|
||||
|
||||
async def _send_status(
|
||||
self,
|
||||
stream: ReadWriteCloser,
|
||||
code: int,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""Send a status message."""
|
||||
try:
|
||||
logger.debug("Sending status message with code %s: %s", code, message)
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(
|
||||
Any, int(code)
|
||||
) # Cast to Any to avoid type errors
|
||||
pb_status.message = message
|
||||
|
||||
status_msg = HopMessage(
|
||||
type=HopMessage.STATUS,
|
||||
status=pb_status,
|
||||
)
|
||||
|
||||
msg_bytes = status_msg.SerializeToString()
|
||||
logger.debug("Status message serialized (%d bytes)", len(msg_bytes))
|
||||
|
||||
await stream.write(msg_bytes)
|
||||
logger.debug("Status message sent, waiting for processing")
|
||||
|
||||
# Wait longer to ensure the message is sent
|
||||
await trio.sleep(1.5)
|
||||
logger.debug("Status message sending completed")
|
||||
except trio.TooSlowError:
|
||||
logger.error(
|
||||
"Timeout sending status message: code=%s, message=%s", code, message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error sending status message: %s", str(e))
|
||||
|
||||
async def _send_stop_status(
|
||||
self,
|
||||
stream: ReadWriteCloser,
|
||||
code: int,
|
||||
message: str,
|
||||
) -> None:
|
||||
"""Send a status message on a STOP stream."""
|
||||
try:
|
||||
logger.debug("Sending stop status message with code %s: %s", code, message)
|
||||
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
||||
# Create a proto Status directly
|
||||
pb_status = PbStatus()
|
||||
pb_status.code = cast(
|
||||
Any, int(code)
|
||||
) # Cast to Any to avoid type errors
|
||||
pb_status.message = message
|
||||
|
||||
status_msg = StopMessage(
|
||||
type=StopMessage.STATUS,
|
||||
status=pb_status,
|
||||
)
|
||||
await stream.write(status_msg.SerializeToString())
|
||||
await trio.sleep(0.5) # Ensure message is sent
|
||||
except Exception as e:
|
||||
logger.error("Error sending stop status message: %s", str(e))
|
||||
55
libp2p/relay/circuit_v2/protocol_buffer.py
Normal file
55
libp2p/relay/circuit_v2/protocol_buffer.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
Protocol buffer wrapper classes for Circuit Relay v2.
|
||||
|
||||
This module provides wrapper classes for protocol buffer generated objects
|
||||
to make them easier to work with in type-checked code.
|
||||
"""
|
||||
|
||||
from enum import (
|
||||
IntEnum,
|
||||
)
|
||||
from typing import (
|
||||
Any,
|
||||
)
|
||||
|
||||
from .pb.circuit_pb2 import Status as PbStatus
|
||||
|
||||
|
||||
# Define Status codes as an Enum for better type safety and organization
|
||||
class StatusCode(IntEnum):
|
||||
OK = 0
|
||||
RESERVATION_REFUSED = 100
|
||||
RESOURCE_LIMIT_EXCEEDED = 101
|
||||
PERMISSION_DENIED = 102
|
||||
CONNECTION_FAILED = 200
|
||||
DIAL_REFUSED = 201
|
||||
STOP_FAILED = 300
|
||||
MALFORMED_MESSAGE = 400
|
||||
INTERNAL_ERROR = 500
|
||||
|
||||
|
||||
def create_status(code: int = StatusCode.OK, message: str = "") -> Any:
|
||||
"""
|
||||
Create a protocol buffer Status object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code : int
|
||||
The status code
|
||||
message : str
|
||||
The status message
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
The protocol buffer Status object
|
||||
|
||||
"""
|
||||
# Create status object
|
||||
pb_obj = PbStatus()
|
||||
|
||||
# Convert the integer status code to the protobuf enum value type
|
||||
pb_obj.code = PbStatus.Code.ValueType(code)
|
||||
pb_obj.message = message
|
||||
|
||||
return pb_obj
|
||||
254
libp2p/relay/circuit_v2/resources.py
Normal file
254
libp2p/relay/circuit_v2/resources.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""
|
||||
Resource management for Circuit Relay v2.
|
||||
|
||||
This module handles managing resources for relay operations,
|
||||
including reservations and connection limits.
|
||||
"""
|
||||
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
|
||||
# Import the protobuf definitions
|
||||
from .pb.circuit_pb2 import Reservation as PbReservation
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelayLimits:
|
||||
"""Configuration for relay resource limits."""
|
||||
|
||||
duration: int # Maximum duration of a relay connection in seconds
|
||||
data: int # Maximum data transfer allowed in bytes
|
||||
max_circuit_conns: int # Maximum number of concurrent circuit connections
|
||||
max_reservations: int # Maximum number of active reservations
|
||||
|
||||
|
||||
class Reservation:
|
||||
"""Represents a relay reservation."""
|
||||
|
||||
def __init__(self, peer_id: ID, limits: RelayLimits):
|
||||
"""
|
||||
Initialize a new reservation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID this reservation is for
|
||||
limits : RelayLimits
|
||||
The resource limits for this reservation
|
||||
|
||||
"""
|
||||
self.peer_id = peer_id
|
||||
self.limits = limits
|
||||
self.created_at = time.time()
|
||||
self.expires_at = self.created_at + limits.duration
|
||||
self.data_used = 0
|
||||
self.active_connections = 0
|
||||
self.voucher = self._generate_voucher()
|
||||
|
||||
def _generate_voucher(self) -> bytes:
|
||||
"""
|
||||
Generate a unique cryptographically secure voucher for this reservation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bytes
|
||||
A secure voucher token
|
||||
|
||||
"""
|
||||
# Create a random token using a combination of:
|
||||
# - Random bytes for unpredictability
|
||||
# - Peer ID to bind it to the specific peer
|
||||
# - Timestamp for uniqueness
|
||||
# - Hash everything for a fixed size output
|
||||
random_bytes = os.urandom(16) # 128 bits of randomness
|
||||
timestamp = str(int(self.created_at * 1000000)).encode()
|
||||
peer_bytes = self.peer_id.to_bytes()
|
||||
|
||||
# Combine all elements and hash them
|
||||
h = hashlib.sha256()
|
||||
h.update(random_bytes)
|
||||
h.update(timestamp)
|
||||
h.update(peer_bytes)
|
||||
|
||||
return h.digest()
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the reservation has expired."""
|
||||
return time.time() > self.expires_at
|
||||
|
||||
def can_accept_connection(self) -> bool:
|
||||
"""Check if a new connection can be accepted."""
|
||||
return (
|
||||
not self.is_expired()
|
||||
and self.active_connections < self.limits.max_circuit_conns
|
||||
and self.data_used < self.limits.data
|
||||
)
|
||||
|
||||
def to_proto(self) -> PbReservation:
|
||||
"""Convert the reservation to its protobuf representation."""
|
||||
# TODO: For production use, implement proper signature generation
|
||||
# The signature should be created by signing the voucher with the
|
||||
# peer's private key. The current implementation with an empty signature
|
||||
# is intended for development and testing only.
|
||||
return PbReservation(
|
||||
expire=int(self.expires_at),
|
||||
voucher=self.voucher,
|
||||
signature=b"",
|
||||
)
|
||||
|
||||
|
||||
class RelayResourceManager:
|
||||
"""
|
||||
Manages resources and reservations for relay operations.
|
||||
|
||||
This class handles:
|
||||
- Tracking active reservations
|
||||
- Enforcing resource limits
|
||||
- Managing connection quotas
|
||||
"""
|
||||
|
||||
def __init__(self, limits: RelayLimits):
|
||||
"""
|
||||
Initialize the resource manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limits : RelayLimits
|
||||
The resource limits to enforce
|
||||
|
||||
"""
|
||||
self.limits = limits
|
||||
self._reservations: dict[ID, Reservation] = {}
|
||||
|
||||
def can_accept_reservation(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a new reservation can be accepted for the given peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID requesting the reservation
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the reservation can be accepted
|
||||
|
||||
"""
|
||||
# Clean expired reservations
|
||||
self._clean_expired()
|
||||
|
||||
# Check if peer already has a valid reservation
|
||||
existing = self._reservations.get(peer_id)
|
||||
if existing and not existing.is_expired():
|
||||
return True
|
||||
|
||||
# Check if we're at the reservation limit
|
||||
return len(self._reservations) < self.limits.max_reservations
|
||||
|
||||
def create_reservation(self, peer_id: ID) -> Reservation:
|
||||
"""
|
||||
Create a new reservation for the given peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to create the reservation for
|
||||
|
||||
Returns
|
||||
-------
|
||||
Reservation
|
||||
The newly created reservation
|
||||
|
||||
"""
|
||||
reservation = Reservation(peer_id, self.limits)
|
||||
self._reservations[peer_id] = reservation
|
||||
return reservation
|
||||
|
||||
def verify_reservation(self, peer_id: ID, proto_res: PbReservation) -> bool:
|
||||
"""
|
||||
Verify a reservation from a protobuf message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID the reservation is for
|
||||
proto_res : PbReservation
|
||||
The protobuf reservation message
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the reservation is valid
|
||||
|
||||
"""
|
||||
# TODO: Implement voucher and signature verification
|
||||
reservation = self._reservations.get(peer_id)
|
||||
return (
|
||||
reservation is not None
|
||||
and not reservation.is_expired()
|
||||
and reservation.expires_at == proto_res.expire
|
||||
)
|
||||
|
||||
def can_accept_connection(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a new connection can be accepted for the given peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID requesting the connection
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the connection can be accepted
|
||||
|
||||
"""
|
||||
reservation = self._reservations.get(peer_id)
|
||||
return reservation is not None and reservation.can_accept_connection()
|
||||
|
||||
def _clean_expired(self) -> None:
|
||||
"""Remove expired reservations."""
|
||||
now = time.time()
|
||||
expired = [
|
||||
peer_id
|
||||
for peer_id, res in self._reservations.items()
|
||||
if now > res.expires_at
|
||||
]
|
||||
for peer_id in expired:
|
||||
del self._reservations[peer_id]
|
||||
|
||||
def reserve(self, peer_id: ID) -> int:
|
||||
"""
|
||||
Create or update a reservation for a peer and return the TTL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_id : ID
|
||||
The peer ID to reserve for
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The TTL of the reservation in seconds
|
||||
|
||||
"""
|
||||
# Check for existing reservation
|
||||
existing = self._reservations.get(peer_id)
|
||||
if existing and not existing.is_expired():
|
||||
# Return remaining time for existing reservation
|
||||
remaining = max(0, int(existing.expires_at - time.time()))
|
||||
return remaining
|
||||
|
||||
# Create new reservation
|
||||
self.create_reservation(peer_id)
|
||||
return self.limits.duration
|
||||
427
libp2p/relay/circuit_v2/transport.py
Normal file
427
libp2p/relay/circuit_v2/transport.py
Normal file
@ -0,0 +1,427 @@
|
||||
"""
|
||||
Transport implementation for Circuit Relay v2.
|
||||
|
||||
This module implements the transport layer for Circuit Relay v2,
|
||||
allowing peers to establish connections through relay nodes.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import (
|
||||
IHost,
|
||||
IListener,
|
||||
INetStream,
|
||||
ITransport,
|
||||
ReadWriteCloser,
|
||||
)
|
||||
from libp2p.network.connection.raw_connection import (
|
||||
RawConnection,
|
||||
)
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
PeerInfo,
|
||||
)
|
||||
from libp2p.tools.async_service import (
|
||||
Service,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
ClientConfig,
|
||||
RelayConfig,
|
||||
)
|
||||
from .discovery import (
|
||||
RelayDiscovery,
|
||||
)
|
||||
from .pb.circuit_pb2 import (
|
||||
HopMessage,
|
||||
StopMessage,
|
||||
)
|
||||
from .protocol import (
|
||||
PROTOCOL_ID,
|
||||
CircuitV2Protocol,
|
||||
)
|
||||
from .protocol_buffer import (
|
||||
StatusCode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.relay.circuit_v2.transport")
|
||||
|
||||
|
||||
class CircuitV2Transport(ITransport):
|
||||
"""
|
||||
CircuitV2Transport implements the transport interface for Circuit Relay v2.
|
||||
|
||||
This transport allows peers to establish connections through relay nodes
|
||||
when direct connections are not possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
protocol: CircuitV2Protocol,
|
||||
config: RelayConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Circuit v2 transport.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this transport is running on
|
||||
protocol : CircuitV2Protocol
|
||||
The Circuit v2 protocol instance
|
||||
config : RelayConfig
|
||||
Relay configuration
|
||||
|
||||
"""
|
||||
self.host = host
|
||||
self.protocol = protocol
|
||||
self.config = config
|
||||
self.client_config = ClientConfig()
|
||||
self.discovery = RelayDiscovery(
|
||||
host=host,
|
||||
auto_reserve=config.enable_client,
|
||||
discovery_interval=config.discovery_interval,
|
||||
max_relays=config.max_relays,
|
||||
)
|
||||
|
||||
async def dial(
|
||||
self,
|
||||
maddr: multiaddr.Multiaddr,
|
||||
) -> RawConnection:
|
||||
"""
|
||||
Dial a peer using the multiaddr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maddr : multiaddr.Multiaddr
|
||||
The multiaddr to dial
|
||||
|
||||
Returns
|
||||
-------
|
||||
RawConnection
|
||||
The established connection
|
||||
|
||||
Raises
|
||||
------
|
||||
ConnectionError
|
||||
If the connection cannot be established
|
||||
|
||||
"""
|
||||
# Extract peer ID from multiaddr - P_P2P code is 0x01A5 (421)
|
||||
peer_id_str = maddr.value_for_protocol("p2p")
|
||||
if not peer_id_str:
|
||||
raise ConnectionError("Multiaddr does not contain peer ID")
|
||||
|
||||
peer_id = ID.from_base58(peer_id_str)
|
||||
peer_info = PeerInfo(peer_id, [maddr])
|
||||
|
||||
# Use the internal dial_peer_info method
|
||||
return await self.dial_peer_info(peer_info)
|
||||
|
||||
async def dial_peer_info(
|
||||
self,
|
||||
peer_info: PeerInfo,
|
||||
*,
|
||||
relay_peer_id: ID | None = None,
|
||||
) -> RawConnection:
|
||||
"""
|
||||
Dial a peer through a relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_info : PeerInfo
|
||||
The peer to dial
|
||||
relay_peer_id : Optional[ID], optional
|
||||
Optional specific relay peer to use
|
||||
|
||||
Returns
|
||||
-------
|
||||
RawConnection
|
||||
The established connection
|
||||
|
||||
Raises
|
||||
------
|
||||
ConnectionError
|
||||
If the connection cannot be established
|
||||
|
||||
"""
|
||||
# If no specific relay is provided, try to find one
|
||||
if relay_peer_id is None:
|
||||
relay_peer_id = await self._select_relay(peer_info)
|
||||
if not relay_peer_id:
|
||||
raise ConnectionError("No suitable relay found")
|
||||
|
||||
# Get a stream to the relay
|
||||
relay_stream = await self.host.new_stream(relay_peer_id, [PROTOCOL_ID])
|
||||
if not relay_stream:
|
||||
raise ConnectionError(f"Could not open stream to relay {relay_peer_id}")
|
||||
|
||||
try:
|
||||
# First try to make a reservation if enabled
|
||||
if self.config.enable_client:
|
||||
success = await self._make_reservation(relay_stream, relay_peer_id)
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Failed to make reservation with relay %s", relay_peer_id
|
||||
)
|
||||
|
||||
# Send HOP CONNECT message
|
||||
hop_msg = HopMessage(
|
||||
type=HopMessage.CONNECT,
|
||||
peer=peer_info.peer_id.to_bytes(),
|
||||
)
|
||||
await relay_stream.write(hop_msg.SerializeToString())
|
||||
|
||||
# Read response
|
||||
resp_bytes = await relay_stream.read()
|
||||
resp = HopMessage()
|
||||
resp.ParseFromString(resp_bytes)
|
||||
|
||||
# Access status attributes directly
|
||||
status_code = getattr(resp.status, "code", StatusCode.OK)
|
||||
status_msg = getattr(resp.status, "message", "Unknown error")
|
||||
|
||||
if status_code != StatusCode.OK:
|
||||
raise ConnectionError(f"Relay connection failed: {status_msg}")
|
||||
|
||||
# Create raw connection from stream
|
||||
return RawConnection(stream=relay_stream, initiator=True)
|
||||
|
||||
except Exception as e:
|
||||
await relay_stream.close()
|
||||
raise ConnectionError(f"Failed to establish relay connection: {str(e)}")
|
||||
|
||||
async def _select_relay(self, peer_info: PeerInfo) -> ID | None:
|
||||
"""
|
||||
Select an appropriate relay for the given peer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
peer_info : PeerInfo
|
||||
The peer to connect to
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[ID]
|
||||
Selected relay peer ID, or None if no suitable relay found
|
||||
|
||||
"""
|
||||
# Try to find a relay
|
||||
attempts = 0
|
||||
while attempts < self.client_config.max_auto_relay_attempts:
|
||||
# Get a relay from the list of discovered relays
|
||||
relays = self.discovery.get_relays()
|
||||
if relays:
|
||||
# TODO: Implement more sophisticated relay selection
|
||||
# For now, just return the first available relay
|
||||
return relays[0]
|
||||
|
||||
# Wait and try discovery
|
||||
await trio.sleep(1)
|
||||
attempts += 1
|
||||
|
||||
return None
|
||||
|
||||
async def _make_reservation(
|
||||
self,
|
||||
stream: INetStream,
|
||||
relay_peer_id: ID,
|
||||
) -> bool:
|
||||
"""
|
||||
Make a reservation with a relay.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stream : INetStream
|
||||
Stream to the relay
|
||||
relay_peer_id : ID
|
||||
The relay's peer ID
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if reservation was successful
|
||||
|
||||
"""
|
||||
try:
|
||||
# Send reservation request
|
||||
reserve_msg = HopMessage(
|
||||
type=HopMessage.RESERVE,
|
||||
peer=self.host.get_id().to_bytes(),
|
||||
)
|
||||
await stream.write(reserve_msg.SerializeToString())
|
||||
|
||||
# Read response
|
||||
resp_bytes = await stream.read()
|
||||
resp = HopMessage()
|
||||
resp.ParseFromString(resp_bytes)
|
||||
|
||||
# Access status attributes directly
|
||||
status_code = getattr(resp.status, "code", StatusCode.OK)
|
||||
status_msg = getattr(resp.status, "message", "Unknown error")
|
||||
|
||||
if status_code != StatusCode.OK:
|
||||
logger.warning(
|
||||
"Reservation failed with relay %s: %s",
|
||||
relay_peer_id,
|
||||
status_msg,
|
||||
)
|
||||
return False
|
||||
|
||||
# Store reservation info
|
||||
# TODO: Implement reservation storage and refresh mechanism
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error making reservation: %s", str(e))
|
||||
return False
|
||||
|
||||
def create_listener(
|
||||
self,
|
||||
handler_function: Callable[[ReadWriteCloser], Awaitable[None]],
|
||||
) -> IListener:
|
||||
"""
|
||||
Create a listener for incoming relay connections.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_function : Callable[[ReadWriteCloser], Awaitable[None]]
|
||||
The handler function for new connections
|
||||
|
||||
Returns
|
||||
-------
|
||||
IListener
|
||||
The created listener
|
||||
|
||||
"""
|
||||
return CircuitV2Listener(self.host, self.protocol, self.config)
|
||||
|
||||
|
||||
class CircuitV2Listener(Service, IListener):
|
||||
"""Listener for incoming relay connections."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: IHost,
|
||||
protocol: CircuitV2Protocol,
|
||||
config: RelayConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Circuit v2 listener.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : IHost
|
||||
The libp2p host this listener is running on
|
||||
protocol : CircuitV2Protocol
|
||||
The Circuit v2 protocol instance
|
||||
config : RelayConfig
|
||||
Relay configuration
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.protocol = protocol
|
||||
self.config = config
|
||||
self.multiaddrs: list[
|
||||
multiaddr.Multiaddr
|
||||
] = [] # Store multiaddrs as Multiaddr objects
|
||||
|
||||
async def handle_incoming_connection(
|
||||
self,
|
||||
stream: INetStream,
|
||||
remote_peer_id: ID,
|
||||
) -> RawConnection:
|
||||
"""
|
||||
Handle an incoming relay connection.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stream : INetStream
|
||||
The incoming stream
|
||||
remote_peer_id : ID
|
||||
The remote peer's ID
|
||||
|
||||
Returns
|
||||
-------
|
||||
RawConnection
|
||||
The established connection
|
||||
|
||||
Raises
|
||||
------
|
||||
ConnectionError
|
||||
If the connection cannot be established
|
||||
|
||||
"""
|
||||
if not self.config.enable_stop:
|
||||
raise ConnectionError("Stop role is not enabled")
|
||||
|
||||
try:
|
||||
# Read STOP message
|
||||
msg_bytes = await stream.read()
|
||||
stop_msg = StopMessage()
|
||||
stop_msg.ParseFromString(msg_bytes)
|
||||
|
||||
if stop_msg.type != StopMessage.CONNECT:
|
||||
raise ConnectionError("Invalid STOP message type")
|
||||
|
||||
# Create raw connection
|
||||
return RawConnection(stream=stream, initiator=False)
|
||||
|
||||
except Exception as e:
|
||||
await stream.close()
|
||||
raise ConnectionError(f"Failed to handle incoming connection: {str(e)}")
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the listener service."""
|
||||
# Implementation would go here
|
||||
|
||||
async def listen(self, maddr: multiaddr.Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
"""
|
||||
Start listening on the given multiaddr.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maddr : multiaddr.Multiaddr
|
||||
The multiaddr to listen on
|
||||
nursery : trio.Nursery
|
||||
The nursery to run tasks in
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if listening successfully started
|
||||
|
||||
"""
|
||||
# Convert string to Multiaddr if needed
|
||||
addr = (
|
||||
maddr
|
||||
if isinstance(maddr, multiaddr.Multiaddr)
|
||||
else multiaddr.Multiaddr(maddr)
|
||||
)
|
||||
self.multiaddrs.append(addr)
|
||||
return True
|
||||
|
||||
def get_addrs(self) -> tuple[multiaddr.Multiaddr, ...]:
|
||||
"""
|
||||
Get the listening addresses.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[multiaddr.Multiaddr, ...]
|
||||
Tuple of listening multiaddresses
|
||||
|
||||
"""
|
||||
return tuple(self.multiaddrs)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the listener."""
|
||||
self.multiaddrs.clear()
|
||||
await self.manager.stop()
|
||||
@ -87,14 +87,16 @@ async def connect(node1: IHost, node2: IHost) -> None:
|
||||
addr = node2.get_addrs()[0]
|
||||
info = info_from_p2p_addr(addr)
|
||||
|
||||
# Add retry logic for more robust connection
|
||||
# Add retry logic for more robust connection with timeout
|
||||
max_retries = 3
|
||||
retry_delay = 0.2
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await node1.connect(info)
|
||||
# Use timeout for each connection attempt
|
||||
with trio.move_on_after(5): # 5 second timeout
|
||||
await node1.connect(info)
|
||||
|
||||
# Verify connection is established in both directions
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user