diff --git a/libp2p/pubsub/pubsub_router_interface.py b/libp2p/pubsub/abc.py similarity index 86% rename from libp2p/pubsub/pubsub_router_interface.py rename to libp2p/pubsub/abc.py index 99a9be75..19f9b2a6 100644 --- a/libp2p/pubsub/pubsub_router_interface.py +++ b/libp2p/pubsub/abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List from libp2p.peer.id import ID from libp2p.typing import TProtocol @@ -10,6 +10,11 @@ if TYPE_CHECKING: from .pubsub import Pubsub # noqa: F401 +# TODO: Add interface for Pubsub +class IPubsub(ABC): + pass + + class IPubsubRouter(ABC): @abstractmethod def get_protocols(self) -> List[TProtocol]: @@ -53,7 +58,6 @@ class IPubsubRouter(ABC): :param rpc: rpc message """ - # FIXME: Should be changed to type 'peer.ID' @abstractmethod async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: """ @@ -80,3 +84,15 @@ class IPubsubRouter(ABC): :param topic: topic to leave """ + + +class ISubscriptionAPI( + AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] +): + @abstractmethod + async def cancel(self) -> None: + ... + + @abstractmethod + async def get(self) -> rpc_pb2.Message: + ... diff --git a/libp2p/pubsub/floodsub.py b/libp2p/pubsub/floodsub.py index 9e323eb2..06300eec 100644 --- a/libp2p/pubsub/floodsub.py +++ b/libp2p/pubsub/floodsub.py @@ -8,9 +8,9 @@ from libp2p.peer.id import ID from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .abc import IPubsubRouter from .pb import rpc_pb2 from .pubsub import Pubsub -from .pubsub_router_interface import IPubsubRouter PROTOCOL_ID = TProtocol("/floodsub/1.0.0") diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index df0f83f4..df886db5 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -12,11 +12,11 @@ from libp2p.pubsub import floodsub from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed +from .abc import IPubsubRouter from .exceptions import NoPubsubAttached from .mcache import MessageCache from .pb import rpc_pb2 from .pubsub import Pubsub -from .pubsub_router_interface import IPubsubRouter PROTOCOL_ID = TProtocol("/meshsub/1.0.0") diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 0c3b162c..71b82f48 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,4 +1,3 @@ -from abc import ABC import logging import math import time @@ -30,12 +29,14 @@ from libp2p.peer.id import ID from libp2p.typing import TProtocol from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes +from .abc import IPubsub, ISubscriptionAPI from .pb import rpc_pb2 from .pubsub_notifee import PubsubNotifee +from .subscription import TrioSubscriptionAPI from .validators import signature_validator if TYPE_CHECKING: - from .pubsub_router_interface import IPubsubRouter # noqa: F401 + from .abc import IPubsubRouter # noqa: F401 from typing import Any # noqa: F401 @@ -57,11 +58,6 @@ class TopicValidator(NamedTuple): is_async: bool -# TODO: Add interface for Pubsub -class IPubsub(ABC): - pass - - class Pubsub(IPubsub, Service): host: IHost @@ -75,7 +71,7 @@ class Pubsub(IPubsub, Service): # TODO: Implement `trio.abc.Channel`? subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"] - subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"] + subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"] peer_topics: Dict[str, List[ID]] peers: Dict[ID, INetStream] @@ -380,10 +376,7 @@ class Pubsub(IPubsub, Service): # for each topic await self.subscribed_topics_send[topic].send(publish_message) - # TODO: Change to return an `AsyncIterable` to be I/O-agnostic? - async def subscribe( - self, topic_id: str - ) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]": + async def subscribe(self, topic_id: str) -> ISubscriptionAPI: """ Subscribe ourself to a topic. @@ -396,14 +389,14 @@ class Pubsub(IPubsub, Service): if topic_id in self.topic_ids: return self.subscribed_topics_receive[topic_id] - # Map topic_id to a blocking channel channels: Tuple[ "trio.MemorySendChannel[rpc_pb2.Message]", "trio.MemoryReceiveChannel[rpc_pb2.Message]", ] = trio.open_memory_channel(math.inf) send_channel, receive_channel = channels + subscription = TrioSubscriptionAPI(receive_channel) self.subscribed_topics_send[topic_id] = send_channel - self.subscribed_topics_receive[topic_id] = receive_channel + self.subscribed_topics_receive[topic_id] = subscription # Create subscribe message packet: rpc_pb2.RPC = rpc_pb2.RPC() @@ -417,8 +410,8 @@ class Pubsub(IPubsub, Service): # Tell router we are joining this topic await self.router.join(topic_id) - # Return the trio channel for messages on this topic - return receive_channel + # Return the subscription for messages on this topic + return subscription async def unsubscribe(self, topic_id: str) -> None: """ diff --git a/libp2p/pubsub/subscription.py b/libp2p/pubsub/subscription.py new file mode 100644 index 00000000..1d88d09b --- /dev/null +++ b/libp2p/pubsub/subscription.py @@ -0,0 +1,39 @@ +from types import TracebackType +from typing import AsyncIterator, Optional, Type + +import trio + +from .abc import ISubscriptionAPI +from .pb import rpc_pb2 + + +class BaseSubscriptionAPI(ISubscriptionAPI): + async def __aenter__(self) -> "BaseSubscriptionAPI": + await trio.hazmat.checkpoint() + return self + + async def __aexit__( + self, + exc_type: "Optional[Type[BaseException]]", + exc_value: "Optional[BaseException]", + traceback: "Optional[TracebackType]", + ) -> None: + await self.cancel() + + +class TrioSubscriptionAPI(BaseSubscriptionAPI): + receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + + def __init__( + self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + ) -> None: + self.receive_channel = receive_channel + + async def cancel(self) -> None: + await self.receive_channel.aclose() + + def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]: + return self.receive_channel.__aiter__() + + async def get(self) -> rpc_pb2.Message: + return await self.receive_channel.receive() diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index a9eb6a53..52208da4 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -17,10 +17,10 @@ from libp2p.network.stream.net_stream_interface import INetStream from libp2p.network.swarm import Swarm from libp2p.peer.id import ID from libp2p.peer.peerstore import PeerStore +from libp2p.pubsub.abc import IPubsubRouter from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.pubsub import Pubsub -from libp2p.pubsub.pubsub_router_interface import IPubsubRouter from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport import libp2p.security.secio.transport as secio diff --git a/libp2p/tools/pubsub/dummy_account_node.py b/libp2p/tools/pubsub/dummy_account_node.py index 5a61ed69..9079ac20 100644 --- a/libp2p/tools/pubsub/dummy_account_node.py +++ b/libp2p/tools/pubsub/dummy_account_node.py @@ -61,7 +61,7 @@ class DummyAccountNode(Service): async def handle_incoming_msgs(self) -> None: """Handle all incoming messages on the CRYPTO_TOPIC from peers.""" while True: - incoming = await self.subscription.receive() + incoming = await self.subscription.get() msg_comps = incoming.data.decode("utf-8").split(",") if msg_comps[0] == "send": diff --git a/libp2p/tools/pubsub/floodsub_integration_test_settings.py b/libp2p/tools/pubsub/floodsub_integration_test_settings.py index 58a5b242..0d25586e 100644 --- a/libp2p/tools/pubsub/floodsub_integration_test_settings.py +++ b/libp2p/tools/pubsub/floodsub_integration_test_settings.py @@ -250,7 +250,7 @@ async def perform_test_from_obj(obj, pubsub_factory) -> None: # Look at each node in each topic for node_id in topic_map[topic]: # Get message from subscription queue - msg = await queues_map[node_id][topic].receive() + msg = await queues_map[node_id][topic].get() assert data == msg.data # Check the message origin assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id diff --git a/tests/pubsub/test_floodsub.py b/tests/pubsub/test_floodsub.py index dbeb6833..148c001b 100644 --- a/tests/pubsub/test_floodsub.py +++ b/tests/pubsub/test_floodsub.py @@ -27,7 +27,7 @@ async def test_simple_two_nodes(): await pubsubs_fsub[0].publish(topic, data) - res_b = await sub_b.receive() + res_b = await sub_b.get() # Check that the msg received by node_b is the same # as the message sent by node_a @@ -75,12 +75,9 @@ async def test_lru_cache_two_nodes(monkeypatch): await trio.sleep(0.25) for index in expected_received_indices: - res_b = await sub_b.receive() + res_b = await sub_b.get() assert res_b.data == _make_testing_data(index) - with pytest.raises(trio.WouldBlock): - sub_b.receive_nowait() - @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.trio diff --git a/tests/pubsub/test_gossipsub.py b/tests/pubsub/test_gossipsub.py index b1ed3af1..e9d789a9 100644 --- a/tests/pubsub/test_gossipsub.py +++ b/tests/pubsub/test_gossipsub.py @@ -196,7 +196,7 @@ async def test_dense(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: - msg = await queue.receive() + msg = await queue.get() assert msg.data == msg_content @@ -229,7 +229,7 @@ async def test_fanout(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for sub in subs: - msg = await sub.receive() + msg = await sub.get() assert msg.data == msg_content # Subscribe message origin @@ -248,7 +248,7 @@ async def test_fanout(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for sub in subs: - msg = await sub.receive() + msg = await sub.get() assert msg.data == msg_content @@ -287,7 +287,7 @@ async def test_fanout_maintenance(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: - msg = await queue.receive() + msg = await queue.get() assert msg.data == msg_content for sub in pubsubs_gsub: @@ -319,7 +319,7 @@ async def test_fanout_maintenance(): await trio.sleep(0.5) # Assert that all blocking queues receive the message for queue in queues: - msg = await queue.receive() + msg = await queue.get() assert msg.data == msg_content @@ -346,5 +346,5 @@ async def test_gossip_propagation(): await trio.sleep(2) # should be able to read message - msg = await queue_1.receive() + msg = await queue_1.get() assert msg.data == msg_content diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index ea04788d..c4f00011 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -384,7 +384,7 @@ async def test_handle_talk(): len(pubsubs_fsub[0].topic_ids) == 1 and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC] ) - assert (await sub.receive()) == msg_0 + assert (await sub.get()) == msg_0 @pytest.mark.trio @@ -486,7 +486,7 @@ async def test_push_msg(monkeypatch): with trio.fail_after(0.1): await event.wait() # Test: Subscribers are notified when `push_msg` new messages. - assert (await sub.receive()) == msg_1 + assert (await sub.get()) == msg_1 with mock_router_publish() as event: # Test: add a topic validator and `push_msg` the message that diff --git a/tests/pubsub/test_subscription.py b/tests/pubsub/test_subscription.py new file mode 100644 index 00000000..c5a20eda --- /dev/null +++ b/tests/pubsub/test_subscription.py @@ -0,0 +1,77 @@ +import math + +import pytest +import trio + +from libp2p.pubsub.pb import rpc_pb2 +from libp2p.pubsub.subscription import TrioSubscriptionAPI + +GET_TIMEOUT = 0.001 + + +def make_trio_subscription(): + send_channel, receive_channel = trio.open_memory_channel(math.inf) + return send_channel, TrioSubscriptionAPI(receive_channel) + + +def make_pubsub_msg(): + return rpc_pb2.Message() + + +async def send_something(send_channel): + msg = make_pubsub_msg() + await send_channel.send(msg) + return msg + + +@pytest.mark.trio +async def test_trio_subscription_get(): + send_channel, sub = make_trio_subscription() + data_0 = await send_something(send_channel) + data_1 = await send_something(send_channel) + assert data_0 == await sub.get() + assert data_1 == await sub.get() + # No more message + with pytest.raises(trio.TooSlowError): + with trio.fail_after(GET_TIMEOUT): + await sub.get() + + +@pytest.mark.trio +async def test_trio_subscription_iter(): + send_channel, sub = make_trio_subscription() + received_data = [] + + async def iter_subscriptions(subscription): + async for data in sub: + received_data.append(data) + + async with trio.open_nursery() as nursery: + nursery.start_soon(iter_subscriptions, sub) + await send_something(send_channel) + await send_something(send_channel) + await send_channel.aclose() + + assert len(received_data) == 2 + + +@pytest.mark.trio +async def test_trio_subscription_cancel(): + send_channel, sub = make_trio_subscription() + await sub.cancel() + # Test: If the subscription is cancelled, `send_channel` should be broken. + with pytest.raises(trio.BrokenResourceError): + await send_something(send_channel) + # Test: No side effect when cancelled twice. + await sub.cancel() + + +@pytest.mark.trio +async def test_trio_subscription_async_context_manager(): + send_channel, sub = make_trio_subscription() + async with sub: + # Test: `sub` is not cancelled yet, so `send_something` works fine. + await send_something(send_channel) + # Test: `sub` is cancelled, `send_something` fails + with pytest.raises(trio.BrokenResourceError): + await send_something(send_channel)