From 037b95252daa397cbe3579aad770a70159b38d0e Mon Sep 17 00:00:00 2001 From: mhchia Date: Mon, 29 Jul 2019 22:49:48 +0800 Subject: [PATCH] Add tests for `Pubsub` - `test_get_hello_packet` - `test_continuously_read_stream` - `test_publish` - `test_push_msg` --- tests/pubsub/test_pubsub.py | 194 +++++++++++++++++++++++++++++++++++- tests/pubsub/utils.py | 8 +- 2 files changed, 198 insertions(+), 4 deletions(-) diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 4de5cf8d..497fd862 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -1,16 +1,23 @@ import asyncio +import io +from typing import NamedTuple import pytest +from libp2p.peer.id import ID from libp2p.pubsub.pb import rpc_pb2 from tests.utils import ( connect, ) +from .utils import ( + make_pubsub_msg, +) + TESTING_TOPIC = "TEST_SUBSCRIBE" -TESTIND_DATA = b"data" +TESTING_DATA = b"data" @pytest.mark.parametrize( @@ -101,3 +108,188 @@ async def test_get_hello_packet(pubsubs_fsub): for topic in topic_ids: assert topic in topic_ids_in_hello + +class FakeNetStream: + _queue: asyncio.Queue + + class FakeMplexConn(NamedTuple): + peer_id: ID = ID(b"\x12\x20" + b"\x00" * 32) + + mplex_conn = FakeMplexConn() + + def __init__(self) -> None: + self._queue = asyncio.Queue() + + async def read(self) -> bytes: + buf = io.BytesIO() + while not self._queue.empty(): + buf.write(await self._queue.get()) + return buf.getvalue() + + async def write(self, data: bytes) -> int: + for i in data: + await self._queue.put(i.to_bytes(1, 'big')) + return len(data) + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_continuously_read_stream(pubsubs_fsub, monkeypatch): + s = FakeNetStream() + + await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + + event_push_msg = asyncio.Event() + event_handle_subscription = asyncio.Event() + event_handle_rpc = asyncio.Event() + + async def mock_push_msg(msg_forwarder, msg): + event_push_msg.set() + + def mock_handle_subscription(origin_id, sub_message): + event_handle_subscription.set() + + async def mock_handle_rpc(rpc, sender_peer_id): + event_handle_rpc.set() + + monkeypatch.setattr(pubsubs_fsub[0], "push_msg", mock_push_msg) + monkeypatch.setattr(pubsubs_fsub[0], "handle_subscription", mock_handle_subscription) + monkeypatch.setattr(pubsubs_fsub[0].router, "handle_rpc", mock_handle_rpc) + + async def wait_for_event_occurring(event): + try: + await asyncio.wait_for(event.wait(), timeout=0.01) + except asyncio.TimeoutError as error: + event.clear() + raise asyncio.TimeoutError( + f"Event {event} is not set before the timeout. " + "This indicates the mocked functions are not called properly." + ) from error + else: + event.clear() + + # Kick off the task `continuously_read_stream` + task = asyncio.ensure_future(pubsubs_fsub[0].continuously_read_stream(s)) + + # Test: `push_msg` is called when publishing to a subscribed topic. + publish_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message( + topicIDs=[TESTING_TOPIC] + )], + ) + await s.write(publish_subscribed_topic.SerializeToString()) + await wait_for_event_occurring(event_push_msg) + # Make sure the other events are not emitted. + with pytest.raises(asyncio.TimeoutError): + await wait_for_event_occurring(event_handle_subscription) + with pytest.raises(asyncio.TimeoutError): + await wait_for_event_occurring(event_handle_rpc) + + # Test: `push_msg` is not called when publishing to a topic-not-subscribed. + publish_not_subscribed_topic = rpc_pb2.RPC( + publish=[rpc_pb2.Message( + topicIDs=["NOT_SUBSCRIBED"] + )], + ) + await s.write(publish_not_subscribed_topic.SerializeToString()) + with pytest.raises(asyncio.TimeoutError): + await wait_for_event_occurring(event_push_msg) + + # Test: `handle_subscription` is called when a subscription message is received. + subscription_msg = rpc_pb2.RPC( + subscriptions=[rpc_pb2.RPC.SubOpts()], + ) + await s.write(subscription_msg.SerializeToString()) + await wait_for_event_occurring(event_handle_subscription) + # Make sure the other events are not emitted. + with pytest.raises(asyncio.TimeoutError): + await wait_for_event_occurring(event_push_msg) + with pytest.raises(asyncio.TimeoutError): + await wait_for_event_occurring(event_handle_rpc) + + # Test: `handle_rpc` is called when a control message is received. + control_msg = rpc_pb2.RPC(control=rpc_pb2.ControlMessage()) + await s.write(control_msg.SerializeToString()) + await wait_for_event_occurring(event_handle_rpc) + # Make sure the other events are not emitted. + with pytest.raises(asyncio.TimeoutError): + await wait_for_event_occurring(event_push_msg) + with pytest.raises(asyncio.TimeoutError): + await wait_for_event_occurring(event_handle_subscription) + + task.cancel() + + +@pytest.mark.parametrize( + "num_hosts", + (2,), +) +@pytest.mark.asyncio +async def test_publish(pubsubs_fsub, monkeypatch): + msg_forwarders = [] + msgs = [] + + async def push_msg(msg_forwarder, msg): + msg_forwarders.append(msg_forwarder) + msgs.append(msg) + monkeypatch.setattr(pubsubs_fsub[0], "push_msg", push_msg) + + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + await pubsubs_fsub[0].publish(TESTING_TOPIC, TESTING_DATA) + + assert len(msgs) == 2, "`push_msg` should be called every time `publish` is called" + assert (msg_forwarders[0] == msg_forwarders[1]) and (msg_forwarders[1] == pubsubs_fsub[0].my_id) + assert msgs[0].seqno != msgs[1].seqno, "`seqno` should be different every time" + + +@pytest.mark.parametrize( + "num_hosts", + (1,), +) +@pytest.mark.asyncio +async def test_push_msg(pubsubs_fsub, monkeypatch): + # pylint: disable=protected-access + msg_0 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x00" * 8, + ) + + event = asyncio.Event() + + async def router_publish(*args, **kwargs): + event.set() + monkeypatch.setattr(pubsubs_fsub[0].router, "publish", router_publish) + + # Test: `msg` is not seen before `push_msg`, and is seen after `push_msg`. + assert not pubsubs_fsub[0]._is_msg_seen(msg_0) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + assert pubsubs_fsub[0]._is_msg_seen(msg_0) + # Test: Ensure `router.publish` is called in `push_msg` + await asyncio.wait_for(event.wait(), timeout=0.1) + + # Test: `push_msg` the message again and it will be reject. + # `router_publish` is not called then. + event.clear() + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_0) + await asyncio.sleep(0.01) + assert not event.is_set() + + sub = await pubsubs_fsub[0].subscribe(TESTING_TOPIC) + # Test: `push_msg` succeeds with another unseen msg. + msg_1 = make_pubsub_msg( + origin_id=pubsubs_fsub[0].my_id, + topic_ids=[TESTING_TOPIC], + data=TESTING_DATA, + seqno=b"\x11" * 8, + ) + assert not pubsubs_fsub[0]._is_msg_seen(msg_1) + await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg_1) + assert pubsubs_fsub[0]._is_msg_seen(msg_1) + await asyncio.wait_for(event.wait(), timeout=0.1) + # Test: Subscribers are notified when `push_msg` new messages. + assert (await sub.get()) == msg_1 diff --git a/tests/pubsub/utils.py b/tests/pubsub/utils.py index a0cc0274..b83c854c 100644 --- a/tests/pubsub/utils.py +++ b/tests/pubsub/utils.py @@ -14,6 +14,8 @@ from libp2p.pubsub.pubsub import Pubsub from tests.utils import connect +from .configs import LISTEN_MADDR + def message_id_generator(start_val): """ @@ -80,13 +82,13 @@ async def create_libp2p_hosts(num_hosts): tasks_create = [] for i in range(0, num_hosts): # Create node - tasks_create.append(new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"])) + tasks_create.append(new_node(transport_opt=[str(LISTEN_MADDR)])) hosts = await asyncio.gather(*tasks_create) tasks_listen = [] for node in hosts: # Start listener - tasks_listen.append(asyncio.ensure_future(node.get_network().listen(multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/0")))) + tasks_listen.append(node.get_network().listen(LISTEN_MADDR)) await asyncio.gather(*tasks_listen) return hosts @@ -109,7 +111,7 @@ def create_pubsub_and_gossipsub_instances( degree_low, degree_high, time_to_live, gossip_window, gossip_history, heartbeat_interval) - pubsub = Pubsub(node, gossipsub, "a") + pubsub = Pubsub(node, gossipsub, node.get_id()) pubsubs.append(pubsub) gossipsubs.append(gossipsub)