mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-08 14:10:54 +00:00
make readwrite more safe
This commit is contained in:
@ -35,55 +35,69 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Ensures mutual exclusion for writers
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
self._readers += 1
|
||||
if self._readers == 1:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
async with self._readers_lock:
|
||||
if self._readers > 0:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._writer_lock.release()
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
await self.acquire_read()
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
await self.release_read()
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
await self.acquire_write()
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
self.release_write()
|
||||
if acquire:
|
||||
self.release_write()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
@ -168,9 +182,7 @@ class MplexStream(IMuxedStream):
|
||||
:param n: number of bytes to read
|
||||
:return: bytes actually read
|
||||
"""
|
||||
await self.rw_lock.acquire_read()
|
||||
payload: bytes = b""
|
||||
try:
|
||||
async with self.rw_lock.read_lock():
|
||||
if n is not None and n < 0:
|
||||
raise ValueError(
|
||||
"the number of bytes to read n must be non-negative or "
|
||||
@ -210,12 +222,9 @@ class MplexStream(IMuxedStream):
|
||||
"This should never happen."
|
||||
) from error
|
||||
self._buf.extend(self._read_return_when_blocked())
|
||||
chunk = self._buf[:n]
|
||||
self._buf = self._buf[len(chunk) :]
|
||||
payload = bytes(chunk)
|
||||
finally:
|
||||
await self.rw_lock.release_read()
|
||||
return payload
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
@ -223,8 +232,7 @@ class MplexStream(IMuxedStream):
|
||||
|
||||
:return: number of bytes written
|
||||
"""
|
||||
await self.rw_lock.acquire_write()
|
||||
try:
|
||||
async with self.rw_lock.write_lock():
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
|
||||
flag = (
|
||||
@ -233,8 +241,6 @@ class MplexStream(IMuxedStream):
|
||||
else HeaderTags.MessageReceiver
|
||||
)
|
||||
await self.muxed_conn.send_message(flag, data, self.stream_id)
|
||||
finally:
|
||||
self.rw_lock.release_write()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user