diff --git a/tests/factories.py b/tests/factories.py index b4e8be23..b65b918d 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -37,8 +37,9 @@ def security_transport_factory( return {secio.ID: secio.Transport(key_pair)} -def SwarmFactory(is_secure: bool, muxer_opt: TMuxerOptions = None) -> Swarm: - key_pair = generate_new_rsa_identity() +def SwarmFactory( + is_secure: bool, key_pair: KeyPair, muxer_opt: TMuxerOptions = None +) -> Swarm: sec_opt = security_transport_factory(is_secure, key_pair) return initialize_default_swarm(key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt) @@ -50,15 +51,16 @@ class ListeningSwarmFactory(factory.Factory): @classmethod async def create_and_listen( cls, is_secure: bool, muxer_opt: TMuxerOptions = None - ) -> Swarm: - swarm = SwarmFactory(is_secure, muxer_opt=muxer_opt) + ) -> Tuple[Swarm, KeyPair]: + key_pair = generate_new_rsa_identity() + swarm = SwarmFactory(is_secure, key_pair, muxer_opt=muxer_opt) await swarm.listen(LISTEN_MADDR) - return swarm + return swarm, key_pair @classmethod async def create_batch_and_listen( cls, is_secure: bool, number: int, muxer_opt: TMuxerOptions = None - ) -> Tuple[Swarm, ...]: + ) -> Tuple[Tuple[Swarm, KeyPair], ...]: return await asyncio.gather( *[ cls.create_and_listen(is_secure, muxer_opt=muxer_opt) @@ -74,19 +76,26 @@ class HostFactory(factory.Factory): class Params: is_secure = False - network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure)) + network = factory.LazyAttribute(lambda o: SwarmFactory(o.is_secure, o.key_pair)) @classmethod async def create_and_listen(cls, is_secure: bool) -> BasicHost: - swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, 1) - return BasicHost(swarms[0]) + swarms_and_keys = await ListeningSwarmFactory.create_batch_and_listen( + is_secure, 1 + ) + swarm, key_pair = swarms_and_keys[0] + return BasicHost(key_pair.public_key, swarm) @classmethod async def create_batch_and_listen( cls, is_secure: bool, number: int ) -> Tuple[BasicHost, ...]: - swarms = await ListeningSwarmFactory.create_batch_and_listen(is_secure, number) - return tuple(BasicHost(swarm) for swarm in range(swarms)) + swarms_and_keys = await ListeningSwarmFactory.create_batch_and_listen( + is_secure, number + ) + return tuple( + BasicHost(key_pair.public_key, swarm) for swarm, key_pair in swarms_and_keys + ) class FloodsubFactory(factory.Factory): @@ -123,9 +132,10 @@ class PubsubFactory(factory.Factory): async def swarm_pair_factory( is_secure: bool, muxer_opt: TMuxerOptions = None ) -> Tuple[Swarm, Swarm]: - swarms = await ListeningSwarmFactory.create_batch_and_listen( + swarms_and_keys = await ListeningSwarmFactory.create_batch_and_listen( is_secure, 2, muxer_opt=muxer_opt ) + swarms = tuple(swarm for swarm, _key_pair in swarms_and_keys) await connect_swarm(swarms[0], swarms[1]) return swarms[0], swarms[1] diff --git a/tests/network/test_swarm.py b/tests/network/test_swarm.py index cf8eadfa..3d63d6b6 100644 --- a/tests/network/test_swarm.py +++ b/tests/network/test_swarm.py @@ -9,7 +9,10 @@ from tests.utils import connect_swarm @pytest.mark.asyncio async def test_swarm_dial_peer(is_host_secure): - swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) + swarms_and_keys = await ListeningSwarmFactory.create_batch_and_listen( + is_host_secure, 3 + ) + swarms = tuple(swarm for swarm, _key_pair in swarms_and_keys) # Test: No addr found. with pytest.raises(SwarmException): await swarms[0].dial_peer(swarms[1].get_peer_id()) @@ -41,7 +44,10 @@ async def test_swarm_dial_peer(is_host_secure): @pytest.mark.asyncio async def test_swarm_close_peer(is_host_secure): - swarms = await ListeningSwarmFactory.create_batch_and_listen(is_host_secure, 3) + swarms_and_keys = await ListeningSwarmFactory.create_batch_and_listen( + is_host_secure, 3 + ) + swarms = tuple(swarm for swarm, _key_pair in swarms_and_keys) # 0 <> 1 <> 2 await connect_swarm(swarms[0], swarms[1]) await connect_swarm(swarms[1], swarms[2])