diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index 586bbc2d..eba0156e 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -98,16 +98,32 @@ class YamuxStream(IMuxedStream): # Flow control: Check if we have enough send window total_len = len(data) sent = 0 - + logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ") while sent < total_len: + # Wait for available window with timeout + timeout = False async with self.window_lock: - # Wait for available window - while self.send_window == 0 and not self.closed: - # Release lock while waiting + if self.send_window == 0: + logging.debug( + f"Stream {self.stream_id}: Window is zero, waiting for update" + ) + # Release lock and wait with timeout self.window_lock.release() - await trio.sleep(0.01) + # To avoid re-acquiring the lock immediately, + with trio.move_on_after(5.0) as cancel_scope: + while self.send_window == 0 and not self.closed: + await trio.sleep(0.01) + # If we timed out, cancel the scope + timeout = cancel_scope.cancelled_caught + # Re-acquire lock await self.window_lock.acquire() + # If we timed out waiting for window update, raise an error + if timeout: + raise MuxedStreamError( + "Timed out waiting for window update after 5 seconds." + ) + if self.closed: raise MuxedStreamError("Stream is closed") @@ -123,25 +139,45 @@ class YamuxStream(IMuxedStream): await self.conn.secured_conn.write(header + chunk) sent += to_send - # If window is getting low, consider updating - if self.send_window < DEFAULT_WINDOW_SIZE // 2: - await self.send_window_update() - - async def send_window_update(self, increment: int | None = None) -> None: - """Send a window update to peer.""" - if increment is None: - increment = DEFAULT_WINDOW_SIZE - self.recv_window + async def send_window_update(self, increment: int, skip_lock: bool = False) -> None: + """ + Send a window update to peer. + param:increment: The amount to increment the window size by. + If None, uses the difference between DEFAULT_WINDOW_SIZE + and current receive window. + param:skip_lock (bool): If True, skips acquiring window_lock. + This should only be used when calling from a context + that already holds the lock. + """ if increment <= 0: + # If increment is zero or negative, skip sending update + logging.debug( + f"Stream {self.stream_id}: Skipping window update" + f"(increment={increment})" + ) return + logging.debug( + f"Stream {self.stream_id}: Sending window update with increment={increment}" + ) - async with self.window_lock: - self.recv_window += increment + async def _do_window_update() -> None: header = struct.pack( - YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment + YAMUX_HEADER_FORMAT, + 0, + TYPE_WINDOW_UPDATE, + 0, + self.stream_id, + increment, ) await self.conn.secured_conn.write(header) + if skip_lock: + await _do_window_update() + else: + async with self.window_lock: + await _do_window_update() + async def read(self, n: int | None = -1) -> bytes: # Handle None value for n by converting it to -1 if n is None: @@ -154,55 +190,68 @@ class YamuxStream(IMuxedStream): ) raise MuxedStreamEOF("Stream is closed for receiving") - # If reading until EOF (n == -1), block until stream is closed if n == -1: - while not self.recv_closed and not self.conn.event_shutting_down.is_set(): + data = b"" + while not self.conn.event_shutting_down.is_set(): # Check if there's data in the buffer buffer = self.conn.stream_buffers.get(self.stream_id) - if buffer and len(buffer) > 0: - # Wait for closure even if data is available - logging.debug( - f"Stream {self.stream_id}:Waiting for FIN before returning data" - ) - await self.conn.stream_events[self.stream_id].wait() - self.conn.stream_events[self.stream_id] = trio.Event() - else: - # No data, wait for data or closure - logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN") - await self.conn.stream_events[self.stream_id].wait() - self.conn.stream_events[self.stream_id] = trio.Event() - # After loop, check if stream is closed or shutting down - async with self.conn.streams_lock: - if self.conn.event_shutting_down.is_set(): - logging.debug(f"Stream {self.stream_id}: Connection shutting down") - raise MuxedStreamEOF("Connection shut down") - if self.closed: - if self.reset_received: - logging.debug(f"Stream {self.stream_id}: Stream was reset") - raise MuxedStreamReset("Stream was reset") - else: - logging.debug( - f"Stream {self.stream_id}: Stream closed cleanly (EOF)" - ) - raise MuxedStreamEOF("Stream closed cleanly (EOF)") - buffer = self.conn.stream_buffers.get(self.stream_id) + # If buffer is not available, check if stream is closed if buffer is None: - logging.debug( - f"Stream {self.stream_id}: Buffer gone, assuming closed" - ) + logging.debug(f"Stream {self.stream_id}: No buffer available") raise MuxedStreamEOF("Stream buffer closed") + + # If we have data in buffer, process it + if len(buffer) > 0: + chunk = bytes(buffer) + buffer.clear() + data += chunk + + # Send window update for the chunk we just read + async with self.window_lock: + self.recv_window += len(chunk) + logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}") + await self.send_window_update(len(chunk), skip_lock=True) + + # If stream is closed (FIN received) and buffer is empty, break if self.recv_closed and len(buffer) == 0: - logging.debug(f"Stream {self.stream_id}: EOF reached") - raise MuxedStreamEOF("Stream is closed for receiving") - # Return all buffered data - data = bytes(buffer) - buffer.clear() - logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes") + logging.debug(f"Stream {self.stream_id}: Closed with empty buffer") + break + + # If stream was reset, raise reset error + if self.reset_received: + logging.debug(f"Stream {self.stream_id}: Stream was reset") + raise MuxedStreamReset("Stream was reset") + + # Wait for more data or stream closure + logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN") + await self.conn.stream_events[self.stream_id].wait() + self.conn.stream_events[self.stream_id] = trio.Event() + + # After loop exit, first check if we have data to return + if data: + logging.debug( + f"Stream {self.stream_id}: Returning {len(data)} bytes after loop" + ) return data - # For specific size read (n > 0), return available data immediately - return await self.conn.read_stream(self.stream_id, n) + # No data accumulated, now check why we exited the loop + if self.conn.event_shutting_down.is_set(): + logging.debug(f"Stream {self.stream_id}: Connection shutting down") + raise MuxedStreamEOF("Connection shut down") + + # Return empty data + return b"" + else: + data = await self.conn.read_stream(self.stream_id, n) + async with self.window_lock: + self.recv_window += len(data) + logging.debug( + f"Stream {self.stream_id}: Sending window update after read, " + f"increment={len(data)}" + ) + await self.send_window_update(len(data), skip_lock=True) + return data async def close(self) -> None: if not self.send_closed: diff --git a/newsfragments/639.feature.rst b/newsfragments/639.feature.rst new file mode 100644 index 00000000..93476b68 --- /dev/null +++ b/newsfragments/639.feature.rst @@ -0,0 +1,6 @@ +Fixed several flow-control and concurrency issues in the `YamuxStream` class. Previously, stress-testing revealed that transferring data over `DEFAULT_WINDOW_SIZE` would break the stream due to inconsistent window update handling and lock management. The fixes include: + +- Removed sending of window updates during writes to maintain correct flow-control. +- Added proper timeout handling when releasing and acquiring locks to prevent concurrency errors. +- Corrected the `read` function to properly handle window updates for both `read_until_EOF` and `read_n_bytes`. +- Added event logging at `send_window_updates` and `waiting_for_window_updates` for better observability. diff --git a/tests/core/stream_muxer/test_yamux_interleaving.py b/tests/core/stream_muxer/test_yamux_interleaving.py new file mode 100644 index 00000000..1ce62952 --- /dev/null +++ b/tests/core/stream_muxer/test_yamux_interleaving.py @@ -0,0 +1,199 @@ +import logging + +import pytest +import trio +from trio.testing import ( + memory_stream_pair, +) + +from libp2p.abc import IRawConnection +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.security.insecure.transport import ( + InsecureTransport, +) +from libp2p.stream_muxer.yamux.yamux import ( + Yamux, + YamuxStream, +) + + +class TrioStreamAdapter(IRawConnection): + """Adapter to make trio memory streams work with libp2p.""" + + def __init__(self, send_stream, receive_stream, is_initiator=False): + self.send_stream = send_stream + self.receive_stream = receive_stream + self.is_initiator = is_initiator + + async def write(self, data: bytes) -> None: + logging.debug(f"Attempting to write {len(data)} bytes") + with trio.move_on_after(2): + await self.send_stream.send_all(data) + + async def read(self, n: int | None = None) -> bytes: + if n is None or n <= 0: + raise ValueError("Reading unbounded or zero bytes not supported") + logging.debug(f"Attempting to read {n} bytes") + with trio.move_on_after(2): + data = await self.receive_stream.receive_some(n) + logging.debug(f"Read {len(data)} bytes") + return data + + async def close(self) -> None: + logging.debug("Closing stream") + await self.send_stream.aclose() + await self.receive_stream.aclose() + + def get_remote_address(self) -> tuple[str, int] | None: + """Return None since this is a test adapter without real network info.""" + return None + + +@pytest.fixture +def key_pair(): + return create_new_key_pair() + + +@pytest.fixture +def peer_id(key_pair): + return ID.from_pubkey(key_pair.public_key) + + +@pytest.fixture +async def secure_conn_pair(key_pair, peer_id): + """Create a pair of secure connections for testing.""" + logging.debug("Setting up secure_conn_pair") + client_send, server_receive = memory_stream_pair() + server_send, client_receive = memory_stream_pair() + + client_rw = TrioStreamAdapter(client_send, client_receive) + server_rw = TrioStreamAdapter(server_send, server_receive) + + insecure_transport = InsecureTransport(key_pair) + + async def run_outbound(nursery_results): + with trio.move_on_after(5): + client_conn = await insecure_transport.secure_outbound(client_rw, peer_id) + logging.debug("Outbound handshake complete") + nursery_results["client"] = client_conn + + async def run_inbound(nursery_results): + with trio.move_on_after(5): + server_conn = await insecure_transport.secure_inbound(server_rw) + logging.debug("Inbound handshake complete") + nursery_results["server"] = server_conn + + nursery_results = {} + async with trio.open_nursery() as nursery: + nursery.start_soon(run_outbound, nursery_results) + nursery.start_soon(run_inbound, nursery_results) + await trio.sleep(0.1) # Give tasks a chance to finish + + client_conn = nursery_results.get("client") + server_conn = nursery_results.get("server") + + if client_conn is None or server_conn is None: + raise RuntimeError("Handshake failed: client_conn or server_conn is None") + + logging.debug("secure_conn_pair setup complete") + return client_conn, server_conn + + +@pytest.fixture +async def yamux_pair(secure_conn_pair, peer_id): + """Create a pair of Yamux multiplexers for testing.""" + logging.debug("Setting up yamux_pair") + client_conn, server_conn = secure_conn_pair + client_yamux = Yamux(client_conn, peer_id, is_initiator=True) + server_yamux = Yamux(server_conn, peer_id, is_initiator=False) + async with trio.open_nursery() as nursery: + with trio.move_on_after(5): + nursery.start_soon(client_yamux.start) + nursery.start_soon(server_yamux.start) + await trio.sleep(0.1) + logging.debug("yamux_pair started") + yield client_yamux, server_yamux + logging.debug("yamux_pair cleanup") + + +@pytest.mark.trio +async def test_yamux_race_condition_without_locks(yamux_pair): + """ + Test for race-around/interleaving in Yamux streams,when reading in + segments of data. + This launches concurrent writers/readers on both sides of a stream. + If there is no proper locking, the received data may be interleaved + or corrupted. + + The test creates structured messages and verifies they are received + intact and in order. + Without proper locking, concurrent read/write operations could cause + data corruption + or message interleaving, which this test will catch. + """ + client_yamux, server_yamux = yamux_pair + client_stream: YamuxStream = await client_yamux.open_stream() + server_stream: YamuxStream = await server_yamux.accept_stream() + MSG_COUNT = 10 + MSG_SIZE = 256 * 1024 # At max,only DEFAULT_WINDOW_SIZE bytes can be read + client_msgs = [ + f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT) + ] + server_msgs = [ + f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT) + ] + client_received = [] + server_received = [] + + async def writer(stream, msgs, name): + """Write messages with minimal delays to encourage race conditions.""" + for i, msg in enumerate(msgs): + await stream.write(msg) + # Yield control frequently to encourage interleaving + if i % 5 == 0: + await trio.sleep(0.005) + + async def reader(stream, received, name): + """Read messages and store them for verification.""" + for i in range(MSG_COUNT): + data = await stream.read(MSG_SIZE) + received.append(data) + if i % 3 == 0: + await trio.sleep(0.001) + + # Running all operations concurrently + async with trio.open_nursery() as nursery: + nursery.start_soon(writer, client_stream, client_msgs, "client") + nursery.start_soon(writer, server_stream, server_msgs, "server") + nursery.start_soon(reader, client_stream, client_received, "client") + nursery.start_soon(reader, server_stream, server_received, "server") + + assert len(client_received) == MSG_COUNT, ( + f"Client received {len(client_received)} messages, expected {MSG_COUNT}" + ) + assert len(server_received) == MSG_COUNT, ( + f"Server received {len(server_received)} messages, expected {MSG_COUNT}" + ) + assert client_received == server_msgs, ( + "Client did not receive server messages in order or intact!" + ) + assert server_received == client_msgs, ( + "Server did not receive client messages in order or intact!" + ) + for i, msg in enumerate(client_received): + assert len(msg) == MSG_SIZE, ( + f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}" + ) + + for i, msg in enumerate(server_received): + assert len(msg) == MSG_SIZE, ( + f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}" + ) + + await client_stream.close() + await server_stream.close() diff --git a/tests/core/stream_muxer/test_yamux_interleaving_EOF.py b/tests/core/stream_muxer/test_yamux_interleaving_EOF.py new file mode 100644 index 00000000..23d2c2b4 --- /dev/null +++ b/tests/core/stream_muxer/test_yamux_interleaving_EOF.py @@ -0,0 +1,195 @@ +import logging + +import pytest +import trio +from trio.testing import ( + memory_stream_pair, +) + +from libp2p.abc import IRawConnection +from libp2p.crypto.ed25519 import ( + create_new_key_pair, +) +from libp2p.peer.id import ( + ID, +) +from libp2p.security.insecure.transport import ( + InsecureTransport, +) +from libp2p.stream_muxer.exceptions import MuxedStreamEOF +from libp2p.stream_muxer.yamux.yamux import ( + Yamux, + YamuxStream, +) + + +class TrioStreamAdapter(IRawConnection): + """Adapter to make trio memory streams work with libp2p.""" + + def __init__(self, send_stream, receive_stream, is_initiator=False): + self.send_stream = send_stream + self.receive_stream = receive_stream + self.is_initiator = is_initiator + + async def write(self, data: bytes) -> None: + logging.debug(f"Attempting to write {len(data)} bytes") + with trio.move_on_after(2): + await self.send_stream.send_all(data) + + async def read(self, n: int | None = None) -> bytes: + if n is None or n <= 0: + raise ValueError("Reading unbounded or zero bytes not supported") + logging.debug(f"Attempting to read {n} bytes") + with trio.move_on_after(2): + data = await self.receive_stream.receive_some(n) + logging.debug(f"Read {len(data)} bytes") + return data + + async def close(self) -> None: + logging.debug("Closing stream") + await self.send_stream.aclose() + await self.receive_stream.aclose() + + def get_remote_address(self) -> tuple[str, int] | None: + """Return None since this is a test adapter without real network info.""" + return None + + +@pytest.fixture +def key_pair(): + return create_new_key_pair() + + +@pytest.fixture +def peer_id(key_pair): + return ID.from_pubkey(key_pair.public_key) + + +@pytest.fixture +async def secure_conn_pair(key_pair, peer_id): + """Create a pair of secure connections for testing.""" + logging.debug("Setting up secure_conn_pair") + client_send, server_receive = memory_stream_pair() + server_send, client_receive = memory_stream_pair() + + client_rw = TrioStreamAdapter(client_send, client_receive) + server_rw = TrioStreamAdapter(server_send, server_receive) + + insecure_transport = InsecureTransport(key_pair) + + async def run_outbound(nursery_results): + with trio.move_on_after(5): + client_conn = await insecure_transport.secure_outbound(client_rw, peer_id) + logging.debug("Outbound handshake complete") + nursery_results["client"] = client_conn + + async def run_inbound(nursery_results): + with trio.move_on_after(5): + server_conn = await insecure_transport.secure_inbound(server_rw) + logging.debug("Inbound handshake complete") + nursery_results["server"] = server_conn + + nursery_results = {} + async with trio.open_nursery() as nursery: + nursery.start_soon(run_outbound, nursery_results) + nursery.start_soon(run_inbound, nursery_results) + await trio.sleep(0.1) # Give tasks a chance to finish + + client_conn = nursery_results.get("client") + server_conn = nursery_results.get("server") + + if client_conn is None or server_conn is None: + raise RuntimeError("Handshake failed: client_conn or server_conn is None") + + logging.debug("secure_conn_pair setup complete") + return client_conn, server_conn + + +@pytest.fixture +async def yamux_pair(secure_conn_pair, peer_id): + """Create a pair of Yamux multiplexers for testing.""" + logging.debug("Setting up yamux_pair") + client_conn, server_conn = secure_conn_pair + client_yamux = Yamux(client_conn, peer_id, is_initiator=True) + server_yamux = Yamux(server_conn, peer_id, is_initiator=False) + async with trio.open_nursery() as nursery: + with trio.move_on_after(5): + nursery.start_soon(client_yamux.start) + nursery.start_soon(server_yamux.start) + await trio.sleep(0.1) + logging.debug("yamux_pair started") + yield client_yamux, server_yamux + logging.debug("yamux_pair cleanup") + + +@pytest.mark.trio +async def test_yamux_race_condition_without_locks(yamux_pair): + """ + Test for race-around/interleaving in Yamux streams,when reading till + EOF is being used. + This launches concurrent writers/readers on both sides of a stream. + If there is no proper locking, the received data may be interleaved + or corrupted. + + The test creates structured messages and verifies they are received + intact and in order. + Without proper locking, concurrent read/write operations could cause + data corruption + or message interleaving, which this test will catch. + """ + client_yamux, server_yamux = yamux_pair + client_stream: YamuxStream = await client_yamux.open_stream() + server_stream: YamuxStream = await server_yamux.accept_stream() + MSG_COUNT = 1 + MSG_SIZE = 512 * 1024 + client_msgs = [ + f"CLIENT-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"C") for i in range(MSG_COUNT) + ] + server_msgs = [ + f"SERVER-MSG-{i:03d}-".encode().ljust(MSG_SIZE, b"S") for i in range(MSG_COUNT) + ] + client_received = [] + server_received = [] + + async def writer(stream, msgs, name): + """Write messages with minimal delays to encourage race conditions.""" + for i, msg in enumerate(msgs): + await stream.write(msg) + # Yield control frequently to encourage interleaving + if i % 5 == 0: + await trio.sleep(0.005) + + async def reader(stream, received, name): + """Read messages and store them for verification.""" + try: + data = await stream.read() + if data: + received.append(data) + except MuxedStreamEOF: + pass + + # Running all operations concurrently + async with trio.open_nursery() as nursery: + nursery.start_soon(writer, client_stream, client_msgs, "client") + nursery.start_soon(writer, server_stream, server_msgs, "server") + nursery.start_soon(reader, client_stream, client_received, "client") + nursery.start_soon(reader, server_stream, server_received, "server") + + assert client_received == server_msgs, ( + "Client did not receive server messages in order or intact!" + ) + assert server_received == client_msgs, ( + "Server did not receive client messages in order or intact!" + ) + for i, msg in enumerate(client_received): + assert len(msg) == MSG_SIZE, ( + f"Client message {i} has wrong size: {len(msg)} != {MSG_SIZE}" + ) + + for i, msg in enumerate(server_received): + assert len(msg) == MSG_SIZE, ( + f"Server message {i} has wrong size: {len(msg)} != {MSG_SIZE}" + ) + + await client_stream.close() + await server_stream.close()