diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index ecddedd9..00df44c2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,5 +1,5 @@ import asyncio -from typing import Dict, Tuple, Optional +from typing import Dict, Optional, Tuple from multiaddr import Multiaddr diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index afc49ee6..7b8c1e53 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -39,28 +39,39 @@ class MplexStream(IMuxedStream): self.stream_lock = asyncio.Lock() self._buf = b"" - async def read(self, n) -> bytes: + async def read(self, n: int = -1) -> bytes: """ Read up to n bytes. Read possibly returns fewer than `n` bytes, if there are not enough bytes in the Mplex buffer. + If `n == -1`, read until EOF. :param n: number of bytes to read :return: bytes actually read """ + if n < 0 and n != -1: + raise ValueError("`n` can only be -1 if it is negative") # If the buffer is empty at first, blocking wait for data. if len(self._buf) == 0: self._buf = await self.mplex_conn.read_buffer(self.stream_id) - # Here, `self._buf` should never be `None`. + # Sanity check: `self._buf` should never be empty here. if self._buf is None or len(self._buf) == 0: - raise Exception("start to `read_buffer_nonblocking` only when there are bytes read.") + raise Exception("`self._buf` should never be empty here") - while len(self._buf) < n: + # FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when + # no message is available. + # If `n >= 0`, read up to `n` bytes. + # Else, read until no message is available. + while len(self._buf) < n or n == -1: new_bytes = await self.mplex_conn.read_buffer_nonblocking(self.stream_id) if new_bytes is None: # Nothing to read in the `MplexConn` buffer break self._buf += new_bytes - payload = self._buf[:n] - self._buf = self._buf[n:] + payload: bytes + if n == -1: + payload = self._buf + else: + payload = self._buf[:n] + self._buf = self._buf[len(payload) :] return payload async def write(self, data: bytes) -> int: