add test for read-write-lock

This commit is contained in:
Jinesh Jain
2025-07-08 20:06:30 +05:30
parent 5f497c7f5d
commit 242998ae9d
2 changed files with 211 additions and 8 deletions

View File

@ -35,26 +35,26 @@ if TYPE_CHECKING:
class ReadWriteLock:
def __init__(self) -> None:
self._readers = 0
self._lock = trio.Lock() # Protects _readers
self._write_lock = trio.Lock()
self._readers_lock = trio.Lock() # Protects readers count
self._writer_lock = trio.Semaphore(1) # Acts like a task-transferable lock
async def acquire_read(self) -> None:
async with self._lock:
async with self._readers_lock:
self._readers += 1
if self._readers == 1:
await self._write_lock.acquire()
await self._writer_lock.acquire()
async def release_read(self) -> None:
async with self._lock:
async with self._readers_lock:
self._readers -= 1
if self._readers == 0:
self._write_lock.release()
self._writer_lock.release()
async def acquire_write(self) -> None:
await self._write_lock.acquire()
await self._writer_lock.acquire()
def release_write(self) -> None:
self._write_lock.release()
self._writer_lock.release()
class MplexStream(IMuxedStream):

View File

@ -0,0 +1,203 @@
from unittest.mock import AsyncMock, MagicMock
import pytest
import trio
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream, StreamID
@pytest.fixture
def stream_with_lock() -> tuple[MplexStream, trio.MemorySendChannel[bytes]]:
muxed_conn = MagicMock()
muxed_conn.send_message = AsyncMock()
muxed_conn.streams_lock = trio.Lock()
muxed_conn.streams = {}
muxed_conn.get_remote_address = MagicMock(return_value=("127.0.0.1", 8000))
send_chan: trio.MemorySendChannel[bytes]
recv_chan: trio.MemoryReceiveChannel[bytes]
send_chan, recv_chan = trio.open_memory_channel(0)
dummy_stream_id = MagicMock(spec=StreamID)
dummy_stream_id.is_initiator = True # mock read-only property
stream = MplexStream(
name="test",
stream_id=dummy_stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=recv_chan,
)
return stream, send_chan
@pytest.mark.trio
async def test_writing_blocked_if_read_in_progress(
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
) -> None:
stream, _ = stream_with_lock
log: list[str] = []
async def reader() -> None:
await stream.rw_lock.acquire_read()
log.append("read_acquired")
await trio.sleep(0.3)
log.append("read_released")
await stream.rw_lock.release_read()
async def writer() -> None:
await stream.rw_lock.acquire_write()
log.append("write_acquired")
await trio.sleep(0.1)
log.append("write_released")
stream.rw_lock.release_write()
async with trio.open_nursery() as nursery:
nursery.start_soon(reader)
await trio.sleep(0.05)
nursery.start_soon(writer)
assert log == [
"read_acquired",
"read_released",
"write_acquired",
"write_released",
], f"Unexpected order: {log}"
@pytest.mark.trio
async def test_reading_blocked_if_write_in_progress(
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
) -> None:
stream, _ = stream_with_lock
log: list[str] = []
async def writer() -> None:
await stream.rw_lock.acquire_write()
log.append("write_acquired")
await trio.sleep(0.3)
log.append("write_released")
stream.rw_lock.release_write()
async def reader() -> None:
await stream.rw_lock.acquire_read()
log.append("read_acquired")
await trio.sleep(0.1)
log.append("read_released")
await stream.rw_lock.release_read()
async with trio.open_nursery() as nursery:
nursery.start_soon(writer)
await trio.sleep(0.05)
nursery.start_soon(reader)
assert log == [
"write_acquired",
"write_released",
"read_acquired",
"read_released",
], f"Unexpected order: {log}"
@pytest.mark.trio
async def test_multiple_reads_allowed_concurrently(
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
) -> None:
stream, _ = stream_with_lock
log: list[str] = []
async def read_task(i: int) -> None:
await stream.rw_lock.acquire_read()
log.append(f"read_{i}_acquired")
await trio.sleep(0.2)
log.append(f"read_{i}_released")
await stream.rw_lock.release_read()
async with trio.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(read_task, i)
acquires = [entry for entry in log if "acquired" in entry]
releases = [entry for entry in log if "released" in entry]
assert len(acquires) == 5 and len(releases) == 5, "Not all reads executed"
assert all(
log.index(acq) < min(log.index(rel) for rel in releases) for acq in acquires
), f"Reads didn't overlap properly: {log}"
@pytest.mark.trio
async def test_only_one_write_allowed(
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
) -> None:
stream, _ = stream_with_lock
log: list[str] = []
async def write_task(i: int) -> None:
await stream.rw_lock.acquire_write()
log.append(f"write_{i}_acquired")
await trio.sleep(0.2)
log.append(f"write_{i}_released")
stream.rw_lock.release_write()
async with trio.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(write_task, i)
active = 0
for entry in log:
if "acquired" in entry:
active += 1
elif "released" in entry:
active -= 1
assert active <= 1, f"More than one write active: {log}"
assert active == 0, f"Write locks not properly released: {log}"
@pytest.mark.trio
async def test_interleaved_read_write_behavior(
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
) -> None:
stream, _ = stream_with_lock
log: list[str] = []
async def read(i: int) -> None:
await stream.rw_lock.acquire_read()
log.append(f"read_{i}_acquired")
await trio.sleep(0.15)
log.append(f"read_{i}_released")
await stream.rw_lock.release_read()
async def write(i: int) -> None:
await stream.rw_lock.acquire_write()
log.append(f"write_{i}_acquired")
await trio.sleep(0.2)
log.append(f"write_{i}_released")
stream.rw_lock.release_write()
async with trio.open_nursery() as nursery:
nursery.start_soon(read, 1)
await trio.sleep(0.05)
nursery.start_soon(read, 2)
await trio.sleep(0.05)
nursery.start_soon(write, 1)
await trio.sleep(0.05)
nursery.start_soon(read, 3)
await trio.sleep(0.05)
nursery.start_soon(write, 2)
read1_released = log.index("read_1_released")
read2_released = log.index("read_2_released")
write1_acquired = log.index("write_1_acquired")
assert write1_acquired > read1_released and write1_acquired > read2_released, (
f"write_1 acquired too early: {log}"
)
read3_acquired = log.index("read_3_acquired")
read3_released = log.index("read_3_released")
write1_released = log.index("write_1_released")
assert read3_released < write1_acquired or read3_acquired > write1_released, (
f"read_3 improperly overlapped with write_1: {log}"
)
write2_acquired = log.index("write_2_acquired")
assert write2_acquired > write1_released, f"write_2 started too early: {log}"