Fix close behavior

This commit is contained in:
mhchia
2019-09-09 15:45:35 +08:00
parent b2146c5268
commit be2c0f122a
8 changed files with 149 additions and 21 deletions

View File

@ -0,0 +1,17 @@
from libp2p.exceptions import BaseLibp2pError
class StreamError(BaseLibp2pError):
pass
class StreamEOF(StreamError, EOFError):
pass
class StreamReset(StreamError):
pass
class StreamClosed(StreamError):
pass

View File

@ -1,9 +1,18 @@
from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream
from libp2p.stream_muxer.exceptions import (
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
)
from libp2p.typing import TProtocol
from .exceptions import StreamClosed, StreamEOF, StreamReset
from .net_stream_interface import INetStream
# TODO: Handle exceptions from `muxed_stream`
# TODO: Add stream state
# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501
class NetStream(INetStream):
muxed_stream: IMuxedStream
@ -35,14 +44,22 @@ class NetStream(INetStream):
:param n: number of bytes to read
:return: bytes of input
"""
return await self.muxed_stream.read(n)
try:
return await self.muxed_stream.read(n)
except MuxedStreamEOF as error:
raise StreamEOF from error
except MuxedStreamReset as error:
raise StreamReset from error
async def write(self, data: bytes) -> int:
"""
write to stream
:return: number of bytes written
"""
return await self.muxed_stream.write(data)
try:
return await self.muxed_stream.write(data)
except MuxedStreamClosed as error:
raise StreamClosed from error
async def close(self) -> None:
"""
@ -51,5 +68,5 @@ class NetStream(INetStream):
"""
await self.muxed_stream.close()
async def reset(self) -> bool:
return await self.muxed_stream.reset()
async def reset(self) -> None:
await self.muxed_stream.reset()

View File

@ -23,7 +23,7 @@ class INetStream(ReadWriteCloser):
"""
@abstractmethod
async def reset(self) -> bool:
async def reset(self) -> None:
"""
Close both ends of the stream.
"""

View File

@ -0,0 +1,25 @@
from libp2p.exceptions import BaseLibp2pError
class MuxedConnError(BaseLibp2pError):
pass
class MuxedConnShutdown(MuxedConnError):
pass
class MuxedStreamError(BaseLibp2pError):
pass
class MuxedStreamReset(MuxedStreamError):
pass
class MuxedStreamEOF(MuxedStreamError, EOFError):
pass
class MuxedStreamClosed(MuxedStreamError):
pass

View File

@ -1,17 +1,27 @@
from libp2p.exceptions import BaseLibp2pError
from libp2p.stream_muxer.exceptions import (
MuxedConnError,
MuxedConnShutdown,
MuxedStreamClosed,
MuxedStreamEOF,
MuxedStreamReset,
)
class MplexError(BaseLibp2pError):
class MplexError(MuxedConnError):
pass
class MplexStreamReset(MplexError):
class MplexShutdown(MuxedConnShutdown):
pass
class MplexStreamEOF(MplexError, EOFError):
class MplexStreamReset(MuxedStreamReset):
pass
class MplexShutdown(MplexError):
class MplexStreamEOF(MuxedStreamEOF):
pass
class MplexStreamClosed(MuxedStreamClosed):
pass

View File

@ -188,6 +188,10 @@ class Mplex(IMuxedConn):
# before. It is abnormal. Possibly disconnect?
# TODO: Warn and emit logs about this.
continue
async with stream.close_lock:
if stream.event_remote_closed.is_set():
# TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501
continue
await stream.incoming_data.put(message)
elif flag in (
HeaderTags.CloseInitiator.value,

View File

@ -1,11 +1,11 @@
import asyncio
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from libp2p.stream_muxer.abc import IMuxedStream
from .constants import HeaderTags
from .datastructures import StreamID
from .exceptions import MplexStreamEOF, MplexStreamReset
from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset
if TYPE_CHECKING:
from libp2p.stream_muxer.mplex.mplex import Mplex
@ -58,20 +58,24 @@ class MplexStream(IMuxedStream):
done, pending = await asyncio.wait( # type: ignore
[
self.event_reset.wait(),
self.event_remote_closed.wait(),
self.incoming_data.get(),
self.event_remote_closed.wait(),
],
return_when=asyncio.FIRST_COMPLETED,
)
for fut in pending:
fut.cancel()
if self.event_reset.is_set():
raise MplexStreamReset
done_task = tuple(done)[0]
if done_task._coro.__qualname__ == "Queue.get":
data = done_task.result()
self._buf.extend(data)
return
if self.event_remote_closed.is_set():
raise MplexStreamEOF
# TODO: Handle timeout when deadline is used.
data = tuple(done)[0].result()
self._buf.extend(data)
async def _read_until_eof(self) -> bytes:
while True:
try:
@ -99,13 +103,15 @@ class MplexStream(IMuxedStream):
raise MplexStreamReset
if n == -1:
return await self._read_until_eof()
if len(self._buf) == 0:
if len(self._buf) == 0 and self.incoming_data.empty():
await self._wait_for_data()
# Read up to `n` bytes.
# Either `buf` is not empty or `incoming_data` is not empty now.
# Try to put enough incoming data into `self._buf`.
while len(self._buf) < n:
if self.incoming_data.empty() or self.event_remote_closed.is_set():
try:
self._buf.extend(self.incoming_data.get_nowait())
except asyncio.QueueEmpty:
break
self._buf.extend(await self.incoming_data.get())
payload = self._buf[:n]
self._buf = self._buf[len(payload) :]
return bytes(payload)
@ -115,6 +121,8 @@ class MplexStream(IMuxedStream):
write to stream
:return: number of bytes written
"""
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data}")
flag = (
HeaderTags.MessageInitiator
if self.is_initiator