From 209deffc8aaf438caea7f18fd0f4615d90044cbf Mon Sep 17 00:00:00 2001 From: kaneki003 Date: Sat, 21 Jun 2025 13:39:03 +0530 Subject: [PATCH] resolved recv_window updates,added support for read_EOF --- libp2p/stream_muxer/yamux/yamux.py | 105 +++++++----------- .../test_mplex_read_write_lock.py | 15 ++- 2 files changed, 52 insertions(+), 68 deletions(-) diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index faf24b29..f58e98c4 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -141,9 +141,7 @@ class YamuxStream(IMuxedStream): await self.conn.secured_conn.write(header + chunk) sent += to_send - async def send_window_update( - self, increment: int | None, skip_lock: bool = False - ) -> None: + async def send_window_update(self, increment: int, skip_lock: bool = False) -> None: """ Send a window update to peer. @@ -154,12 +152,7 @@ class YamuxStream(IMuxedStream): This should only be used when calling from a context that already holds the lock. """ - increment_value = 0 - if increment is None: - increment_value = DEFAULT_WINDOW_SIZE - self.recv_window - else: - increment_value = increment - if increment_value <= 0: + if increment <= 0: # If increment is zero or negative, skip sending update logging.debug( f"Stream {self.stream_id}: Skipping window update" @@ -171,14 +164,13 @@ class YamuxStream(IMuxedStream): ) async def _do_window_update() -> None: - self.recv_window += increment_value header = struct.pack( YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, - increment_value, + increment, ) await self.conn.secured_conn.write(header) @@ -188,6 +180,22 @@ class YamuxStream(IMuxedStream): async with self.window_lock: await _do_window_update() + async def read_EOF(self) -> bytes: + """ + To read data from stream until it is closed. + """ + data = b"" + try: + while True: + recv = await self.read() + if recv: + data += recv + except MuxedStreamEOF: + logging.debug( + f"Stream {self.stream_id}:EOF reached,total data read:{len(data)} bytes" + ) + return data + async def read(self, n: int | None = -1) -> bytes: # Handle None value for n by converting it to -1 if n is None: @@ -202,61 +210,34 @@ class YamuxStream(IMuxedStream): # If reading until EOF (n == -1), block until stream is closed if n == -1: - while not self.recv_closed and not self.conn.event_shutting_down.is_set(): - # Check if there's data in the buffer - buffer = self.conn.stream_buffers.get(self.stream_id) - if buffer and len(buffer) > 0: - # Wait for closure even if data is available - logging.debug( - f"Stream {self.stream_id}:Waiting for FIN before returning data" - ) - await self.conn.stream_events[self.stream_id].wait() - self.conn.stream_events[self.stream_id] = trio.Event() - else: - # No data, wait for data or closure - logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN") - await self.conn.stream_events[self.stream_id].wait() - self.conn.stream_events[self.stream_id] = trio.Event() - - # After loop, check if stream is closed or shutting down - async with self.conn.streams_lock: - if self.conn.event_shutting_down.is_set(): - logging.debug(f"Stream {self.stream_id}: Connection shutting down") - raise MuxedStreamEOF("Connection shut down") - if self.closed: - if self.reset_received: - logging.debug(f"Stream {self.stream_id}: Stream was reset") - raise MuxedStreamReset("Stream was reset") - else: - logging.debug( - f"Stream {self.stream_id}: Stream closed cleanly (EOF)" - ) - raise MuxedStreamEOF("Stream closed cleanly (EOF)") - buffer = self.conn.stream_buffers.get(self.stream_id) - if buffer is None: - logging.debug( - f"Stream {self.stream_id}: Buffer gone, assuming closed" - ) - raise MuxedStreamEOF("Stream buffer closed") - if self.recv_closed and len(buffer) == 0: - logging.debug(f"Stream {self.stream_id}: EOF reached") - raise MuxedStreamEOF("Stream is closed for receiving") - # Return all buffered data + # Check if there's data in the buffer + buffer = self.conn.stream_buffers.get(self.stream_id) + size = len(buffer) if buffer else 0 + if size > 0: + # If any data is available,return it immediately + assert buffer is not None data = bytes(buffer) buffer.clear() - return data - - data = await self.conn.read_stream(self.stream_id, n) - async with self.window_lock: - self.recv_window -= len(data) - # Automatically send a window update if recv_window is low - if self.recv_window <= DEFAULT_WINDOW_SIZE // 2: + async with self.window_lock: + self.recv_window += len(data) + await self.send_window_update(len(data), skip_lock=True) + return data + # Otherwise,wait for data or FIN + if self.recv_closed: + raise MuxedStreamEOF("Stream is closed for receiving") + await self.conn.stream_events[self.stream_id].wait() + self.conn.stream_events[self.stream_id] = trio.Event() + return b"" + else: + data = await self.conn.read_stream(self.stream_id, n) + async with self.window_lock: + self.recv_window += len(data) logging.debug( - f"Stream {self.stream_id}: " - f"Low recv_window ({self.recv_window}), sending update" + f"Stream {self.stream_id}: Sending window update after read, " + f"increment={len(data)}" ) - await self.send_window_update(None, skip_lock=True) - return data + await self.send_window_update(len(data), skip_lock=True) + return data async def close(self) -> None: if not self.send_closed: diff --git a/tests/core/stream_muxer/test_mplex_read_write_lock.py b/tests/core/stream_muxer/test_mplex_read_write_lock.py index afc197ac..d00d5b8e 100644 --- a/tests/core/stream_muxer/test_mplex_read_write_lock.py +++ b/tests/core/stream_muxer/test_mplex_read_write_lock.py @@ -1,7 +1,7 @@ import pytest import trio -from libp2p.abc import ISecureConn +from libp2p.abc import IMuxedStream, ISecureConn from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.peer.id import ID from libp2p.stream_muxer.mplex.constants import ( @@ -59,13 +59,15 @@ class DummyMuxedConn(Mplex): self.event_started = trio.Event() self.stream_backlog_limit = 256 self.stream_backlog_semaphore = trio.Semaphore(256) - channels = trio.open_memory_channel[MplexStream](0) + # Use IMuxedStream for type consistency with Mplex + channels = trio.open_memory_channel[IMuxedStream](0) self.new_stream_send_channel, self.new_stream_receive_channel = channels async def send_message( - self, flag: HeaderTags, data: bytes, stream_id: StreamID - ) -> None: + self, flag: HeaderTags, data: bytes | None, stream_id: StreamID + ) -> int: await trio.sleep(0.01) + return 0 @pytest.mark.trio @@ -75,10 +77,11 @@ async def test_concurrent_writes_are_serialized(): class LoggingMuxedConn(DummyMuxedConn): async def send_message( - self, flag: HeaderTags, data: bytes, stream_id: StreamID - ) -> None: + self, flag: HeaderTags, data: bytes | None, stream_id: StreamID + ) -> int: send_log.append(data) await trio.sleep(0.01) + return 0 memory_send, memory_recv = trio.open_memory_channel(8) stream = MplexStream(