mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into todo/handletimeout
This commit is contained in:
@ -113,7 +113,7 @@ def parse_identify_response(response: bytes) -> Identify:
|
||||
|
||||
|
||||
def identify_handler_for(
|
||||
host: IHost, use_varint_format: bool = False
|
||||
host: IHost, use_varint_format: bool = True
|
||||
) -> StreamHandlerFn:
|
||||
async def handle_identify(stream: INetStream) -> None:
|
||||
# get observed address from ``stream``
|
||||
|
||||
@ -28,7 +28,7 @@ from libp2p.utils import (
|
||||
varint,
|
||||
)
|
||||
from libp2p.utils.varint import (
|
||||
decode_varint_from_bytes,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
|
||||
from ..identify.identify import (
|
||||
@ -66,49 +66,8 @@ def identify_push_handler_for(
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
try:
|
||||
if use_varint_format:
|
||||
# Read length-prefixed identify message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
break
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
if not length_bytes:
|
||||
logger.warning("No length prefix received from peer %s", peer_id)
|
||||
return
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
logger.warning("Incomplete message received from peer %s", peer_id)
|
||||
return
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, we need to read all data before the stream is closed
|
||||
data = b""
|
||||
try:
|
||||
# Read all available data in a single operation
|
||||
data = await stream.read()
|
||||
except StreamClosed:
|
||||
# Try to read any remaining data
|
||||
try:
|
||||
data = await stream.read()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we got no data, log a warning and return
|
||||
if not data:
|
||||
logger.warning(
|
||||
"No data received in raw format from peer %s", peer_id
|
||||
)
|
||||
return
|
||||
# Use the utility function to read the protobuf message
|
||||
data = await read_length_prefixed_protobuf(stream, use_varint_format)
|
||||
|
||||
identify_msg = Identify()
|
||||
identify_msg.ParseFromString(data)
|
||||
@ -119,6 +78,11 @@ def identify_push_handler_for(
|
||||
)
|
||||
|
||||
logger.debug("Successfully processed identify/push from peer %s", peer_id)
|
||||
|
||||
# Send acknowledgment to indicate successful processing
|
||||
# This ensures the sender knows the message was received before closing
|
||||
await stream.write(b"OK")
|
||||
|
||||
except StreamClosed:
|
||||
logger.debug(
|
||||
"Stream closed while processing identify/push from %s", peer_id
|
||||
@ -127,7 +91,10 @@ def identify_push_handler_for(
|
||||
logger.error("Error processing identify/push from %s: %s", peer_id, e)
|
||||
finally:
|
||||
# Close the stream after processing
|
||||
await stream.close()
|
||||
try:
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass # Ignore errors when closing
|
||||
|
||||
return handle_identify_push
|
||||
|
||||
@ -226,7 +193,20 @@ async def push_identify_to_peer(
|
||||
# Send raw protobuf message
|
||||
await stream.write(response)
|
||||
|
||||
# Close the stream
|
||||
# Wait for acknowledgment from the receiver with timeout
|
||||
# This ensures the message was processed before closing
|
||||
try:
|
||||
with trio.move_on_after(1.0): # 1 second timeout
|
||||
ack = await stream.read(2) # Read "OK" acknowledgment
|
||||
if ack != b"OK":
|
||||
logger.warning(
|
||||
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
|
||||
# Continue anyway, as the message might have been processed
|
||||
|
||||
# Close the stream after acknowledgment (or timeout)
|
||||
await stream.close()
|
||||
|
||||
logger.debug("Successfully pushed identify to peer %s", peer_id)
|
||||
|
||||
@ -102,6 +102,9 @@ class TopicValidator(NamedTuple):
|
||||
is_async: bool
|
||||
|
||||
|
||||
MAX_CONCURRENT_VALIDATORS = 10
|
||||
|
||||
|
||||
class Pubsub(Service, IPubsub):
|
||||
host: IHost
|
||||
|
||||
@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
|
||||
_validator_semaphore: trio.Semaphore
|
||||
|
||||
seen_messages: LastSeenCache
|
||||
|
||||
@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub):
|
||||
msg_id_constructor: Callable[
|
||||
[rpc_pb2.Message], bytes
|
||||
] = get_peer_and_seqno_msg_id,
|
||||
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub):
|
||||
# Therefore, we can only close from the receive side.
|
||||
self.peer_receive_channel = peer_receive
|
||||
self.dead_peer_receive_channel = dead_peer_receive
|
||||
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
|
||||
# Register a notifee
|
||||
self.host.get_network().register_notifee(
|
||||
PubsubNotifee(peer_send, dead_peer_send)
|
||||
@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
logger.debug("successfully published message %s", msg)
|
||||
|
||||
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
async def validate_msg(
|
||||
self,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
) -> None:
|
||||
"""
|
||||
Validate the received message.
|
||||
|
||||
@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub):
|
||||
if not validator(msg_forwarder, msg):
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
# TODO: Implement throttle on async validators
|
||||
|
||||
if len(async_topic_validators) > 0:
|
||||
# Appends to lists are thread safe in CPython
|
||||
results = []
|
||||
|
||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
||||
result = await func(msg_forwarder, msg)
|
||||
results.append(result)
|
||||
results: list[bool] = []
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for async_validator in async_topic_validators:
|
||||
nursery.start_soon(run_async_validator, async_validator)
|
||||
nursery.start_soon(
|
||||
self._run_async_validator,
|
||||
async_validator,
|
||||
msg_forwarder,
|
||||
msg,
|
||||
results,
|
||||
)
|
||||
|
||||
if not all(results):
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
async def _run_async_validator(
|
||||
self,
|
||||
func: AsyncValidatorFn,
|
||||
msg_forwarder: ID,
|
||||
msg: rpc_pb2.Message,
|
||||
results: list[bool],
|
||||
) -> None:
|
||||
async with self._validator_semaphore:
|
||||
result = await func(msg_forwarder, msg)
|
||||
results.append(result)
|
||||
|
||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
"""
|
||||
Push a pubsub message to others.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
@ -32,6 +34,72 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
"""
|
||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
|
||||
read_deadline: int | None
|
||||
write_deadline: int | None
|
||||
|
||||
# TODO: Add lock for read/write to avoid interleaving receiving messages?
|
||||
rw_lock: ReadWriteLock
|
||||
close_lock: trio.Lock
|
||||
|
||||
# NOTE: `dataIn` is size of 8 in Go implementation.
|
||||
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
|
||||
self.event_remote_closed = trio.Event()
|
||||
self.event_reset = trio.Event()
|
||||
self.close_lock = trio.Lock()
|
||||
self.rw_lock = ReadWriteLock()
|
||||
self.incoming_data_channel = incoming_data_channel
|
||||
self._buf = bytearray()
|
||||
|
||||
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
|
||||
:param n: number of bytes to read
|
||||
:return: bytes actually read
|
||||
"""
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
"the number of bytes to read `n` must be non-negative or "
|
||||
f"`None` to indicate read until EOF, got n={n}"
|
||||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if n is None:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until there is
|
||||
# no data, then return.
|
||||
try:
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with
|
||||
# `receive` and catch all kinds of errors here.
|
||||
async with self.rw_lock.read_lock():
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
"the number of bytes to read `n` must be non-negative or "
|
||||
f"`None` to indicate read until EOF, got n={n}"
|
||||
)
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if n is None:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
data: bytes
|
||||
# Peek whether there is data available. If yes, we just read until
|
||||
# there is no data, then return.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
data = self.incoming_data_channel.receive_nowait()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when we are
|
||||
# waiting for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset. "
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
raise MplexStreamEOF
|
||||
except trio.WouldBlock:
|
||||
# We know `receive` will be blocked here. Wait for data here with
|
||||
# `receive` and catch all kinds of errors here.
|
||||
try:
|
||||
data = await self.incoming_data_channel.receive()
|
||||
self._buf.extend(data)
|
||||
except trio.EndOfChannel:
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
except trio.ClosedResourceError as error:
|
||||
# Probably `incoming_data_channel` is closed in `reset` when
|
||||
# we are waiting for `receive`.
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
raise Exception(
|
||||
"`incoming_data_channel` is closed but stream is not reset."
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
@ -162,14 +232,15 @@ class MplexStream(IMuxedStream):
|
||||
|
||||
:return: number of bytes written
|
||||
"""
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.is_initiator
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
await self.muxed_conn.send_message(flag, data, self.stream_id)
|
||||
async with self.rw_lock.write_lock():
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.is_initiator
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
await self.muxed_conn.send_message(flag, data, self.stream_id)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
|
||||
@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamReset,
|
||||
)
|
||||
|
||||
# Configure logger for this module
|
||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||
|
||||
PROTOCOL_ID = "/yamux/1.0.0"
|
||||
TYPE_DATA = 0x0
|
||||
TYPE_WINDOW_UPDATE = 0x1
|
||||
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
|
||||
"""
|
||||
if increment <= 0:
|
||||
# If increment is zero or negative, skip sending update
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Skipping window update"
|
||||
f"(increment={increment})"
|
||||
)
|
||||
return
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update with increment={increment}"
|
||||
)
|
||||
|
||||
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If the stream is closed for receiving and the buffer is empty, raise EOF
|
||||
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
# If buffer is not available, check if stream is closed
|
||||
if buffer is None:
|
||||
logging.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
logger.debug(f"Stream {self.stream_id}: No buffer available")
|
||||
raise MuxedStreamEOF("Stream buffer closed")
|
||||
|
||||
# If we have data in buffer, process it
|
||||
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
|
||||
# Send window update for the chunk we just read
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(chunk)
|
||||
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
|
||||
await self.send_window_update(len(chunk), skip_lock=True)
|
||||
|
||||
# If stream is closed (FIN received) and buffer is empty, break
|
||||
if self.recv_closed and len(buffer) == 0:
|
||||
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
|
||||
break
|
||||
|
||||
# If stream was reset, raise reset error
|
||||
if self.reset_received:
|
||||
logging.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
logger.debug(f"Stream {self.stream_id}: Stream was reset")
|
||||
raise MuxedStreamReset("Stream was reset")
|
||||
|
||||
# Wait for more data or stream closure
|
||||
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
|
||||
await self.conn.stream_events[self.stream_id].wait()
|
||||
self.conn.stream_events[self.stream_id] = trio.Event()
|
||||
|
||||
# After loop exit, first check if we have data to return
|
||||
if data:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
|
||||
)
|
||||
return data
|
||||
|
||||
# No data accumulated, now check why we exited the loop
|
||||
if self.conn.event_shutting_down.is_set():
|
||||
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
|
||||
# Return empty data
|
||||
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
|
||||
data = await self.conn.read_stream(self.stream_id, n)
|
||||
async with self.window_lock:
|
||||
self.recv_window += len(data)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Sending window update after read, "
|
||||
f"increment={len(data)}"
|
||||
)
|
||||
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
|
||||
|
||||
async def reset(self) -> None:
|
||||
if not self.closed:
|
||||
logging.debug(f"Resetting stream {self.stream_id}")
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
|
||||
self._nursery: Nursery | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
logging.debug(f"Starting Yamux for {self.peer_id}")
|
||||
logger.debug(f"Starting Yamux for {self.peer_id}")
|
||||
if self.event_started.is_set():
|
||||
return
|
||||
async with trio.open_nursery() as nursery:
|
||||
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
|
||||
return self.is_initiator_value
|
||||
|
||||
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
|
||||
logging.debug(f"Closing Yamux connection with code {error_code}")
|
||||
logger.debug(f"Closing Yamux connection with code {error_code}")
|
||||
async with self.streams_lock:
|
||||
if not self.event_shutting_down.is_set():
|
||||
try:
|
||||
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(header)
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to send GO_AWAY: {e}")
|
||||
logger.debug(f"Failed to send GO_AWAY: {e}")
|
||||
self.event_shutting_down.set()
|
||||
for stream in self.streams.values():
|
||||
stream.closed = True
|
||||
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
|
||||
self.stream_events.clear()
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as e:
|
||||
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
|
||||
self.event_closed.set()
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
if self.on_close is not None:
|
||||
await self.on_close()
|
||||
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
|
||||
)
|
||||
logging.debug(f"Sending SYN header for stream {stream_id}")
|
||||
logger.debug(f"Sending SYN header for stream {stream_id}")
|
||||
await self.secured_conn.write(header)
|
||||
return stream
|
||||
except Exception as e:
|
||||
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
|
||||
raise e
|
||||
|
||||
async def accept_stream(self) -> IMuxedStream:
|
||||
logging.debug("Waiting for new stream")
|
||||
logger.debug("Waiting for new stream")
|
||||
try:
|
||||
stream = await self.new_stream_receive_channel.receive()
|
||||
logging.debug(f"Received stream {stream.stream_id}")
|
||||
logger.debug(f"Received stream {stream.stream_id}")
|
||||
return stream
|
||||
except trio.EndOfChannel:
|
||||
raise MuxedStreamError("No new streams available")
|
||||
|
||||
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
|
||||
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
while True:
|
||||
async with self.streams_lock:
|
||||
if stream_id not in self.streams:
|
||||
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
|
||||
raise MuxedStreamEOF("Stream closed")
|
||||
if self.event_shutting_down.is_set():
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
|
||||
)
|
||||
raise MuxedStreamEOF("Connection shut down")
|
||||
stream = self.streams[stream_id]
|
||||
buffer = self.stream_buffers.get(stream_id)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}: "
|
||||
f"closed={stream.closed}, "
|
||||
f"recv_closed={stream.recv_closed}, "
|
||||
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
|
||||
f"buffer_len={len(buffer) if buffer else 0}"
|
||||
)
|
||||
if buffer is None:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"Buffer gone, assuming closed"
|
||||
)
|
||||
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# If reset received and buffer is empty, raise reset
|
||||
if stream.reset_received:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"reset_received=True, raising MuxedStreamReset"
|
||||
)
|
||||
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
|
||||
else:
|
||||
data = bytes(buffer[:n])
|
||||
del buffer[:n]
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Returning {len(data)} bytes"
|
||||
f"from stream {self.peer_id}:{stream_id}, "
|
||||
f"buffer_len={len(buffer)}"
|
||||
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
|
||||
return data
|
||||
# Check if stream is closed
|
||||
if stream.closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"closed=True, raising MuxedStreamReset"
|
||||
)
|
||||
raise MuxedStreamReset("Stream is reset or closed")
|
||||
# Check if recv_closed and buffer empty
|
||||
if stream.recv_closed:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Stream {self.peer_id}:{stream_id}:"
|
||||
f"recv_closed=True, buffer empty, raising EOF"
|
||||
)
|
||||
raise MuxedStreamEOF("Stream is closed for receiving")
|
||||
|
||||
# Wait for data if stream is still open
|
||||
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
|
||||
try:
|
||||
await self.stream_events[stream_id].wait()
|
||||
self.stream_events[stream_id] = trio.Event()
|
||||
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
|
||||
try:
|
||||
header = await self.secured_conn.read(HEADER_SIZE)
|
||||
if not header or len(header) < HEADER_SIZE:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Connection closed orincomplete header for peer {self.peer_id}"
|
||||
)
|
||||
self.event_shutting_down.set()
|
||||
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
|
||||
version, typ, flags, stream_id, length = struct.unpack(
|
||||
YAMUX_HEADER_FORMAT, header
|
||||
)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received header for peer {self.peer_id}:"
|
||||
f"type={typ}, flags={flags}, stream_id={stream_id},"
|
||||
f"length={length}"
|
||||
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
|
||||
0,
|
||||
)
|
||||
await self.secured_conn.write(ack_header)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Sending stream {stream_id}"
|
||||
f"to channel for peer {self.peer_id}"
|
||||
)
|
||||
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_RST:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Resetting stream {stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
self.streams[stream_id].closed = True
|
||||
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
|
||||
elif typ == TYPE_DATA and flags & FLAG_ACK:
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ACK for stream"
|
||||
f"{stream_id} for peer {self.peer_id}"
|
||||
)
|
||||
elif typ == TYPE_GO_AWAY:
|
||||
error_code = length
|
||||
if error_code == GO_AWAY_NORMAL:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received GO_AWAY for peer"
|
||||
f"{self.peer_id}: Normal termination"
|
||||
)
|
||||
elif error_code == GO_AWAY_PROTOCOL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
|
||||
)
|
||||
elif error_code == GO_AWAY_INTERNAL_ERROR:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Received GO_AWAY for peer {self.peer_id}"
|
||||
f"with unknown error code: {error_code}"
|
||||
)
|
||||
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
|
||||
break
|
||||
elif typ == TYPE_PING:
|
||||
if flags & FLAG_SYN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping request with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
|
||||
)
|
||||
await self.secured_conn.write(ping_header)
|
||||
elif flags & FLAG_ACK:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received ping response with value"
|
||||
f"{length} for peer {self.peer_id}"
|
||||
)
|
||||
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
|
||||
self.stream_buffers[stream_id].extend(data)
|
||||
self.stream_events[stream_id].set()
|
||||
if flags & FLAG_FIN:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received FIN for stream {self.peer_id}:"
|
||||
f"{stream_id}, marking recv_closed"
|
||||
)
|
||||
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
|
||||
if self.streams[stream_id].send_closed:
|
||||
self.streams[stream_id].closed = True
|
||||
except Exception as e:
|
||||
logging.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
logger.error(f"Error reading data for stream {stream_id}: {e}")
|
||||
# Mark stream as closed on read error
|
||||
async with self.streams_lock:
|
||||
if stream_id in self.streams:
|
||||
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
async with stream.window_lock:
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Received window update for stream"
|
||||
f"{self.peer_id}:{stream_id},"
|
||||
f" increment: {increment}"
|
||||
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
|
||||
and details.get("requested_count") == 2
|
||||
and details.get("received_count") == 0
|
||||
):
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Stream closed cleanly for peer {self.peer_id}"
|
||||
+ f" (IncompleteReadError: {details})"
|
||||
)
|
||||
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
|
||||
await self._cleanup_on_error()
|
||||
break
|
||||
else:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logging.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Handle RawConnError with more nuance
|
||||
if isinstance(e, RawConnError):
|
||||
error_msg = str(e)
|
||||
# If RawConnError is empty, it's likely normal cleanup
|
||||
if not error_msg.strip():
|
||||
logger.info(
|
||||
f"RawConnError (empty) during cleanup for peer "
|
||||
f"{self.peer_id} (normal connection shutdown)"
|
||||
)
|
||||
else:
|
||||
# Log non-empty RawConnError as warning
|
||||
logger.warning(
|
||||
f"RawConnError during connection handling for peer "
|
||||
f"{self.peer_id}: {error_msg}"
|
||||
)
|
||||
else:
|
||||
# Log all other errors normally
|
||||
logger.error(
|
||||
f"Error in handle_incoming for peer {self.peer_id}: "
|
||||
+ f"{type(e).__name__}: {str(e)}"
|
||||
)
|
||||
# Don't crash the whole connection for temporary errors
|
||||
if self.event_shutting_down.is_set() or isinstance(
|
||||
e, (RawConnError, OSError)
|
||||
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
|
||||
# Close the secured connection
|
||||
try:
|
||||
await self.secured_conn.close()
|
||||
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
|
||||
except Exception as close_error:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
|
||||
)
|
||||
|
||||
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
|
||||
|
||||
# Call on_close callback if provided
|
||||
if self.on_close:
|
||||
logging.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
logger.debug(f"Calling on_close for peer {self.peer_id}")
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.on_close):
|
||||
await self.on_close()
|
||||
else:
|
||||
self.on_close()
|
||||
except Exception as callback_error:
|
||||
logging.error(f"Error in on_close callback: {callback_error}")
|
||||
logger.error(f"Error in on_close callback: {callback_error}")
|
||||
|
||||
# Cancel nursery tasks
|
||||
if self._nursery:
|
||||
|
||||
@ -9,6 +9,7 @@ from libp2p.utils.varint import (
|
||||
read_varint_prefixed_bytes,
|
||||
decode_varint_from_bytes,
|
||||
decode_varint_with_size,
|
||||
read_length_prefixed_protobuf,
|
||||
)
|
||||
from libp2p.utils.version import (
|
||||
get_agent_version,
|
||||
@ -24,4 +25,5 @@ __all__ = [
|
||||
"read_varint_prefixed_bytes",
|
||||
"decode_varint_from_bytes",
|
||||
"decode_varint_with_size",
|
||||
"read_length_prefixed_protobuf",
|
||||
]
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import BinaryIO
|
||||
|
||||
from libp2p.abc import INetStream
|
||||
from libp2p.exceptions import (
|
||||
ParseError,
|
||||
)
|
||||
@ -25,42 +27,41 @@ HIGH_MASK = 2**7
|
||||
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
|
||||
|
||||
|
||||
def encode_uvarint(number: int) -> bytes:
|
||||
"""Pack `number` into varint bytes."""
|
||||
buf = b""
|
||||
while True:
|
||||
towrite = number & 0x7F
|
||||
number >>= 7
|
||||
if number:
|
||||
buf += bytes((towrite | 0x80,))
|
||||
else:
|
||||
buf += bytes((towrite,))
|
||||
def encode_uvarint(value: int) -> bytes:
|
||||
"""Encode an unsigned integer as a varint."""
|
||||
if value < 0:
|
||||
raise ValueError("Cannot encode negative value as uvarint")
|
||||
|
||||
result = bytearray()
|
||||
while value >= 0x80:
|
||||
result.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
result.append(value & 0x7F)
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def decode_uvarint(data: bytes) -> int:
|
||||
"""Decode a varint from bytes."""
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
result = 0
|
||||
shift = 0
|
||||
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
return buf
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def decode_varint_from_bytes(data: bytes) -> int:
|
||||
"""
|
||||
Decode a varint from bytes and return the value.
|
||||
|
||||
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
|
||||
"""
|
||||
res = 0
|
||||
for shift in itertools.count(0, 7):
|
||||
if shift > SHIFT_64_BIT_MAX:
|
||||
raise ParseError("Integer is too large...")
|
||||
|
||||
if not data:
|
||||
raise ParseError("Unexpected end of data")
|
||||
|
||||
value = data[0]
|
||||
data = data[1:]
|
||||
|
||||
res += (value & LOW_MASK) << shift
|
||||
|
||||
if not value & HIGH_MASK:
|
||||
break
|
||||
return res
|
||||
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
|
||||
return decode_uvarint(data)
|
||||
|
||||
|
||||
async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
|
||||
|
||||
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
|
||||
"""
|
||||
Decode a varint from bytes and return (value, bytes_consumed).
|
||||
Returns (0, 0) if the data doesn't start with a valid varint.
|
||||
Decode a varint from bytes and return both the value and the number of bytes
|
||||
consumed.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (value, bytes_consumed)
|
||||
|
||||
"""
|
||||
try:
|
||||
# Calculate how many bytes the varint consumes
|
||||
varint_size = 0
|
||||
for i, byte in enumerate(data):
|
||||
varint_size += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
result = 0
|
||||
shift = 0
|
||||
bytes_consumed = 0
|
||||
|
||||
if varint_size == 0:
|
||||
return 0, 0
|
||||
for byte in data:
|
||||
result |= (byte & 0x7F) << shift
|
||||
bytes_consumed += 1
|
||||
if (byte & 0x80) == 0:
|
||||
break
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Varint too long")
|
||||
|
||||
# Extract just the varint bytes
|
||||
varint_bytes = data[:varint_size]
|
||||
|
||||
# Decode the varint
|
||||
value = decode_varint_from_bytes(varint_bytes)
|
||||
|
||||
return value, varint_size
|
||||
except Exception:
|
||||
return 0, 0
|
||||
return result, bytes_consumed
|
||||
|
||||
|
||||
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
|
||||
varint_len = encode_uvarint(len(msg_bytes))
|
||||
return varint_len + msg_bytes
|
||||
def encode_varint_prefixed(data: bytes) -> bytes:
|
||||
"""Encode data with a varint length prefix."""
|
||||
length_bytes = encode_uvarint(len(data))
|
||||
return length_bytes + data
|
||||
|
||||
|
||||
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
|
||||
@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes:
|
||||
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
|
||||
)
|
||||
return msg_bytes[:-1]
|
||||
|
||||
|
||||
def read_varint_prefixed_bytes_sync(
|
||||
stream: BinaryIO, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""
|
||||
Read varint-prefixed bytes from a stream.
|
||||
|
||||
Args:
|
||||
stream: A stream-like object with a read() method
|
||||
max_length: Maximum allowed data length to prevent memory exhaustion
|
||||
|
||||
Returns:
|
||||
bytes: The data without the length prefix
|
||||
|
||||
Raises:
|
||||
ValueError: If the length prefix is invalid or too large
|
||||
EOFError: If the stream ends unexpectedly
|
||||
|
||||
"""
|
||||
# Read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
byte_data = stream.read(1)
|
||||
if not byte_data:
|
||||
raise EOFError("Stream ended while reading varint length prefix")
|
||||
|
||||
length_bytes += byte_data
|
||||
if byte_data[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
# Decode the length
|
||||
length = decode_uvarint(length_bytes)
|
||||
|
||||
if length > max_length:
|
||||
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
|
||||
|
||||
# Read the data
|
||||
data = stream.read(length)
|
||||
if len(data) != length:
|
||||
raise EOFError(f"Expected {length} bytes, got {len(data)}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def read_length_prefixed_protobuf(
|
||||
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
|
||||
) -> bytes:
|
||||
"""Read a protobuf message from a stream, handling both formats."""
|
||||
if use_varint_format:
|
||||
# Read length-prefixed protobuf message from the stream
|
||||
# First read the varint length prefix
|
||||
length_bytes = b""
|
||||
while True:
|
||||
b = await stream.read(1)
|
||||
if not b:
|
||||
raise Exception("No length prefix received")
|
||||
|
||||
length_bytes += b
|
||||
if b[0] & 0x80 == 0:
|
||||
break
|
||||
|
||||
msg_length = decode_varint_from_bytes(length_bytes)
|
||||
|
||||
if msg_length > max_length:
|
||||
raise Exception(
|
||||
f"Message length {msg_length} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
# Read the protobuf message
|
||||
data = await stream.read(msg_length)
|
||||
if len(data) != msg_length:
|
||||
raise Exception(
|
||||
f"Incomplete message: expected {msg_length}, got {len(data)}"
|
||||
)
|
||||
|
||||
return data
|
||||
else:
|
||||
# Read raw protobuf message from the stream
|
||||
# For raw format, read all available data in one go
|
||||
data = await stream.read()
|
||||
|
||||
# If we got no data, raise an exception
|
||||
if not data:
|
||||
raise Exception("No data received in raw format")
|
||||
|
||||
if len(data) > max_length:
|
||||
raise Exception(
|
||||
f"Message length {len(data)} exceeds maximum allowed {max_length}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user