diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index 10b83a16..78e6ead6 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -43,11 +43,15 @@ class SwarmConn(INetConn): # We *could* optimize this but it really isn't worth it. for stream in self.streams: await stream.reset() - # Schedule `self._notify_disconnected` to make it execute after `close` is finished. - asyncio.ensure_future(self._notify_disconnected()) for task in self._tasks: task.cancel() + try: + await task + except asyncio.CancelledError: + pass + # Schedule `self._notify_disconnected` to make it execute after `close` is finished. + asyncio.ensure_future(self._notify_disconnected()) async def _handle_new_streams(self) -> None: while True: @@ -70,7 +74,6 @@ class SwarmConn(INetConn): async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: net_stream = NetStream(muxed_stream) self.streams.add(net_stream) - # Call notifiers since event occurred for notifee in self.swarm.notifees: await notifee.opened_stream(self.swarm, net_stream) return net_stream @@ -91,3 +94,7 @@ class SwarmConn(INetConn): async def get_streams(self) -> Tuple[NetStream, ...]: return tuple(self.streams) + + # TODO: Called by `Stream` whenever it is time to remove the stream. + def remove_stream(self, stream: NetStream) -> None: + self.streams.remove(stream) diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 3ae7c9c3..018ef6dd 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -66,3 +66,7 @@ class NetStream(INetStream): async def reset(self) -> None: await self.muxed_stream.reset() + + # TODO: `remove`: Called by close and write when the stream is in specific states. + # It notify `ClosedStream` after `SwarmConn.remove_stream` is called. + # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 diff --git a/tests/factories.py b/tests/factories.py index e39b12d3..af4d529b 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -6,6 +6,7 @@ import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.network.connection.swarm_connection import SwarmConn from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.pubsub.floodsub import FloodSub @@ -128,11 +129,13 @@ async def host_pair_factory(is_secure) -> Tuple[BasicHost, BasicHost]: return hosts[0], hosts[1] -# async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: -# host_0, host_1 = await host_pair_factory() -# mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] -# mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] -# return mplex_conn_0, host_0, mplex_conn_1, host_1 +async def swarm_conn_pair_factory( + is_secure +) -> Tuple[SwarmConn, Swarm, SwarmConn, Swarm]: + swarms = await swarm_pair_factory(is_secure) + conn_0 = swarms[0].connections[swarms[1].get_peer_id()] + conn_1 = swarms[1].connections[swarms[0].get_peer_id()] + return conn_0, swarms[0], conn_1, swarms[1] async def net_stream_pair_factory( diff --git a/tests/network/conftest.py b/tests/network/conftest.py index 47d5c5f0..018e822d 100644 --- a/tests/network/conftest.py +++ b/tests/network/conftest.py @@ -2,7 +2,11 @@ import asyncio import pytest -from tests.factories import net_stream_pair_factory, swarm_pair_factory +from tests.factories import ( + net_stream_pair_factory, + swarm_conn_pair_factory, + swarm_pair_factory, +) @pytest.fixture @@ -21,3 +25,12 @@ async def swarm_pair(is_host_secure): yield swarm_0, swarm_1 finally: await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) + + +@pytest.fixture +async def swarm_conn_pair(is_host_secure): + conn_0, swarm_0, conn_1, swarm_1 = await swarm_conn_pair_factory(is_host_secure) + try: + yield conn_0, conn_1 + finally: + await asyncio.gather(*[swarm_0.close(), swarm_1.close()]) diff --git a/tests/network/test_swarm_conn.py b/tests/network/test_swarm_conn.py new file mode 100644 index 00000000..f9974e14 --- /dev/null +++ b/tests/network/test_swarm_conn.py @@ -0,0 +1,43 @@ +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_swarm_conn_close(swarm_conn_pair): + conn_0, conn_1 = swarm_conn_pair + + assert not conn_0.event_closed.is_set() + assert not conn_1.event_closed.is_set() + + await conn_0.close() + + await asyncio.sleep(0.01) + + assert conn_0.event_closed.is_set() + assert conn_1.event_closed.is_set() + assert conn_0 not in conn_0.swarm.connections.values() + assert conn_1 not in conn_1.swarm.connections.values() + + +@pytest.mark.asyncio +async def test_swarm_conn_streams(swarm_conn_pair): + conn_0, conn_1 = swarm_conn_pair + + assert len(await conn_0.get_streams()) == 0 + assert len(await conn_1.get_streams()) == 0 + + stream_0_0 = await conn_0.new_stream() + await asyncio.sleep(0.01) + assert len(await conn_0.get_streams()) == 1 + assert len(await conn_1.get_streams()) == 1 + + stream_0_1 = await conn_0.new_stream() + await asyncio.sleep(0.01) + assert len(await conn_0.get_streams()) == 2 + assert len(await conn_1.get_streams()) == 2 + + conn_0.remove_stream(stream_0_0) + assert len(await conn_0.get_streams()) == 1 + conn_0.remove_stream(stream_0_1) + assert len(await conn_0.get_streams()) == 0