mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
add test for read-write-lock
This commit is contained in:
@ -35,26 +35,26 @@ if TYPE_CHECKING:
|
||||
class ReadWriteLock:
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._lock = trio.Lock() # Protects _readers
|
||||
self._write_lock = trio.Lock()
|
||||
self._readers_lock = trio.Lock() # Protects readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Acts like a task-transferable lock
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
async with self._lock:
|
||||
async with self._readers_lock:
|
||||
self._readers += 1
|
||||
if self._readers == 1:
|
||||
await self._write_lock.acquire()
|
||||
await self._writer_lock.acquire()
|
||||
|
||||
async def release_read(self) -> None:
|
||||
async with self._lock:
|
||||
async with self._readers_lock:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._write_lock.release()
|
||||
self._writer_lock.release()
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
await self._write_lock.acquire()
|
||||
await self._writer_lock.acquire()
|
||||
|
||||
def release_write(self) -> None:
|
||||
self._write_lock.release()
|
||||
self._writer_lock.release()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
|
||||
203
libp2p/stream_muxer/mplex/test_read_write_lock.py
Normal file
203
libp2p/stream_muxer/mplex/test_read_write_lock.py
Normal file
@ -0,0 +1,203 @@
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import trio
|
||||
|
||||
from libp2p.stream_muxer.mplex.mplex_stream import MplexStream, StreamID
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_with_lock() -> tuple[MplexStream, trio.MemorySendChannel[bytes]]:
|
||||
muxed_conn = MagicMock()
|
||||
muxed_conn.send_message = AsyncMock()
|
||||
muxed_conn.streams_lock = trio.Lock()
|
||||
muxed_conn.streams = {}
|
||||
muxed_conn.get_remote_address = MagicMock(return_value=("127.0.0.1", 8000))
|
||||
|
||||
send_chan: trio.MemorySendChannel[bytes]
|
||||
recv_chan: trio.MemoryReceiveChannel[bytes]
|
||||
send_chan, recv_chan = trio.open_memory_channel(0)
|
||||
|
||||
dummy_stream_id = MagicMock(spec=StreamID)
|
||||
dummy_stream_id.is_initiator = True # mock read-only property
|
||||
|
||||
stream = MplexStream(
|
||||
name="test",
|
||||
stream_id=dummy_stream_id,
|
||||
muxed_conn=muxed_conn,
|
||||
incoming_data_channel=recv_chan,
|
||||
)
|
||||
return stream, send_chan
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_writing_blocked_if_read_in_progress(
|
||||
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
|
||||
) -> None:
|
||||
stream, _ = stream_with_lock
|
||||
log: list[str] = []
|
||||
|
||||
async def reader() -> None:
|
||||
await stream.rw_lock.acquire_read()
|
||||
log.append("read_acquired")
|
||||
await trio.sleep(0.3)
|
||||
log.append("read_released")
|
||||
await stream.rw_lock.release_read()
|
||||
|
||||
async def writer() -> None:
|
||||
await stream.rw_lock.acquire_write()
|
||||
log.append("write_acquired")
|
||||
await trio.sleep(0.1)
|
||||
log.append("write_released")
|
||||
stream.rw_lock.release_write()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(reader)
|
||||
await trio.sleep(0.05)
|
||||
nursery.start_soon(writer)
|
||||
|
||||
assert log == [
|
||||
"read_acquired",
|
||||
"read_released",
|
||||
"write_acquired",
|
||||
"write_released",
|
||||
], f"Unexpected order: {log}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_reading_blocked_if_write_in_progress(
|
||||
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
|
||||
) -> None:
|
||||
stream, _ = stream_with_lock
|
||||
log: list[str] = []
|
||||
|
||||
async def writer() -> None:
|
||||
await stream.rw_lock.acquire_write()
|
||||
log.append("write_acquired")
|
||||
await trio.sleep(0.3)
|
||||
log.append("write_released")
|
||||
stream.rw_lock.release_write()
|
||||
|
||||
async def reader() -> None:
|
||||
await stream.rw_lock.acquire_read()
|
||||
log.append("read_acquired")
|
||||
await trio.sleep(0.1)
|
||||
log.append("read_released")
|
||||
await stream.rw_lock.release_read()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(writer)
|
||||
await trio.sleep(0.05)
|
||||
nursery.start_soon(reader)
|
||||
|
||||
assert log == [
|
||||
"write_acquired",
|
||||
"write_released",
|
||||
"read_acquired",
|
||||
"read_released",
|
||||
], f"Unexpected order: {log}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_multiple_reads_allowed_concurrently(
|
||||
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
|
||||
) -> None:
|
||||
stream, _ = stream_with_lock
|
||||
log: list[str] = []
|
||||
|
||||
async def read_task(i: int) -> None:
|
||||
await stream.rw_lock.acquire_read()
|
||||
log.append(f"read_{i}_acquired")
|
||||
await trio.sleep(0.2)
|
||||
log.append(f"read_{i}_released")
|
||||
await stream.rw_lock.release_read()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(read_task, i)
|
||||
|
||||
acquires = [entry for entry in log if "acquired" in entry]
|
||||
releases = [entry for entry in log if "released" in entry]
|
||||
|
||||
assert len(acquires) == 5 and len(releases) == 5, "Not all reads executed"
|
||||
assert all(
|
||||
log.index(acq) < min(log.index(rel) for rel in releases) for acq in acquires
|
||||
), f"Reads didn't overlap properly: {log}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_only_one_write_allowed(
|
||||
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
|
||||
) -> None:
|
||||
stream, _ = stream_with_lock
|
||||
log: list[str] = []
|
||||
|
||||
async def write_task(i: int) -> None:
|
||||
await stream.rw_lock.acquire_write()
|
||||
log.append(f"write_{i}_acquired")
|
||||
await trio.sleep(0.2)
|
||||
log.append(f"write_{i}_released")
|
||||
stream.rw_lock.release_write()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(write_task, i)
|
||||
|
||||
active = 0
|
||||
for entry in log:
|
||||
if "acquired" in entry:
|
||||
active += 1
|
||||
elif "released" in entry:
|
||||
active -= 1
|
||||
assert active <= 1, f"More than one write active: {log}"
|
||||
assert active == 0, f"Write locks not properly released: {log}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_interleaved_read_write_behavior(
|
||||
stream_with_lock: tuple[MplexStream, trio.MemorySendChannel[bytes]],
|
||||
) -> None:
|
||||
stream, _ = stream_with_lock
|
||||
log: list[str] = []
|
||||
|
||||
async def read(i: int) -> None:
|
||||
await stream.rw_lock.acquire_read()
|
||||
log.append(f"read_{i}_acquired")
|
||||
await trio.sleep(0.15)
|
||||
log.append(f"read_{i}_released")
|
||||
await stream.rw_lock.release_read()
|
||||
|
||||
async def write(i: int) -> None:
|
||||
await stream.rw_lock.acquire_write()
|
||||
log.append(f"write_{i}_acquired")
|
||||
await trio.sleep(0.2)
|
||||
log.append(f"write_{i}_released")
|
||||
stream.rw_lock.release_write()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(read, 1)
|
||||
await trio.sleep(0.05)
|
||||
nursery.start_soon(read, 2)
|
||||
await trio.sleep(0.05)
|
||||
nursery.start_soon(write, 1)
|
||||
await trio.sleep(0.05)
|
||||
nursery.start_soon(read, 3)
|
||||
await trio.sleep(0.05)
|
||||
nursery.start_soon(write, 2)
|
||||
|
||||
read1_released = log.index("read_1_released")
|
||||
read2_released = log.index("read_2_released")
|
||||
write1_acquired = log.index("write_1_acquired")
|
||||
assert write1_acquired > read1_released and write1_acquired > read2_released, (
|
||||
f"write_1 acquired too early: {log}"
|
||||
)
|
||||
|
||||
read3_acquired = log.index("read_3_acquired")
|
||||
read3_released = log.index("read_3_released")
|
||||
write1_released = log.index("write_1_released")
|
||||
assert read3_released < write1_acquired or read3_acquired > write1_released, (
|
||||
f"read_3 improperly overlapped with write_1: {log}"
|
||||
)
|
||||
|
||||
write2_acquired = log.index("write_2_acquired")
|
||||
assert write2_acquired > write1_released, f"write_2 started too early: {log}"
|
||||
Reference in New Issue
Block a user