diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index dc65ac5f..3b640df1 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -46,9 +46,8 @@ class MplexStream(IMuxedStream): read_deadline: int | None write_deadline: int | None + # TODO: Add lock for read/write to avoid interleaving receiving messages? close_lock: trio.Lock - read_lock: trio.Lock - write_lock: trio.Lock # NOTE: `dataIn` is size of 8 in Go implementation. incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" @@ -81,8 +80,6 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() - self.read_lock = trio.Lock() - self.write_lock = trio.Lock() self.incoming_data_channel = incoming_data_channel self._buf = bytearray() @@ -116,49 +113,48 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - async with self.read_lock: - if n is not None and n < 0: - raise ValueError( - "the number of bytes to read `n` must be non-negative or " - f"`None` to indicate read until EOF, got n={n}" - ) - if self.event_reset.is_set(): - raise MplexStreamReset - if n is None: - return await self._read_until_eof() - if len(self._buf) == 0: - data: bytes - # Peek whether there is data available. If yes, we just read until - # there is no data, then return. + if n is not None and n < 0: + raise ValueError( + "the number of bytes to read `n` must be non-negative or " + f"`None` to indicate read until EOF, got n={n}" + ) + if self.event_reset.is_set(): + raise MplexStreamReset + if n is None: + return await self._read_until_eof() + if len(self._buf) == 0: + data: bytes + # Peek whether there is data available. If yes, we just read until there is + # no data, then return. + try: + data = self.incoming_data_channel.receive_nowait() + self._buf.extend(data) + except trio.EndOfChannel: + raise MplexStreamEOF + except trio.WouldBlock: + # We know `receive` will be blocked here. Wait for data here with + # `receive` and catch all kinds of errors here. try: - data = self.incoming_data_channel.receive_nowait() + data = await self.incoming_data_channel.receive() self._buf.extend(data) except trio.EndOfChannel: - raise MplexStreamEOF - except trio.WouldBlock: - # We know `receive` will be blocked here. Wait for data here with - # `receive` and catch all kinds of errors here. - try: - data = await self.incoming_data_channel.receive() - self._buf.extend(data) - except trio.EndOfChannel: - if self.event_reset.is_set(): - raise MplexStreamReset - if self.event_remote_closed.is_set(): - raise MplexStreamEOF - except trio.ClosedResourceError as error: - # Probably `incoming_data_channel` is closed in `reset` when - # we are waiting for `receive`. - if self.event_reset.is_set(): - raise MplexStreamReset - raise Exception( - "`incoming_data_channel` is closed but stream is not reset." - "This should never happen." - ) from error - self._buf.extend(self._read_return_when_blocked()) - payload = self._buf[:n] - self._buf = self._buf[len(payload) :] - return bytes(payload) + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + except trio.ClosedResourceError as error: + # Probably `incoming_data_channel` is closed in `reset` when we are + # waiting for `receive`. + if self.event_reset.is_set(): + raise MplexStreamReset + raise Exception( + "`incoming_data_channel` is closed but stream is not reset. " + "This should never happen." + ) from error + self._buf.extend(self._read_return_when_blocked()) + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) async def write(self, data: bytes) -> None: """ @@ -166,15 +162,14 @@ class MplexStream(IMuxedStream): :return: number of bytes written """ - async with self.write_lock: - if self.event_local_closed.is_set(): - raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") - flag = ( - HeaderTags.MessageInitiator - if self.is_initiator - else HeaderTags.MessageReceiver - ) - await self.muxed_conn.send_message(flag, data, self.stream_id) + if self.event_local_closed.is_set(): + raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") + flag = ( + HeaderTags.MessageInitiator + if self.is_initiator + else HeaderTags.MessageReceiver + ) + await self.muxed_conn.send_message(flag, data, self.stream_id) async def close(self) -> None: """ diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 023251ed..f90ba9a1 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -77,8 +77,6 @@ class YamuxStream(IMuxedStream): self.send_window = DEFAULT_WINDOW_SIZE self.recv_window = DEFAULT_WINDOW_SIZE self.window_lock = trio.Lock() - self.read_lock = trio.Lock() - self.write_lock = trio.Lock() async def __aenter__(self) -> "YamuxStream": """Enter the async context manager.""" diff --git a/newsfragments/639.feature.rst b/newsfragments/639.feature.rst index 3f3d7510..93476b68 100644 --- a/newsfragments/639.feature.rst +++ b/newsfragments/639.feature.rst @@ -1 +1,6 @@ -Added separate read and write locks to the `MplexStream` & `YamuxStream` class.This ensures thread-safe access and data integrity when multiple coroutines interact with the same MplexStream instance. +Fixed several flow-control and concurrency issues in the `YamuxStream` class. Previously, stress-testing revealed that transferring data over `DEFAULT_WINDOW_SIZE` would break the stream due to inconsistent window update handling and lock management. The fixes include: + +- Removed sending of window updates during writes to maintain correct flow-control. +- Added proper timeout handling when releasing and acquiring locks to prevent concurrency errors. +- Corrected the `read` function to properly handle window updates for both `read_until_EOF` and `read_n_bytes`. +- Added event logging at `send_window_updates` and `waiting_for_window_updates` for better observability. diff --git a/tests/core/stream_muxer/test_mplex_read_write_lock.py b/tests/core/stream_muxer/test_mplex_read_write_lock.py deleted file mode 100644 index d00d5b8e..00000000 --- a/tests/core/stream_muxer/test_mplex_read_write_lock.py +++ /dev/null @@ -1,127 +0,0 @@ -import pytest -import trio - -from libp2p.abc import IMuxedStream, ISecureConn -from libp2p.crypto.keys import PrivateKey, PublicKey -from libp2p.peer.id import ID -from libp2p.stream_muxer.mplex.constants import ( - HeaderTags, -) -from libp2p.stream_muxer.mplex.datastructures import ( - StreamID, -) -from libp2p.stream_muxer.mplex.mplex import ( - Mplex, -) -from libp2p.stream_muxer.mplex.mplex_stream import ( - MplexStream, -) - - -class DummySecureConn(ISecureConn): - """A minimal implementation of ISecureConn for testing.""" - - async def write(self, data: bytes) -> None: - pass - - async def read(self, n: int | None = -1) -> bytes: - return b"" - - async def close(self) -> None: - pass - - def get_remote_address(self) -> tuple[str, int] | None: - return None - - def get_local_peer(self) -> ID: - return ID(b"local") - - def get_local_private_key(self) -> PrivateKey: - return PrivateKey() # Dummy key for testing - - def get_remote_peer(self) -> ID: - return ID(b"remote") - - def get_remote_public_key(self) -> PublicKey: - return PublicKey() # Dummy key for testing - - -class DummyMuxedConn(Mplex): - """A minimal mock of Mplex for testing read/write locks.""" - - def __init__(self) -> None: - self.secured_conn = DummySecureConn() - self.peer_id = ID(b"dummy") - self.streams = {} - self.streams_lock = trio.Lock() - self.event_shutting_down = trio.Event() - self.event_closed = trio.Event() - self.event_started = trio.Event() - self.stream_backlog_limit = 256 - self.stream_backlog_semaphore = trio.Semaphore(256) - # Use IMuxedStream for type consistency with Mplex - channels = trio.open_memory_channel[IMuxedStream](0) - self.new_stream_send_channel, self.new_stream_receive_channel = channels - - async def send_message( - self, flag: HeaderTags, data: bytes | None, stream_id: StreamID - ) -> int: - await trio.sleep(0.01) - return 0 - - -@pytest.mark.trio -async def test_concurrent_writes_are_serialized(): - stream_id = StreamID(1, True) - send_log = [] - - class LoggingMuxedConn(DummyMuxedConn): - async def send_message( - self, flag: HeaderTags, data: bytes | None, stream_id: StreamID - ) -> int: - send_log.append(data) - await trio.sleep(0.01) - return 0 - - memory_send, memory_recv = trio.open_memory_channel(8) - stream = MplexStream( - name="test", - stream_id=stream_id, - muxed_conn=LoggingMuxedConn(), - incoming_data_channel=memory_recv, - ) - - async def writer(data): - await stream.write(data) - - async with trio.open_nursery() as nursery: - for i in range(5): - nursery.start_soon(writer, f"msg-{i}".encode()) - # Order doesn't matter due to concurrent execution - assert sorted(send_log) == sorted([f"msg-{i}".encode() for i in range(5)]) - - -@pytest.mark.trio -async def test_concurrent_reads_are_serialized(): - stream_id = StreamID(2, True) - muxed_conn = DummyMuxedConn() - memory_send, memory_recv = trio.open_memory_channel(8) - results = [] - stream = MplexStream( - name="test", - stream_id=stream_id, - muxed_conn=muxed_conn, - incoming_data_channel=memory_recv, - ) - for i in range(5): - await memory_send.send(f"data-{i}".encode()) - await memory_send.aclose() - - async def reader(): - data = await stream.read(6) - results.append(data) - - async with trio.open_nursery() as nursery: - for _ in range(5): - nursery.start_soon(reader) - assert sorted(results) == [f"data-{i}".encode() for i in range(5)] diff --git a/tests/core/stream_muxer/test_yamux.py b/tests/core/stream_muxer/test_yamux.py index bfd8eb5a..81d05676 100644 --- a/tests/core/stream_muxer/test_yamux.py +++ b/tests/core/stream_muxer/test_yamux.py @@ -224,16 +224,14 @@ async def test_yamux_stream_reset(yamux_pair): await client_stream.reset() # After reset, reading should raise MuxedStreamReset or MuxedStreamEOF try: - while True: - await server_stream.read() + await server_stream.read() except (MuxedStreamEOF, MuxedStreamError): pass else: pytest.fail("Expected MuxedStreamEOF or MuxedStreamError") # Verify subsequent operations fail with StreamReset or EOF with pytest.raises(MuxedStreamError): - while True: - await server_stream.read() + await server_stream.read() with pytest.raises(MuxedStreamError): await server_stream.write(b"test") logging.debug("test_yamux_stream_reset complete") diff --git a/tests/core/stream_muxer/test_yamux_read_write_lock.py b/tests/core/stream_muxer/test_yamux_read_write.py similarity index 100% rename from tests/core/stream_muxer/test_yamux_read_write_lock.py rename to tests/core/stream_muxer/test_yamux_read_write.py