make readwrite more safe

This commit is contained in:
Jinesh Jain
2025-07-13 18:37:44 +05:30
parent 3592ad308f
commit 9cd3805542
2 changed files with 593 additions and 200 deletions

View File

@ -35,55 +35,69 @@ if TYPE_CHECKING:
class ReadWriteLock:
"""
A read-write lock that allows multiple concurrent readers
or one exclusive writer, implemented using Trio primitives.
"""
def __init__(self) -> None:
self._readers = 0
self._readers_lock = trio.Lock() # Protects readers count
self._writer_lock = trio.Semaphore(1) # Ensures mutual exclusion for writers
self._readers_lock = trio.Lock() # Protects access to _readers count
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
async def acquire_read(self) -> None:
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
try:
async with self._readers_lock:
self._readers += 1
if self._readers == 1:
if self._readers == 0:
await self._writer_lock.acquire()
self._readers += 1
except trio.Cancelled:
async with self._readers_lock:
if self._readers > 0:
self._readers -= 1
if self._readers == 0:
self._writer_lock.release()
raise
async def release_read(self) -> None:
"""Release a read lock."""
async with self._readers_lock:
self._readers -= 1
if self._readers == 0:
if self._readers == 1:
self._writer_lock.release()
self._readers -= 1
async def acquire_write(self) -> None:
"""Acquire an exclusive write lock."""
try:
await self._writer_lock.acquire()
except trio.Cancelled:
raise
def release_write(self) -> None:
"""Release the exclusive write lock."""
self._writer_lock.release()
@asynccontextmanager
async def read_lock(self) -> AsyncGenerator[None, None]:
await self.acquire_read()
"""Context manager for acquiring and releasing a read lock safely."""
acquire = False
try:
await self.acquire_read()
acquire = True
yield
finally:
await self.release_read()
if acquire:
with trio.CancelScope() as scope:
scope.shield = True
await self.release_read()
@asynccontextmanager
async def write_lock(self) -> AsyncGenerator[None, None]:
await self.acquire_write()
"""Context manager for acquiring and releasing a write lock safely."""
acquire = False
try:
await self.acquire_write()
acquire = True
yield
finally:
self.release_write()
if acquire:
self.release_write()
class MplexStream(IMuxedStream):
@ -168,9 +182,7 @@ class MplexStream(IMuxedStream):
:param n: number of bytes to read
:return: bytes actually read
"""
await self.rw_lock.acquire_read()
payload: bytes = b""
try:
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
"the number of bytes to read n must be non-negative or "
@ -210,12 +222,9 @@ class MplexStream(IMuxedStream):
"This should never happen."
) from error
self._buf.extend(self._read_return_when_blocked())
chunk = self._buf[:n]
self._buf = self._buf[len(chunk) :]
payload = bytes(chunk)
finally:
await self.rw_lock.release_read()
return payload
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def write(self, data: bytes) -> None:
"""
@ -223,8 +232,7 @@ class MplexStream(IMuxedStream):
:return: number of bytes written
"""
await self.rw_lock.acquire_write()
try:
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
flag = (
@ -233,8 +241,6 @@ class MplexStream(IMuxedStream):
else HeaderTags.MessageReceiver
)
await self.muxed_conn.send_message(flag, data, self.stream_id)
finally:
self.rw_lock.release_write()
async def close(self) -> None:
"""