diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 126359aa..af4225f4 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -30,6 +30,10 @@ logger.setLevel(logging.DEBUG) class BasicHost(IHost): + """ + BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation + on a stream with multistream-select right after a stream is initialized. + """ _network: INetwork _router: KadmeliaPeerRouter @@ -38,7 +42,6 @@ class BasicHost(IHost): multiselect: Multiselect multiselect_client: MultiselectClient - # default options constructor def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) @@ -76,6 +79,7 @@ class BasicHost(IHost): """ :return: all the multiaddr addresses this host is listening to """ + # TODO: We don't need "/p2p/{peer_id}" postfix actually. p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty())) addrs: List[multiaddr.Multiaddr] = [] @@ -94,8 +98,6 @@ class BasicHost(IHost): """ self.multiselect.add_handler(protocol_id, stream_handler) - # `protocol_ids` can be a list of `protocol_id` - # stream will decide which `protocol_id` to run on async def new_stream( self, peer_id: ID, protocol_ids: Sequence[TProtocol] ) -> INetStream: diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index 144c1a85..f193f435 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -37,8 +37,8 @@ class RawConnection(IRawConnection): async with self._drain_lock: try: await self.writer.drain() - except ConnectionResetError: - raise RawConnError() + except ConnectionResetError as error: + raise RawConnError(error) async def read(self, n: int = -1) -> bytes: """ diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index b72fd256..15816fcb 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, List, Set, Tuple from libp2p.network.connection.net_connection_interface import INetConn from libp2p.network.stream.net_stream import NetStream from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.exceptions import MuxedConnUnavailable if TYPE_CHECKING: from libp2p.network.swarm import Swarm # noqa: F401 @@ -34,17 +35,28 @@ class SwarmConn(INetConn): if self.event_closed.is_set(): return self.event_closed.set() + self.swarm.remove_conn(self) + await self.conn.close() + + # This is just for cleaning up state. The connection has already been closed. + # We *could* optimize this but it really isn't worth it. + for stream in self.streams: + await stream.reset() + # Schedule `self._notify_disconnected` to make it execute after `close` is finished. + asyncio.ensure_future(self._notify_disconnected()) + for task in self._tasks: task.cancel() - # TODO: Reset streams for local. - # TODO: Notify closed. - async def _handle_new_streams(self) -> None: - # TODO: Break the loop when anything wrong in the connection. while True: - stream = await self.conn.accept_stream() + try: + stream = await self.conn.accept_stream() + except MuxedConnUnavailable: + # If there is anything wrong in the MuxedConn, + # we should break the loop and close the connection. + break # Asynchronously handle the accepted stream, to avoid blocking the next stream. await self.run_task(self._handle_muxed_stream(stream)) @@ -57,11 +69,16 @@ class SwarmConn(INetConn): async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) + self.streams.add(net_stream) # Call notifiers since event occurred for notifee in self.swarm.notifees: await notifee.opened_stream(self.swarm, net_stream) return net_stream + async def _notify_disconnected(self) -> None: + for notifee in self.swarm.notifees: + await notifee.disconnected(self.swarm, self.conn) + async def start(self) -> None: await self.run_task(self._handle_new_streams()) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 7a1b0990..32997203 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -272,13 +272,18 @@ class Swarm(INetwork): async def close_peer(self, peer_id: ID) -> None: if peer_id not in self.connections: return + # TODO: Should be changed to close multisple connections, + # if we have several connections per peer in the future. connection = self.connections[peer_id] - del self.connections[peer_id] await connection.close() logger.debug("successfully close the connection to peer %s", peer_id) async def add_conn(self, muxed_conn: IMuxedConn) -> SwarmConn: + """ + Add a `IMuxedConn` to `Swarm` as a `SwarmConn`, notify "connected", + and start to monitor the connection for its new streams and disconnection. + """ swarm_conn = SwarmConn(muxed_conn, self) # Store muxed_conn with peer id self.connections[muxed_conn.peer_id] = swarm_conn @@ -288,3 +293,14 @@ class Swarm(INetwork): await notifee.connected(self, muxed_conn) await swarm_conn.start() return swarm_conn + + def remove_conn(self, swarm_conn: SwarmConn) -> None: + """ + Simply remove the connection from Swarm's records, without closing the connection. + """ + peer_id = swarm_conn.conn.peer_id + if peer_id not in self.connections: + return + # TODO: Should be changed to remove the exact connection, + # if we have several connections per peer in the future. + del self.connections[peer_id] diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py index 861319a4..ce0f92e3 100644 --- a/libp2p/stream_muxer/exceptions.py +++ b/libp2p/stream_muxer/exceptions.py @@ -5,7 +5,7 @@ class MuxedConnError(BaseLibp2pError): pass -class MuxedConnShutdown(MuxedConnError): +class MuxedConnUnavailable(MuxedConnError): pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 154c3719..a7be76ee 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,6 +1,6 @@ from libp2p.stream_muxer.exceptions import ( MuxedConnError, - MuxedConnShutdown, + MuxedConnUnavailable, MuxedStreamClosed, MuxedStreamEOF, MuxedStreamReset, @@ -11,7 +11,7 @@ class MplexError(MuxedConnError): pass -class MplexShutdown(MuxedConnShutdown): +class MplexUnavailable(MuxedConnUnavailable): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 38858bbf..ea7e47ac 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -1,9 +1,10 @@ import asyncio from typing import Any # noqa: F401 -from typing import Dict, List, Optional, Tuple +from typing import Awaitable, Dict, List, Optional, Tuple from libp2p.exceptions import ParseError from libp2p.io.exceptions import IncompleteReadError +from libp2p.network.connection.exceptions import RawConnError from libp2p.peer.id import ID from libp2p.security.secure_conn_interface import ISecureConn from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream @@ -17,6 +18,7 @@ from libp2p.utils import ( from .constants import HeaderTags from .datastructures import StreamID +from .exceptions import MplexUnavailable from .mplex_stream import MplexStream MPLEX_PROTOCOL_ID = TProtocol("/mplex/6.7.0") @@ -36,7 +38,8 @@ class Mplex(IMuxedConn): streams: Dict[StreamID, MplexStream] streams_lock: asyncio.Lock new_stream_queue: "asyncio.Queue[IMuxedStream]" - shutdown: asyncio.Event + event_shutting_down: asyncio.Event + event_closed: asyncio.Event _tasks: List["asyncio.Future[Any]"] @@ -60,7 +63,8 @@ class Mplex(IMuxedConn): self.streams = {} self.streams_lock = asyncio.Lock() self.new_stream_queue = asyncio.Queue() - self.shutdown = asyncio.Event() + self.event_shutting_down = asyncio.Event() + self.event_closed = asyncio.Event() self._tasks = [] @@ -75,16 +79,20 @@ class Mplex(IMuxedConn): """ close the stream muxer and underlying secured connection """ - for task in self._tasks: - task.cancel() + if self.event_shutting_down.is_set(): + return + # Set the `event_shutting_down`, to allow graceful shutdown. + self.event_shutting_down.set() await self.secured_conn.close() + # Blocked until `close` is finally set. + await self.event_closed.wait() def is_closed(self) -> bool: """ check connection is fully closed :return: true if successful """ - raise NotImplementedError() + return self.event_closed.is_set() def _get_next_channel_id(self) -> int: """ @@ -114,11 +122,29 @@ class Mplex(IMuxedConn): await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) return stream + async def _wait_until_shutting_down_or_closed(self, coro: Awaitable[Any]) -> Any: + task_coro = asyncio.ensure_future(coro) + task_wait_closed = asyncio.ensure_future(self.event_closed.wait()) + task_wait_shutting_down = asyncio.ensure_future(self.event_shutting_down.wait()) + done, pending = await asyncio.wait( + [task_coro, task_wait_closed, task_wait_shutting_down], + return_when=asyncio.FIRST_COMPLETED, + ) + for fut in pending: + fut.cancel() + if task_wait_closed in done: + raise MplexUnavailable("Mplex is closed") + if task_wait_shutting_down in done: + raise MplexUnavailable("Mplex is shutting down") + return task_coro.result() + async def accept_stream(self) -> IMuxedStream: """ accepts a muxed stream opened by the other end """ - return await self.new_stream_queue.get() + return await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.get() + ) async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID @@ -137,7 +163,9 @@ class Mplex(IMuxedConn): _bytes = header + encode_varint_prefixed(data) - return await self.write_to_stream(_bytes) + return await self._wait_until_shutting_down_or_closed( + self.write_to_stream(_bytes) + ) async def write_to_stream(self, _bytes: bytes) -> int: """ @@ -152,89 +180,17 @@ class Mplex(IMuxedConn): """ Read a message off of the secured connection and add it to the corresponding message buffer """ - # TODO Deal with other types of messages using flag (currently _) while True: - channel_id, flag, message = await self.read_message() - if channel_id is not None and flag is not None and message is not None: - stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) - is_stream_id_seen: bool - stream: MplexStream - async with self.streams_lock: - is_stream_id_seen = stream_id in self.streams - if is_stream_id_seen: - stream = self.streams[stream_id] - # Other consequent stream message should wait until the stream get accepted - # TODO: Handle more tags, and refactor `HeaderTags` - if flag == HeaderTags.NewStream.value: - if is_stream_id_seen: - # `NewStream` for the same id is received twice... - # TODO: Shutdown - pass - mplex_stream = await self._initialize_stream( - stream_id, message.decode() - ) - # TODO: Check if `self` is shutdown. - await self.new_stream_queue.put(mplex_stream) - elif flag in ( - HeaderTags.MessageInitiator.value, - HeaderTags.MessageReceiver.value, - ): - if not is_stream_id_seen: - # We receive a message of the stream `stream_id` which is not accepted - # before. It is abnormal. Possibly disconnect? - # TODO: Warn and emit logs about this. - continue - async with stream.close_lock: - if stream.event_remote_closed.is_set(): - # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 - continue - await stream.incoming_data.put(message) - elif flag in ( - HeaderTags.CloseInitiator.value, - HeaderTags.CloseReceiver.value, - ): - if not is_stream_id_seen: - continue - # NOTE: If remote is already closed, then return: Technically a bug - # on the other side. We should consider killing the connection. - async with stream.close_lock: - if stream.event_remote_closed.is_set(): - continue - is_local_closed: bool - async with stream.close_lock: - stream.event_remote_closed.set() - is_local_closed = stream.event_local_closed.is_set() - # If local is also closed, both sides are closed. Then, we should clean up - # the entry of this stream, to avoid others from accessing it. - if is_local_closed: - async with self.streams_lock: - del self.streams[stream_id] - elif flag in ( - HeaderTags.ResetInitiator.value, - HeaderTags.ResetReceiver.value, - ): - if not is_stream_id_seen: - # This is *ok*. We forget the stream on reset. - continue - async with stream.close_lock: - if not stream.event_remote_closed.is_set(): - # TODO: Why? Only if remote is not closed before then reset. - stream.event_reset.set() - - stream.event_remote_closed.set() - # If local is not closed, we should close it. - if not stream.event_local_closed.is_set(): - stream.event_local_closed.set() - async with self.streams_lock: - del self.streams[stream_id] - else: - # TODO: logging - if is_stream_id_seen: - await stream.reset() - + try: + await self._handle_incoming_message() + except MplexUnavailable: + break # Force context switch await asyncio.sleep(0) + # If we enter here, it means this connection is shutting down. + # We should clean things up. + await self._cleanup() async def read_message(self) -> Tuple[int, int, bytes]: """ @@ -243,21 +199,130 @@ class Mplex(IMuxedConn): """ # FIXME: No timeout is used in Go implementation. - # Timeout is set to a relatively small value to alleviate wait time to exit - # loop in handle_incoming try: header = await decode_uvarint_from_stream(self.secured_conn) - except ParseError: - return None, None, None - try: message = await asyncio.wait_for( read_varint_prefixed_bytes(self.secured_conn), timeout=5 ) - except (ParseError, IncompleteReadError, asyncio.TimeoutError): - # TODO: Investigate what we should do if time is out. - return None, None, None + except (ParseError, RawConnError, IncompleteReadError) as error: + raise MplexUnavailable( + "failed to read messages correctly from the underlying connection" + ) from error + except asyncio.TimeoutError as error: + raise MplexUnavailable( + "failed to read more message body within the timeout" + ) from error flag = header & 0x07 channel_id = header >> 3 return channel_id, flag, message + + async def _handle_incoming_message(self) -> None: + """ + Read and handle a new incoming message. + :raise MplexUnavailable: `Mplex` encounters fatal error or is shutting down. + """ + channel_id, flag, message = await self._wait_until_shutting_down_or_closed( + self.read_message() + ) + stream_id = StreamID(channel_id=channel_id, is_initiator=bool(flag & 1)) + + if flag == HeaderTags.NewStream.value: + await self._handle_new_stream(stream_id, message) + elif flag in ( + HeaderTags.MessageInitiator.value, + HeaderTags.MessageReceiver.value, + ): + await self._handle_message(stream_id, message) + elif flag in (HeaderTags.CloseInitiator.value, HeaderTags.CloseReceiver.value): + await self._handle_close(stream_id) + elif flag in (HeaderTags.ResetInitiator.value, HeaderTags.ResetReceiver.value): + await self._handle_reset(stream_id) + else: + # Receives messages with an unknown flag + # TODO: logging + async with self.streams_lock: + if stream_id in self.streams: + stream = self.streams[stream_id] + await stream.reset() + + async def _handle_new_stream(self, stream_id: StreamID, message: bytes) -> None: + async with self.streams_lock: + if stream_id in self.streams: + # `NewStream` for the same id is received twice... + raise MplexUnavailable( + f"received NewStream message for existing stream: {stream_id}" + ) + mplex_stream = await self._initialize_stream(stream_id, message.decode()) + await self._wait_until_shutting_down_or_closed( + self.new_stream_queue.put(mplex_stream) + ) + + async def _handle_message(self, stream_id: StreamID, message: bytes) -> None: + async with self.streams_lock: + if stream_id not in self.streams: + # We receive a message of the stream `stream_id` which is not accepted + # before. It is abnormal. Possibly disconnect? + # TODO: Warn and emit logs about this. + return + stream = self.streams[stream_id] + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 + return + await self._wait_until_shutting_down_or_closed( + stream.incoming_data.put(message) + ) + + async def _handle_close(self, stream_id: StreamID) -> None: + async with self.streams_lock: + if stream_id not in self.streams: + # Ignore unmatched messages for now. + return + stream = self.streams[stream_id] + # NOTE: If remote is already closed, then return: Technically a bug + # on the other side. We should consider killing the connection. + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + return + is_local_closed: bool + async with stream.close_lock: + stream.event_remote_closed.set() + is_local_closed = stream.event_local_closed.is_set() + # If local is also closed, both sides are closed. Then, we should clean up + # the entry of this stream, to avoid others from accessing it. + if is_local_closed: + async with self.streams_lock: + del self.streams[stream_id] + + async def _handle_reset(self, stream_id: StreamID) -> None: + async with self.streams_lock: + if stream_id not in self.streams: + # This is *ok*. We forget the stream on reset. + return + stream = self.streams[stream_id] + + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_reset.set() + + stream.event_remote_closed.set() + # If local is not closed, we should close it. + if not stream.event_local_closed.is_set(): + stream.event_local_closed.set() + async with self.streams_lock: + del self.streams[stream_id] + + async def _cleanup(self) -> None: + if not self.event_shutting_down.is_set(): + self.event_shutting_down.set() + async with self.streams_lock: + for stream in self.streams.values(): + async with stream.close_lock: + if not stream.event_remote_closed.is_set(): + stream.event_remote_closed.set() + stream.event_reset.set() + stream.event_local_closed.set() + self.streams = None + self.event_closed.set() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 87b039f9..8cabccc4 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -204,7 +204,11 @@ class MplexStream(IMuxedStream): self.event_remote_closed.set() async with self.mplex_conn.streams_lock: - del self.mplex_conn.streams[self.stream_id] + if ( + self.mplex_conn.streams is not None + and self.stream_id in self.mplex_conn.streams + ): + del self.mplex_conn.streams[self.stream_id] # TODO deadline not in use def set_deadline(self, ttl: int) -> bool: diff --git a/libp2p/utils.py b/libp2p/utils.py index 8362a5ac..0e15b567 100644 --- a/libp2p/utils.py +++ b/libp2p/utils.py @@ -41,14 +41,8 @@ async def decode_uvarint_from_stream(reader: Reader) -> int: if shift > SHIFT_64_BIT_MAX: raise ParseError("TODO: better exception msg: Integer is too large...") - byte = await reader.read(1) - - try: - value = byte[0] - except IndexError: - raise ParseError( - "Unexpected end of stream while parsing LEB128 encoded integer" - ) + byte = await read_exactly(reader, 1) + value = byte[0] res += (value & LOW_MASK) << shift diff --git a/tests/factories.py b/tests/factories.py index 0f69707a..e39b12d3 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -6,15 +6,14 @@ import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost -from libp2p.host.host_interface import IHost from libp2p.network.stream.net_stream_interface import INetStream +from libp2p.network.swarm import Swarm from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio -from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( @@ -22,7 +21,7 @@ from tests.pubsub.configs import ( GOSSIPSUB_PARAMS, GOSSIPSUB_PROTOCOL_ID, ) -from tests.utils import connect +from tests.utils import connect, connect_swarm def security_transport_factory( @@ -34,12 +33,31 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -def swarm_factory(is_secure: bool): +def SwarmFactory(is_secure: bool) -> Swarm: key_pair = generate_new_rsa_identity() - sec_opt = security_transport_factory(is_secure, key_pair) + sec_opt = security_transport_factory(False, key_pair) return initialize_default_swarm(key_pair, sec_opt=sec_opt) +class ListeningSwarmFactory(factory.Factory): + class Meta: + model = Swarm + + @classmethod + async def create_and_listen(cls, is_secure: bool) -> Swarm: + swarm = SwarmFactory(is_secure) + await swarm.listen(LISTEN_MADDR) + return swarm + + @classmethod + async def create_batch_and_listen( + cls, is_secure: bool, number: int + ) -> Tuple[Swarm, ...]: + return await asyncio.gather( + *[cls.create_and_listen(is_secure) for _ in range(number)] + ) + + class HostFactory(factory.Factory): class Meta: model = BasicHost @@ -47,13 +65,19 @@ class HostFactory(factory.Factory): class Params: is_secure = False - network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) + network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure)) @classmethod - async def create_and_listen(cls) -> IHost: - host = cls() - await host.get_network().listen(LISTEN_MADDR) - return host + async def create_and_listen(cls, is_secure: bool) -> BasicHost: + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1) + return BasicHost(swarms[0]) + + @classmethod + async def create_batch_and_listen( + cls, is_secure: bool, number: int + ) -> Tuple[BasicHost, ...]: + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number) + return tuple(BasicHost(swarm) for swarm in range(swarms)) class FloodsubFactory(factory.Factory): @@ -87,24 +111,33 @@ class PubsubFactory(factory.Factory): cache_size = None -async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: +async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 2) + await connect_swarm(swarms[0], swarms[1]) + return swarms[0], swarms[1] + + +async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: hosts = await asyncio.gather( - *[HostFactory.create_and_listen(), HostFactory.create_and_listen()] + *[ + HostFactory.create_and_listen(is_secure), + HostFactory.create_and_listen(is_secure), + ] ) await connect(hosts[0], hosts[1]) return hosts[0], hosts[1] -async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: - host_0, host_1 = await host_pair_factory() - mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] - mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] - return mplex_conn_0, host_0, mplex_conn_1, host_1 +# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: +# host_0, host_1 = await host_pair_factory() +# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] +# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] +# return mplex_conn_0, host_0, mplex_conn_1, host_1 -async def net_stream_pair_factory() -> Tuple[ - INetStream, BasicHost, INetStream, BasicHost -]: +async def net_stream_pair_factory( + is_secure: bool +) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: protocol_id = "/example/id/1" stream_1: INetStream @@ -114,7 +147,7 @@ async def net_stream_pair_factory() -> Tuple[ nonlocal stream_1 stream_1 = stream - host_0, host_1 = await host_pair_factory() + host_0, host_1 = await host_pair_factory(is_secure) host_1.set_stream_handler(protocol_id, handler) stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) diff --git a/tests/interop/test_bindings.py b/tests/interop/test_bindings.py index 1189e0b7..1e78ff43 100644 --- a/tests/interop/test_bindings.py +++ b/tests/interop/test_bindings.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from .utils import connect @@ -21,4 +23,5 @@ async def test_connect(hosts, p2pds): # Test: `disconnect` from Go await p2pd.control.disconnect(host.get_id()) # FIXME: Failed to handle disconnect - # assert len(host.get_network().connections) == 0 + await asyncio.sleep(0.01) + assert len(host.get_network().connections) == 0 diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 10f77918..47d5c5f0 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -2,13 +2,22 @@ import asyncio import pytest -from tests.factories import net_stream_pair_factory +from tests.factories import net_stream_pair_factory, swarm_pair_factory @pytest.fixture -async def net_stream_pair(): - stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory() +async def net_stream_pair(is_host_secure): + stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) try: yield stream_0, stream_1 finally: await asyncio.gather(*[host_0.close(), host_1.close()]) + + +@pytest.fixture +async def swarm_pair(is_host_secure): + swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) + try: + yield swarm_0, swarm_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py new file mode 100644 index 00000000..cf8eadfa --- /dev/null +++ b/tests/network/test_swarm.py @@ -0,0 +1,93 @@ +import asyncio + +import pytest + +from libp2p.network.exceptions import SwarmException +from tests.factories import ListeningSwarmFactory +from tests.utils import connect_swarm + + +@pytest.mark.asyncio +async def test_swarm_dial_peer(is_host_secure): + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) + # Test: No addr found. + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: len(addr) in the peerstore is 0. + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), [], 10000) + with pytest.raises(SwarmException): + await swarms[0].dial_peer(swarms[1].get_peer_id()) + + # Test: Succeed if addrs of the peer_id are present in the peerstore. + addrs = tuple( + addr + for transport in swarms[1].listeners.values() + for addr in transport.get_addrs() + ) + swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs, 10000) + await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert swarms[0].get_peer_id() in swarms[1].connections + assert swarms[1].get_peer_id() in swarms[0].connections + + # Test: Reuse connections when we already have ones with a peer. + conn_to_1 = swarms[0].connections[swarms[1].get_peer_id()] + conn = await swarms[0].dial_peer(swarms[1].get_peer_id()) + assert conn is conn_to_1 + + # Clean up + await asyncio.gather(*[swarm.close() for swarm in swarms]) + + +@pytest.mark.asyncio +async def test_swarm_close_peer(is_host_secure): + swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) + # 0 <> 1 <> 2 + await connect_swarm(swarms[0], swarms[1]) + await connect_swarm(swarms[1], swarms[2]) + + # peer 1 closes peer 0 + await swarms[1].close_peer(swarms[0].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 <> 2 + assert len(swarms[0].connections) == 0 + assert ( + len(swarms[1].connections) == 1 + and swarms[2].get_peer_id() in swarms[1].connections + ) + + # peer 1 is closed by peer 2 + await swarms[2].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + + await connect_swarm(swarms[0], swarms[1]) + # 0 <> 1 2 + assert ( + len(swarms[0].connections) == 1 + and swarms[1].get_peer_id() in swarms[0].connections + ) + assert ( + len(swarms[1].connections) == 1 + and swarms[0].get_peer_id() in swarms[1].connections + ) + # peer 0 closes peer 1 + await swarms[0].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + + # Clean up + await asyncio.gather(*[swarm.close() for swarm in swarms]) + + +@pytest.mark.asyncio +async def test_swarm_remove_conn(swarm_pair): + swarm_0, swarm_1 = swarm_pair + conn_0 = swarm_0.connections[swarm_1.get_peer_id()] + swarm_0.remove_conn(conn_0) + assert swarm_1.get_peer_id() not in swarm_0.connections + # Test: Remove twice. There should not be errors. + swarm_0.remove_conn(conn_0) + assert swarm_1.get_peer_id() not in swarm_0.connections diff --git a/tests/utils.py b/tests/utils.py index e9d6c09f..4b4357db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr from tests.constants import MAX_READ_LEN +async def connect_swarm(swarm_0, swarm_1): + peer_id = swarm_1.get_peer_id() + addrs = tuple( + addr + for transport in swarm_1.listeners.values() + for addr in transport.get_addrs() + ) + swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) + await swarm_0.dial_peer(peer_id) + assert swarm_0.get_peer_id() in swarm_1.connections + assert swarm_1.get_peer_id() in swarm_0.connections + + async def connect(node1, node2): """ Connect node1 to node2