Fix MplexStream.read

This commit is contained in:
mhchia
2019-09-06 17:26:40 +08:00
parent 95926b7376
commit 649a230776
8 changed files with 116 additions and 143 deletions

View File

@ -55,7 +55,6 @@ class MplexStream(IMuxedStream):
return self.stream_id.is_initiator
async def _wait_for_data(self) -> None:
print("!@# _wait_for_data: 0")
done, pending = await asyncio.wait(
[
self.event_reset.wait(),
@ -64,16 +63,25 @@ class MplexStream(IMuxedStream):
],
return_when=asyncio.FIRST_COMPLETED,
)
print("!@# _wait_for_data: 1")
if self.event_reset.is_set():
raise MplexStreamReset
if self.event_remote_closed.is_set():
while not self.incoming_data.empty():
self._buf.extend(await self.incoming_data.get())
raise MplexStreamEOF
# TODO: Handle timeout when deadline is used.
data = tuple(done)[0].result()
self._buf.extend(data)
async def _read_until_eof(self) -> bytes:
while True:
try:
await self._wait_for_data()
except MplexStreamEOF:
break
payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)
async def read(self, n: int = -1) -> bytes:
"""
Read up to n bytes. Read possibly returns fewer than `n` bytes,
@ -87,22 +95,18 @@ class MplexStream(IMuxedStream):
raise ValueError(
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF"
)
# FIXME: If `n == -1`, we should blocking read until EOF, instead of returning when
# no message is available.
# If `n >= 0`, read up to `n` bytes.
# Else, read until no message is available.
while len(self._buf) < n or n == -1:
# new_bytes = await self.incoming_data.get()
try:
await self._wait_for_data()
except MplexStreamEOF:
break
payload: bytearray
if self.event_reset.is_set():
raise MplexStreamReset
if n == -1:
payload = self._buf
else:
payload = self._buf[:n]
return await self._read_until_eof()
if len(self._buf) == 0:
await self._wait_for_data()
# Read up to `n` bytes.
while len(self._buf) < n:
if self.incoming_data.empty() or self.event_remote_closed.is_set():
break
self._buf.extend(await self.incoming_data.get())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)