mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Fix close behavior
This commit is contained in:
@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from libp2p.stream_muxer.abc import IMuxedStream
|
||||
|
||||
from .constants import HeaderTags
|
||||
from .datastructures import StreamID
|
||||
from .exceptions import MplexStreamEOF, MplexStreamReset
|
||||
from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.stream_muxer.mplex.mplex import Mplex
|
||||
@ -58,20 +58,24 @@ class MplexStream(IMuxedStream):
|
||||
done, pending = await asyncio.wait( # type: ignore
|
||||
[
|
||||
self.event_reset.wait(),
|
||||
self.event_remote_closed.wait(),
|
||||
self.incoming_data.get(),
|
||||
self.event_remote_closed.wait(),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for fut in pending:
|
||||
fut.cancel()
|
||||
if self.event_reset.is_set():
|
||||
raise MplexStreamReset
|
||||
done_task = tuple(done)[0]
|
||||
if done_task._coro.__qualname__ == "Queue.get":
|
||||
data = done_task.result()
|
||||
self._buf.extend(data)
|
||||
return
|
||||
if self.event_remote_closed.is_set():
|
||||
raise MplexStreamEOF
|
||||
# TODO: Handle timeout when deadline is used.
|
||||
|
||||
data = tuple(done)[0].result()
|
||||
self._buf.extend(data)
|
||||
|
||||
async def _read_until_eof(self) -> bytes:
|
||||
while True:
|
||||
try:
|
||||
@ -99,13 +103,15 @@ class MplexStream(IMuxedStream):
|
||||
raise MplexStreamReset
|
||||
if n == -1:
|
||||
return await self._read_until_eof()
|
||||
if len(self._buf) == 0:
|
||||
if len(self._buf) == 0 and self.incoming_data.empty():
|
||||
await self._wait_for_data()
|
||||
# Read up to `n` bytes.
|
||||
# Either `buf` is not empty or `incoming_data` is not empty now.
|
||||
# Try to put enough incoming data into `self._buf`.
|
||||
while len(self._buf) < n:
|
||||
if self.incoming_data.empty() or self.event_remote_closed.is_set():
|
||||
try:
|
||||
self._buf.extend(self.incoming_data.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
self._buf.extend(await self.incoming_data.get())
|
||||
payload = self._buf[:n]
|
||||
self._buf = self._buf[len(payload) :]
|
||||
return bytes(payload)
|
||||
@ -115,6 +121,8 @@ class MplexStream(IMuxedStream):
|
||||
write to stream
|
||||
:return: number of bytes written
|
||||
"""
|
||||
if self.event_local_closed.is_set():
|
||||
raise MplexStreamClosed(f"cannot write to closed stream: data={data}")
|
||||
flag = (
|
||||
HeaderTags.MessageInitiator
|
||||
if self.is_initiator
|
||||
|
||||
Reference in New Issue
Block a user