From 5653b3f6049524d8f7963c10771edaa97c41e9ca Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 12 Sep 2019 17:07:41 +0800 Subject: [PATCH] Add "closed" and "shutting_down" events --- libp2p/stream_muxer/exceptions.py | 6 +++- libp2p/stream_muxer/mplex/exceptions.py | 9 ++++-- libp2p/stream_muxer/mplex/mplex.py | 39 +++++++++++++++++++++---- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py index 861319a4..8db5cdc2 100644 --- a/libp2p/stream_muxer/exceptions.py +++ b/libp2p/stream_muxer/exceptions.py @@ -5,7 +5,11 @@ class MuxedConnError(BaseLibp2pError): pass -class MuxedConnShutdown(MuxedConnError): +class MuxedConnShuttingDown(MuxedConnError): + pass + + +class MuxedConnClosed(MuxedConnError): pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 154c3719..6ff6cf20 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,6 +1,7 @@ from libp2p.stream_muxer.exceptions import ( MuxedConnError, - MuxedConnShutdown, + MuxedConnShuttingDown, + MuxedConnClosed, MuxedStreamClosed, MuxedStreamEOF, MuxedStreamReset, @@ -11,7 +12,11 @@ class MplexError(MuxedConnError): pass -class MplexShutdown(MuxedConnShutdown): +class MplexShuttingDown(MuxedConnShuttingDown): + pass + + +class MplexClosed(MuxedConnClosed): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index afbf288b..13de4942 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -14,6 +14,7 @@ from libp2p.utils import ( ) from .constants import HeaderTags +from .exceptions import MplexClosed, MplexShuttingDown from .datastructures import StreamID from .mplex_stream import MplexStream @@ -34,7 +35,8 @@ class Mplex(IMuxedConn): streams: Dict[StreamID, MplexStream] streams_lock: asyncio.Lock new_stream_queue: "asyncio.Queue[IMuxedStream]" - shutdown: asyncio.Event + event_shutting_down: asyncio.Event + event_closed: asyncio.Event _tasks: List["asyncio.Future[Any]"] @@ -58,7 +60,8 @@ class Mplex(IMuxedConn): self.streams = {} self.streams_lock = asyncio.Lock() self.new_stream_queue = asyncio.Queue() - self.shutdown = asyncio.Event() + self.event_shutting_down = asyncio.Event() + self.event_closed = asyncio.Event() self._tasks = [] @@ -73,16 +76,20 @@ class Mplex(IMuxedConn): """ close the stream muxer and underlying secured connection """ - for task in self._tasks: - task.cancel() + # for task in self._tasks: + # task.cancel() await self.secured_conn.close() + # Set the `event_shutting_down`, to allow graceful shutdown. + self.event_shutting_down.set() + # Blocked until `close` is finally set. + # await self.event_closed.wait() def is_closed(self) -> bool: """ check connection is fully closed :return: true if successful """ - raise NotImplementedError() + return self.event_closed.is_set() def _get_next_channel_id(self) -> int: """ @@ -112,11 +119,31 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream + async def _wait_until_closed(self, coro) -> Any: + task_coro = asyncio.ensure_future(coro) + task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) + done, pending = await asyncio.wait( + [task_coro, task_wait_closed], return_when=asyncio.FIRST_COMPLETED + ) + if task_wait_closed in done: + raise MplexClosed + return task_coro.result() + + async def _wait_until_shutting_down(self, coro) -> Any: + task_coro = asyncio.ensure_future(coro) + task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) + done, pending = await asyncio.wait( + [task_coro, task_wait_shutting_down], return_when=asyncio.FIRST_COMPLETED + ) + if task_wait_shutting_down in done: + raise MplexShuttingDown + return task_coro.result() + async def accept_stream(self) -> IMuxedStream: """ accepts a muxed stream opened by the other end """ - return await self.new_stream_queue.get() + return await self._wait_until_closed(self.new_stream_queue.get()) async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID