Merge branch 'main' into todo/handletimeout

This commit is contained in:
Manu Sheel Gupta
2025-07-21 08:17:08 -07:00
committed by GitHub
24 changed files with 2307 additions and 766 deletions

View File

@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from types import (
TracebackType,
)
@ -32,6 +34,72 @@ if TYPE_CHECKING:
)
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
@ -46,7 +114,7 @@ class MplexStream(IMuxedStream):
read_deadline: int | None
write_deadline: int | None
# TODO: Add lock for read/write to avoid interleaving receiving messages?
rw_lock: ReadWriteLock
close_lock: trio.Lock
# NOTE: `dataIn` is size of 8 in Go implementation.
@ -80,6 +148,7 @@ class MplexStream(IMuxedStream):
self.event_remote_closed = trio.Event()
self.event_reset = trio.Event()
self.close_lock = trio.Lock()
self.rw_lock = ReadWriteLock()
self.incoming_data_channel = incoming_data_channel
self._buf = bytearray()
@ -113,48 +182,49 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read
:return: bytes actually read
"""
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until there is
# no data, then return.
try:
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read `n` must be non-negative or "
f"`None` to indicate read until EOF, got n={n}"
)
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
if len(self._buf) == 0:
data: bytes
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
try:
data = await self.incoming_data_channel.receive()
data = self.incoming_data_channel.receive_nowait()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when we are
# waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset. "
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
raise MplexStreamEOF
except trio.WouldBlock:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
self._buf.extend(data)
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
raise MplexStreamEOF
except trio.ClosedResourceError as error:
# Probably `incoming_data_channel` is closed in `reset` when
# we are waiting for `receive`.
if self.event_reset.is_set():
raise MplexStreamReset
raise Exception(
"`incoming_data_channel` is closed but stream is not reset."
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> None:
"""
@ -162,14 +232,15 @@ class MplexStream(IMuxedStream):
:return: number of bytes written
"""
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
async def close(self) -> None:
"""

View File

@ -45,6 +45,9 @@ from libp2p.stream_muxer.exceptions import (
MuxedStreamReset,
)
# Configure logger for this module
logger = logging.getLogger("libp2p.stream_muxer.yamux")
PROTOCOL_ID = "/yamux/1.0.0"
TYPE_DATA = 0x0
TYPE_WINDOW_UPDATE = 0x1
@ -98,13 +101,13 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
if self.send_window == 0:
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
@ -152,12 +155,12 @@ class YamuxStream(IMuxedStream):
"""
if increment <= 0:
# If increment is zero or negative, skip sending update
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
@ -185,7 +188,7 @@ class YamuxStream(IMuxedStream):
# If the stream is closed for receiving and the buffer is empty, raise EOF
if self.recv_closed and not self.conn.stream_buffers.get(self.stream_id):
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Stream closed for receiving and buffer empty"
)
raise MuxedStreamEOF("Stream is closed for receiving")
@ -198,7 +201,7 @@ class YamuxStream(IMuxedStream):
# If buffer is not available, check if stream is closed
if buffer is None:
logging.debug(f"Stream {self.stream_id}: No buffer available")
logger.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it
@ -210,34 +213,34 @@ class YamuxStream(IMuxedStream):
# Send window update for the chunk we just read
async with self.window_lock:
self.recv_window += len(chunk)
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
logger.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
logger.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break
# If stream was reset, raise reset error
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
logger.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
logger.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return
if data:
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
)
return data
# No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
logger.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
# Return empty data
@ -246,7 +249,7 @@ class YamuxStream(IMuxedStream):
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logging.debug(
logger.debug(
f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}"
)
@ -255,7 +258,7 @@ class YamuxStream(IMuxedStream):
async def close(self) -> None:
if not self.send_closed:
logging.debug(f"Half-closing stream {self.stream_id} (local end)")
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
)
@ -271,7 +274,7 @@ class YamuxStream(IMuxedStream):
async def reset(self) -> None:
if not self.closed:
logging.debug(f"Resetting stream {self.stream_id}")
logger.debug(f"Resetting stream {self.stream_id}")
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
)
@ -349,7 +352,7 @@ class Yamux(IMuxedConn):
self._nursery: Nursery | None = None
async def start(self) -> None:
logging.debug(f"Starting Yamux for {self.peer_id}")
logger.debug(f"Starting Yamux for {self.peer_id}")
if self.event_started.is_set():
return
async with trio.open_nursery() as nursery:
@ -362,7 +365,7 @@ class Yamux(IMuxedConn):
return self.is_initiator_value
async def close(self, error_code: int = GO_AWAY_NORMAL) -> None:
logging.debug(f"Closing Yamux connection with code {error_code}")
logger.debug(f"Closing Yamux connection with code {error_code}")
async with self.streams_lock:
if not self.event_shutting_down.is_set():
try:
@ -371,7 +374,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(header)
except Exception as e:
logging.debug(f"Failed to send GO_AWAY: {e}")
logger.debug(f"Failed to send GO_AWAY: {e}")
self.event_shutting_down.set()
for stream in self.streams.values():
stream.closed = True
@ -382,12 +385,12 @@ class Yamux(IMuxedConn):
self.stream_events.clear()
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as e:
logging.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
logger.debug(f"Error closing secured_conn for peer {self.peer_id}: {e}")
self.event_closed.set()
if self.on_close:
logging.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
logger.debug(f"Calling on_close in Yamux.close for peer {self.peer_id}")
if inspect.iscoroutinefunction(self.on_close):
if self.on_close is not None:
await self.on_close()
@ -416,7 +419,7 @@ class Yamux(IMuxedConn):
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_SYN, stream_id, 0
)
logging.debug(f"Sending SYN header for stream {stream_id}")
logger.debug(f"Sending SYN header for stream {stream_id}")
await self.secured_conn.write(header)
return stream
except Exception as e:
@ -424,32 +427,32 @@ class Yamux(IMuxedConn):
raise e
async def accept_stream(self) -> IMuxedStream:
logging.debug("Waiting for new stream")
logger.debug("Waiting for new stream")
try:
stream = await self.new_stream_receive_channel.receive()
logging.debug(f"Received stream {stream.stream_id}")
logger.debug(f"Received stream {stream.stream_id}")
return stream
except trio.EndOfChannel:
raise MuxedStreamError("No new streams available")
async def read_stream(self, stream_id: int, n: int = -1) -> bytes:
logging.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
logger.debug(f"Reading from stream {self.peer_id}:{stream_id}, n={n}")
if n is None:
n = -1
while True:
async with self.streams_lock:
if stream_id not in self.streams:
logging.debug(f"Stream {self.peer_id}:{stream_id} unknown")
logger.debug(f"Stream {self.peer_id}:{stream_id} unknown")
raise MuxedStreamEOF("Stream closed")
if self.event_shutting_down.is_set():
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}: connection shutting down"
)
raise MuxedStreamEOF("Connection shut down")
stream = self.streams[stream_id]
buffer = self.stream_buffers.get(stream_id)
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}: "
f"closed={stream.closed}, "
f"recv_closed={stream.recv_closed}, "
@ -457,7 +460,7 @@ class Yamux(IMuxedConn):
f"buffer_len={len(buffer) if buffer else 0}"
)
if buffer is None:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"Buffer gone, assuming closed"
)
@ -470,7 +473,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
logger.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -478,7 +481,7 @@ class Yamux(IMuxedConn):
return data
# If reset received and buffer is empty, raise reset
if stream.reset_received:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"reset_received=True, raising MuxedStreamReset"
)
@ -491,7 +494,7 @@ class Yamux(IMuxedConn):
else:
data = bytes(buffer[:n])
del buffer[:n]
logging.debug(
logger.debug(
f"Returning {len(data)} bytes"
f"from stream {self.peer_id}:{stream_id}, "
f"buffer_len={len(buffer)}"
@ -499,21 +502,21 @@ class Yamux(IMuxedConn):
return data
# Check if stream is closed
if stream.closed:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"closed=True, raising MuxedStreamReset"
)
raise MuxedStreamReset("Stream is reset or closed")
# Check if recv_closed and buffer empty
if stream.recv_closed:
logging.debug(
logger.debug(
f"Stream {self.peer_id}:{stream_id}:"
f"recv_closed=True, buffer empty, raising EOF"
)
raise MuxedStreamEOF("Stream is closed for receiving")
# Wait for data if stream is still open
logging.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
logger.debug(f"Waiting for data on stream {self.peer_id}:{stream_id}")
try:
await self.stream_events[stream_id].wait()
self.stream_events[stream_id] = trio.Event()
@ -528,7 +531,7 @@ class Yamux(IMuxedConn):
try:
header = await self.secured_conn.read(HEADER_SIZE)
if not header or len(header) < HEADER_SIZE:
logging.debug(
logger.debug(
f"Connection closed orincomplete header for peer {self.peer_id}"
)
self.event_shutting_down.set()
@ -537,7 +540,7 @@ class Yamux(IMuxedConn):
version, typ, flags, stream_id, length = struct.unpack(
YAMUX_HEADER_FORMAT, header
)
logging.debug(
logger.debug(
f"Received header for peer {self.peer_id}:"
f"type={typ}, flags={flags}, stream_id={stream_id},"
f"length={length}"
@ -558,7 +561,7 @@ class Yamux(IMuxedConn):
0,
)
await self.secured_conn.write(ack_header)
logging.debug(
logger.debug(
f"Sending stream {stream_id}"
f"to channel for peer {self.peer_id}"
)
@ -576,7 +579,7 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_RST:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
logger.debug(
f"Resetting stream {stream_id} for peer {self.peer_id}"
)
self.streams[stream_id].closed = True
@ -585,27 +588,27 @@ class Yamux(IMuxedConn):
elif typ == TYPE_DATA and flags & FLAG_ACK:
async with self.streams_lock:
if stream_id in self.streams:
logging.debug(
logger.debug(
f"Received ACK for stream"
f"{stream_id} for peer {self.peer_id}"
)
elif typ == TYPE_GO_AWAY:
error_code = length
if error_code == GO_AWAY_NORMAL:
logging.debug(
logger.debug(
f"Received GO_AWAY for peer"
f"{self.peer_id}: Normal termination"
)
elif error_code == GO_AWAY_PROTOCOL_ERROR:
logging.error(
logger.error(
f"Received GO_AWAY for peer{self.peer_id}: Protocol error"
)
elif error_code == GO_AWAY_INTERNAL_ERROR:
logging.error(
logger.error(
f"Received GO_AWAY for peer {self.peer_id}: Internal error"
)
else:
logging.error(
logger.error(
f"Received GO_AWAY for peer {self.peer_id}"
f"with unknown error code: {error_code}"
)
@ -614,7 +617,7 @@ class Yamux(IMuxedConn):
break
elif typ == TYPE_PING:
if flags & FLAG_SYN:
logging.debug(
logger.debug(
f"Received ping request with value"
f"{length} for peer {self.peer_id}"
)
@ -623,7 +626,7 @@ class Yamux(IMuxedConn):
)
await self.secured_conn.write(ping_header)
elif flags & FLAG_ACK:
logging.debug(
logger.debug(
f"Received ping response with value"
f"{length} for peer {self.peer_id}"
)
@ -637,7 +640,7 @@ class Yamux(IMuxedConn):
self.stream_buffers[stream_id].extend(data)
self.stream_events[stream_id].set()
if flags & FLAG_FIN:
logging.debug(
logger.debug(
f"Received FIN for stream {self.peer_id}:"
f"{stream_id}, marking recv_closed"
)
@ -645,7 +648,7 @@ class Yamux(IMuxedConn):
if self.streams[stream_id].send_closed:
self.streams[stream_id].closed = True
except Exception as e:
logging.error(f"Error reading data for stream {stream_id}: {e}")
logger.error(f"Error reading data for stream {stream_id}: {e}")
# Mark stream as closed on read error
async with self.streams_lock:
if stream_id in self.streams:
@ -659,7 +662,7 @@ class Yamux(IMuxedConn):
if stream_id in self.streams:
stream = self.streams[stream_id]
async with stream.window_lock:
logging.debug(
logger.debug(
f"Received window update for stream"
f"{self.peer_id}:{stream_id},"
f" increment: {increment}"
@ -674,7 +677,7 @@ class Yamux(IMuxedConn):
and details.get("requested_count") == 2
and details.get("received_count") == 0
):
logging.info(
logger.info(
f"Stream closed cleanly for peer {self.peer_id}"
+ f" (IncompleteReadError: {details})"
)
@ -682,15 +685,32 @@ class Yamux(IMuxedConn):
await self._cleanup_on_error()
break
else:
logging.error(
logger.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
else:
logging.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Handle RawConnError with more nuance
if isinstance(e, RawConnError):
error_msg = str(e)
# If RawConnError is empty, it's likely normal cleanup
if not error_msg.strip():
logger.info(
f"RawConnError (empty) during cleanup for peer "
f"{self.peer_id} (normal connection shutdown)"
)
else:
# Log non-empty RawConnError as warning
logger.warning(
f"RawConnError during connection handling for peer "
f"{self.peer_id}: {error_msg}"
)
else:
# Log all other errors normally
logger.error(
f"Error in handle_incoming for peer {self.peer_id}: "
+ f"{type(e).__name__}: {str(e)}"
)
# Don't crash the whole connection for temporary errors
if self.event_shutting_down.is_set() or isinstance(
e, (RawConnError, OSError)
@ -720,9 +740,9 @@ class Yamux(IMuxedConn):
# Close the secured connection
try:
await self.secured_conn.close()
logging.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
logger.debug(f"Successfully closed secured_conn for peer {self.peer_id}")
except Exception as close_error:
logging.error(
logger.error(
f"Error closing secured_conn for peer {self.peer_id}: {close_error}"
)
@ -731,14 +751,14 @@ class Yamux(IMuxedConn):
# Call on_close callback if provided
if self.on_close:
logging.debug(f"Calling on_close for peer {self.peer_id}")
logger.debug(f"Calling on_close for peer {self.peer_id}")
try:
if inspect.iscoroutinefunction(self.on_close):
await self.on_close()
else:
self.on_close()
except Exception as callback_error:
logging.error(f"Error in on_close callback: {callback_error}")
logger.error(f"Error in on_close callback: {callback_error}")
# Cancel nursery tasks
if self._nursery: