Merge branch 'main' into todo/handletimeout

This commit is contained in:
Manu Sheel Gupta
2025-07-21 08:17:08 -07:00
committed by GitHub
24 changed files with 2307 additions and 766 deletions

View File

@ -113,7 +113,7 @@ def parse_identify_response(response: bytes) -> Identify:
def identify_handler_for(
host: IHost, use_varint_format: bool = False
host: IHost, use_varint_format: bool = True
) -> StreamHandlerFn:
async def handle_identify(stream: INetStream) -> None:
# get observed address from ``stream``

View File

@ -28,7 +28,7 @@ from libp2p.utils import (
varint,
)
from libp2p.utils.varint import (
decode_varint_from_bytes,
read_length_prefixed_protobuf,
)
from ..identify.identify import (
@ -66,49 +66,8 @@ def identify_push_handler_for(
peer_id = stream.muxed_conn.peer_id
try:
if use_varint_format:
# Read length-prefixed identify message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
break
length_bytes += b
if b[0] & 0x80 == 0:
break
if not length_bytes:
logger.warning("No length prefix received from peer %s", peer_id)
return
msg_length = decode_varint_from_bytes(length_bytes)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
logger.warning("Incomplete message received from peer %s", peer_id)
return
else:
# Read raw protobuf message from the stream
# For raw format, we need to read all data before the stream is closed
data = b""
try:
# Read all available data in a single operation
data = await stream.read()
except StreamClosed:
# Try to read any remaining data
try:
data = await stream.read()
except Exception:
pass
# If we got no data, log a warning and return
if not data:
logger.warning(
"No data received in raw format from peer %s", peer_id
)
return
# Use the utility function to read the protobuf message
data = await read_length_prefixed_protobuf(stream, use_varint_format)
identify_msg = Identify()
identify_msg.ParseFromString(data)
@ -119,6 +78,11 @@ def identify_push_handler_for(
)
logger.debug("Successfully processed identify/push from peer %s", peer_id)
# Send acknowledgment to indicate successful processing
# This ensures the sender knows the message was received before closing
await stream.write(b"OK")
except StreamClosed:
logger.debug(
"Stream closed while processing identify/push from %s", peer_id
@ -127,7 +91,10 @@ def identify_push_handler_for(
logger.error("Error processing identify/push from %s: %s", peer_id, e)
finally:
# Close the stream after processing
await stream.close()
try:
await stream.close()
except Exception:
pass # Ignore errors when closing
return handle_identify_push
@ -226,7 +193,20 @@ async def push_identify_to_peer(
# Send raw protobuf message
await stream.write(response)
# Close the stream
# Wait for acknowledgment from the receiver with timeout
# This ensures the message was processed before closing
try:
with trio.move_on_after(1.0): # 1 second timeout
ack = await stream.read(2) # Read "OK" acknowledgment
if ack != b"OK":
logger.warning(
"Unexpected acknowledgment from peer %s: %s", peer_id, ack
)
except Exception as e:
logger.debug("No acknowledgment received from peer %s: %s", peer_id, e)
# Continue anyway, as the message might have been processed
# Close the stream after acknowledgment (or timeout)
await stream.close()
logger.debug("Successfully pushed identify to peer %s", peer_id)

View File

@ -102,6 +102,9 @@ class TopicValidator(NamedTuple):
is_async: bool
MAX_CONCURRENT_VALIDATORS = 10
class Pubsub(Service, IPubsub):
host: IHost
@ -109,6 +112,7 @@ class Pubsub(Service, IPubsub):
peer_receive_channel: trio.MemoryReceiveChannel[ID]
dead_peer_receive_channel: trio.MemoryReceiveChannel[ID]
_validator_semaphore: trio.Semaphore
seen_messages: LastSeenCache
@ -143,6 +147,7 @@ class Pubsub(Service, IPubsub):
msg_id_constructor: Callable[
[rpc_pb2.Message], bytes
] = get_peer_and_seqno_msg_id,
max_concurrent_validator_count: int = MAX_CONCURRENT_VALIDATORS,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -168,6 +173,7 @@ class Pubsub(Service, IPubsub):
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive
self.dead_peer_receive_channel = dead_peer_receive
self._validator_semaphore = trio.Semaphore(max_concurrent_validator_count)
# Register a notifee
self.host.get_network().register_notifee(
PubsubNotifee(peer_send, dead_peer_send)
@ -657,7 +663,11 @@ class Pubsub(Service, IPubsub):
logger.debug("successfully published message %s", msg)
async def validate_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
async def validate_msg(
self,
msg_forwarder: ID,
msg: rpc_pb2.Message,
) -> None:
"""
Validate the received message.
@ -680,23 +690,34 @@ class Pubsub(Service, IPubsub):
if not validator(msg_forwarder, msg):
raise ValidationError(f"Validation failed for msg={msg}")
# TODO: Implement throttle on async validators
if len(async_topic_validators) > 0:
# Appends to lists are thread safe in CPython
results = []
async def run_async_validator(func: AsyncValidatorFn) -> None:
result = await func(msg_forwarder, msg)
results.append(result)
results: list[bool] = []
async with trio.open_nursery() as nursery:
for async_validator in async_topic_validators:
nursery.start_soon(run_async_validator, async_validator)
nursery.start_soon(
self._run_async_validator,
async_validator,
msg_forwarder,
msg,
results,
)
if not all(results):
raise ValidationError(f"Validation failed for msg={msg}")
async def _run_async_validator(
self,
func: AsyncValidatorFn,
msg_forwarder: ID,
msg: rpc_pb2.Message,
results: list[bool],
) -> None:
async with self._validator_semaphore:
result = await func(msg_forwarder, msg)
results.append(result)
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
"""
Push a pubsub message to others.

View File

@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -32,6 +34,72 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
read_deadline: int | None
write_deadline: int | None
# TODO: Add lock for read/write to avoid interleaving receiving messages?
rw_lock: ReadWriteLock
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read
:return: bytes actually read
"""
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
try:
data = await self.incoming_data_channel.receive()
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are
# waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when
# we are waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset."
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> None:
"""
@ -162,14 +232,15 @@ class MplexStream(IMuxedStream):
:return: number of bytes written
"""
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async def close(self) -> None:
"""

View File

@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamReset,
)
# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
PROTOCOL_ID = "/yamux/1.0.0"
TYPE_DATA = 0x0
TYPE_WINDOW_UPDATE = 0x1
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
"""
if increment <= 0:
# If increment is zero or negative, skip sending update
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
# If the stream is closed for receiving and the buffer is empty, raise EOF
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
)
raise MuxedStreamEOF("Stream is closed for receiving")
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
# If buffer is not available, check if stream is closed
if buffer is None:
logging.debug(f"Stream {self.stream_id}: No buffer available")
logger.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
# Send window update for the chunk we just read
async with self.window_lock:
self.recv_window += len(chunk)
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break
# If stream was reset, raise reset error
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
logger.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return
if data:
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
)
return data
# No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
# Return empty data
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}"
)
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
async def close(self) -> None:
if not self.send_closed:
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
async def reset(self) -> None:
if not self.closed:
logging.debug(f"Resetting stream {self.stream_id}")
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
self._nursery: Nursery | None = None
async def start(self) -> None:
logging.debug(f"Starting Yamux for {self.peer_id}")
logger.debug(f"Starting Yamux for {self.peer_id}")
if self.event_started.is_set():
return
async with trio.open_nursery() as nursery:
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
return self.is_initiator_value
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
logging.debug(f"Closing Yamux connection with code {error_code}")
logger.debug(f"Closing Yamux connection with code {error_code}")
async with self.streams_lock:
if not self.event_shutting_down.is_set():
try:
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(header)
except Exception as e:
logging.debug(f"Failed to send GO_AWAY: {e}")
logger.debug(f"Failed to send GO_AWAY: {e}")
self.event_shutting_down.set()
for stream in self.streams.values():
stream.closed = True
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
self.stream_events.clear()
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as e:
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
self.event_closed.set()
if self.on_close:
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
if inspect.iscoroutinefunction(self.on_close):
if self.on_close is not None:
await self.on_close()
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
)
logging.debug(f"Sending SYN header for stream {stream_id}")
logger.debug(f"Sending SYN header for stream {stream_id}")
await self.secured_conn.write(header)
return stream
except Exception as e:
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
raise e
async def accept_stream(self) -> IMuxedStream:
logging.debug("Waiting for new stream")
logger.debug("Waiting for new stream")
try:
stream = await self.new_stream_receive_channel.receive()
logging.debug(f"Received stream {stream.stream_id}")
logger.debug(f"Received stream {stream.stream_id}")
return stream
except trio.EndOfChannel:
raise MuxedStreamError("No new streams available")
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
if n is None:
n = -1
while True:
async with self.streams_lock:
if stream_id not in self.streams:
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
raise MuxedStreamEOF("Stream closed")
if self.event_shutting_down.is_set():
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
)
raise MuxedStreamEOF("Connection shut down")
stream = self.streams[stream_id]
buffer = self.stream_buffers.get(stream_id)
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}: "
f"closed={stream.closed}, "
f"recv_closed={stream.recv_closed}, "
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
f"buffer_len={len(buffer) if buffer else 0}"
)
if buffer is None:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"Buffer gone, assuming closed"
)
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
logger.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
return data
# If reset received and buffer is empty, raise reset
if stream.reset_received:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"reset_received=True, raising MuxedStreamReset"
)
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
logger.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
return data
# Check if stream is closed
if stream.closed:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"closed=True, raising MuxedStreamReset"
)
raise MuxedStreamReset("Stream is reset or closed")
# Check if recv_closed and buffer empty
if stream.recv_closed:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"recv_closed=True, buffer empty, raising EOF"
)
raise MuxedStreamEOF("Stream is closed for receiving")
# Wait for data if stream is still open
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
try:
await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event()
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
try:
header = await self.secured_conn.read(HEADER_SIZE)
if not header or len(header) < HEADER_SIZE:
logging.debug(
logger.debug(
f"Connection closed orincomplete header for peer {self.peer_id}"
)
self.event_shutting_down.set()
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
version, typ, flags, stream_id, length = struct.unpack(
YAMUX_HEADER_FORMAT, header
)
logging.debug(
logger.debug(
f"Received header for peer {self.peer_id}:"
f"type={typ}, flags={flags}, stream_id={stream_id},"
f"length={length}"
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
0,
)
await self.secured_conn.write(ack_header)
logging.debug(
logger.debug(
f"Sending stream {stream_id}"
f"to channel for peer {self.peer_id}"
)
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_RST:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
logger.debug(
f"Resetting stream {stream_id} for peer {self.peer_id}"
)
self.streams[stream_id].closed = True
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_ACK:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
logger.debug(
f"Received ACK for stream"
f"{stream_id} for peer {self.peer_id}"
)
elif typ == TYPE_GO_AWAY:
error_code = length
if error_code == GO_AWAY_NORMAL:
logging.debug(
logger.debug(
f"Received GO_AWAY for peer"
f"{self.peer_id}: Normal termination"
)
elif error_code == GO_AWAY_PROTOCOL_ERROR:
logging.error(
logger.error(
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
)
elif error_code == GO_AWAY_INTERNAL_ERROR:
logging.error(
logger.error(
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
)
else:
logging.error(
logger.error(
f"Received GO_AWAY for peer {self.peer_id}"
f"with unknown error code: {error_code}"
)
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
break
elif typ == TYPE_PING:
if flags & FLAG_SYN:
logging.debug(
logger.debug(
f"Received ping request with value"
f"{length} for peer {self.peer_id}"
)
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(ping_header)
elif flags & FLAG_ACK:
logging.debug(
logger.debug(
f"Received ping response with value"
f"{length} for peer {self.peer_id}"
)
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
self.stream_buffers[stream_id].extend(data)
self.stream_events[stream_id].set()
if flags & FLAG_FIN:
logging.debug(
logger.debug(
f"Received FIN for stream {self.peer_id}:"
f"{stream_id}, marking recv_closed"
)
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
if self.streams[stream_id].send_closed:
self.streams[stream_id].closed = True
except Exception as e:
logging.error(f"Error reading data for stream {stream_id}: {e}")
logger.error(f"Error reading data for stream {stream_id}: {e}")
# Mark stream as closed on read error
async with self.streams_lock:
if stream_id in self.streams:
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
if stream_id in self.streams:
stream = self.streams[stream_id]
async with stream.window_lock:
logging.debug(
logger.debug(
f"Received window update for stream"
f"{self.peer_id}:{stream_id},"
f" increment: {increment}"
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
and details.get("requested_count") == 2
and details.get("received_count") == 0
):
logging.info(
logger.info(
f"Stream closed cleanly for peer {self.peer_id}"
+ f" (IncompleteReadError: {details})"
)
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
await self._cleanup_on_error()
break
else:
logging.error(
logger.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
else:
logging.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Handle RawConnError with more nuance
if isinstance(e, RawConnError):
error_msg = str(e)
# If RawConnError is empty, it's likely normal cleanup
if not error_msg.strip():
logger.info(
f"RawConnError (empty) during cleanup for peer "
f"{self.peer_id} (normal connection shutdown)"
)
else:
# Log non-empty RawConnError as warning
logger.warning(
f"RawConnError during connection handling for peer "
f"{self.peer_id}: {error_msg}"
)
else:
# Log all other errors normally
logger.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Don't crash the whole connection for temporary errors
if self.event_shutting_down.is_set() or isinstance(
e, (RawConnError, OSError)
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
# Close the secured connection
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as close_error:
logging.error(
logger.error(
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
)
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
# Call on_close callback if provided
if self.on_close:
logging.debug(f"Calling on_close for peer {self.peer_id}")
logger.debug(f"Calling on_close for peer {self.peer_id}")
try:
if inspect.iscoroutinefunction(self.on_close):
await self.on_close()
else:
self.on_close()
except Exception as callback_error:
logging.error(f"Error in on_close callback: {callback_error}")
logger.error(f"Error in on_close callback: {callback_error}")
# Cancel nursery tasks
if self._nursery:

View File

@ -9,6 +9,7 @@ from libp2p.utils.varint import (
read_varint_prefixed_bytes,
decode_varint_from_bytes,
decode_varint_with_size,
read_length_prefixed_protobuf,
)
from libp2p.utils.version import (
get_agent_version,
@ -24,4 +25,5 @@ __all__ = [
"read_varint_prefixed_bytes",
"decode_varint_from_bytes",
"decode_varint_with_size",
"read_length_prefixed_protobuf",
]

View File

@ -1,7 +1,9 @@
import itertools
import logging
import math
from typing import BinaryIO
from libp2p.abc import INetStream
from libp2p.exceptions import (
ParseError,
)
@ -25,42 +27,41 @@ HIGH_MASK = 2**7
SHIFT_64_BIT_MAX = int(math.ceil(64 / 7)) * 7
def encode_uvarint(number: int) -> bytes:
"""Pack `number` into varint bytes."""
buf = b""
while True:
towrite = number & 0x7F
number >>= 7
if number:
buf += bytes((towrite | 0x80,))
else:
buf += bytes((towrite,))
def encode_uvarint(value: int) -> bytes:
"""Encode an unsigned integer as a varint."""
if value < 0:
raise ValueError("Cannot encode negative value as uvarint")
result = bytearray()
while value >= 0x80:
result.append((value & 0x7F) | 0x80)
value >>= 7
result.append(value & 0x7F)
return bytes(result)
def decode_uvarint(data: bytes) -> int:
"""Decode a varint from bytes."""
if not data:
raise ParseError("Unexpected end of data")
result = 0
shift = 0
for byte in data:
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
break
return buf
shift += 7
if shift >= 64:
raise ValueError("Varint too long")
return result
def decode_varint_from_bytes(data: bytes) -> int:
"""
Decode a varint from bytes and return the value.
This is a synchronous version of decode_uvarint_from_stream for already-read bytes.
"""
res = 0
for shift in itertools.count(0, 7):
if shift > SHIFT_64_BIT_MAX:
raise ParseError("Integer is too large...")
if not data:
raise ParseError("Unexpected end of data")
value = data[0]
data = data[1:]
res += (value & LOW_MASK) << shift
if not value & HIGH_MASK:
break
return res
"""Decode a varint from bytes (alias for decode_uvarint for backward comp)."""
return decode_uvarint(data)
async def decode_uvarint_from_stream(reader: Reader) -> int:
@ -84,34 +85,33 @@ async def decode_uvarint_from_stream(reader: Reader) -> int:
def decode_varint_with_size(data: bytes) -> tuple[int, int]:
"""
Decode a varint from bytes and return (value, bytes_consumed).
Returns (0, 0) if the data doesn't start with a valid varint.
Decode a varint from bytes and return both the value and the number of bytes
consumed.
Returns:
Tuple[int, int]: (value, bytes_consumed)
"""
try:
# Calculate how many bytes the varint consumes
varint_size = 0
for i, byte in enumerate(data):
varint_size += 1
if (byte & 0x80) == 0:
break
result = 0
shift = 0
bytes_consumed = 0
if varint_size == 0:
return 0, 0
for byte in data:
result |= (byte & 0x7F) << shift
bytes_consumed += 1
if (byte & 0x80) == 0:
break
shift += 7
if shift >= 64:
raise ValueError("Varint too long")
# Extract just the varint bytes
varint_bytes = data[:varint_size]
# Decode the varint
value = decode_varint_from_bytes(varint_bytes)
return value, varint_size
except Exception:
return 0, 0
return result, bytes_consumed
def encode_varint_prefixed(msg_bytes: bytes) -> bytes:
varint_len = encode_uvarint(len(msg_bytes))
return varint_len + msg_bytes
def encode_varint_prefixed(data: bytes) -> bytes:
"""Encode data with a varint length prefix."""
length_bytes = encode_uvarint(len(data))
return length_bytes + data
async def read_varint_prefixed_bytes(reader: Reader) -> bytes:
@ -138,3 +138,95 @@ async def read_delim(reader: Reader) -> bytes:
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
)
return msg_bytes[:-1]
def read_varint_prefixed_bytes_sync(
stream: BinaryIO, max_length: int = 1024 * 1024
) -> bytes:
"""
Read varint-prefixed bytes from a stream.
Args:
stream: A stream-like object with a read() method
max_length: Maximum allowed data length to prevent memory exhaustion
Returns:
bytes: The data without the length prefix
Raises:
ValueError: If the length prefix is invalid or too large
EOFError: If the stream ends unexpectedly
"""
# Read the varint length prefix
length_bytes = b""
while True:
byte_data = stream.read(1)
if not byte_data:
raise EOFError("Stream ended while reading varint length prefix")
length_bytes += byte_data
if byte_data[0] & 0x80 == 0:
break
# Decode the length
length = decode_uvarint(length_bytes)
if length > max_length:
raise ValueError(f"Data length {length} exceeds maximum allowed {max_length}")
# Read the data
data = stream.read(length)
if len(data) != length:
raise EOFError(f"Expected {length} bytes, got {len(data)}")
return data
async def read_length_prefixed_protobuf(
stream: INetStream, use_varint_format: bool = True, max_length: int = 1024 * 1024
) -> bytes:
"""Read a protobuf message from a stream, handling both formats."""
if use_varint_format:
# Read length-prefixed protobuf message from the stream
# First read the varint length prefix
length_bytes = b""
while True:
b = await stream.read(1)
if not b:
raise Exception("No length prefix received")
length_bytes += b
if b[0] & 0x80 == 0:
break
msg_length = decode_varint_from_bytes(length_bytes)
if msg_length > max_length:
raise Exception(
f"Message length {msg_length} exceeds maximum allowed {max_length}"
)
# Read the protobuf message
data = await stream.read(msg_length)
if len(data) != msg_length:
raise Exception(
f"Incomplete message: expected {msg_length}, got {len(data)}"
)
return data
else:
# Read raw protobuf message from the stream
# For raw format, read all available data in one go
data = await stream.read()
# If we got no data, raise an exception
if not data:
raise Exception("No data received in raw format")
if len(data) > max_length:
raise Exception(
f"Message length {len(data)} exceeds maximum allowed {max_length}"
)
return data