From e7304538dae30e2ff4e011ad88bec124fbef0d9c Mon Sep 17 00:00:00 2001 From: mhchia Date: Sat, 14 Sep 2019 23:37:01 +0800 Subject: [PATCH] Add test for `Swarm.close_peer` --- libp2p/host/basic_host.py | 8 ++- libp2p/network/connection/swarm_connection.py | 3 +- tests/factories.py | 71 +++++++++++++------ tests/network/conftest.py | 15 +++- tests/network/test_swarm.py | 49 +++++++++++++ tests/utils.py | 13 ++++ 6 files changed, 130 insertions(+), 29 deletions(-) create mode 100644 tests/network/test_swarm.py diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index ea78c988..862fd5c9 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -23,6 +23,10 @@ from .host_interface import IHost class BasicHost(IHost): + """ + BasicHost is a wrapper of a `INetwork` implementation. It performs protocol negotiation + on a stream with multistream-select right after a stream is initialized. + """ _network: INetwork _router: KadmeliaPeerRouter @@ -31,7 +35,6 @@ class BasicHost(IHost): multiselect: Multiselect multiselect_client: MultiselectClient - # default options constructor def __init__(self, network: INetwork, router: KadmeliaPeerRouter = None) -> None: self._network = network self._network.set_stream_handler(self._swarm_stream_handler) @@ -69,6 +72,7 @@ class BasicHost(IHost): """ :return: all the multiaddr addresses this host is listening to """ + # TODO: We don't need "/p2p/{peer_id}" postfix actually. p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty())) addrs: List[multiaddr.Multiaddr] = [] @@ -87,8 +91,6 @@ class BasicHost(IHost): """ self.multiselect.add_handler(protocol_id, stream_handler) - # `protocol_ids` can be a list of `protocol_id` - # stream will decide which `protocol_id` to run on async def new_stream( self, peer_id: ID, protocol_ids: Sequence[TProtocol] ) -> INetStream: diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 50d09e79..15816fcb 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -50,11 +50,12 @@ class SwarmConn(INetConn): task.cancel() async def _handle_new_streams(self) -> None: - # TODO: Break the loop when anything wrong in the connection. while True: try: stream = await self.conn.accept_stream() except MuxedConnUnavailable: + # If there is anything wrong in the MuxedConn, + # we should break the loop and close the connection. break # Asynchronously handle the accepted stream, to avoid blocking the next stream. await self.run_task(self._handle_muxed_stream(stream)) diff --git a/tests/factories.py b/tests/factories.py index 0f69707a..efa16c88 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -8,13 +8,13 @@ from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost from libp2p.host.host_interface import IHost from libp2p.network.stream.net_stream_interface import INetStream +from libp2p.network.swarm import Swarm from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio -from libp2p.stream_muxer.mplex.mplex import Mplex from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( @@ -22,7 +22,7 @@ from tests.pubsub.configs import ( GOSSIPSUB_PARAMS, GOSSIPSUB_PROTOCOL_ID, ) -from tests.utils import connect +from tests.utils import connect, connect_swarm def security_transport_factory( @@ -34,10 +34,29 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -def swarm_factory(is_secure: bool): - key_pair = generate_new_rsa_identity() - sec_opt = security_transport_factory(is_secure, key_pair) - return initialize_default_swarm(key_pair, sec_opt=sec_opt) +class SwarmFactory(factory.Factory): + class Meta: + model = Swarm + + @classmethod + def _create(cls, is_secure=False): + key_pair = generate_new_rsa_identity() + sec_opt = security_transport_factory(is_secure, key_pair) + return initialize_default_swarm(key_pair, sec_opt=sec_opt) + + @classmethod + async def create_and_listen(cls, is_secure: bool) -> Swarm: + swarm = cls._create(is_secure) + await swarm.listen(LISTEN_MADDR) + return swarm + + @classmethod + async def create_batch_and_listen( + cls, is_secure: bool, number: int + ) -> Tuple[Swarm, ...]: + return await asyncio.gather( + *[cls.create_and_listen(is_secure) for _ in range(number)] + ) class HostFactory(factory.Factory): @@ -47,13 +66,12 @@ class HostFactory(factory.Factory): class Params: is_secure = False - network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) + network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure)) @classmethod - async def create_and_listen(cls) -> IHost: - host = cls() - await host.get_network().listen(LISTEN_MADDR) - return host + async def create_and_listen(cls, is_secure: bool) -> IHost: + swarm = await SwarmFactory.create_and_listen(is_secure) + return BasicHost(swarm) class FloodsubFactory(factory.Factory): @@ -87,24 +105,33 @@ class PubsubFactory(factory.Factory): cache_size = None -async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: +async def swarm_pair_factory(is_secure: bool) -> Tuple[Swarm, Swarm]: + swarms = await SwarmFactory.create_batch_and_listen(2) + await connect_swarm(swarms[0], swarms[1]) + return swarms[0], swarms[1] + + +async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: hosts = await asyncio.gather( - *[HostFactory.create_and_listen(), HostFactory.create_and_listen()] + *[ + HostFactory.create_and_listen(is_secure), + HostFactory.create_and_listen(is_secure), + ] ) await connect(hosts[0], hosts[1]) return hosts[0], hosts[1] -async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: - host_0, host_1 = await host_pair_factory() - mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] - mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] - return mplex_conn_0, host_0, mplex_conn_1, host_1 +# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: +# host_0, host_1 = await host_pair_factory() +# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] +# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] +# return mplex_conn_0, host_0, mplex_conn_1, host_1 -async def net_stream_pair_factory() -> Tuple[ - INetStream, BasicHost, INetStream, BasicHost -]: +async def net_stream_pair_factory( + is_secure: bool +) -> Tuple[INetStream, BasicHost, INetStream, BasicHost]: protocol_id = "/example/id/1" stream_1: INetStream @@ -114,7 +141,7 @@ async def net_stream_pair_factory() -> Tuple[ nonlocal stream_1 stream_1 = stream - host_0, host_1 = await host_pair_factory() + host_0, host_1 = await host_pair_factory(is_secure) host_1.set_stream_handler(protocol_id, handler) stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 10f77918..47d5c5f0 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -2,13 +2,22 @@ import asyncio import pytest -from tests.factories import net_stream_pair_factory +from tests.factories import net_stream_pair_factory, swarm_pair_factory @pytest.fixture -async def net_stream_pair(): - stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory() +async def net_stream_pair(is_host_secure): + stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory(is_host_secure) try: yield stream_0, stream_1 finally: await asyncio.gather(*[host_0.close(), host_1.close()]) + + +@pytest.fixture +async def swarm_pair(is_host_secure): + swarm_0, swarm_1 = await swarm_pair_factory(is_host_secure) + try: + yield swarm_0, swarm_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py new file mode 100644 index 00000000..e531de08 --- /dev/null +++ b/tests/network/test_swarm.py @@ -0,0 +1,49 @@ +import asyncio + +import pytest + +from tests.factories import SwarmFactory +from tests.utils import connect_swarm + + +@pytest.mark.asyncio +async def test_swarm_close_peer(is_host_secure): + swarms = await SwarmFactory.create_batch_and_listen(is_host_secure, 3) + # 0 <> 1 <> 2 + await connect_swarm(swarms[0], swarms[1]) + await connect_swarm(swarms[1], swarms[2]) + + # peer 1 closes peer 0 + await swarms[1].close_peer(swarms[0].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 <> 2 + assert len(swarms[0].connections) == 0 + assert ( + len(swarms[1].connections) == 1 + and swarms[2].get_peer_id() in swarms[1].connections + ) + + # peer 1 is closed by peer 2 + await swarms[2].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + + await connect_swarm(swarms[0], swarms[1]) + # 0 <> 1 2 + assert ( + len(swarms[0].connections) == 1 + and swarms[1].get_peer_id() in swarms[0].connections + ) + assert ( + len(swarms[1].connections) == 1 + and swarms[0].get_peer_id() in swarms[1].connections + ) + # peer 0 closes peer 1 + await swarms[0].close_peer(swarms[1].get_peer_id()) + await asyncio.sleep(0.01) + # 0 1 2 + assert len(swarms[1].connections) == 0 and len(swarms[2].connections) == 0 + + # Clean up + await asyncio.gather(*[swarm.close() for swarm in swarms]) diff --git a/tests/utils.py b/tests/utils.py index e9d6c09f..4b4357db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,19 @@ from libp2p.peer.peerinfo import info_from_p2p_addr from tests.constants import MAX_READ_LEN +async def connect_swarm(swarm_0, swarm_1): + peer_id = swarm_1.get_peer_id() + addrs = tuple( + addr + for transport in swarm_1.listeners.values() + for addr in transport.get_addrs() + ) + swarm_0.peerstore.add_addrs(peer_id, addrs, 10000) + await swarm_0.dial_peer(peer_id) + assert swarm_0.get_peer_id() in swarm_1.connections + assert swarm_1.get_peer_id() in swarm_0.connections + + async def connect(node1, node2): """ Connect node1 to node2