diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 3b640df1..8f45495f 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -31,6 +31,29 @@ if TYPE_CHECKING: Mplex, ) +class ReadWriteLock: + def __init__(self): + self._readers = 0 + self._lock = trio.Lock() # Protects _readers + self._write_lock = trio.Lock() + + async def acquire_read(self): + async with self._lock: + self._readers += 1 + if self._readers == 1: + await self._write_lock.acquire() + + async def release_read(self): + async with self._lock: + self._readers -= 1 + if self._readers == 0: + self._write_lock.release() + + async def acquire_write(self): + await self._write_lock.acquire() + + def release_write(self): + self._write_lock.release() class MplexStream(IMuxedStream): """ @@ -39,17 +62,17 @@ class MplexStream(IMuxedStream): name: str stream_id: StreamID - # NOTE: All methods used here are part of `Mplex` which is a derived + # NOTE: All methods used here are part of Mplex which is a derived # class of IMuxedConn. Ignoring this type assignment should not pose # any risk. muxed_conn: "Mplex" # type: ignore[assignment] read_deadline: int | None write_deadline: int | None - # TODO: Add lock for read/write to avoid interleaving receiving messages? + rw_lock: ReadWriteLock close_lock: trio.Lock - # NOTE: `dataIn` is size of 8 in Go implementation. + # NOTE: dataIn is size of 8 in Go implementation. incoming_data_channel: "trio.MemoryReceiveChannel[bytes]" event_local_closed: trio.Event @@ -80,6 +103,7 @@ class MplexStream(IMuxedStream): self.event_remote_closed = trio.Event() self.event_reset = trio.Event() self.close_lock = trio.Lock() + self.rw_lock = ReadWriteLock() self.incoming_data_channel = incoming_data_channel self._buf = bytearray() @@ -106,55 +130,59 @@ class MplexStream(IMuxedStream): async def read(self, n: int | None = None) -> bytes: """ - Read up to n bytes. Read possibly returns fewer than `n` bytes, if - there are not enough bytes in the Mplex buffer. If `n is None`, read + Read up to n bytes. Read possibly returns fewer than n bytes, if + there are not enough bytes in the Mplex buffer. If n is None, read until EOF. :param n: number of bytes to read :return: bytes actually read """ - if n is not None and n < 0: - raise ValueError( - "the number of bytes to read `n` must be non-negative or " - f"`None` to indicate read until EOF, got n={n}" - ) - if self.event_reset.is_set(): - raise MplexStreamReset - if n is None: - return await self._read_until_eof() - if len(self._buf) == 0: - data: bytes - # Peek whether there is data available. If yes, we just read until there is - # no data, then return. - try: - data = self.incoming_data_channel.receive_nowait() - self._buf.extend(data) - except trio.EndOfChannel: - raise MplexStreamEOF - except trio.WouldBlock: - # We know `receive` will be blocked here. Wait for data here with - # `receive` and catch all kinds of errors here. + await self.rw_lock.acquire_read() + try: + if n is not None and n < 0: + raise ValueError( + "the number of bytes to read n must be non-negative or " + f"None to indicate read until EOF, got n={n}" + ) + if self.event_reset.is_set(): + raise MplexStreamReset + if n is None: + return await self._read_until_eof() + if len(self._buf) == 0: + data: bytes + # Peek whether there is data available. If yes, we just read until there is + # no data, then return. try: - data = await self.incoming_data_channel.receive() + data = self.incoming_data_channel.receive_nowait() self._buf.extend(data) except trio.EndOfChannel: - if self.event_reset.is_set(): - raise MplexStreamReset - if self.event_remote_closed.is_set(): - raise MplexStreamEOF - except trio.ClosedResourceError as error: - # Probably `incoming_data_channel` is closed in `reset` when we are - # waiting for `receive`. - if self.event_reset.is_set(): - raise MplexStreamReset - raise Exception( - "`incoming_data_channel` is closed but stream is not reset. " - "This should never happen." - ) from error - self._buf.extend(self._read_return_when_blocked()) - payload = self._buf[:n] - self._buf = self._buf[len(payload) :] - return bytes(payload) + raise MplexStreamEOF + except trio.WouldBlock: + # We know receive will be blocked here. Wait for data here with + # receive and catch all kinds of errors here. + try: + data = await self.incoming_data_channel.receive() + self._buf.extend(data) + except trio.EndOfChannel: + if self.event_reset.is_set(): + raise MplexStreamReset + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + except trio.ClosedResourceError as error: + # Probably incoming_data_channel is closed in reset when we are + # waiting for receive. + if self.event_reset.is_set(): + raise MplexStreamReset + raise Exception( + "incoming_data_channel is closed but stream is not reset. " + "This should never happen." + ) from error + self._buf.extend(self._read_return_when_blocked()) + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] + return bytes(payload) + finally: + await self.rw_lock.release_read() async def write(self, data: bytes) -> None: """ @@ -162,14 +190,18 @@ class MplexStream(IMuxedStream): :return: number of bytes written """ - if self.event_local_closed.is_set(): - raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") - flag = ( - HeaderTags.MessageInitiator - if self.is_initiator - else HeaderTags.MessageReceiver - ) - await self.muxed_conn.send_message(flag, data, self.stream_id) + await self.rw_lock.acquire_write() + try: + if self.event_local_closed.is_set(): + raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") + flag = ( + HeaderTags.MessageInitiator + if self.is_initiator + else HeaderTags.MessageReceiver + ) + await self.muxed_conn.send_message(flag, data, self.stream_id) + finally: + self.rw_lock.release_write() async def close(self) -> None: """ @@ -185,7 +217,7 @@ class MplexStream(IMuxedStream): flag = ( HeaderTags.CloseInitiator if self.is_initiator else HeaderTags.CloseReceiver ) - # TODO: Raise when `muxed_conn.send_message` fails and `Mplex` isn't shutdown. + # TODO: Raise when muxed_conn.send_message fails and Mplex isn't shutdown. await self.muxed_conn.send_message(flag, None, self.stream_id) _is_remote_closed: bool