diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index e6b98244..91f872d4 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,3 +1,5 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from types import ( TracebackType, ) @@ -36,13 +38,21 @@ class ReadWriteLock: def __init__(self) -> None: self._readers = 0 self._readers_lock = trio.Lock() # Protects readers count - self._writer_lock = trio.Semaphore(1) # Acts like a task-transferable lock + self._writer_lock = trio.Semaphore(1) # Ensures mutual exclusion for writers async def acquire_read(self) -> None: - async with self._readers_lock: - self._readers += 1 - if self._readers == 1: - await self._writer_lock.acquire() + try: + async with self._readers_lock: + self._readers += 1 + if self._readers == 1: + await self._writer_lock.acquire() + except trio.Cancelled: + async with self._readers_lock: + if self._readers > 0: + self._readers -= 1 + if self._readers == 0: + self._writer_lock.release() + raise async def release_read(self) -> None: async with self._readers_lock: @@ -51,11 +61,30 @@ class ReadWriteLock: self._writer_lock.release() async def acquire_write(self) -> None: - await self._writer_lock.acquire() + try: + await self._writer_lock.acquire() + except trio.Cancelled: + raise def release_write(self) -> None: self._writer_lock.release() + @asynccontextmanager + async def read_lock(self) -> AsyncGenerator[None, None]: + await self.acquire_read() + try: + yield + finally: + await self.release_read() + + @asynccontextmanager + async def write_lock(self) -> AsyncGenerator[None, None]: + await self.acquire_write() + try: + yield + finally: + self.release_write() + class MplexStream(IMuxedStream): """