diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 91f872d4..7e439160 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -35,55 +35,69 @@ if TYPE_CHECKING: class ReadWriteLock: + """ + A read-write lock that allows multiple concurrent readers + or one exclusive writer, implemented using Trio primitives. + """ + def __init__(self) -> None: self._readers = 0 - self._readers_lock = trio.Lock() # Protects readers count - self._writer_lock = trio.Semaphore(1) # Ensures mutual exclusion for writers + self._readers_lock = trio.Lock() # Protects access to _readers count + self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time async def acquire_read(self) -> None: + """Acquire a read lock. Multiple readers can hold it simultaneously.""" try: async with self._readers_lock: - self._readers += 1 - if self._readers == 1: + if self._readers == 0: await self._writer_lock.acquire() + self._readers += 1 except trio.Cancelled: - async with self._readers_lock: - if self._readers > 0: - self._readers -= 1 - if self._readers == 0: - self._writer_lock.release() raise async def release_read(self) -> None: + """Release a read lock.""" async with self._readers_lock: - self._readers -= 1 - if self._readers == 0: + if self._readers == 1: self._writer_lock.release() + self._readers -= 1 async def acquire_write(self) -> None: + """Acquire an exclusive write lock.""" try: await self._writer_lock.acquire() except trio.Cancelled: raise def release_write(self) -> None: + """Release the exclusive write lock.""" self._writer_lock.release() @asynccontextmanager async def read_lock(self) -> AsyncGenerator[None, None]: - await self.acquire_read() + """Context manager for acquiring and releasing a read lock safely.""" + acquire = False try: + await self.acquire_read() + acquire = True yield finally: - await self.release_read() + if acquire: + with trio.CancelScope() as scope: + scope.shield = True + await self.release_read() @asynccontextmanager async def write_lock(self) -> AsyncGenerator[None, None]: - await self.acquire_write() + """Context manager for acquiring and releasing a write lock safely.""" + acquire = False try: + await self.acquire_write() + acquire = True yield finally: - self.release_write() + if acquire: + self.release_write() class MplexStream(IMuxedStream): @@ -168,9 +182,7 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - await self.rw_lock.acquire_read() - payload: bytes = b"" - try: + async with self.rw_lock.read_lock(): if n is not None and n < 0: raise ValueError( "the number of bytes to read n must be non-negative or " @@ -210,12 +222,9 @@ class MplexStream(IMuxedStream): "This should never happen." ) from error self._buf.extend(self._read_return_when_blocked()) - chunk = self._buf[:n] - self._buf = self._buf[len(chunk) :] - payload = bytes(chunk) - finally: - await self.rw_lock.release_read() - return payload + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) async def write(self, data: bytes) -> None: """ @@ -223,8 +232,7 @@ class MplexStream(IMuxedStream): :return: number of bytes written """ - await self.rw_lock.acquire_write() - try: + async with self.rw_lock.write_lock(): if self.event_local_closed.is_set(): raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") flag = ( @@ -233,8 +241,6 @@ class MplexStream(IMuxedStream): else HeaderTags.MessageReceiver ) await self.muxed_conn.send_message(flag, data, self.stream_id) - finally: - self.rw_lock.release_write() async def close(self) -> None: """ diff --git a/tests/core/stream_muxer/test_read_write_lock.py b/tests/core/stream_muxer/test_read_write_lock.py index c52aa36c..621f3841 100644 --- a/tests/core/stream_muxer/test_read_write_lock.py +++ b/tests/core/stream_muxer/test_read_write_lock.py @@ -1,203 +1,590 @@ -from unittest.mock import AsyncMock, MagicMock +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, patch import pytest import trio +from trio.testing import wait_all_tasks_blocked -from libp2p.stream_muxer.mplex.mplex_stream import MplexStream, StreamID +from libp2p.stream_muxer.exceptions import ( + MuxedConnUnavailable, +) +from libp2p.stream_muxer.mplex.constants import HeaderTags +from libp2p.stream_muxer.mplex.datastructures import StreamID +from libp2p.stream_muxer.mplex.exceptions import ( + MplexStreamClosed, + MplexStreamEOF, + MplexStreamReset, +) +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream + + +class MockMuxedConn: + """A mock Mplex connection for testing purposes.""" + + def __init__(self): + self.sent_messages = [] + self.streams: dict[StreamID, MplexStream] = {} + self.streams_lock = trio.Lock() + self.is_unavailable = False + + async def send_message( + self, flag: HeaderTags, data: bytes | None, stream_id: StreamID + ) -> None: + """Mocks sending a message over the connection.""" + if self.is_unavailable: + raise MuxedConnUnavailable("Connection is unavailable") + self.sent_messages.append((flag, data, stream_id)) + # Yield to allow other tasks to run + await trio.lowlevel.checkpoint() + + def get_remote_address(self) -> tuple[str, int]: + """Mocks getting the remote address.""" + return "127.0.0.1", 4001 @pytest.fixture -def stream_with_lock() -> tuple[MplexStream, trio.MemorySendChannel[bytes]]: - muxed_conn = MagicMock() - muxed_conn.send_message = AsyncMock() - muxed_conn.streams_lock = trio.Lock() - muxed_conn.streams = {} - muxed_conn.get_remote_address = MagicMock(return_value=("127.0.0.1", 8000)) +async def mplex_stream(): + """Provides a fully initialized MplexStream and its communication channels.""" + # Use a buffered channel to prevent deadlocks in simple tests + send_chan, recv_chan = trio.open_memory_channel(10) + stream_id = StreamID(1, is_initiator=True) + muxed_conn = MockMuxedConn() + stream = MplexStream("test-stream", stream_id, cast(Any, muxed_conn), recv_chan) + muxed_conn.streams[stream_id] = stream - send_chan: trio.MemorySendChannel[bytes] - recv_chan: trio.MemoryReceiveChannel[bytes] - send_chan, recv_chan = trio.open_memory_channel(0) + yield stream, send_chan, muxed_conn - dummy_stream_id = MagicMock(spec=StreamID) - dummy_stream_id.is_initiator = True # mock read-only property + # Cleanup: Close channels and reset stream state + await send_chan.aclose() + await recv_chan.aclose() + # Reset stream state to prevent cross-test contamination + stream.event_local_closed = trio.Event() + stream.event_remote_closed = trio.Event() + stream.event_reset = trio.Event() - stream = MplexStream( - name="test", - stream_id=dummy_stream_id, - muxed_conn=muxed_conn, - incoming_data_channel=recv_chan, - ) - return stream, send_chan + +# =============================================== +# 1. Tests for Stream-Level Lock Integration +# =============================================== @pytest.mark.trio -async def test_writing_blocked_if_read_in_progress( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] +async def test_stream_write_is_protected_by_rwlock(mplex_stream): + """Verify that stream.write() acquires and releases the write lock.""" + stream, _, muxed_conn = mplex_stream - async def reader() -> None: - await stream.rw_lock.acquire_read() - log.append("read_acquired") - await trio.sleep(0.3) - log.append("read_released") - await stream.rw_lock.release_read() + # Mock lock methods + original_acquire = stream.rw_lock.acquire_write + original_release = stream.rw_lock.release_write - async def writer() -> None: - await stream.rw_lock.acquire_write() - log.append("write_acquired") - await trio.sleep(0.1) - log.append("write_released") - stream.rw_lock.release_write() + stream.rw_lock.acquire_write = AsyncMock(wraps=original_acquire) + stream.rw_lock.release_write = MagicMock(wraps=original_release) + + await stream.write(b"test data") + + stream.rw_lock.acquire_write.assert_awaited_once() + stream.rw_lock.release_write.assert_called_once() + + # Verify the message was actually sent + assert len(muxed_conn.sent_messages) == 1 + flag, data, stream_id = muxed_conn.sent_messages[0] + assert flag == HeaderTags.MessageInitiator + assert data == b"test data" + assert stream_id == stream.stream_id + + +@pytest.mark.trio +async def test_stream_read_is_protected_by_rwlock(mplex_stream): + """Verify that stream.read() acquires and releases the read lock.""" + stream, send_chan, _ = mplex_stream + + # Mock lock methods + original_acquire = stream.rw_lock.acquire_read + original_release = stream.rw_lock.release_read + + stream.rw_lock.acquire_read = AsyncMock(wraps=original_acquire) + stream.rw_lock.release_read = AsyncMock(wraps=original_release) + + await send_chan.send(b"hello") + result = await stream.read(5) + + stream.rw_lock.acquire_read.assert_awaited_once() + stream.rw_lock.release_read.assert_awaited_once() + assert result == b"hello" + + +@pytest.mark.trio +async def test_multiple_readers_can_coexist(mplex_stream): + """Verify multiple readers can operate concurrently.""" + stream, send_chan, _ = mplex_stream + + # Send enough data for both reads + await send_chan.send(b"data1") + await send_chan.send(b"data2") + + # Track lock acquisition order + acquisition_order = [] + release_order = [] + + # Patch lock methods to track concurrency + original_acquire = stream.rw_lock.acquire_read + original_release = stream.rw_lock.release_read + + async def tracked_acquire(): + nonlocal acquisition_order + acquisition_order.append("start") + await original_acquire() + acquisition_order.append("acquired") + + async def tracked_release(): + nonlocal release_order + release_order.append("start") + await original_release() + release_order.append("released") + + with ( + patch.object( + stream.rw_lock, "acquire_read", side_effect=tracked_acquire, autospec=True + ), + patch.object( + stream.rw_lock, "release_read", side_effect=tracked_release, autospec=True + ), + ): + # Execute concurrent reads + async with trio.open_nursery() as nursery: + nursery.start_soon(stream.read, 5) + nursery.start_soon(stream.read, 5) + + # Verify both reads happened + assert acquisition_order.count("start") == 2 + assert acquisition_order.count("acquired") == 2 + assert release_order.count("start") == 2 + assert release_order.count("released") == 2 + + +@pytest.mark.trio +async def test_writer_blocks_readers(mplex_stream): + """Verify that a writer blocks all readers and new readers queue behind.""" + stream, send_chan, _ = mplex_stream + + writer_acquired = trio.Event() + readers_ready = trio.Event() + writer_finished = trio.Event() + all_readers_started = trio.Event() + all_readers_done = trio.Event() + + counters = {"reader_start_count": 0, "reader_done_count": 0} + reader_target = 3 + reader_start_lock = trio.Lock() + + # Patch write lock to control test flow + original_acquire_write = stream.rw_lock.acquire_write + original_release_write = stream.rw_lock.release_write + + async def tracked_acquire_write(): + await original_acquire_write() + writer_acquired.set() + # Wait for readers to queue up + await readers_ready.wait() + + # Must be synchronous since real release_write is sync + def tracked_release_write(): + original_release_write() + writer_finished.set() + + with ( + patch.object( + stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write + ), + patch.object( + stream.rw_lock, "release_write", side_effect=tracked_release_write + ), + ): + async with trio.open_nursery() as nursery: + # Start writer + nursery.start_soon(stream.write, b"test") + await writer_acquired.wait() + + # Start readers + async def reader_task(): + async with reader_start_lock: + counters["reader_start_count"] += 1 + if counters["reader_start_count"] == reader_target: + all_readers_started.set() + + try: + # This will block until data is available + await stream.read(5) + except (MplexStreamReset, MplexStreamEOF): + pass + finally: + async with reader_start_lock: + counters["reader_done_count"] += 1 + if counters["reader_done_count"] == reader_target: + all_readers_done.set() + + for _ in range(reader_target): + nursery.start_soon(reader_task) + + # Wait until all readers are started + await all_readers_started.wait() + + # Let the writer finish and release the lock + readers_ready.set() + await writer_finished.wait() + + # Send data to unblock the readers + for i in range(reader_target): + await send_chan.send(b"data" + str(i).encode()) + + # Wait for all readers to finish + await all_readers_done.wait() + + +@pytest.mark.trio +async def test_writer_waits_for_readers(mplex_stream): + """Verify a writer waits for existing readers to complete.""" + stream, send_chan, _ = mplex_stream + readers_started = trio.Event() + writer_entered = trio.Event() + writer_acquiring = trio.Event() + readers_finished = trio.Event() + + # Send data for readers + await send_chan.send(b"data1") + await send_chan.send(b"data2") + + # Patch read lock to control test flow + original_acquire_read = stream.rw_lock.acquire_read + + async def tracked_acquire_read(): + await original_acquire_read() + readers_started.set() + # Wait until readers are allowed to finish + await readers_finished.wait() + + # Patch write lock to detect when writer is blocked + original_acquire_write = stream.rw_lock.acquire_write + + async def tracked_acquire_write(): + writer_acquiring.set() + await original_acquire_write() + writer_entered.set() + + with ( + patch.object(stream.rw_lock, "acquire_read", side_effect=tracked_acquire_read), + patch.object( + stream.rw_lock, "acquire_write", side_effect=tracked_acquire_write + ), + ): + async with trio.open_nursery() as nursery: + # Start readers + nursery.start_soon(stream.read, 5) + nursery.start_soon(stream.read, 5) + + # Wait for at least one reader to acquire the lock + await readers_started.wait() + + # Start writer (should block) + nursery.start_soon(stream.write, b"test") + + # Wait for writer to start acquiring lock + await writer_acquiring.wait() + + # Verify writer hasn't entered critical section + assert not writer_entered.is_set() + + # Allow readers to finish + readers_finished.set() + + # Verify writer can proceed + await writer_entered.wait() + + +@pytest.mark.trio +async def test_lock_behavior_during_cancellation(mplex_stream): + """Verify that a lock is released when a task holding it is cancelled.""" + stream, _, _ = mplex_stream + + reader_acquired_lock = trio.Event() + + async def cancellable_reader(task_status): + async with stream.rw_lock.read_lock(): + reader_acquired_lock.set() + task_status.started() + # Wait indefinitely until cancelled. + await trio.sleep_forever() + + async with trio.open_nursery() as nursery: + # Start the reader and wait for it to acquire the lock. + await nursery.start(cancellable_reader) + await reader_acquired_lock.wait() + + # Now that the reader has the lock, cancel the nursery. + # This will cancel the reader task, and its lock should be released. + nursery.cancel_scope.cancel() + + # After the nursery is cancelled, the reader should have released the lock. + # To verify, we try to acquire a write lock. If the read lock was not + # released, this will time out. + with trio.move_on_after(1) as cancel_scope: + async with stream.rw_lock.write_lock(): + pass + if cancel_scope.cancelled_caught: + pytest.fail( + "Write lock could not be acquired after a cancelled reader, " + "indicating the read lock was not released." + ) + + +@pytest.mark.trio +async def test_concurrent_read_write_sequence(mplex_stream): + """Verify complex sequence of interleaved reads and writes.""" + stream, send_chan, _ = mplex_stream + results = [] + # Use a mock to intercept writes and feed them back to the read channel + original_write = stream.write + + reader1_finished = trio.Event() + writer1_finished = trio.Event() + reader2_finished = trio.Event() + + async def mocked_write(data: bytes) -> None: + await original_write(data) + # Simulate the other side receiving the data and sending a response + # by putting data into the read channel. + await send_chan.send(data) + + with patch.object(stream, "write", wraps=mocked_write) as patched_write: + async with trio.open_nursery() as nursery: + # Test scenario: + # 1. Reader 1 starts, waits for data. + # 2. Writer 1 writes, which gets fed back to the stream. + # 3. Reader 2 starts, reads what Writer 1 wrote. + # 4. Writer 2 writes. + + async def reader1(): + nonlocal results + results.append("R1 start") + data = await stream.read(5) + results.append(data) + results.append("R1 done") + reader1_finished.set() + + async def writer1(): + nonlocal results + await reader1_finished.wait() + results.append("W1 start") + await stream.write(b"write1") + results.append("W1 done") + writer1_finished.set() + + async def reader2(): + nonlocal results + await writer1_finished.wait() + # This will read the data from writer1 + results.append("R2 start") + data = await stream.read(6) + results.append(data) + results.append("R2 done") + reader2_finished.set() + + async def writer2(): + nonlocal results + await reader2_finished.wait() + results.append("W2 start") + await stream.write(b"write2") + results.append("W2 done") + + # Execute sequence + nursery.start_soon(reader1) + nursery.start_soon(writer1) + nursery.start_soon(reader2) + nursery.start_soon(writer2) + + await send_chan.send(b"data1") + + # Verify sequence and that write was called + assert patched_write.call_count == 2 + assert results == [ + "R1 start", + b"data1", + "R1 done", + "W1 start", + "W1 done", + "R2 start", + b"write1", + "R2 done", + "W2 start", + "W2 done", + ] + + +# =============================================== +# 2. Tests for Reset, EOF, and Close Interactions +# =============================================== + + +@pytest.mark.trio +async def test_read_after_remote_close_triggers_eof(mplex_stream): + """Verify reading from a remotely closed stream returns EOF correctly.""" + stream, send_chan, _ = mplex_stream + + # Send some data that can be read first + await send_chan.send(b"data") + # Close the channel to signify no more data will ever arrive + await send_chan.aclose() + + # Mark the stream as remotely closed + stream.event_remote_closed.set() + + # The first read should succeed, consuming the buffered data + data = await stream.read(4) + assert data == b"data" + + # Now that the buffer is empty and the channel is closed, this should raise EOF + with pytest.raises(MplexStreamEOF): + await stream.read(1) + + +@pytest.mark.trio +async def test_read_on_closed_stream_raises_eof(mplex_stream): + """Test that reading from a closed stream with no data raises EOF.""" + stream, send_chan, _ = mplex_stream + stream.event_remote_closed.set() + await send_chan.aclose() # Ensure the channel is closed + + # Reading from a stream that is closed and has no data should raise EOF + with pytest.raises(MplexStreamEOF): + await stream.read(100) + + +@pytest.mark.trio +async def test_write_to_locally_closed_stream_raises(mplex_stream): + """Verify writing to a locally closed stream raises MplexStreamClosed.""" + stream, _, _ = mplex_stream + stream.event_local_closed.set() + + with pytest.raises(MplexStreamClosed): + await stream.write(b"this should fail") + + +@pytest.mark.trio +async def test_read_from_reset_stream_raises(mplex_stream): + """Verify reading from a reset stream raises MplexStreamReset.""" + stream, _, _ = mplex_stream + stream.event_reset.set() + + with pytest.raises(MplexStreamReset): + await stream.read(10) + + +@pytest.mark.trio +async def test_write_to_reset_stream_raises(mplex_stream): + """Verify writing to a reset stream raises MplexStreamClosed.""" + stream, _, _ = mplex_stream + # A stream reset implies it's also locally closed. + await stream.reset() + + # The `write` method checks `event_local_closed`, which `reset` sets. + with pytest.raises(MplexStreamClosed): + await stream.write(b"this should also fail") + + +@pytest.mark.trio +async def test_stream_reset_cleans_up_resources(mplex_stream): + """Verify reset() cleans up stream state and resources.""" + stream, _, muxed_conn = mplex_stream + stream_id = stream.stream_id + + assert stream_id in muxed_conn.streams + await stream.reset() + + assert stream.event_reset.is_set() + assert stream.event_local_closed.is_set() + assert stream.event_remote_closed.is_set() + assert (HeaderTags.ResetInitiator, None, stream_id) in muxed_conn.sent_messages + assert stream_id not in muxed_conn.streams + # Verify the underlying data channel is closed + with pytest.raises(trio.ClosedResourceError): + await stream.incoming_data_channel.receive() + + +# =============================================== +# 3. Rigorous Concurrency Tests with Events +# =============================================== + + +@pytest.mark.trio +async def test_writer_is_blocked_by_reader_using_events(mplex_stream): + """Verify a writer must wait for a reader using trio.Event for synchronization.""" + stream, _, _ = mplex_stream + + reader_has_lock = trio.Event() + writer_finished = trio.Event() + + async def reader(): + async with stream.rw_lock.read_lock(): + reader_has_lock.set() + # Hold the lock until the writer has finished its attempt + await writer_finished.wait() + + async def writer(): + await reader_has_lock.wait() + # This call will now block until the reader releases the lock + await stream.write(b"data") + writer_finished.set() async with trio.open_nursery() as nursery: nursery.start_soon(reader) - await trio.sleep(0.05) nursery.start_soon(writer) - assert log == [ - "read_acquired", - "read_released", - "write_acquired", - "write_released", - ], f"Unexpected order: {log}" + # Verify writer is blocked + await wait_all_tasks_blocked() + assert not writer_finished.is_set() + + # Signal the reader to finish + writer_finished.set() @pytest.mark.trio -async def test_reading_blocked_if_write_in_progress( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] +async def test_multiple_readers_can_read_concurrently_using_events(mplex_stream): + """Verify that multiple readers can acquire a read lock simultaneously.""" + stream, _, _ = mplex_stream - async def writer() -> None: - await stream.rw_lock.acquire_write() - log.append("write_acquired") - await trio.sleep(0.3) - log.append("write_released") - stream.rw_lock.release_write() + counters = {"readers_in_critical_section": 0} + lock = trio.Lock() # To safely mutate the counter - async def reader() -> None: - await stream.rw_lock.acquire_read() - log.append("read_acquired") - await trio.sleep(0.1) - log.append("read_released") - await stream.rw_lock.release_read() + reader1_acquired = trio.Event() + reader2_acquired = trio.Event() + all_readers_finished = trio.Event() + + async def concurrent_reader(event_to_set: trio.Event): + async with stream.rw_lock.read_lock(): + async with lock: + counters["readers_in_critical_section"] += 1 + event_to_set.set() + # Wait until all readers have finished before exiting the lock context + await all_readers_finished.wait() + async with lock: + counters["readers_in_critical_section"] -= 1 async with trio.open_nursery() as nursery: - nursery.start_soon(writer) - await trio.sleep(0.05) - nursery.start_soon(reader) + nursery.start_soon(concurrent_reader, reader1_acquired) + nursery.start_soon(concurrent_reader, reader2_acquired) - assert log == [ - "write_acquired", - "write_released", - "read_acquired", - "read_released", - ], f"Unexpected order: {log}" + # Wait for both readers to acquire their locks + await reader1_acquired.wait() + await reader2_acquired.wait() + # Check that both were in the critical section at the same time + async with lock: + assert counters["readers_in_critical_section"] == 2 -@pytest.mark.trio -async def test_multiple_reads_allowed_concurrently( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] + # Signal for all readers to finish + all_readers_finished.set() - async def read_task(i: int) -> None: - await stream.rw_lock.acquire_read() - log.append(f"read_{i}_acquired") - await trio.sleep(0.2) - log.append(f"read_{i}_released") - await stream.rw_lock.release_read() - - async with trio.open_nursery() as nursery: - for i in range(5): - nursery.start_soon(read_task, i) - - acquires = [entry for entry in log if "acquired" in entry] - releases = [entry for entry in log if "released" in entry] - - assert len(acquires) == 5 and len(releases) == 5, "Not all reads executed" - assert all( - log.index(acq) < min(log.index(rel) for rel in releases) for acq in acquires - ), f"Reads didn't overlap properly: {log}" - - -@pytest.mark.trio -async def test_only_one_write_allowed( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] - - async def write_task(i: int) -> None: - await stream.rw_lock.acquire_write() - log.append(f"write_{i}_acquired") - await trio.sleep(0.2) - log.append(f"write_{i}_released") - stream.rw_lock.release_write() - - async with trio.open_nursery() as nursery: - for i in range(5): - nursery.start_soon(write_task, i) - - active = 0 - for entry in log: - if "acquired" in entry: - active += 1 - elif "released" in entry: - active -= 1 - assert active <= 1, f"More than one write active: {log}" - assert active == 0, f"Write locks not properly released: {log}" - - -@pytest.mark.trio -async def test_interleaved_read_write_behavior( - stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], -) -> None: - stream, _ = stream_with_lock - log: list[str] = [] - - async def read(i: int) -> None: - await stream.rw_lock.acquire_read() - log.append(f"read_{i}_acquired") - await trio.sleep(0.15) - log.append(f"read_{i}_released") - await stream.rw_lock.release_read() - - async def write(i: int) -> None: - await stream.rw_lock.acquire_write() - log.append(f"write_{i}_acquired") - await trio.sleep(0.2) - log.append(f"write_{i}_released") - stream.rw_lock.release_write() - - async with trio.open_nursery() as nursery: - nursery.start_soon(read, 1) - await trio.sleep(0.05) - nursery.start_soon(read, 2) - await trio.sleep(0.05) - nursery.start_soon(write, 1) - await trio.sleep(0.05) - nursery.start_soon(read, 3) - await trio.sleep(0.05) - nursery.start_soon(write, 2) - - read1_released = log.index("read_1_released") - read2_released = log.index("read_2_released") - write1_acquired = log.index("write_1_acquired") - assert write1_acquired > read1_released and write1_acquired > read2_released, ( - f"write_1 acquired too early: {log}" - ) - - read3_acquired = log.index("read_3_acquired") - read3_released = log.index("read_3_released") - write1_released = log.index("write_1_released") - assert read3_released < write1_acquired or read3_acquired > write1_released, ( - f"read_3 improperly overlapped with write_1: {log}" - ) - - write2_acquired = log.index("write_2_acquired") - assert write2_acquired > write1_released, f"write_2 started too early: {log}" + # Verify they exit cleanly + await wait_all_tasks_blocked() + async with lock: + assert counters["readers_in_critical_section"] == 0