diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 9c5f04da..e6b98244 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -35,26 +35,26 @@ if TYPE_CHECKING: class ReadWriteLock: def __init__(self) -> None: self._readers = 0 - self._lock = trio.Lock() # Protects _readers - self._write_lock = trio.Lock() + self._readers_lock = trio.Lock() # Protects readers count + self._writer_lock = trio.Semaphore(1) # Acts like a task-transferable lock async def acquire_read(self) -> None: - async with self._lock: + async with self._readers_lock: self._readers += 1 if self._readers == 1: - await self._write_lock.acquire() + await self._writer_lock.acquire() async def release_read(self) -> None: - async with self._lock: + async with self._readers_lock: self._readers -= 1 if self._readers == 0: - self._write_lock.release() + self._writer_lock.release() async def acquire_write(self) -> None: - await self._write_lock.acquire() + await self._writer_lock.acquire() def release_write(self) -> None: - self._write_lock.release() + self._writer_lock.release() class MplexStream(IMuxedStream): diff --git a/libp2p/stream_muxer/mplex/test_read_write_lock.py b/libp2p/stream_muxer/mplex/test_read_write_lock.py new file mode 100644 index 00000000..c52aa36c --- /dev/null +++ b/libp2p/stream_muxer/mplex/test_read_write_lock.py @@ -0,0 +1,203 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +import trio + +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream, StreamID + + +@pytest.fixture +def stream_with_lock() -> tuple[MplexStream, trio.MemorySendChannel[bytes]]: + muxed_conn = MagicMock() + muxed_conn.send_message = AsyncMock() + muxed_conn.streams_lock = trio.Lock() + muxed_conn.streams = {} + muxed_conn.get_remote_address = MagicMock(return_value=("127.0.0.1", 8000)) + + send_chan: trio.MemorySendChannel[bytes] + recv_chan: trio.MemoryReceiveChannel[bytes] + send_chan, recv_chan = trio.open_memory_channel(0) + + dummy_stream_id = MagicMock(spec=StreamID) + dummy_stream_id.is_initiator = True # mock read-only property + + stream = MplexStream( + name="test", + stream_id=dummy_stream_id, + muxed_conn=muxed_conn, + incoming_data_channel=recv_chan, + ) + return stream, send_chan + + +@pytest.mark.trio +async def test_writing_blocked_if_read_in_progress( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def reader() -> None: + await stream.rw_lock.acquire_read() + log.append("read_acquired") + await trio.sleep(0.3) + log.append("read_released") + await stream.rw_lock.release_read() + + async def writer() -> None: + await stream.rw_lock.acquire_write() + log.append("write_acquired") + await trio.sleep(0.1) + log.append("write_released") + stream.rw_lock.release_write() + + async with trio.open_nursery() as nursery: + nursery.start_soon(reader) + await trio.sleep(0.05) + nursery.start_soon(writer) + + assert log == [ + "read_acquired", + "read_released", + "write_acquired", + "write_released", + ], f"Unexpected order: {log}" + + +@pytest.mark.trio +async def test_reading_blocked_if_write_in_progress( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def writer() -> None: + await stream.rw_lock.acquire_write() + log.append("write_acquired") + await trio.sleep(0.3) + log.append("write_released") + stream.rw_lock.release_write() + + async def reader() -> None: + await stream.rw_lock.acquire_read() + log.append("read_acquired") + await trio.sleep(0.1) + log.append("read_released") + await stream.rw_lock.release_read() + + async with trio.open_nursery() as nursery: + nursery.start_soon(writer) + await trio.sleep(0.05) + nursery.start_soon(reader) + + assert log == [ + "write_acquired", + "write_released", + "read_acquired", + "read_released", + ], f"Unexpected order: {log}" + + +@pytest.mark.trio +async def test_multiple_reads_allowed_concurrently( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def read_task(i: int) -> None: + await stream.rw_lock.acquire_read() + log.append(f"read_{i}_acquired") + await trio.sleep(0.2) + log.append(f"read_{i}_released") + await stream.rw_lock.release_read() + + async with trio.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(read_task, i) + + acquires = [entry for entry in log if "acquired" in entry] + releases = [entry for entry in log if "released" in entry] + + assert len(acquires) == 5 and len(releases) == 5, "Not all reads executed" + assert all( + log.index(acq) < min(log.index(rel) for rel in releases) for acq in acquires + ), f"Reads didn't overlap properly: {log}" + + +@pytest.mark.trio +async def test_only_one_write_allowed( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def write_task(i: int) -> None: + await stream.rw_lock.acquire_write() + log.append(f"write_{i}_acquired") + await trio.sleep(0.2) + log.append(f"write_{i}_released") + stream.rw_lock.release_write() + + async with trio.open_nursery() as nursery: + for i in range(5): + nursery.start_soon(write_task, i) + + active = 0 + for entry in log: + if "acquired" in entry: + active += 1 + elif "released" in entry: + active -= 1 + assert active <= 1, f"More than one write active: {log}" + assert active == 0, f"Write locks not properly released: {log}" + + +@pytest.mark.trio +async def test_interleaved_read_write_behavior( + stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]], +) -> None: + stream, _ = stream_with_lock + log: list[str] = [] + + async def read(i: int) -> None: + await stream.rw_lock.acquire_read() + log.append(f"read_{i}_acquired") + await trio.sleep(0.15) + log.append(f"read_{i}_released") + await stream.rw_lock.release_read() + + async def write(i: int) -> None: + await stream.rw_lock.acquire_write() + log.append(f"write_{i}_acquired") + await trio.sleep(0.2) + log.append(f"write_{i}_released") + stream.rw_lock.release_write() + + async with trio.open_nursery() as nursery: + nursery.start_soon(read, 1) + await trio.sleep(0.05) + nursery.start_soon(read, 2) + await trio.sleep(0.05) + nursery.start_soon(write, 1) + await trio.sleep(0.05) + nursery.start_soon(read, 3) + await trio.sleep(0.05) + nursery.start_soon(write, 2) + + read1_released = log.index("read_1_released") + read2_released = log.index("read_2_released") + write1_acquired = log.index("write_1_acquired") + assert write1_acquired > read1_released and write1_acquired > read2_released, ( + f"write_1 acquired too early: {log}" + ) + + read3_acquired = log.index("read_3_acquired") + read3_released = log.index("read_3_released") + write1_released = log.index("write_1_released") + assert read3_released < write1_acquired or read3_acquired > write1_released, ( + f"read_3 improperly overlapped with write_1: {log}" + ) + + write2_acquired = log.index("write_2_acquired") + assert write2_acquired > write1_released, f"write_2 started too early: {log}"