From 5653b3f6049524d8f7963c10771edaa97c41e9ca Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 12 Sep 2019 17:07:41 +0800 Subject: [PATCH 01/12] Add "closed" and "shutting_down" events --- libp2p/stream_muxer/exceptions.py | 6 +++- libp2p/stream_muxer/mplex/exceptions.py | 9 ++++-- libp2p/stream_muxer/mplex/mplex.py | 39 +++++++++++++++++++++---- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py index 861319a4..8db5cdc2 100644 --- a/libp2p/stream_muxer/exceptions.py +++ b/libp2p/stream_muxer/exceptions.py @@ -5,7 +5,11 @@ class MuxedConnError(BaseLibp2pError): pass -class MuxedConnShutdown(MuxedConnError): +class MuxedConnShuttingDown(MuxedConnError): + pass + + +class MuxedConnClosed(MuxedConnError): pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 154c3719..6ff6cf20 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,6 +1,7 @@ from libp2p.stream_muxer.exceptions import ( MuxedConnError, - MuxedConnShutdown, + MuxedConnShuttingDown, + MuxedConnClosed, MuxedStreamClosed, MuxedStreamEOF, MuxedStreamReset, @@ -11,7 +12,11 @@ class MplexError(MuxedConnError): pass -class MplexShutdown(MuxedConnShutdown): +class MplexShuttingDown(MuxedConnShuttingDown): + pass + + +class MplexClosed(MuxedConnClosed): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index afbf288b..13de4942 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -14,6 +14,7 @@ from libp2p.utils import ( ) from .constants import HeaderTags +from .exceptions import MplexClosed, MplexShuttingDown from .datastructures import StreamID from .mplex_stream import MplexStream @@ -34,7 +35,8 @@ class Mplex(IMuxedConn): streams: Dict[StreamID, MplexStream] streams_lock: asyncio.Lock new_stream_queue: "asyncio.Queue[IMuxedStream]" - shutdown: asyncio.Event + event_shutting_down: asyncio.Event + event_closed: asyncio.Event _tasks: List["asyncio.Future[Any]"] @@ -58,7 +60,8 @@ class Mplex(IMuxedConn): self.streams = {} self.streams_lock = asyncio.Lock() self.new_stream_queue = asyncio.Queue() - self.shutdown = asyncio.Event() + self.event_shutting_down = asyncio.Event() + self.event_closed = asyncio.Event() self._tasks = [] @@ -73,16 +76,20 @@ class Mplex(IMuxedConn): """ close the stream muxer and underlying secured connection """ - for task in self._tasks: - task.cancel() + # for task in self._tasks: + # task.cancel() await self.secured_conn.close() + # Set the `event_shutting_down`, to allow graceful shutdown. + self.event_shutting_down.set() + # Blocked until `close` is finally set. + # await self.event_closed.wait() def is_closed(self) -> bool: """ check connection is fully closed :return: true if successful """ - raise NotImplementedError() + return self.event_closed.is_set() def _get_next_channel_id(self) -> int: """ @@ -112,11 +119,31 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream + async def _wait_until_closed(self, coro) -> Any: + task_coro = asyncio.ensure_future(coro) + task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) + done, pending = await asyncio.wait( + [task_coro, task_wait_closed], return_when=asyncio.FIRST_COMPLETED + ) + if task_wait_closed in done: + raise MplexClosed + return task_coro.result() + + async def _wait_until_shutting_down(self, coro) -> Any: + task_coro = asyncio.ensure_future(coro) + task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) + done, pending = await asyncio.wait( + [task_coro, task_wait_shutting_down], return_when=asyncio.FIRST_COMPLETED + ) + if task_wait_shutting_down in done: + raise MplexShuttingDown + return task_coro.result() + async def accept_stream(self) -> IMuxedStream: """ accepts a muxed stream opened by the other end """ - return await self.new_stream_queue.get() + return await self._wait_until_closed(self.new_stream_queue.get()) async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID From 393b51a744636e573fc6a63b53350a2288672dd8 Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 12 Sep 2019 17:09:11 +0800 Subject: [PATCH 02/12] isort --- libp2p/stream_muxer/mplex/exceptions.py | 2 +- libp2p/stream_muxer/mplex/mplex.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 6ff6cf20..f42c561b 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,7 +1,7 @@ from libp2p.stream_muxer.exceptions import ( + MuxedConnClosed, MuxedConnError, MuxedConnShuttingDown, - MuxedConnClosed, MuxedStreamClosed, MuxedStreamEOF, MuxedStreamReset, diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 13de4942..589a623f 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -14,8 +14,8 @@ from libp2p.utils import ( ) from .constants import HeaderTags -from .exceptions import MplexClosed, MplexShuttingDown from .datastructures import StreamID +from .exceptions import MplexClosed, MplexShuttingDown from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") From 2d8e02b7eb6d22eaaf57fda04a7aa8e5e4587b18 Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 13 Sep 2019 15:29:24 +0800 Subject: [PATCH 03/12] Add detection for disconnections in mplex --- libp2p/network/connection/swarm_connection.py | 24 +++++- libp2p/network/swarm.py | 8 +- libp2p/stream_muxer/exceptions.py | 6 +- libp2p/stream_muxer/mplex/exceptions.py | 9 +-- libp2p/stream_muxer/mplex/mplex.py | 78 +++++++++++++------ libp2p/utils.py | 10 +-- tests/interop/test_bindings.py | 5 +- 7 files changed, 89 insertions(+), 51 deletions(-) diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index b72fd256..50d09e79 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.exceptions import MuxedConnUnavailable if TYPE_CHECKING: from libp2p.network.swarm import Swarm # noqa: F401 @@ -34,17 +35,27 @@ class SwarmConn(INetConn): if self.event_closed.is_set(): return self.event_closed.set() + self.swarm.remove_conn(self) + await self.conn.close() + + # This is just for cleaning up state. The connection has already been closed. + # We *could* optimize this but it really isn't worth it. + for stream in self.streams: + await stream.reset() + # Schedule `self._notify_disconnected` to make it execute after `close` is finished. + asyncio.ensure_future(self._notify_disconnected()) + for task in self._tasks: task.cancel() - # TODO: Reset streams for local. - # TODO: Notify closed. - async def _handle_new_streams(self) -> None: # TODO: Break the loop when anything wrong in the connection. while True: - stream = await self.conn.accept_stream() + try: + stream = await self.conn.accept_stream() + except MuxedConnUnavailable: + break # Asynchronously handle the accepted stream, to avoid blocking the next stream. await self.run_task(self._handle_muxed_stream(stream)) @@ -57,11 +68,16 @@ class SwarmConn(INetConn): async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) + self.streams.add(net_stream) # Call notifiers since event occurred for notifee in self.swarm.notifees: await notifee.opened_stream(self.swarm, net_stream) return net_stream + async def _notify_disconnected(self) -> None: + for notifee in self.swarm.notifees: + await notifee.disconnected(self.swarm, self.conn) + async def start(self) -> None: await self.run_task(self._handle_new_streams()) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index dd4ca6e9..5bbbe0a0 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -262,7 +262,6 @@ class Swarm(INetwork): if peer_id not in self.connections: return connection = self.connections[peer_id] - del self.connections[peer_id] await connection.close() logger.debug("successfully close the connection to peer %s", peer_id) @@ -277,3 +276,10 @@ class Swarm(INetwork): await notifee.connected(self, muxed_conn) await swarm_conn.start() return swarm_conn + + def remove_conn(self, swarm_conn: SwarmConn) -> None: + print(f"!@# remove_conn: {swarm_conn}") + peer_id = swarm_conn.conn.peer_id + # TODO: Should be changed to remove the exact connection, + # if we have several connections per peer in the future. + del self.connections[peer_id] diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py index 8db5cdc2..ce0f92e3 100644 --- a/libp2p/stream_muxer/exceptions.py +++ b/libp2p/stream_muxer/exceptions.py @@ -5,11 +5,7 @@ class MuxedConnError(BaseLibp2pError): pass -class MuxedConnShuttingDown(MuxedConnError): - pass - - -class MuxedConnClosed(MuxedConnError): +class MuxedConnUnavailable(MuxedConnError): pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index f42c561b..a7be76ee 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,7 +1,6 @@ from libp2p.stream_muxer.exceptions import ( - MuxedConnClosed, MuxedConnError, - MuxedConnShuttingDown, + MuxedConnUnavailable, MuxedStreamClosed, MuxedStreamEOF, MuxedStreamReset, @@ -12,11 +11,7 @@ class MplexError(MuxedConnError): pass -class MplexShuttingDown(MuxedConnShuttingDown): - pass - - -class MplexClosed(MuxedConnClosed): +class MplexUnavailable(MuxedConnUnavailable): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 589a623f..7f822923 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,6 +1,6 @@ import asyncio from typing import Any # noqa: F401 -from typing import Dict, List, Optional, Tuple +from typing import Awaitable, Dict, List, Optional, Tuple from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn @@ -15,7 +15,7 @@ from libp2p.utils import ( from .constants import HeaderTags from .datastructures import StreamID -from .exceptions import MplexClosed, MplexShuttingDown +from .exceptions import MplexUnavailable from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") @@ -76,13 +76,13 @@ class Mplex(IMuxedConn): """ close the stream muxer and underlying secured connection """ - # for task in self._tasks: - # task.cancel() - await self.secured_conn.close() + if self.event_shutting_down.is_set(): + return # Set the `event_shutting_down`, to allow graceful shutdown. self.event_shutting_down.set() + await self.secured_conn.close() # Blocked until `close` is finally set. - # await self.event_closed.wait() + await self.event_closed.wait() def is_closed(self) -> bool: """ @@ -119,31 +119,29 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream - async def _wait_until_closed(self, coro) -> Any: + async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any: task_coro = asyncio.ensure_future(coro) task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) - done, pending = await asyncio.wait( - [task_coro, task_wait_closed], return_when=asyncio.FIRST_COMPLETED - ) - if task_wait_closed in done: - raise MplexClosed - return task_coro.result() - - async def _wait_until_shutting_down(self, coro) -> Any: - task_coro = asyncio.ensure_future(coro) task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) done, pending = await asyncio.wait( - [task_coro, task_wait_shutting_down], return_when=asyncio.FIRST_COMPLETED + [task_coro, task_wait_closed, task_wait_shutting_down], + return_when=asyncio.FIRST_COMPLETED, ) + for fut in pending: + fut.cancel() + if task_wait_closed in done: + raise MplexUnavailable("Mplex is closed") if task_wait_shutting_down in done: - raise MplexShuttingDown + raise MplexUnavailable("Mplex is shutting down") return task_coro.result() async def accept_stream(self) -> IMuxedStream: """ accepts a muxed stream opened by the other end """ - return await self._wait_until_closed(self.new_stream_queue.get()) + return await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.get() + ) async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -162,7 +160,9 @@ class Mplex(IMuxedConn): _bytes = header + encode_varint_prefixed(data) - return await self.write_to_stream(_bytes) + return await self._wait_until_shutting_down_or_closed( + self.write_to_stream(_bytes) + ) async def write_to_stream(self, _bytes: bytes) -> int: """ @@ -180,7 +180,13 @@ class Mplex(IMuxedConn): # TODO Deal with other types of messages using flag (currently _) while True: - channel_id, flag, message = await self.read_message() + try: + channel_id, flag, message = await self._wait_until_shutting_down_or_closed( + self.read_message() + ) + except (MplexUnavailable, ConnectionResetError) as error: + print(f"!@# handle_incoming: read_message: exception={error}") + break if channel_id is not None and flag is not None and message is not None: stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) is_stream_id_seen: bool @@ -199,8 +205,12 @@ class Mplex(IMuxedConn): mplex_stream = await self._initialize_stream( stream_id, message.decode() ) - # TODO: Check if `self` is shutdown. - await self.new_stream_queue.put(mplex_stream) + try: + await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.put(mplex_stream) + ) + except MplexUnavailable: + break elif flag in ( HeaderTags.MessageInitiator.value, HeaderTags.MessageReceiver.value, @@ -214,7 +224,12 @@ class Mplex(IMuxedConn): if stream.event_remote_closed.is_set(): # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 continue - await stream.incoming_data.put(message) + try: + await self._wait_until_shutting_down_or_closed( + stream.incoming_data.put(message) + ) + except MplexUnavailable: + break elif flag in ( HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value, @@ -244,7 +259,6 @@ class Mplex(IMuxedConn): continue async with stream.close_lock: if not stream.event_remote_closed.is_set(): - # TODO: Why? Only if remote is not closed before then reset. stream.event_reset.set() stream.event_remote_closed.set() @@ -260,6 +274,7 @@ class Mplex(IMuxedConn): # Force context switch await asyncio.sleep(0) + await self._cleanup() async def read_message(self) -> Tuple[int, int, bytes]: """ @@ -284,3 +299,16 @@ class Mplex(IMuxedConn): channel_id = header >> 3 return channel_id, flag, message + + async def _cleanup(self) -> None: + if not self.event_shutting_down.is_set(): + self.event_shutting_down.set() + async with self.streams_lock: + for stream in self.streams.values(): + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_remote_closed.set() + stream.event_reset.set() + stream.event_local_closed.set() + self.streams = None + self.event_closed.set() diff --git a/libp2p/utils.py b/libp2p/utils.py index c69f61b3..0c1eea8d 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -41,14 +41,8 @@ async def decode_uvarint_from_stream(reader: Reader) -> int: if shift > SHIFT_64_BIT_MAX: raise ParseError("TODO: better exception msg: Integer is too large...") - byte = await reader.read(1) - - try: - value = byte[0] - except IndexError: - raise ParseError( - "Unexpected end of stream while parsing LEB128 encoded integer" - ) + byte = await read_exactly(reader, 1) + value = byte[0] res += (value & LOW_MASK) << shift diff --git a/tests/interop/test_bindings.py b/tests/interop/test_bindings.py index 1189e0b7..1e78ff43 100644 --- a/tests/interop/test_bindings.py +++ b/tests/interop/test_bindings.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from .utils import connect @@ -21,4 +23,5 @@ async def test_connect(hosts, p2pds): # Test: `disconnect` from Go await p2pd.control.disconnect(host.get_id()) # FIXME: Failed to handle disconnect - # assert len(host.get_network().connections) == 0 + await asyncio.sleep(0.01) + assert len(host.get_network().connections) == 0 From f62f07bb9f2f7fab542f4494a9b75fa599685a32 Mon Sep 17 00:00:00 2001 From: mhchia Date: Fri, 13 Sep 2019 15:32:10 +0800 Subject: [PATCH 04/12] Handle `IncompleteRead` in `handle_incoming` --- libp2p/stream_muxer/mplex/mplex.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 7f822923..56684058 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -2,6 +2,7 @@ import asyncio from typing import Any # noqa: F401 from typing import Awaitable, Dict, List, Optional, Tuple +from libp2p.io.exceptions import IncompleteReadError from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream @@ -184,7 +185,11 @@ class Mplex(IMuxedConn): channel_id, flag, message = await self._wait_until_shutting_down_or_closed( self.read_message() ) - except (MplexUnavailable, ConnectionResetError) as error: + except ( + MplexUnavailable, + ConnectionResetError, + IncompleteReadError, + ) as error: print(f"!@# handle_incoming: read_message: exception={error}") break if channel_id is not None and flag is not None and message is not None: From b51c2939a82895409f8371b6402e29a1c662b396 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 14 Sep 2019 14:16:40 +0800 Subject: [PATCH 05/12] Handle exceptions inside `read_message` And remove the need of checking `None` for every read messages. --- libp2p/stream_muxer/mplex/mplex.py | 178 ++++++++++++++--------------- 1 file changed, 88 insertions(+), 90 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 56684058..3c7898a2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -178,107 +178,101 @@ class Mplex(IMuxedConn): """ Read a message off of the secured connection and add it to the corresponding message buffer """ - # TODO Deal with other types of messages using flag (currently _) while True: try: channel_id, flag, message = await self._wait_until_shutting_down_or_closed( self.read_message() ) - except ( - MplexUnavailable, - ConnectionResetError, - IncompleteReadError, - ) as error: + except MplexUnavailable as error: print(f"!@# handle_incoming: read_message: exception={error}") break - if channel_id is not None and flag is not None and message is not None: - stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) - is_stream_id_seen: bool - stream: MplexStream - async with self.streams_lock: - is_stream_id_seen = stream_id in self.streams - if is_stream_id_seen: - stream = self.streams[stream_id] - # Other consequent stream message should wait until the stream get accepted - # TODO: Handle more tags, and refactor `HeaderTags` - if flag == HeaderTags.NewStream.value: - if is_stream_id_seen: - # `NewStream` for the same id is received twice... - # TODO: Shutdown - pass - mplex_stream = await self._initialize_stream( - stream_id, message.decode() + stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) + is_stream_id_seen: bool + stream: MplexStream + async with self.streams_lock: + is_stream_id_seen = stream_id in self.streams + if is_stream_id_seen: + stream = self.streams[stream_id] + if flag == HeaderTags.NewStream.value: + if is_stream_id_seen: + # `NewStream` for the same id is received twice... + # TODO: Shutdown + pass + mplex_stream = await self._initialize_stream( + stream_id, message.decode() + ) + try: + await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.put(mplex_stream) ) - try: - await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.put(mplex_stream) - ) - except MplexUnavailable: - break - elif flag in ( - HeaderTags.MessageInitiator.value, - HeaderTags.MessageReceiver.value, - ): - if not is_stream_id_seen: - # We receive a message of the stream `stream_id` which is not accepted - # before. It is abnormal. Possibly disconnect? - # TODO: Warn and emit logs about this. + except MplexUnavailable: + break + elif flag in ( + HeaderTags.MessageInitiator.value, + HeaderTags.MessageReceiver.value, + ): + if not is_stream_id_seen: + # We receive a message of the stream `stream_id` which is not accepted + # before. It is abnormal. Possibly disconnect? + # TODO: Warn and emit logs about this. + continue + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 continue - async with stream.close_lock: - if stream.event_remote_closed.is_set(): - # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 - continue - try: - await self._wait_until_shutting_down_or_closed( - stream.incoming_data.put(message) - ) - except MplexUnavailable: - break - elif flag in ( - HeaderTags.CloseInitiator.value, - HeaderTags.CloseReceiver.value, - ): - if not is_stream_id_seen: + try: + await self._wait_until_shutting_down_or_closed( + stream.incoming_data.put(message) + ) + except MplexUnavailable: + break + elif flag in ( + HeaderTags.CloseInitiator.value, + HeaderTags.CloseReceiver.value, + ): + if not is_stream_id_seen: + continue + # NOTE: If remote is already closed, then return: Technically a bug + # on the other side. We should consider killing the connection. + async with stream.close_lock: + if stream.event_remote_closed.is_set(): continue - # NOTE: If remote is already closed, then return: Technically a bug - # on the other side. We should consider killing the connection. - async with stream.close_lock: - if stream.event_remote_closed.is_set(): - continue - is_local_closed: bool - async with stream.close_lock: - stream.event_remote_closed.set() - is_local_closed = stream.event_local_closed.is_set() - # If local is also closed, both sides are closed. Then, we should clean up - # the entry of this stream, to avoid others from accessing it. - if is_local_closed: - async with self.streams_lock: - del self.streams[stream_id] - elif flag in ( - HeaderTags.ResetInitiator.value, - HeaderTags.ResetReceiver.value, - ): - if not is_stream_id_seen: - # This is *ok*. We forget the stream on reset. - continue - async with stream.close_lock: - if not stream.event_remote_closed.is_set(): - stream.event_reset.set() - - stream.event_remote_closed.set() - # If local is not closed, we should close it. - if not stream.event_local_closed.is_set(): - stream.event_local_closed.set() + is_local_closed: bool + async with stream.close_lock: + stream.event_remote_closed.set() + is_local_closed = stream.event_local_closed.is_set() + # If local is also closed, both sides are closed. Then, we should clean up + # the entry of this stream, to avoid others from accessing it. + if is_local_closed: async with self.streams_lock: del self.streams[stream_id] - else: - # TODO: logging - if is_stream_id_seen: - await stream.reset() + elif flag in ( + HeaderTags.ResetInitiator.value, + HeaderTags.ResetReceiver.value, + ): + if not is_stream_id_seen: + # This is *ok*. We forget the stream on reset. + continue + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_reset.set() + + stream.event_remote_closed.set() + # If local is not closed, we should close it. + if not stream.event_local_closed.is_set(): + stream.event_local_closed.set() + async with self.streams_lock: + del self.streams[stream_id] + else: + # TODO: logging + if is_stream_id_seen: + await stream.reset() # Force context switch await asyncio.sleep(0) + # If we enter here, it means this connection is shutting down. + # We should clean the things up. await self._cleanup() async def read_message(self) -> Tuple[int, int, bytes]: @@ -290,15 +284,19 @@ class Mplex(IMuxedConn): # FIXME: No timeout is used in Go implementation. # Timeout is set to a relatively small value to alleviate wait time to exit # loop in handle_incoming - header = await decode_uvarint_from_stream(self.secured_conn) - # TODO: Handle the case of EOF and other exceptions? try: + header = await decode_uvarint_from_stream(self.secured_conn) message = await asyncio.wait_for( read_varint_prefixed_bytes(self.secured_conn), timeout=5 ) - except asyncio.TimeoutError: - # TODO: Investigate what we should do if time is out. - return None, None, None + except (ConnectionResetError, IncompleteReadError) as error: + raise MplexUnavailable( + "failed to read messages correctly from the underlying connection" + ) from error + except asyncio.TimeoutError as error: + raise MplexUnavailable( + "failed to read more message body within the timeout" + ) from error flag = header & 0x07 channel_id = header >> 3 From 4a689c7d57d436d7e562b3c73b8bc2446fd37b99 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 14 Sep 2019 14:57:43 +0800 Subject: [PATCH 06/12] Fix error when reset If `Mplex` is cleanup first, `MplexStream.reset` possibly fails because `Mplex.streams` is set to `None` in `cleanup`. --- libp2p/stream_muxer/mplex/mplex_stream.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 87b039f9..8cabccc4 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -204,7 +204,11 @@ class MplexStream(IMuxedStream): self.event_remote_closed.set() async with self.mplex_conn.streams_lock: - del self.mplex_conn.streams[self.stream_id] + if ( + self.mplex_conn.streams is not None + and self.stream_id in self.mplex_conn.streams + ): + del self.mplex_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: From 5f064dd3295b674846bdc73fb65711170b6d8d37 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 14 Sep 2019 14:59:07 +0800 Subject: [PATCH 07/12] Refactor: get rid of single huge _handle_incoming --- libp2p/stream_muxer/mplex/mplex.py | 187 +++++++++++++++-------------- 1 file changed, 99 insertions(+), 88 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 3c7898a2..6781fed4 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -181,98 +181,13 @@ class Mplex(IMuxedConn): while True: try: - channel_id, flag, message = await self._wait_until_shutting_down_or_closed( - self.read_message() - ) - except MplexUnavailable as error: - print(f"!@# handle_incoming: read_message: exception={error}") + await self._handle_incoming_message() + except MplexUnavailable: break - stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) - is_stream_id_seen: bool - stream: MplexStream - async with self.streams_lock: - is_stream_id_seen = stream_id in self.streams - if is_stream_id_seen: - stream = self.streams[stream_id] - if flag == HeaderTags.NewStream.value: - if is_stream_id_seen: - # `NewStream` for the same id is received twice... - # TODO: Shutdown - pass - mplex_stream = await self._initialize_stream( - stream_id, message.decode() - ) - try: - await self._wait_until_shutting_down_or_closed( - self.new_stream_queue.put(mplex_stream) - ) - except MplexUnavailable: - break - elif flag in ( - HeaderTags.MessageInitiator.value, - HeaderTags.MessageReceiver.value, - ): - if not is_stream_id_seen: - # We receive a message of the stream `stream_id` which is not accepted - # before. It is abnormal. Possibly disconnect? - # TODO: Warn and emit logs about this. - continue - async with stream.close_lock: - if stream.event_remote_closed.is_set(): - # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 - continue - try: - await self._wait_until_shutting_down_or_closed( - stream.incoming_data.put(message) - ) - except MplexUnavailable: - break - elif flag in ( - HeaderTags.CloseInitiator.value, - HeaderTags.CloseReceiver.value, - ): - if not is_stream_id_seen: - continue - # NOTE: If remote is already closed, then return: Technically a bug - # on the other side. We should consider killing the connection. - async with stream.close_lock: - if stream.event_remote_closed.is_set(): - continue - is_local_closed: bool - async with stream.close_lock: - stream.event_remote_closed.set() - is_local_closed = stream.event_local_closed.is_set() - # If local is also closed, both sides are closed. Then, we should clean up - # the entry of this stream, to avoid others from accessing it. - if is_local_closed: - async with self.streams_lock: - del self.streams[stream_id] - elif flag in ( - HeaderTags.ResetInitiator.value, - HeaderTags.ResetReceiver.value, - ): - if not is_stream_id_seen: - # This is *ok*. We forget the stream on reset. - continue - async with stream.close_lock: - if not stream.event_remote_closed.is_set(): - stream.event_reset.set() - - stream.event_remote_closed.set() - # If local is not closed, we should close it. - if not stream.event_local_closed.is_set(): - stream.event_local_closed.set() - async with self.streams_lock: - del self.streams[stream_id] - else: - # TODO: logging - if is_stream_id_seen: - await stream.reset() - # Force context switch await asyncio.sleep(0) # If we enter here, it means this connection is shutting down. - # We should clean the things up. + # We should clean things up. await self._cleanup() async def read_message(self) -> Tuple[int, int, bytes]: @@ -303,6 +218,102 @@ class Mplex(IMuxedConn): return channel_id, flag, message + async def _handle_incoming_message(self) -> None: + """ + Read and handle a new incoming message. + :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. + """ + channel_id, flag, message = await self._wait_until_shutting_down_or_closed( + self.read_message() + ) + stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) + + if flag == HeaderTags.NewStream.value: + await self._handle_new_stream(stream_id, message) + elif flag in ( + HeaderTags.MessageInitiator.value, + HeaderTags.MessageReceiver.value, + ): + await self._handle_message(stream_id, message) + elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): + await self._handle_close(stream_id) + elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): + await self._handle_reset(stream_id) + else: + # Receives messages with an unknown flag + # TODO: logging + async with self.streams_lock: + if stream_id in self.streams: + stream = self.streams[stream_id] + await stream.reset() + + async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: + async with self.streams_lock: + if stream_id in self.streams: + # `NewStream` for the same id is received twice... + raise MplexUnavailable( + f"received NewStream message for existing stream: {stream_id}" + ) + mplex_stream = await self._initialize_stream(stream_id, message.decode()) + await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.put(mplex_stream) + ) + + async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: + async with self.streams_lock: + if stream_id not in self.streams: + # We receive a message of the stream `stream_id` which is not accepted + # before. It is abnormal. Possibly disconnect? + # TODO: Warn and emit logs about this. + return + stream = self.streams[stream_id] + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 + return + await self._wait_until_shutting_down_or_closed( + stream.incoming_data.put(message) + ) + + async def _handle_close(self, stream_id: StreamID) -> None: + async with self.streams_lock: + if stream_id not in self.streams: + # Ignore unmatched messages for now. + return + stream = self.streams[stream_id] + # NOTE: If remote is already closed, then return: Technically a bug + # on the other side. We should consider killing the connection. + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + return + is_local_closed: bool + async with stream.close_lock: + stream.event_remote_closed.set() + is_local_closed = stream.event_local_closed.is_set() + # If local is also closed, both sides are closed. Then, we should clean up + # the entry of this stream, to avoid others from accessing it. + if is_local_closed: + async with self.streams_lock: + del self.streams[stream_id] + + async def _handle_reset(self, stream_id: StreamID) -> None: + async with self.streams_lock: + if stream_id not in self.streams: + # This is *ok*. We forget the stream on reset. + return + stream = self.streams[stream_id] + + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_reset.set() + + stream.event_remote_closed.set() + # If local is not closed, we should close it. + if not stream.event_local_closed.is_set(): + stream.event_local_closed.set() + async with self.streams_lock: + del self.streams[stream_id] + async def _cleanup(self) -> None: if not self.event_shutting_down.is_set(): self.event_shutting_down.set() From 6923f257f6df7e66291f722064132f8a8e8112ae Mon Sep 17 00:00:00 2001 From: mhchia Date: Thu, 19 Sep 2019 16:07:53 +0800 Subject: [PATCH 08/12] Remove print --- libp2p/network/swarm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 5bbbe0a0..3c83846e 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -278,7 +278,6 @@ class Swarm(INetwork): return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: - print(f"!@# remove_conn: {swarm_conn}") peer_id = swarm_conn.conn.peer_id # TODO: Should be changed to remove the exact connection, # if we have several connections per peer in the future. From e7304538dae30e2ff4e011ad88bec124fbef0d9c Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 14 Sep 2019 23:37:01 +0800 Subject: [PATCH 09/12] Add test for `Swarm.close_peer` --- libp2p/host/basic_host.py | 8 ++- libp2p/network/connection/swarm_connection.py | 3 +- tests/factories.py | 71 +++++++++++++------ tests/network/conftest.py | 15 +++- tests/network/test_swarm.py | 49 +++++++++++++ tests/utils.py | 13 ++++ 6 files changed, 130 insertions(+), 29 deletions(-) create mode 100644 tests/network/test_swarm.py diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index ea78c988..862fd5c9 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -23,6 +23,10 @@ from .host_interface import IHost class BasicHost(IHost): + """ + BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation + on a stream with multistream-select right after a stream is initialized. + """ _network: INetwork _router: KadmeliaPeerRouter @@ -31,7 +35,6 @@ class BasicHost(IHost): multiselect: Multiselect multiselect_client: MultiselectClient - # default options constructor def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) @@ -69,6 +72,7 @@ class BasicHost(IHost): """ :return: all the multiaddr addresses this host is listening to """ + # TODO: We don't need "/p2p/{peer_id}" postfix actually. p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty())) addrs: List[multiaddr.Multiaddr] = [] @@ -87,8 +91,6 @@ class BasicHost(IHost): """ self.multiselect.add_handler(protocol_id, stream_handler) - # `protocol_ids` can be a list of `protocol_id` - # stream will decide which `protocol_id` to run on async def new_stream( self, peer_id: ID, protocol_ids: Sequence[TProtocol] ) -> INetStream: diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 50d09e79..15816fcb 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -50,11 +50,12 @@ class SwarmConn(INetConn): task.cancel() async def _handle_new_streams(self) -> None: - # TODO: Break the loop when anything wrong in the connection. while True: try: stream = await self.conn.accept_stream() except MuxedConnUnavailable: + # If there is anything wrong in the MuxedConn, + # we should break the loop and close the connection. break # Asynchronously handle the accepted stream, to avoid blocking the next stream. await self.run_task(self._handle_muxed_stream(stream)) diff --git a/tests/factories.py b/tests/factories.py index 0f69707a..efa16c88 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -8,13 +8,13 @@ from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost from libp2p.network.stream.net_stream_interface import INetStream +from libp2p.network.swarm import Swarm from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio -from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( @@ -22,7 +22,7 @@ from tests.pubsub.configs import ( GOSSIPSUB_PARAMS, GOSSIPSUB_PROTOCOL_ID, ) -from tests.utils import connect +from tests.utils import connect, connect_swarm def security_transport_factory( @@ -34,10 +34,29 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -def swarm_factory(is_secure: bool): - key_pair = generate_new_rsa_identity() - sec_opt = security_transport_factory(is_secure, key_pair) - return initialize_default_swarm(key_pair, sec_opt=sec_opt) +class SwarmFactory(factory.Factory): + class Meta: + model = Swarm + + @classmethod + def _create(cls, is_secure=False): + key_pair = generate_new_rsa_identity() + sec_opt = security_transport_factory(is_secure, key_pair) + return initialize_default_swarm(key_pair, sec_opt=sec_opt) + + @classmethod + async def create_and_listen(cls, is_secure: bool) -> Swarm: + swarm = cls._create(is_secure) + await swarm.listen(LISTEN_MADDR) + return swarm + + @classmethod + async def create_batch_and_listen( + cls, is_secure: bool, number: int + ) -> Tuple[Swarm, ...]: + return await asyncio.gather( + *[cls.create_and_listen(is_secure) for _ in range(number)] + ) class HostFactory(factory.Factory): @@ -47,13 +66,12 @@ class HostFactory(factory.Factory): class Params: is_secure = False - network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) + network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure)) @classmethod - async def create_and_listen(cls) -> IHost: - host = cls() - await host.get_network().listen(LISTEN_MADDR) - return host + async def create_and_listen(cls, is_secure: bool) -> IHost: + swarm = await SwarmFactory.create_and_listen(is_secure) + return BasicHost(swarm) class FloodsubFactory(factory.Factory): @@ -87,24 +105,33 @@ class PubsubFactory(factory.Factory): cache_size = None -async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: +async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: + swarms = await SwarmFactory.create_batch_and_listen(2) + await connect_swarm(swarms[0], swarms[1]) + return swarms[0], swarms[1] + + +async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: hosts = await asyncio.gather( - *[HostFactory.create_and_listen(), HostFactory.create_and_listen()] + *[ + HostFactory.create_and_listen(is_secure), + HostFactory.create_and_listen(is_secure), + ] ) await connect(hosts[0], hosts[1]) return hosts[0], hosts[1] -async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: - host_0, host_1 = await host_pair_factory() - mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] - mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] - return mplex_conn_0, host_0, mplex_conn_1, host_1 +# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: +# host_0, host_1 = await host_pair_factory() +# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] +# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] +# return mplex_conn_0, host_0, mplex_conn_1, host_1 -async def net_stream_pair_factory() -> Tuple[ - INetStream, BasicHost, INetStream, BasicHost -]: +async def net_stream_pair_factory( + is_secure: bool +) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: protocol_id = "/example/id/1" stream_1: INetStream @@ -114,7 +141,7 @@ async def net_stream_pair_factory() -> Tuple[ nonlocal stream_1 stream_1 = stream - host_0, host_1 = await host_pair_factory() + host_0, host_1 = await host_pair_factory(is_secure) host_1.set_stream_handler(protocol_id, handler) stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 10f77918..47d5c5f0 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -2,13 +2,22 @@ import asyncio import pytest -from tests.factories import net_stream_pair_factory +from tests.factories import net_stream_pair_factory, swarm_pair_factory @pytest.fixture -async def net_stream_pair(): - stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory() +async def net_stream_pair(is_host_secure): + stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) try: yield stream_0, stream_1 finally: await asyncio.gather(*[host_0.close(), host_1.close()]) + + +@pytest.fixture +async def swarm_pair(is_host_secure): + swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) + try: + yield swarm_0, swarm_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py new file mode 100644 index 00000000..e531de08 --- /dev/null +++ b/tests/network/test_swarm.py @@ -0,0 +1,49 @@ +import asyncio + +import pytest + +from tests.factories import SwarmFactory +from tests.utils import connect_swarm + + +@pytest.mark.asyncio +async def test_swarm_close_peer(is_host_secure): + swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) + # 0 <> 1 <> 2 + await connect_swarm(swarms[0], swarms[1]) + await connect_swarm(swarms[1], swarms[2]) + + # peer 1 closes peer 0 + await swarms[1].close_peer(swarms[0].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 <> 2 + assert len(swarms[0].connections) == 0 + assert ( + len(swarms[1].connections) == 1 + and swarms[2].get_peer_id() in swarms[1].connections + ) + + # peer 1 is closed by peer 2 + await swarms[2].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + + await connect_swarm(swarms[0], swarms[1]) + # 0 <> 1 2 + assert ( + len(swarms[0].connections) == 1 + and swarms[1].get_peer_id() in swarms[0].connections + ) + assert ( + len(swarms[1].connections) == 1 + and swarms[0].get_peer_id() in swarms[1].connections + ) + # peer 0 closes peer 1 + await swarms[0].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + + # Clean up + await asyncio.gather(*[swarm.close() for swarm in swarms]) diff --git a/tests/utils.py b/tests/utils.py index e9d6c09f..4b4357db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr from tests.constants import MAX_READ_LEN +async def connect_swarm(swarm_0, swarm_1): + peer_id = swarm_1.get_peer_id() + addrs = tuple( + addr + for transport in swarm_1.listeners.values() + for addr in transport.get_addrs() + ) + swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) + await swarm_0.dial_peer(peer_id) + assert swarm_0.get_peer_id() in swarm_1.connections + assert swarm_1.get_peer_id() in swarm_0.connections + + async def connect(node1, node2): """ Connect node1 to node2 From 276ac4d8ab1f5a7cb29c6c76261b0574a8434b17 Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 15 Sep 2019 14:59:59 +0800 Subject: [PATCH 10/12] Add initial test for `Swarm.close_peer` --- tests/factories.py | 32 +++++++++++++++++++------------- tests/network/test_swarm.py | 4 ++-- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/factories.py b/tests/factories.py index efa16c88..e9ec1961 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -6,7 +6,6 @@ import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost -from libp2p.host.host_interface import IHost from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.pubsub.floodsub import FloodSub @@ -34,19 +33,19 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -class SwarmFactory(factory.Factory): +def SwarmFactory(is_secure: bool) -> Swarm: + key_pair = generate_new_rsa_identity() + sec_opt = security_transport_factory(False, key_pair) + return initialize_default_swarm(key_pair, sec_opt=sec_opt) + + +class ListeningSwarmFactory(factory.Factory): class Meta: model = Swarm - @classmethod - def _create(cls, is_secure=False): - key_pair = generate_new_rsa_identity() - sec_opt = security_transport_factory(is_secure, key_pair) - return initialize_default_swarm(key_pair, sec_opt=sec_opt) - @classmethod async def create_and_listen(cls, is_secure: bool) -> Swarm: - swarm = cls._create(is_secure) + swarm = SwarmFactory(is_secure) await swarm.listen(LISTEN_MADDR) return swarm @@ -69,9 +68,16 @@ class HostFactory(factory.Factory): network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure)) @classmethod - async def create_and_listen(cls, is_secure: bool) -> IHost: - swarm = await SwarmFactory.create_and_listen(is_secure) - return BasicHost(swarm) + async def create_and_listen(cls, is_secure: bool) -> BasicHost: + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1) + return BasicHost(swarms[0]) + + @classmethod + async def create_batch_and_listen( + cls, is_secure: bool, number: int + ) -> Tuple[BasicHost, ...]: + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number) + return tuple(BasicHost(swarm) for swarm in range(swarms)) class FloodsubFactory(factory.Factory): @@ -106,7 +112,7 @@ class PubsubFactory(factory.Factory): async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: - swarms = await SwarmFactory.create_batch_and_listen(2) + swarms = await ListeningSwarmFactory.create_batch_and_listen(2) await connect_swarm(swarms[0], swarms[1]) return swarms[0], swarms[1] diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index e531de08..3b27f435 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -2,13 +2,13 @@ import asyncio import pytest -from tests.factories import SwarmFactory +from tests.factories import ListeningSwarmFactory from tests.utils import connect_swarm @pytest.mark.asyncio async def test_swarm_close_peer(is_host_secure): - swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) # 0 <> 1 <> 2 await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[1], swarms[2]) From 0356380996a880ea7f8555b39441914d8d04d28e Mon Sep 17 00:00:00 2001 From: mhchia Date: Sun, 15 Sep 2019 20:44:48 +0800 Subject: [PATCH 11/12] Add tests for swarm, and debug Fix `swarm_pair_factory` --- libp2p/network/swarm.py | 11 ++++++++++ tests/factories.py | 2 +- tests/network/test_swarm.py | 44 +++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 3c83846e..272a8a94 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -261,12 +261,18 @@ class Swarm(INetwork): async def close_peer(self, peer_id: ID) -> None: if peer_id not in self.connections: return + # TODO: Should be changed to close multisple connections, + # if we have several connections per peer in the future. connection = self.connections[peer_id] await connection.close() logger.debug("successfully close the connection to peer %s", peer_id) async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn: + """ + Add a `IMuxedConn` to `Swarm` as a `SwarmConn`, notify "connected", + and start to monitor the connection for its new streams and disconnection. + """ swarm_conn = SwarmConn(muxed_conn, self) # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn @@ -278,7 +284,12 @@ class Swarm(INetwork): return swarm_conn def remove_conn(self, swarm_conn: SwarmConn) -> None: + """ + Simply remove the connection from Swarm's records, without closing the connection. + """ peer_id = swarm_conn.conn.peer_id + if peer_id not in self.connections: + return # TODO: Should be changed to remove the exact connection, # if we have several connections per peer in the future. del self.connections[peer_id] diff --git a/tests/factories.py b/tests/factories.py index e9ec1961..e39b12d3 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -112,7 +112,7 @@ class PubsubFactory(factory.Factory): async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: - swarms = await ListeningSwarmFactory.create_batch_and_listen(2) + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2) await connect_swarm(swarms[0], swarms[1]) return swarms[0], swarms[1] diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index 3b27f435..cf8eadfa 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -2,10 +2,43 @@ import asyncio import pytest +from libp2p.network.exceptions import SwarmException from tests.factories import ListeningSwarmFactory from tests.utils import connect_swarm +@pytest.mark.asyncio +async def test_swarm_dial_peer(is_host_secure): + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) + # Test: No addr found. + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: len(addr) in the peerstore is 0. + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: Succeed if addrs of the peer_id are present in the peerstore. + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert swarms[0].get_peer_id() in swarms[1].connections + assert swarms[1].get_peer_id() in swarms[0].connections + + # Test: Reuse connections when we already have ones with a peer. + conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] + conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert conn is conn_to_1 + + # Clean up + await asyncio.gather(*[swarm.close() for swarm in swarms]) + + @pytest.mark.asyncio async def test_swarm_close_peer(is_host_secure): swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) @@ -47,3 +80,14 @@ async def test_swarm_close_peer(is_host_secure): # Clean up await asyncio.gather(*[swarm.close() for swarm in swarms]) + + +@pytest.mark.asyncio +async def test_swarm_remove_conn(swarm_pair): + swarm_0, swarm_1 = swarm_pair + conn_0 = swarm_0.connections[swarm_1.get_peer_id()] + swarm_0.remove_conn(conn_0) + assert swarm_1.get_peer_id() not in swarm_0.connections + # Test: Remove twice. There should not be errors. + swarm_0.remove_conn(conn_0) + assert swarm_1.get_peer_id() not in swarm_0.connections From 539047be2da1a36af9bce370cf24a7d06146481e Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 21 Sep 2019 18:17:00 +0800 Subject: [PATCH 12/12] Make `mplex.read_message` handle `RawConnError` --- libp2p/stream_muxer/mplex/mplex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 3d9497bc..ea7e47ac 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -4,6 +4,7 @@ from typing import Awaitable, Dict, List, Optional, Tuple from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError +from libp2p.network.connection.exceptions import RawConnError from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream @@ -203,8 +204,7 @@ class Mplex(IMuxedConn): message = await asyncio.wait_for( read_varint_prefixed_bytes(self.secured_conn), timeout=5 ) - # TODO: Catch RawConnError? - except (ParseError, IncompleteReadError) as error: + except (ParseError, RawConnError, IncompleteReadError) as error: raise MplexUnavailable( "failed to read messages correctly from the underlying connection" ) from error