From be2c0f122a1bfa47051700045a557a02933f65de Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 9 Sep 2019 15:45:35 +0800 Subject: [PATCH 1/8] Fix close behavior --- libp2p/network/stream/exceptions.py | 17 +++++++ libp2p/network/stream/net_stream.py | 25 ++++++++-- libp2p/network/stream/net_stream_interface.py | 2 +- libp2p/stream_muxer/exceptions.py | 25 ++++++++++ libp2p/stream_muxer/mplex/exceptions.py | 20 ++++++-- libp2p/stream_muxer/mplex/mplex.py | 4 ++ libp2p/stream_muxer/mplex/mplex_stream.py | 28 +++++++---- tests/factories.py | 49 ++++++++++++++++++- 8 files changed, 149 insertions(+), 21 deletions(-) create mode 100644 libp2p/network/stream/exceptions.py create mode 100644 libp2p/stream_muxer/exceptions.py diff --git a/libp2p/network/stream/exceptions.py b/libp2p/network/stream/exceptions.py new file mode 100644 index 00000000..58f3ddf4 --- /dev/null +++ b/libp2p/network/stream/exceptions.py @@ -0,0 +1,17 @@ +from libp2p.exceptions import BaseLibp2pError + + +class StreamError(BaseLibp2pError): + pass + + +class StreamEOF(StreamError, EOFError): + pass + + +class StreamReset(StreamError): + pass + + +class StreamClosed(StreamError): + pass diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 7383f736..4dedab72 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,9 +1,18 @@ from libp2p.stream_muxer.abc import IMuxedConn, IMuxedStream +from libp2p.stream_muxer.exceptions import ( + MuxedStreamClosed, + MuxedStreamEOF, + MuxedStreamReset, +) from libp2p.typing import TProtocol +from .exceptions import StreamClosed, StreamEOF, StreamReset from .net_stream_interface import INetStream +# TODO: Handle exceptions from `muxed_stream` +# TODO: Add stream state +# - Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 class NetStream(INetStream): muxed_stream: IMuxedStream @@ -35,14 +44,22 @@ class NetStream(INetStream): :param n: number of bytes to read :return: bytes of input """ - return await self.muxed_stream.read(n) + try: + return await self.muxed_stream.read(n) + except MuxedStreamEOF as error: + raise StreamEOF from error + except MuxedStreamReset as error: + raise StreamReset from error async def write(self, data: bytes) -> int: """ write to stream :return: number of bytes written """ - return await self.muxed_stream.write(data) + try: + return await self.muxed_stream.write(data) + except MuxedStreamClosed as error: + raise StreamClosed from error async def close(self) -> None: """ @@ -51,5 +68,5 @@ class NetStream(INetStream): """ await self.muxed_stream.close() - async def reset(self) -> bool: - return await self.muxed_stream.reset() + async def reset(self) -> None: + await self.muxed_stream.reset() diff --git a/libp2p/network/stream/net_stream_interface.py b/libp2p/network/stream/net_stream_interface.py index aaa775a3..53ce0386 100644 --- a/libp2p/network/stream/net_stream_interface.py +++ b/libp2p/network/stream/net_stream_interface.py @@ -23,7 +23,7 @@ class INetStream(ReadWriteCloser): """ @abstractmethod - async def reset(self) -> bool: + async def reset(self) -> None: """ Close both ends of the stream. """ diff --git a/libp2p/stream_muxer/exceptions.py b/libp2p/stream_muxer/exceptions.py new file mode 100644 index 00000000..861319a4 --- /dev/null +++ b/libp2p/stream_muxer/exceptions.py @@ -0,0 +1,25 @@ +from libp2p.exceptions import BaseLibp2pError + + +class MuxedConnError(BaseLibp2pError): + pass + + +class MuxedConnShutdown(MuxedConnError): + pass + + +class MuxedStreamError(BaseLibp2pError): + pass + + +class MuxedStreamReset(MuxedStreamError): + pass + + +class MuxedStreamEOF(MuxedStreamError, EOFError): + pass + + +class MuxedStreamClosed(MuxedStreamError): + pass diff --git a/libp2p/stream_muxer/mplex/exceptions.py b/libp2p/stream_muxer/mplex/exceptions.py index 11663e2e..154c3719 100644 --- a/libp2p/stream_muxer/mplex/exceptions.py +++ b/libp2p/stream_muxer/mplex/exceptions.py @@ -1,17 +1,27 @@ -from libp2p.exceptions import BaseLibp2pError +from libp2p.stream_muxer.exceptions import ( + MuxedConnError, + MuxedConnShutdown, + MuxedStreamClosed, + MuxedStreamEOF, + MuxedStreamReset, +) -class MplexError(BaseLibp2pError): +class MplexError(MuxedConnError): pass -class MplexStreamReset(MplexError): +class MplexShutdown(MuxedConnShutdown): pass -class MplexStreamEOF(MplexError, EOFError): +class MplexStreamReset(MuxedStreamReset): pass -class MplexShutdown(MplexError): +class MplexStreamEOF(MuxedStreamEOF): + pass + + +class MplexStreamClosed(MuxedStreamClosed): pass diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 1e8823a9..c75000d2 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -188,6 +188,10 @@ class Mplex(IMuxedConn): # before. It is abnormal. Possibly disconnect? # TODO: Warn and emit logs about this. continue + async with stream.close_lock: + if stream.event_remote_closed.is_set(): + # TODO: Warn "Received data from remote after stream was closed by them. (len = %d)" # noqa: E501 + continue await stream.incoming_data.put(message) elif flag in ( HeaderTags.CloseInitiator.value, diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 18c8ff02..547d7b84 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,11 +1,11 @@ import asyncio -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from libp2p.stream_muxer.abc import IMuxedStream from .constants import HeaderTags from .datastructures import StreamID -from .exceptions import MplexStreamEOF, MplexStreamReset +from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: from libp2p.stream_muxer.mplex.mplex import Mplex @@ -58,20 +58,24 @@ class MplexStream(IMuxedStream): done, pending = await asyncio.wait( # type: ignore [ self.event_reset.wait(), - self.event_remote_closed.wait(), self.incoming_data.get(), + self.event_remote_closed.wait(), ], return_when=asyncio.FIRST_COMPLETED, ) + for fut in pending: + fut.cancel() if self.event_reset.is_set(): raise MplexStreamReset + done_task = tuple(done)[0] + if done_task._coro.__qualname__ == "Queue.get": + data = done_task.result() + self._buf.extend(data) + return if self.event_remote_closed.is_set(): raise MplexStreamEOF # TODO: Handle timeout when deadline is used. - data = tuple(done)[0].result() - self._buf.extend(data) - async def _read_until_eof(self) -> bytes: while True: try: @@ -99,13 +103,15 @@ class MplexStream(IMuxedStream): raise MplexStreamReset if n == -1: return await self._read_until_eof() - if len(self._buf) == 0: + if len(self._buf) == 0 and self.incoming_data.empty(): await self._wait_for_data() - # Read up to `n` bytes. + # Either `buf` is not empty or `incoming_data` is not empty now. + # Try to put enough incoming data into `self._buf`. while len(self._buf) < n: - if self.incoming_data.empty() or self.event_remote_closed.is_set(): + try: + self._buf.extend(self.incoming_data.get_nowait()) + except asyncio.QueueEmpty: break - self._buf.extend(await self.incoming_data.get()) payload = self._buf[:n] self._buf = self._buf[len(payload) :] return bytes(payload) @@ -115,6 +121,8 @@ class MplexStream(IMuxedStream): write to stream :return: number of bytes written """ + if self.event_local_closed.is_set(): + raise MplexStreamClosed(f"cannot write to closed stream: data={data}") flag = ( HeaderTags.MessageInitiator if self.is_initiator diff --git a/tests/factories.py b/tests/factories.py index 240bdb83..e161e25d 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,22 +1,29 @@ -from typing import Dict +import asyncio +from typing import Dict, Tuple import factory from libp2p import generate_new_rsa_identity, initialize_default_swarm from libp2p.crypto.keys import KeyPair from libp2p.host.basic_host import BasicHost +from libp2p.host.host_interface import IHost +from libp2p.network.stream.net_stream_interface import INetStream from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio +from libp2p.stream_muxer.mplex.mplex import Mplex +from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.typing import TProtocol +from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( FLOODSUB_PROTOCOL_ID, GOSSIPSUB_PARAMS, GOSSIPSUB_PROTOCOL_ID, ) +from tests.utils import connect def security_transport_factory( @@ -43,6 +50,12 @@ class HostFactory(factory.Factory): network = factory.LazyAttribute(lambda o: swarm_factory(o.is_secure)) + @classmethod + async def create_and_listen(cls) -> IHost: + host = cls() + await host.get_network().listen(LISTEN_MADDR) + return host + class FloodsubFactory(factory.Factory): class Meta: @@ -73,3 +86,37 @@ class PubsubFactory(factory.Factory): router = None my_id = factory.LazyAttribute(lambda obj: obj.host.get_id()) cache_size = None + + +async def host_pair_factory() -> Tuple[BasicHost, BasicHost]: + hosts = await asyncio.gather( + *[HostFactory.create_and_listen(), HostFactory.create_and_listen()] + ) + await connect(hosts[0], hosts[1]) + return hosts[0], hosts[1] + + +async def connection_pair_factory() -> Tuple[Mplex, BasicHost, Mplex, BasicHost]: + host_0, host_1 = await host_pair_factory() + mplex_conn_0 = host_0.get_network().connections[host_1.get_id()] + mplex_conn_1 = host_1.get_network().connections[host_0.get_id()] + return mplex_conn_0, host_0, mplex_conn_1, host_1 + + +async def net_stream_pair_factory() -> Tuple[ + INetStream, BasicHost, INetStream, BasicHost +]: + protocol_id = "/example/id/1" + + stream_1: INetStream + + # Just a proxy, we only care about the stream + def handler(stream: INetStream) -> None: + nonlocal stream_1 + stream_1 = stream + + host_0, host_1 = await host_pair_factory() + host_1.set_stream_handler(protocol_id, handler) + + stream_0 = await host_0.new_stream(host_1.get_id(), [protocol_id]) + return stream_0, host_0, stream_1, host_1 From 0ab548aee5cbf0d3b9952ea925766b225439a2b2 Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 9 Sep 2019 16:58:58 +0800 Subject: [PATCH 2/8] Add the missing tests --- tests/network/__init__.py | 0 tests/network/conftest.py | 14 ++++ tests/network/test_net_stream.py | 111 +++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 tests/network/__init__.py create mode 100644 tests/network/conftest.py create mode 100644 tests/network/test_net_stream.py diff --git a/tests/network/__init__.py b/tests/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/network/conftest.py b/tests/network/conftest.py new file mode 100644 index 00000000..10f77918 --- /dev/null +++ b/tests/network/conftest.py @@ -0,0 +1,14 @@ +import asyncio + +import pytest + +from tests.factories import net_stream_pair_factory + + +@pytest.fixture +async def net_stream_pair(): + stream_0, host_0, stream_1, host_1 = await net_stream_pair_factory() + try: + yield stream_0, stream_1 + finally: + await asyncio.gather(*[host_0.close(), host_1.close()]) diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py new file mode 100644 index 00000000..e7029125 --- /dev/null +++ b/tests/network/test_net_stream.py @@ -0,0 +1,111 @@ +import asyncio + +import pytest + +from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset +from tests.constants import MAX_READ_LEN + +DATA = b"data_123" + +# TODO: Move `muxed_stream` specific(currently we are using `MplexStream`) tests to its +# own file, after `generic_protocol_handler` is refactored out of `Mplex`. + + +@pytest.mark.asyncio +async def test_net_stream_read_write(net_stream_pair): + stream_0, stream_1 = net_stream_pair + assert ( + stream_0.protocol_id is not None + and stream_0.protocol_id == stream_1.protocol_id + ) + await stream_0.write(DATA) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_net_stream_read_until_eof(net_stream_pair): + read_bytes = bytearray() + stream_0, stream_1 = net_stream_pair + + async def read_until_eof(): + read_bytes.extend(await stream_1.read()) + + task = asyncio.ensure_future(read_until_eof()) + + expected_data = bytearray() + + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + # Test: `read` doesn't return before `close` is called. + await stream_0.write(DATA) + expected_data.extend(DATA) + await asyncio.sleep(0.01) + assert len(read_bytes) == 0 + + # Test: Close the stream, `read` returns, and receive previous sent data. + await stream_0.close() + await asyncio.sleep(0.01) + assert read_bytes == expected_data + + task.cancel() + + +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_closed(net_stream_pair): + stream_0, stream_1 = net_stream_pair + assert not stream_1.muxed_stream.event_remote_closed.is_set() + await stream_0.write(DATA) + await stream_0.close() + await asyncio.sleep(0.01) + assert stream_1.muxed_stream.event_remote_closed.is_set() + assert (await stream_1.read(MAX_READ_LEN)) == DATA + with pytest.raises(StreamEOF): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_read_after_local_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.reset() + with pytest.raises(StreamReset): + await stream_0.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.write(DATA) + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + with pytest.raises(StreamReset): + await stream_1.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_closed(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.write(DATA) + await stream_0.close() + with pytest.raises(StreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.reset() + with pytest.raises(StreamClosed): + await stream_0.write(DATA) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_remote_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_1.reset() + await asyncio.sleep(0.01) + with pytest.raises(StreamClosed): + await stream_0.write(DATA) From df312f3e5798ca48c6d93ddbd3544edcfddc24e0 Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 9 Sep 2019 17:05:06 +0800 Subject: [PATCH 3/8] Fix linting --- libp2p/stream_muxer/mplex/mplex_stream.py | 6 ++++-- tests/factories.py | 1 - 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 547d7b84..19a2637c 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -8,6 +8,7 @@ from .datastructures import StreamID from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: + from typing import Any # noqa: F401 from libp2p.stream_muxer.mplex.mplex import Mplex @@ -67,8 +68,9 @@ class MplexStream(IMuxedStream): fut.cancel() if self.event_reset.is_set(): raise MplexStreamReset - done_task = tuple(done)[0] - if done_task._coro.__qualname__ == "Queue.get": + done_task = cast("asyncio.Task[Any]", tuple(done)[0]) + # TODO: `_coro` is not in `asyncio.Task`'s typeshed. + if done_task._coro.__qualname__ == "Queue.get": # type: ignore data = done_task.result() self._buf.extend(data) return diff --git a/tests/factories.py b/tests/factories.py index e161e25d..0f69707a 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -15,7 +15,6 @@ from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio from libp2p.stream_muxer.mplex.mplex import Mplex -from libp2p.stream_muxer.mplex.mplex_stream import MplexStream from libp2p.typing import TProtocol from tests.configs import LISTEN_MADDR from tests.pubsub.configs import ( From e5eb01d22b38a146dd4c9060192ce90217c84602 Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 9 Sep 2019 22:48:49 +0800 Subject: [PATCH 4/8] Fix stream read --- libp2p/stream_muxer/mplex/mplex_stream.py | 28 ++++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 19a2637c..8c219659 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,5 +1,5 @@ import asyncio -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from libp2p.stream_muxer.abc import IMuxedStream @@ -8,7 +8,6 @@ from .datastructures import StreamID from .exceptions import MplexStreamClosed, MplexStreamEOF, MplexStreamReset if TYPE_CHECKING: - from typing import Any # noqa: F401 from libp2p.stream_muxer.mplex.mplex import Mplex @@ -66,16 +65,33 @@ class MplexStream(IMuxedStream): ) for fut in pending: fut.cancel() + if self.event_reset.is_set(): raise MplexStreamReset - done_task = cast("asyncio.Task[Any]", tuple(done)[0]) - # TODO: `_coro` is not in `asyncio.Task`'s typeshed. - if done_task._coro.__qualname__ == "Queue.get": # type: ignore + + if len(done) != 1: + raise Exception(f"Should be exactly 1 job in {done}.") + done_task = tuple(done)[0] + # NOTE: Ignore type check because the typeshed for `asyncio.Task` does not + # have the field `_coro`. + coro_qualname = done_task._coro.__qualname__ # type: ignore + # If `qualname == "Queue.get"` then there is incoming data. We can add it to the buffer. + if coro_qualname == "Queue.get": data = done_task.result() self._buf.extend(data) return + if self.event_remote_closed.is_set(): raise MplexStreamEOF + + # If the task is not `Queue.get`, then it must be `Event.wait`. + # However, it is abnormal that `Event.wait` is unblocked without any of the event + # (remote_closed and reset) is set. Then it is highly possible that the task + # is cancelled. + raise Exception( + "Should not enter here. " + f"It is highly possible that `done_task` is cancelled. `done_task`={done_task}" + ) # TODO: Handle timeout when deadline is used. async def _read_until_eof(self) -> bytes: @@ -107,7 +123,7 @@ class MplexStream(IMuxedStream): return await self._read_until_eof() if len(self._buf) == 0 and self.incoming_data.empty(): await self._wait_for_data() - # Either `buf` is not empty or `incoming_data` is not empty now. + # Now we are sure we have something to read. # Try to put enough incoming data into `self._buf`. while len(self._buf) < n: try: From a45eb76421a6ef527a1a488fc974df56ed985857 Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 9 Sep 2019 22:52:47 +0800 Subject: [PATCH 5/8] Suppress all exceptions in clean up. --- tests/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index a26ebc55..8ae72d66 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,6 +17,7 @@ async def connect(node1, node2): await node1.connect(info) +# FIXME: Should be deprecated, since it also kills the main task. async def cleanup(): pending = asyncio.all_tasks() for task in pending: @@ -24,7 +25,9 @@ async def cleanup(): # Now we should await task to execute it's cancellation. # Cancelled task raises asyncio.CancelledError that we can suppress: - with suppress(asyncio.CancelledError): + # NOTE: Changed from `asyncio.CancelledError` to `Exception`, to suppress all exceptions + # including the one in `run_until_complete`. + with suppress(Exception): await task From bb0da41edafb12ac076a4b2314bd2a305eaea711 Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 9 Sep 2019 23:09:33 +0800 Subject: [PATCH 6/8] Remove `cleanup` `cleanup` cancels all tasks in the loop, including the main one run by `run_until_complete` --- tests/examples/test_chat.py | 4 +--- tests/libp2p/test_libp2p.py | 9 +-------- tests/libp2p/test_notify.py | 9 +-------- tests/protocol_muxer/test_protocol_muxer.py | 5 +---- .../floodsub_integration_test_settings.py | 3 +-- tests/pubsub/test_dummyaccount_demo.py | 3 +-- tests/pubsub/test_floodsub.py | 4 +--- tests/pubsub/test_gossipsub.py | 17 +---------------- tests/security/test_security_multistream.py | 3 +-- tests/utils.py | 17 ----------------- 10 files changed, 9 insertions(+), 65 deletions(-) diff --git a/tests/examples/test_chat.py b/tests/examples/test_chat.py index 75d8ec71..18a172c8 100644 --- a/tests/examples/test_chat.py +++ b/tests/examples/test_chat.py @@ -4,7 +4,7 @@ import pytest from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.utils import set_up_nodes_by_transport_opt PROTOCOL_ID = "/chat/1.0.0" @@ -101,5 +101,3 @@ async def test_chat(test): await host_b.connect(info) await test(host_a, host_b) - - await cleanup() diff --git a/tests/libp2p/test_libp2p.py b/tests/libp2p/test_libp2p.py index 8090f5ea..793444c0 100644 --- a/tests/libp2p/test_libp2p.py +++ b/tests/libp2p/test_libp2p.py @@ -3,7 +3,7 @@ import pytest from libp2p.peer.peerinfo import info_from_p2p_addr from tests.constants import MAX_READ_LEN -from tests.utils import cleanup, set_up_nodes_by_transport_opt +from tests.utils import set_up_nodes_by_transport_opt @pytest.mark.asyncio @@ -34,7 +34,6 @@ async def test_simple_messages(): assert response == ("ack:" + message) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -69,7 +68,6 @@ async def test_double_response(): assert response2 == ("ack2:" + message) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -120,7 +118,6 @@ async def test_multiple_streams(): ) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -183,7 +180,6 @@ async def test_multiple_streams_same_initiator_different_protocols(): ) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -264,7 +260,6 @@ async def test_multiple_streams_two_initiators(): ) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -326,7 +321,6 @@ async def test_triangle_nodes_connection(): assert response == ("ack:" + message) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -353,4 +347,3 @@ async def test_host_connect(): assert addr.encapsulate(ma_node_b) in node_b.get_addrs() # Success, terminate pending tasks. - await cleanup() diff --git a/tests/libp2p/test_notify.py b/tests/libp2p/test_notify.py index e21030ab..b9a8707e 100644 --- a/tests/libp2p/test_notify.py +++ b/tests/libp2p/test_notify.py @@ -17,7 +17,7 @@ from libp2p.crypto.rsa import create_new_key_pair from libp2p.host.basic_host import BasicHost from libp2p.network.notifee_interface import INotifee from tests.constants import MAX_READ_LEN -from tests.utils import cleanup, perform_two_host_set_up +from tests.utils import perform_two_host_set_up ACK = "ack:" @@ -91,7 +91,6 @@ async def test_one_notifier(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -138,7 +137,6 @@ async def test_one_notifier_on_two_nodes(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -203,7 +201,6 @@ async def test_one_notifier_on_two_nodes_with_listen(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -235,7 +232,6 @@ async def test_two_notifiers(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -271,7 +267,6 @@ async def test_ten_notifiers(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -325,7 +320,6 @@ async def test_ten_notifiers_on_two_nodes(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -355,4 +349,3 @@ async def test_invalid_notifee(): assert response == expected_resp # Success, terminate pending tasks. - await cleanup() diff --git a/tests/protocol_muxer/test_protocol_muxer.py b/tests/protocol_muxer/test_protocol_muxer.py index 7830aaa8..d7523ac2 100644 --- a/tests/protocol_muxer/test_protocol_muxer.py +++ b/tests/protocol_muxer/test_protocol_muxer.py @@ -1,7 +1,7 @@ import pytest from libp2p.protocol_muxer.exceptions import MultiselectClientError -from tests.utils import cleanup, echo_stream_handler, set_up_nodes_by_transport_opt +from tests.utils import echo_stream_handler, set_up_nodes_by_transport_opt # TODO: Add tests for multiple streams being opened on different # protocols through the same connection @@ -35,7 +35,6 @@ async def perform_simple_test( assert expected_selected_protocol == stream.get_protocol() # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio @@ -52,7 +51,6 @@ async def test_single_protocol_fails(): await perform_simple_test("", ["/echo/1.0.0"], ["/potato/1.0.0"]) # Cleanup not reached on error - await cleanup() @pytest.mark.asyncio @@ -83,4 +81,3 @@ async def test_multiple_protocol_fails(): await perform_simple_test("", protocols_for_client, protocols_for_listener) # Cleanup not reached on error - await cleanup() diff --git a/tests/pubsub/floodsub_integration_test_settings.py b/tests/pubsub/floodsub_integration_test_settings.py index d96fc2b9..0a533e28 100644 --- a/tests/pubsub/floodsub_integration_test_settings.py +++ b/tests/pubsub/floodsub_integration_test_settings.py @@ -4,7 +4,7 @@ import pytest from tests.configs import LISTEN_MADDR from tests.factories import PubsubFactory -from tests.utils import cleanup, connect +from tests.utils import connect from .configs import FLOODSUB_PROTOCOL_ID @@ -258,4 +258,3 @@ async def perform_test_from_obj(obj, router_factory): assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id # Success, terminate pending tasks. - await cleanup() diff --git a/tests/pubsub/test_dummyaccount_demo.py b/tests/pubsub/test_dummyaccount_demo.py index b365134f..edc2f51e 100644 --- a/tests/pubsub/test_dummyaccount_demo.py +++ b/tests/pubsub/test_dummyaccount_demo.py @@ -3,7 +3,7 @@ from threading import Thread import pytest -from tests.utils import cleanup, connect +from tests.utils import connect from .dummy_account_node import DummyAccountNode @@ -64,7 +64,6 @@ async def perform_test(num_nodes, adjacency_map, action_func, assertion_func): assertion_func(dummy_node) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index 7e079d11..c6d28bfb 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -4,7 +4,7 @@ import pytest from libp2p.peer.id import ID from tests.factories import FloodsubFactory -from tests.utils import cleanup, connect +from tests.utils import connect from .floodsub_integration_test_settings import ( floodsub_protocol_pytest_params, @@ -36,7 +36,6 @@ async def test_simple_two_nodes(pubsubs_fsub): assert res_b.topicIDs == [topic] # Success, terminate pending tasks. - await cleanup() # Initialize Pubsub with a cache_size of 4 @@ -82,7 +81,6 @@ async def test_lru_cache_two_nodes(pubsubs_fsub, monkeypatch): assert sub_b.empty() # Success, terminate pending tasks. - await cleanup() @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index 7a0efc2c..95775be9 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -3,7 +3,7 @@ import random import pytest -from tests.utils import cleanup, connect +from tests.utils import connect from .configs import GossipsubParams from .utils import dense_connect, one_to_all_connect @@ -61,8 +61,6 @@ async def test_join(num_hosts, hosts, pubsubs_gsub): assert hosts[i].get_id() not in gossipsubs[central_node_index].mesh[topic] assert topic not in gossipsubs[i].mesh - await cleanup() - @pytest.mark.parametrize("num_hosts", (1,)) @pytest.mark.asyncio @@ -81,8 +79,6 @@ async def test_leave(pubsubs_gsub): # Test re-leave await gossipsub.leave(topic) - await cleanup() - @pytest.mark.parametrize("num_hosts", (2,)) @pytest.mark.asyncio @@ -133,8 +129,6 @@ async def test_handle_graft(pubsubs_gsub, hosts, event_loop, monkeypatch): # Check that bob is now alice's mesh peer assert id_bob in gossipsubs[index_alice].mesh[topic] - await cleanup() - @pytest.mark.parametrize( "num_hosts, gossipsub_params", ((2, GossipsubParams(heartbeat_interval=3)),) @@ -174,8 +168,6 @@ async def test_handle_prune(pubsubs_gsub, hosts): assert id_alice not in gossipsubs[index_bob].mesh[topic] assert id_bob in gossipsubs[index_alice].mesh[topic] - await cleanup() - @pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.asyncio @@ -210,7 +202,6 @@ async def test_dense(num_hosts, pubsubs_gsub, hosts): for queue in queues: msg = await queue.get() assert msg.data == msg_content - await cleanup() @pytest.mark.parametrize("num_hosts", (10,)) @@ -268,8 +259,6 @@ async def test_fanout(hosts, pubsubs_gsub): msg = await queue.get() assert msg.data == msg_content - await cleanup() - @pytest.mark.parametrize("num_hosts", (10,)) @pytest.mark.asyncio @@ -340,8 +329,6 @@ async def test_fanout_maintenance(hosts, pubsubs_gsub): msg = await queue.get() assert msg.data == msg_content - await cleanup() - @pytest.mark.parametrize( "num_hosts, gossipsub_params", @@ -380,5 +367,3 @@ async def test_gossip_propagation(hosts, pubsubs_gsub): # should be able to read message msg = await queue_1.get() assert msg.data == msg_content - - await cleanup() diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index 1d87e7b6..c8e83c1e 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -6,7 +6,7 @@ from libp2p import new_node from libp2p.crypto.rsa import create_new_key_pair from libp2p.security.insecure.transport import InsecureSession, InsecureTransport from tests.configs import LISTEN_MADDR -from tests.utils import cleanup, connect +from tests.utils import connect # TODO: Add tests for multiple streams being opened on different # protocols through the same connection @@ -57,7 +57,6 @@ async def perform_simple_test( assertion_func(node2_conn.secured_conn) # Success, terminate pending tasks. - await cleanup() @pytest.mark.asyncio diff --git a/tests/utils.py b/tests/utils.py index 8ae72d66..e9d6c09f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,3 @@ -import asyncio -from contextlib import suppress - import multiaddr from libp2p import new_node @@ -17,20 +14,6 @@ async def connect(node1, node2): await node1.connect(info) -# FIXME: Should be deprecated, since it also kills the main task. -async def cleanup(): - pending = asyncio.all_tasks() - for task in pending: - task.cancel() - - # Now we should await task to execute it's cancellation. - # Cancelled task raises asyncio.CancelledError that we can suppress: - # NOTE: Changed from `asyncio.CancelledError` to `Exception`, to suppress all exceptions - # including the one in `run_until_complete`. - with suppress(Exception): - await task - - async def set_up_nodes_by_transport_opt(transport_opt_list): nodes_list = [] for transport_opt in transport_opt_list: From df87f5adb939e89045af61dc50bc06527e4f1aeb Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 10 Sep 2019 17:51:39 +0800 Subject: [PATCH 7/8] Add tests against the daemon for close/reset --- libp2p/stream_muxer/mplex/mplex_stream.py | 1 - tests/interop/conftest.py | 71 ++++++++++++++++++++++ tests/interop/test_net_stream.py | 74 +++++++++++++++++++++++ tests/network/test_net_stream.py | 11 ++++ 4 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 tests/interop/test_net_stream.py diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 8c219659..4ae9e4ad 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -112,7 +112,6 @@ class MplexStream(IMuxedStream): :param n: number of bytes to read :return: bytes actually read """ - # TODO: Add exceptions and handle/raise them in this class. if n < 0 and n != -1: raise ValueError( f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" diff --git a/tests/interop/conftest.py b/tests/interop/conftest.py index 7261ee7b..c280a4c4 100644 --- a/tests/interop/conftest.py +++ b/tests/interop/conftest.py @@ -2,13 +2,16 @@ import asyncio import sys from typing import Union +from p2pclient.datastructures import StreamInfo import pexpect import pytest +from libp2p.io.abc import ReadWriteCloser from tests.factories import FloodsubFactory, GossipsubFactory, PubsubFactory from tests.pubsub.configs import GOSSIPSUB_PARAMS from .daemon import Daemon, make_p2pd +from .utils import connect @pytest.fixture @@ -78,3 +81,71 @@ def pubsubs(num_hosts, hosts, is_gossipsub): ) yield _pubsubs # TODO: Clean up + + +class DaemonStream(ReadWriteCloser): + stream_info: StreamInfo + reader: asyncio.StreamReader + writer: asyncio.StreamWriter + + def __init__( + self, + stream_info: StreamInfo, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + self.stream_info = stream_info + self.reader = reader + self.writer = writer + + async def close(self) -> None: + self.writer.close() + await self.writer.wait_closed() + + async def read(self, n: int = -1) -> bytes: + return await self.reader.read(n) + + async def write(self, data: bytes) -> int: + return self.writer.write(data) + + +@pytest.fixture +async def is_to_fail_daemon_stream(): + return False + + +@pytest.fixture +async def py_to_daemon_stream_pair(hosts, p2pds, is_to_fail_daemon_stream): + assert len(hosts) >= 1 + assert len(p2pds) >= 1 + host = hosts[0] + p2pd = p2pds[0] + protocol_id = "/protocol/id/123" + stream_py = None + stream_daemon = None + event_stream_handled = asyncio.Event() + await connect(host, p2pd) + + async def daemon_stream_handler(stream_info, reader, writer): + nonlocal stream_daemon + stream_daemon = DaemonStream(stream_info, reader, writer) + event_stream_handled.set() + + await p2pd.control.stream_handler(protocol_id, daemon_stream_handler) + + if is_to_fail_daemon_stream: + # FIXME: This is a workaround to make daemon reset the stream. + # We intentionally close the listener on the python side, it makes the connection from + # daemon to us fail, and therefore the daemon resets the opened stream on their side. + # Reference: https://github.com/libp2p/go-libp2p-daemon/blob/b95e77dbfcd186ccf817f51e95f73f9fd5982600/stream.go#L47-L50 # noqa: E501 + # We need it because we want to test against `stream_py` after the remote side(daemon) + # is reset. This should be removed after the API `stream.reset` is exposed in daemon + # some day. + listener = p2pds[0].control.control.listener + listener.close() + await listener.wait_closed() + stream_py = await host.new_stream(p2pd.peer_id, [protocol_id]) + if not is_to_fail_daemon_stream: + await event_stream_handled.wait() + # NOTE: If `is_to_fail_daemon_stream == True`, `stream_daemon == None`. + yield stream_py, stream_daemon diff --git a/tests/interop/test_net_stream.py b/tests/interop/test_net_stream.py new file mode 100644 index 00000000..01713396 --- /dev/null +++ b/tests/interop/test_net_stream.py @@ -0,0 +1,74 @@ +import asyncio + +import pytest + +from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset +from tests.constants import MAX_READ_LEN + +DATA = b"data" + + +@pytest.mark.asyncio +async def test_net_stream_read_write(py_to_daemon_stream_pair, p2pds): + stream_py, stream_daemon = py_to_daemon_stream_pair + assert ( + stream_py.protocol_id is not None + and stream_py.protocol_id == stream_daemon.stream_info.proto + ) + await stream_py.write(DATA) + assert (await stream_daemon.read(MAX_READ_LEN)) == DATA + + +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_closed(py_to_daemon_stream_pair, p2pds): + stream_py, stream_daemon = py_to_daemon_stream_pair + await stream_daemon.write(DATA) + await stream_daemon.close() + await asyncio.sleep(0.01) + assert (await stream_py.read(MAX_READ_LEN)) == DATA + # EOF + with pytest.raises(StreamEOF): + await stream_py.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_read_after_local_reset(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await stream_py.reset() + with pytest.raises(StreamReset): + await stream_py.read(MAX_READ_LEN) + + +@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_reset(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await asyncio.sleep(0.01) + with pytest.raises(StreamReset): + await stream_py.read(MAX_READ_LEN) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_closed(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await stream_py.write(DATA) + await stream_py.close() + with pytest.raises(StreamClosed): + await stream_py.write(DATA) + + +@pytest.mark.asyncio +async def test_net_stream_write_after_local_reset(py_to_daemon_stream_pair, p2pds): + stream_py, stream_daemon = py_to_daemon_stream_pair + await stream_py.reset() + with pytest.raises(StreamClosed): + await stream_py.write(DATA) + + +@pytest.mark.parametrize("is_to_fail_daemon_stream", (True,)) +@pytest.mark.asyncio +async def test_net_stream_write_after_remote_reset(py_to_daemon_stream_pair, p2pds): + stream_py, _ = py_to_daemon_stream_pair + await asyncio.sleep(0.01) + with pytest.raises(StreamClosed): + await stream_py.write(DATA) diff --git a/tests/network/test_net_stream.py b/tests/network/test_net_stream.py index e7029125..80bed6ce 100644 --- a/tests/network/test_net_stream.py +++ b/tests/network/test_net_stream.py @@ -85,6 +85,17 @@ async def test_net_stream_read_after_remote_reset(net_stream_pair): await stream_1.read(MAX_READ_LEN) +@pytest.mark.asyncio +async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair): + stream_0, stream_1 = net_stream_pair + await stream_0.write(DATA) + await stream_0.close() + await stream_0.reset() + # Sleep to let `stream_1` receive the message. + await asyncio.sleep(0.01) + assert (await stream_1.read(MAX_READ_LEN)) == DATA + + @pytest.mark.asyncio async def test_net_stream_write_after_local_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair From 31fb4e0b690bcd10dd3ef2d0747fc4a4efc62625 Mon Sep 17 00:00:00 2001 From: mhchia Date: Tue, 10 Sep 2019 23:38:45 +0800 Subject: [PATCH 8/8] Rewrite `_wait_for_data`, to handle task precisely Make the futures first, and then we can compare them with the return value from `asyncio.wait`. --- libp2p/stream_muxer/mplex/mplex_stream.py | 54 ++++++++++++----------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 4ae9e4ad..87b039f9 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -55,43 +55,45 @@ class MplexStream(IMuxedStream): return self.stream_id.is_initiator async def _wait_for_data(self) -> None: + task_event_reset = asyncio.ensure_future(self.event_reset.wait()) + task_incoming_data_get = asyncio.ensure_future(self.incoming_data.get()) + task_event_remote_closed = asyncio.ensure_future( + self.event_remote_closed.wait() + ) done, pending = await asyncio.wait( # type: ignore - [ - self.event_reset.wait(), - self.incoming_data.get(), - self.event_remote_closed.wait(), - ], + [task_event_reset, task_incoming_data_get, task_event_remote_closed], return_when=asyncio.FIRST_COMPLETED, ) for fut in pending: fut.cancel() - if self.event_reset.is_set(): - raise MplexStreamReset + if task_event_reset in done: + if self.event_reset.is_set(): + raise MplexStreamReset + else: + # However, it is abnormal that `Event.wait` is unblocked without any of the flag + # is set. The task is probably cancelled. + raise Exception( + "Should not enter here. " + f"It is probably because {task_event_remote_closed} is cancelled." + ) - if len(done) != 1: - raise Exception(f"Should be exactly 1 job in {done}.") - done_task = tuple(done)[0] - # NOTE: Ignore type check because the typeshed for `asyncio.Task` does not - # have the field `_coro`. - coro_qualname = done_task._coro.__qualname__ # type: ignore - # If `qualname == "Queue.get"` then there is incoming data. We can add it to the buffer. - if coro_qualname == "Queue.get": - data = done_task.result() + if task_incoming_data_get in done: + data = task_incoming_data_get.result() self._buf.extend(data) return - if self.event_remote_closed.is_set(): - raise MplexStreamEOF + if task_event_remote_closed in done: + if self.event_remote_closed.is_set(): + raise MplexStreamEOF + else: + # However, it is abnormal that `Event.wait` is unblocked without any of the flag + # is set. The task is probably cancelled. + raise Exception( + "Should not enter here. " + f"It is probably because {task_event_remote_closed} is cancelled." + ) - # If the task is not `Queue.get`, then it must be `Event.wait`. - # However, it is abnormal that `Event.wait` is unblocked without any of the event - # (remote_closed and reset) is set. Then it is highly possible that the task - # is cancelled. - raise Exception( - "Should not enter here. " - f"It is highly possible that `done_task` is cancelled. `done_task`={done_task}" - ) # TODO: Handle timeout when deadline is used. async def _read_until_eof(self) -> bytes: