mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
* feat: implemented setup of circuit relay and test cases * chore: remove test files to be rewritten * added 1 test suite for protocol * added 1 test suite for discovery * fixed protocol timeouts and message types to handle reservations and stream operations. * Resolved merge conflict in libp2p/tools/utils.py by combining timeout approach with retry mechanism * fix: linting issues * docs: updated documentation with circuit-relay * chore: added enums, improved typing, security and examples * fix: created proper __init__ file to ensure importability * fix: replace transport_opt with listen_addrs in examples, fixed typing and improved code * fix type checking issues across relay module and test suite * regenerated circuit_pb2 file protobuf version 3 * fixed circuit relay example and moved imports to top in test_security_multistream * chore: moved imports to the top * chore: fixed linting of test_circuit_v2_transport.py --------- Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
801 lines
29 KiB
Python
801 lines
29 KiB
Python
"""
|
|
Circuit Relay v2 protocol implementation.
|
|
|
|
This module implements the Circuit Relay v2 protocol as specified in:
|
|
https://github.com/libp2p/specs/blob/master/relay/circuit-v2.md
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from typing import (
|
|
Any,
|
|
Protocol as TypingProtocol,
|
|
cast,
|
|
runtime_checkable,
|
|
)
|
|
|
|
import trio
|
|
|
|
from libp2p.abc import (
|
|
IHost,
|
|
INetStream,
|
|
)
|
|
from libp2p.custom_types import (
|
|
TProtocol,
|
|
)
|
|
from libp2p.io.abc import (
|
|
ReadWriteCloser,
|
|
)
|
|
from libp2p.peer.id import (
|
|
ID,
|
|
)
|
|
from libp2p.stream_muxer.mplex.exceptions import (
|
|
MplexStreamEOF,
|
|
MplexStreamReset,
|
|
)
|
|
from libp2p.tools.async_service import (
|
|
Service,
|
|
)
|
|
|
|
from .pb.circuit_pb2 import (
|
|
HopMessage,
|
|
Limit,
|
|
Reservation,
|
|
Status as PbStatus,
|
|
StopMessage,
|
|
)
|
|
from .protocol_buffer import (
|
|
StatusCode,
|
|
create_status,
|
|
)
|
|
from .resources import (
|
|
RelayLimits,
|
|
RelayResourceManager,
|
|
)
|
|
|
|
logger = logging.getLogger("libp2p.relay.circuit_v2")
|
|
|
|
PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0")
|
|
STOP_PROTOCOL_ID = TProtocol("/libp2p/circuit/relay/2.0.0/stop")
|
|
|
|
# Default limits for relay resources
|
|
DEFAULT_RELAY_LIMITS = RelayLimits(
|
|
duration=60 * 60, # 1 hour
|
|
data=1024 * 1024 * 1024, # 1GB
|
|
max_circuit_conns=8,
|
|
max_reservations=4,
|
|
)
|
|
|
|
# Stream operation timeouts
|
|
STREAM_READ_TIMEOUT = 15 # seconds
|
|
STREAM_WRITE_TIMEOUT = 15 # seconds
|
|
STREAM_CLOSE_TIMEOUT = 10 # seconds
|
|
MAX_READ_RETRIES = 5 # Maximum number of read retries
|
|
|
|
|
|
# Extended interfaces for type checking
|
|
@runtime_checkable
|
|
class IHostWithStreamHandlers(TypingProtocol):
|
|
"""Extended host interface with stream handler methods."""
|
|
|
|
def remove_stream_handler(self, protocol_id: TProtocol) -> None:
|
|
"""Remove a stream handler for a protocol."""
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class INetStreamWithExtras(TypingProtocol):
|
|
"""Extended net stream interface with additional methods."""
|
|
|
|
def get_remote_peer_id(self) -> ID:
|
|
"""Get the remote peer ID."""
|
|
...
|
|
|
|
def is_open(self) -> bool:
|
|
"""Check if the stream is open."""
|
|
...
|
|
|
|
def is_closed(self) -> bool:
|
|
"""Check if the stream is closed."""
|
|
...
|
|
|
|
|
|
class CircuitV2Protocol(Service):
|
|
"""
|
|
CircuitV2Protocol implements the Circuit Relay v2 protocol.
|
|
|
|
This protocol allows peers to establish connections through relay nodes
|
|
when direct connections are not possible (e.g., due to NAT).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
host: IHost,
|
|
limits: RelayLimits | None = None,
|
|
allow_hop: bool = False,
|
|
) -> None:
|
|
"""
|
|
Initialize a Circuit Relay v2 protocol instance.
|
|
|
|
Parameters
|
|
----------
|
|
host : IHost
|
|
The libp2p host instance
|
|
limits : RelayLimits | None
|
|
Resource limits for the relay
|
|
allow_hop : bool
|
|
Whether to allow this node to act as a relay
|
|
|
|
"""
|
|
self.host = host
|
|
self.limits = limits or DEFAULT_RELAY_LIMITS
|
|
self.allow_hop = allow_hop
|
|
self.resource_manager = RelayResourceManager(self.limits)
|
|
self._active_relays: dict[ID, tuple[INetStream, INetStream | None]] = {}
|
|
self.event_started = trio.Event()
|
|
|
|
async def run(self, *, task_status: Any = trio.TASK_STATUS_IGNORED) -> None:
|
|
"""Run the protocol service."""
|
|
try:
|
|
# Register protocol handlers
|
|
if self.allow_hop:
|
|
logger.debug("Registering stream handlers for relay protocol")
|
|
self.host.set_stream_handler(PROTOCOL_ID, self._handle_hop_stream)
|
|
self.host.set_stream_handler(STOP_PROTOCOL_ID, self._handle_stop_stream)
|
|
logger.debug("Stream handlers registered successfully")
|
|
|
|
# Signal that we're ready
|
|
self.event_started.set()
|
|
task_status.started()
|
|
logger.debug("Protocol service started")
|
|
|
|
# Wait for service to be stopped
|
|
await self.manager.wait_finished()
|
|
finally:
|
|
# Clean up any active relay connections
|
|
for src_stream, dst_stream in self._active_relays.values():
|
|
await self._close_stream(src_stream)
|
|
await self._close_stream(dst_stream)
|
|
self._active_relays.clear()
|
|
|
|
# Unregister protocol handlers
|
|
if self.allow_hop:
|
|
try:
|
|
# Cast host to extended interface with remove_stream_handler
|
|
host_with_handlers = cast(IHostWithStreamHandlers, self.host)
|
|
host_with_handlers.remove_stream_handler(PROTOCOL_ID)
|
|
host_with_handlers.remove_stream_handler(STOP_PROTOCOL_ID)
|
|
except Exception as e:
|
|
logger.error("Error unregistering stream handlers: %s", str(e))
|
|
|
|
async def _close_stream(self, stream: INetStream | None) -> None:
|
|
"""Helper function to safely close a stream."""
|
|
if stream is None:
|
|
return
|
|
|
|
try:
|
|
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
|
await stream.close()
|
|
except Exception:
|
|
try:
|
|
await stream.reset()
|
|
except Exception:
|
|
pass
|
|
|
|
async def _read_stream_with_retry(
|
|
self,
|
|
stream: INetStream,
|
|
max_retries: int = MAX_READ_RETRIES,
|
|
) -> bytes | None:
|
|
"""
|
|
Helper function to read from a stream with retries.
|
|
|
|
Parameters
|
|
----------
|
|
stream : INetStream
|
|
The stream to read from
|
|
max_retries : int
|
|
Maximum number of read retries
|
|
|
|
Returns
|
|
-------
|
|
Optional[bytes]
|
|
The data read from the stream, or None if the stream is closed/reset
|
|
|
|
Raises
|
|
------
|
|
trio.TooSlowError
|
|
If read timeout occurs after all retries
|
|
Exception
|
|
For other unexpected errors
|
|
|
|
"""
|
|
retries = 0
|
|
last_error: Any = None
|
|
backoff_time = 0.2 # Base backoff time in seconds
|
|
|
|
while retries < max_retries:
|
|
try:
|
|
with trio.fail_after(STREAM_READ_TIMEOUT):
|
|
# Try reading with timeout
|
|
logger.debug(
|
|
"Attempting to read from stream (attempt %d/%d)",
|
|
retries + 1,
|
|
max_retries,
|
|
)
|
|
data = await stream.read()
|
|
if not data: # EOF
|
|
logger.debug("Stream EOF detected")
|
|
return None
|
|
|
|
logger.debug("Successfully read %d bytes from stream", len(data))
|
|
return data
|
|
except trio.WouldBlock:
|
|
# Just retry immediately if we would block
|
|
retries += 1
|
|
logger.debug(
|
|
"Stream would block (attempt %d/%d), retrying...",
|
|
retries,
|
|
max_retries,
|
|
)
|
|
await trio.sleep(backoff_time * retries) # Increased backoff time
|
|
continue
|
|
except (MplexStreamEOF, MplexStreamReset):
|
|
# Stream closed/reset - no point retrying
|
|
logger.debug("Stream closed/reset during read")
|
|
return None
|
|
except trio.TooSlowError as e:
|
|
last_error = e
|
|
retries += 1
|
|
logger.debug(
|
|
"Read timeout (attempt %d/%d), retrying...", retries, max_retries
|
|
)
|
|
if retries < max_retries:
|
|
# Wait longer before retry with increasing backoff
|
|
await trio.sleep(backoff_time * retries) # Increased backoff
|
|
continue
|
|
except Exception as e:
|
|
logger.error("Unexpected error reading from stream: %s", str(e))
|
|
last_error = e
|
|
retries += 1
|
|
if retries < max_retries:
|
|
await trio.sleep(backoff_time * retries) # Increased backoff
|
|
continue
|
|
raise
|
|
|
|
if last_error:
|
|
if isinstance(last_error, trio.TooSlowError):
|
|
logger.error("Read timed out after %d retries", max_retries)
|
|
raise last_error
|
|
|
|
return None
|
|
|
|
async def _handle_hop_stream(self, stream: INetStream) -> None:
|
|
"""
|
|
Handle incoming HOP streams.
|
|
|
|
This handler processes relay requests from other peers.
|
|
"""
|
|
try:
|
|
# Try to get peer ID first
|
|
try:
|
|
# Cast to extended interface with get_remote_peer_id
|
|
stream_with_peer_id = cast(INetStreamWithExtras, stream)
|
|
remote_peer_id = stream_with_peer_id.get_remote_peer_id()
|
|
remote_id = str(remote_peer_id)
|
|
except Exception:
|
|
# Fall back to address if peer ID not available
|
|
remote_addr = stream.get_remote_address()
|
|
remote_id = f"peer at {remote_addr}" if remote_addr else "unknown peer"
|
|
|
|
logger.debug("Handling hop stream from %s", remote_id)
|
|
|
|
# First, handle the read timeout gracefully
|
|
try:
|
|
with trio.fail_after(
|
|
STREAM_READ_TIMEOUT * 2
|
|
): # Double the timeout for reading
|
|
msg_bytes = await stream.read()
|
|
if not msg_bytes:
|
|
logger.error(
|
|
"Empty read from stream from %s",
|
|
remote_id,
|
|
)
|
|
# Create a proto Status directly
|
|
pb_status = PbStatus()
|
|
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
|
|
pb_status.message = "Empty message received"
|
|
|
|
response = HopMessage(
|
|
type=HopMessage.STATUS,
|
|
status=pb_status,
|
|
)
|
|
await stream.write(response.SerializeToString())
|
|
await trio.sleep(0.5) # Longer wait to ensure message is sent
|
|
return
|
|
except trio.TooSlowError:
|
|
logger.error(
|
|
"Timeout reading from hop stream from %s",
|
|
remote_id,
|
|
)
|
|
# Create a proto Status directly
|
|
pb_status = PbStatus()
|
|
pb_status.code = cast(Any, int(StatusCode.CONNECTION_FAILED))
|
|
pb_status.message = "Stream read timeout"
|
|
|
|
response = HopMessage(
|
|
type=HopMessage.STATUS,
|
|
status=pb_status,
|
|
)
|
|
await stream.write(response.SerializeToString())
|
|
await trio.sleep(0.5) # Longer wait to ensure the message is sent
|
|
return
|
|
except Exception as e:
|
|
logger.error(
|
|
"Error reading from hop stream from %s: %s",
|
|
remote_id,
|
|
str(e),
|
|
)
|
|
# Create a proto Status directly
|
|
pb_status = PbStatus()
|
|
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
|
|
pb_status.message = f"Read error: {str(e)}"
|
|
|
|
response = HopMessage(
|
|
type=HopMessage.STATUS,
|
|
status=pb_status,
|
|
)
|
|
await stream.write(response.SerializeToString())
|
|
await trio.sleep(0.5) # Longer wait to ensure the message is sent
|
|
return
|
|
|
|
# Parse the message
|
|
try:
|
|
hop_msg = HopMessage()
|
|
hop_msg.ParseFromString(msg_bytes)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Error parsing hop message from %s: %s",
|
|
remote_id,
|
|
str(e),
|
|
)
|
|
# Create a proto Status directly
|
|
pb_status = PbStatus()
|
|
pb_status.code = cast(Any, int(StatusCode.MALFORMED_MESSAGE))
|
|
pb_status.message = f"Parse error: {str(e)}"
|
|
|
|
response = HopMessage(
|
|
type=HopMessage.STATUS,
|
|
status=pb_status,
|
|
)
|
|
await stream.write(response.SerializeToString())
|
|
await trio.sleep(0.5) # Longer wait to ensure the message is sent
|
|
return
|
|
|
|
# Process based on message type
|
|
if hop_msg.type == HopMessage.RESERVE:
|
|
logger.debug("Handling RESERVE message from %s", remote_id)
|
|
await self._handle_reserve(stream, hop_msg)
|
|
# For RESERVE requests, let the client close the stream
|
|
return
|
|
elif hop_msg.type == HopMessage.CONNECT:
|
|
logger.debug("Handling CONNECT message from %s", remote_id)
|
|
await self._handle_connect(stream, hop_msg)
|
|
else:
|
|
logger.error("Invalid message type %d from %s", hop_msg.type, remote_id)
|
|
# Send a nice error response using _send_status method
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.MALFORMED_MESSAGE,
|
|
f"Invalid message type: {hop_msg.type}",
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Unexpected error handling hop stream from %s: %s", remote_id, str(e)
|
|
)
|
|
try:
|
|
# Send a nice error response using _send_status method
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.MALFORMED_MESSAGE,
|
|
f"Internal error: {str(e)}",
|
|
)
|
|
except Exception as e2:
|
|
logger.error(
|
|
"Failed to send error response to %s: %s", remote_id, str(e2)
|
|
)
|
|
|
|
async def _handle_stop_stream(self, stream: INetStream) -> None:
|
|
"""
|
|
Handle incoming STOP streams.
|
|
|
|
This handler processes incoming relay connections from the destination side.
|
|
"""
|
|
try:
|
|
# Read the incoming message with timeout
|
|
with trio.fail_after(STREAM_READ_TIMEOUT):
|
|
msg_bytes = await stream.read()
|
|
stop_msg = StopMessage()
|
|
stop_msg.ParseFromString(msg_bytes)
|
|
|
|
if stop_msg.type != StopMessage.CONNECT:
|
|
# Use direct attribute access to create status object for error response
|
|
await self._send_stop_status(
|
|
stream,
|
|
StatusCode.MALFORMED_MESSAGE,
|
|
"Invalid message type",
|
|
)
|
|
await self._close_stream(stream)
|
|
return
|
|
|
|
# Get the source stream from active relays
|
|
peer_id = ID(stop_msg.peer)
|
|
if peer_id not in self._active_relays:
|
|
# Use direct attribute access to create status object for error response
|
|
await self._send_stop_status(
|
|
stream,
|
|
StatusCode.CONNECTION_FAILED,
|
|
"No pending relay connection",
|
|
)
|
|
await self._close_stream(stream)
|
|
return
|
|
|
|
src_stream, _ = self._active_relays[peer_id]
|
|
self._active_relays[peer_id] = (src_stream, stream)
|
|
|
|
# Send success status to both sides
|
|
await self._send_status(
|
|
src_stream,
|
|
StatusCode.OK,
|
|
"Connection established",
|
|
)
|
|
await self._send_stop_status(
|
|
stream,
|
|
StatusCode.OK,
|
|
"Connection established",
|
|
)
|
|
|
|
# Start relaying data
|
|
async with trio.open_nursery() as nursery:
|
|
nursery.start_soon(self._relay_data, src_stream, stream, peer_id)
|
|
nursery.start_soon(self._relay_data, stream, src_stream, peer_id)
|
|
|
|
except trio.TooSlowError:
|
|
logger.error("Timeout reading from stop stream")
|
|
await self._send_stop_status(
|
|
stream,
|
|
StatusCode.CONNECTION_FAILED,
|
|
"Stream read timeout",
|
|
)
|
|
await self._close_stream(stream)
|
|
except Exception as e:
|
|
logger.error("Error handling stop stream: %s", str(e))
|
|
try:
|
|
await self._send_stop_status(
|
|
stream,
|
|
StatusCode.MALFORMED_MESSAGE,
|
|
str(e),
|
|
)
|
|
await self._close_stream(stream)
|
|
except Exception:
|
|
pass
|
|
|
|
async def _handle_reserve(self, stream: INetStream, msg: Any) -> None:
|
|
"""Handle a reservation request."""
|
|
peer_id = None
|
|
try:
|
|
peer_id = ID(msg.peer)
|
|
logger.debug("Handling reservation request from peer %s", peer_id)
|
|
|
|
# Check if we can accept more reservations
|
|
if not self.resource_manager.can_accept_reservation(peer_id):
|
|
logger.debug("Reservation limit exceeded for peer %s", peer_id)
|
|
# Send status message with STATUS type
|
|
status = create_status(
|
|
code=StatusCode.RESOURCE_LIMIT_EXCEEDED,
|
|
message="Reservation limit exceeded",
|
|
)
|
|
|
|
status_msg = HopMessage(
|
|
type=HopMessage.STATUS,
|
|
status=status.to_pb(),
|
|
)
|
|
await stream.write(status_msg.SerializeToString())
|
|
return
|
|
|
|
# Accept reservation
|
|
logger.debug("Accepting reservation from peer %s", peer_id)
|
|
ttl = self.resource_manager.reserve(peer_id)
|
|
|
|
# Send reservation success response
|
|
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
|
status = create_status(
|
|
code=StatusCode.OK, message="Reservation accepted"
|
|
)
|
|
|
|
response = HopMessage(
|
|
type=HopMessage.STATUS,
|
|
status=status.to_pb(),
|
|
reservation=Reservation(
|
|
expire=int(time.time() + ttl),
|
|
voucher=b"", # We don't use vouchers yet
|
|
signature=b"", # We don't use signatures yet
|
|
),
|
|
limit=Limit(
|
|
duration=self.limits.duration,
|
|
data=self.limits.data,
|
|
),
|
|
)
|
|
|
|
# Log the response message details for debugging
|
|
logger.debug(
|
|
"Sending reservation response: type=%s, status=%s, ttl=%d",
|
|
response.type,
|
|
getattr(response.status, "code", "unknown"),
|
|
ttl,
|
|
)
|
|
|
|
# Send the response with increased timeout
|
|
await stream.write(response.SerializeToString())
|
|
|
|
# Add a small wait to ensure the message is fully sent
|
|
await trio.sleep(0.1)
|
|
|
|
logger.debug("Reservation response sent successfully")
|
|
|
|
except Exception as e:
|
|
logger.error("Error handling reservation request: %s", str(e))
|
|
if cast(INetStreamWithExtras, stream).is_open():
|
|
try:
|
|
# Send error response
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.INTERNAL_ERROR,
|
|
f"Failed to process reservation: {str(e)}",
|
|
)
|
|
except Exception as send_err:
|
|
logger.error("Failed to send error response: %s", str(send_err))
|
|
finally:
|
|
# Always close the stream when done with reservation
|
|
if cast(INetStreamWithExtras, stream).is_open():
|
|
try:
|
|
with trio.fail_after(STREAM_CLOSE_TIMEOUT):
|
|
await stream.close()
|
|
except Exception as close_err:
|
|
logger.error("Error closing stream: %s", str(close_err))
|
|
|
|
async def _handle_connect(self, stream: INetStream, msg: Any) -> None:
|
|
"""Handle a connect request."""
|
|
peer_id = ID(msg.peer)
|
|
dst_stream: INetStream | None = None
|
|
|
|
# Verify reservation if provided
|
|
if msg.HasField("reservation"):
|
|
if not self.resource_manager.verify_reservation(peer_id, msg.reservation):
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.PERMISSION_DENIED,
|
|
"Invalid reservation",
|
|
)
|
|
await stream.reset()
|
|
return
|
|
|
|
# Check resource limits
|
|
if not self.resource_manager.can_accept_connection(peer_id):
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.RESOURCE_LIMIT_EXCEEDED,
|
|
"Connection limit exceeded",
|
|
)
|
|
await stream.reset()
|
|
return
|
|
|
|
try:
|
|
# Store the source stream with properly typed None
|
|
self._active_relays[peer_id] = (stream, None)
|
|
|
|
# Try to connect to the destination with timeout
|
|
with trio.fail_after(STREAM_READ_TIMEOUT):
|
|
dst_stream = await self.host.new_stream(peer_id, [STOP_PROTOCOL_ID])
|
|
if not dst_stream:
|
|
raise ConnectionError("Could not connect to destination")
|
|
|
|
# Send STOP CONNECT message
|
|
stop_msg = StopMessage(
|
|
type=StopMessage.CONNECT,
|
|
# Cast to extended interface with get_remote_peer_id
|
|
peer=cast(INetStreamWithExtras, stream)
|
|
.get_remote_peer_id()
|
|
.to_bytes(),
|
|
)
|
|
await dst_stream.write(stop_msg.SerializeToString())
|
|
|
|
# Wait for response from destination
|
|
resp_bytes = await dst_stream.read()
|
|
resp = StopMessage()
|
|
resp.ParseFromString(resp_bytes)
|
|
|
|
# Handle status attributes from the response
|
|
if resp.HasField("status"):
|
|
# Get code and message attributes with defaults
|
|
status_code = getattr(resp.status, "code", StatusCode.OK)
|
|
# Get message with default
|
|
status_msg = getattr(resp.status, "message", "Unknown error")
|
|
else:
|
|
status_code = StatusCode.OK
|
|
status_msg = "No status provided"
|
|
|
|
if status_code != StatusCode.OK:
|
|
raise ConnectionError(
|
|
f"Destination rejected connection: {status_msg}"
|
|
)
|
|
|
|
# Update active relays with destination stream
|
|
self._active_relays[peer_id] = (stream, dst_stream)
|
|
|
|
# Update reservation connection count
|
|
reservation = self.resource_manager._reservations.get(peer_id)
|
|
if reservation:
|
|
reservation.active_connections += 1
|
|
|
|
# Send success status
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.OK,
|
|
"Connection established",
|
|
)
|
|
|
|
# Start relaying data
|
|
async with trio.open_nursery() as nursery:
|
|
nursery.start_soon(self._relay_data, stream, dst_stream, peer_id)
|
|
nursery.start_soon(self._relay_data, dst_stream, stream, peer_id)
|
|
|
|
except (trio.TooSlowError, ConnectionError) as e:
|
|
logger.error("Error establishing relay connection: %s", str(e))
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.CONNECTION_FAILED,
|
|
str(e),
|
|
)
|
|
if peer_id in self._active_relays:
|
|
del self._active_relays[peer_id]
|
|
# Clean up reservation connection count on failure
|
|
reservation = self.resource_manager._reservations.get(peer_id)
|
|
if reservation:
|
|
reservation.active_connections -= 1
|
|
await stream.reset()
|
|
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
|
|
await dst_stream.reset()
|
|
except Exception as e:
|
|
logger.error("Unexpected error in connect handler: %s", str(e))
|
|
await self._send_status(
|
|
stream,
|
|
StatusCode.CONNECTION_FAILED,
|
|
"Internal error",
|
|
)
|
|
if peer_id in self._active_relays:
|
|
del self._active_relays[peer_id]
|
|
await stream.reset()
|
|
if dst_stream and not cast(INetStreamWithExtras, dst_stream).is_closed():
|
|
await dst_stream.reset()
|
|
|
|
async def _relay_data(
|
|
self,
|
|
src_stream: INetStream,
|
|
dst_stream: INetStream,
|
|
peer_id: ID,
|
|
) -> None:
|
|
"""
|
|
Relay data between two streams.
|
|
|
|
Parameters
|
|
----------
|
|
src_stream : INetStream
|
|
Source stream to read from
|
|
dst_stream : INetStream
|
|
Destination stream to write to
|
|
peer_id : ID
|
|
ID of the peer being relayed
|
|
|
|
"""
|
|
try:
|
|
while True:
|
|
# Read data with retries
|
|
data = await self._read_stream_with_retry(src_stream)
|
|
if not data:
|
|
logger.info("Source stream closed/reset")
|
|
break
|
|
|
|
# Write data with timeout
|
|
try:
|
|
with trio.fail_after(STREAM_WRITE_TIMEOUT):
|
|
await dst_stream.write(data)
|
|
except trio.TooSlowError:
|
|
logger.error("Timeout writing to destination stream")
|
|
break
|
|
except Exception as e:
|
|
logger.error("Error writing to destination stream: %s", str(e))
|
|
break
|
|
|
|
# Update resource usage
|
|
reservation = self.resource_manager._reservations.get(peer_id)
|
|
if reservation:
|
|
reservation.data_used += len(data)
|
|
if reservation.data_used >= reservation.limits.data:
|
|
logger.warning("Data limit exceeded for peer %s", peer_id)
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error("Error relaying data: %s", str(e))
|
|
finally:
|
|
# Clean up streams and remove from active relays
|
|
await src_stream.reset()
|
|
await dst_stream.reset()
|
|
if peer_id in self._active_relays:
|
|
del self._active_relays[peer_id]
|
|
|
|
async def _send_status(
|
|
self,
|
|
stream: ReadWriteCloser,
|
|
code: int,
|
|
message: str,
|
|
) -> None:
|
|
"""Send a status message."""
|
|
try:
|
|
logger.debug("Sending status message with code %s: %s", code, message)
|
|
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
|
# Create a proto Status directly
|
|
pb_status = PbStatus()
|
|
pb_status.code = cast(
|
|
Any, int(code)
|
|
) # Cast to Any to avoid type errors
|
|
pb_status.message = message
|
|
|
|
status_msg = HopMessage(
|
|
type=HopMessage.STATUS,
|
|
status=pb_status,
|
|
)
|
|
|
|
msg_bytes = status_msg.SerializeToString()
|
|
logger.debug("Status message serialized (%d bytes)", len(msg_bytes))
|
|
|
|
await stream.write(msg_bytes)
|
|
logger.debug("Status message sent, waiting for processing")
|
|
|
|
# Wait longer to ensure the message is sent
|
|
await trio.sleep(1.5)
|
|
logger.debug("Status message sending completed")
|
|
except trio.TooSlowError:
|
|
logger.error(
|
|
"Timeout sending status message: code=%s, message=%s", code, message
|
|
)
|
|
except Exception as e:
|
|
logger.error("Error sending status message: %s", str(e))
|
|
|
|
async def _send_stop_status(
|
|
self,
|
|
stream: ReadWriteCloser,
|
|
code: int,
|
|
message: str,
|
|
) -> None:
|
|
"""Send a status message on a STOP stream."""
|
|
try:
|
|
logger.debug("Sending stop status message with code %s: %s", code, message)
|
|
with trio.fail_after(STREAM_WRITE_TIMEOUT * 2): # Double the timeout
|
|
# Create a proto Status directly
|
|
pb_status = PbStatus()
|
|
pb_status.code = cast(
|
|
Any, int(code)
|
|
) # Cast to Any to avoid type errors
|
|
pb_status.message = message
|
|
|
|
status_msg = StopMessage(
|
|
type=StopMessage.STATUS,
|
|
status=pb_status,
|
|
)
|
|
await stream.write(status_msg.SerializeToString())
|
|
await trio.sleep(0.5) # Ensure message is sent
|
|
except Exception as e:
|
|
logger.error("Error sending stop status message: %s", str(e))
|