diff --git a/libp2p/pubsub/abc.py b/libp2p/pubsub/abc.py index e4b75840..da37b6a1 100644 --- a/libp2p/pubsub/abc.py +++ b/libp2p/pubsub/abc.py @@ -24,7 +24,7 @@ class ISubscriptionAPI( AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message] ): @abstractmethod - async def cancel(self) -> None: + async def unsubscribe(self) -> None: ... @abstractmethod diff --git a/libp2p/pubsub/pubsub.py b/libp2p/pubsub/pubsub.py index 03140742..26c4b4ff 100644 --- a/libp2p/pubsub/pubsub.py +++ b/libp2p/pubsub/pubsub.py @@ -1,3 +1,4 @@ +import functools import logging import math import time @@ -387,9 +388,14 @@ class Pubsub(Service, IPubsub): if topic_id in self.topic_ids: return self.subscribed_topics_receive[topic_id] - channels = trio.open_memory_channel[rpc_pb2.Message](math.inf) - send_channel, receive_channel = channels - subscription = TrioSubscriptionAPI(receive_channel) + send_channel, receive_channel = trio.open_memory_channel[rpc_pb2.Message]( + math.inf + ) + + subscription = TrioSubscriptionAPI( + receive_channel, + unsubscribe_fn=functools.partial(self.unsubscribe, topic_id), + ) self.subscribed_topics_send[topic_id] = send_channel self.subscribed_topics_receive[topic_id] = subscription diff --git a/libp2p/pubsub/subscription.py b/libp2p/pubsub/subscription.py index 1d88d09b..e3c926cc 100644 --- a/libp2p/pubsub/subscription.py +++ b/libp2p/pubsub/subscription.py @@ -5,6 +5,7 @@ import trio from .abc import ISubscriptionAPI from .pb import rpc_pb2 +from .typing import UnsubscribeFn class BaseSubscriptionAPI(ISubscriptionAPI): @@ -18,19 +19,25 @@ class BaseSubscriptionAPI(ISubscriptionAPI): exc_value: "Optional[BaseException]", traceback: "Optional[TracebackType]", ) -> None: - await self.cancel() + await self.unsubscribe() class TrioSubscriptionAPI(BaseSubscriptionAPI): receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + unsubscribe_fn: UnsubscribeFn def __init__( - self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]" + self, + receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]", + unsubscribe_fn: UnsubscribeFn, ) -> None: self.receive_channel = receive_channel + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + self.unsubscribe_fn = unsubscribe_fn # type: ignore - async def cancel(self) -> None: - await self.receive_channel.aclose() + async def unsubscribe(self) -> None: + # Ignore type here since mypy complains: https://github.com/python/mypy/issues/2427 + await self.unsubscribe_fn() # type: ignore def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]: return self.receive_channel.__aiter__() diff --git a/libp2p/pubsub/typing.py b/libp2p/pubsub/typing.py index c352d529..33297a9f 100644 --- a/libp2p/pubsub/typing.py +++ b/libp2p/pubsub/typing.py @@ -7,3 +7,5 @@ from .pb import rpc_pb2 SyncValidatorFn = Callable[[ID, rpc_pb2.Message], bool] AsyncValidatorFn = Callable[[ID, rpc_pb2.Message], Awaitable[bool]] ValidatorFn = Union[SyncValidatorFn, AsyncValidatorFn] + +UnsubscribeFn = Callable[[], Awaitable[None]] diff --git a/tests/pubsub/test_pubsub.py b/tests/pubsub/test_pubsub.py index 4bea6dd2..1e9d670a 100644 --- a/tests/pubsub/test_pubsub.py +++ b/tests/pubsub/test_pubsub.py @@ -413,7 +413,39 @@ async def test_message_all_peers(monkeypatch, is_host_secure): @pytest.mark.trio -async def test_publish(monkeypatch): +async def test_subscribe_and_publish(): + async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub: + pubsub = pubsubs_fsub[0] + + list_data = [b"d0", b"d1"] + event_receive_data_started = trio.Event() + + async def publish_data(topic): + await event_receive_data_started.wait() + for data in list_data: + await pubsub.publish(topic, data) + + async def receive_data(topic): + i = 0 + event_receive_data_started.set() + assert topic not in pubsub.topic_ids + subscription = await pubsub.subscribe(topic) + async with subscription: + assert topic in pubsub.topic_ids + async for msg in subscription: + assert msg.data == list_data[i] + i += 1 + if i == len(list_data): + break + assert topic not in pubsub.topic_ids + + async with trio.open_nursery() as nursery: + nursery.start_soon(receive_data, TESTING_TOPIC) + nursery.start_soon(publish_data, TESTING_TOPIC) + + +@pytest.mark.trio +async def test_publish_push_msg_is_called(monkeypatch): msg_forwarders = [] msgs = [] diff --git a/tests/pubsub/test_subscription.py b/tests/pubsub/test_subscription.py index c5a20eda..a0a6c10c 100644 --- a/tests/pubsub/test_subscription.py +++ b/tests/pubsub/test_subscription.py @@ -11,7 +11,14 @@ GET_TIMEOUT = 0.001 def make_trio_subscription(): send_channel, receive_channel = trio.open_memory_channel(math.inf) - return send_channel, TrioSubscriptionAPI(receive_channel) + + async def unsubscribe_fn(): + await send_channel.aclose() + + return ( + send_channel, + TrioSubscriptionAPI(receive_channel, unsubscribe_fn=unsubscribe_fn), + ) def make_pubsub_msg(): @@ -56,14 +63,14 @@ async def test_trio_subscription_iter(): @pytest.mark.trio -async def test_trio_subscription_cancel(): +async def test_trio_subscription_unsubscribe(): send_channel, sub = make_trio_subscription() - await sub.cancel() - # Test: If the subscription is cancelled, `send_channel` should be broken. - with pytest.raises(trio.BrokenResourceError): + await sub.unsubscribe() + # Test: If the subscription is unsubscribed, `send_channel` should be closed. + with pytest.raises(trio.ClosedResourceError): await send_something(send_channel) # Test: No side effect when cancelled twice. - await sub.cancel() + await sub.unsubscribe() @pytest.mark.trio @@ -73,5 +80,5 @@ async def test_trio_subscription_async_context_manager(): # Test: `sub` is not cancelled yet, so `send_something` works fine. await send_something(send_channel) # Test: `sub` is cancelled, `send_something` fails - with pytest.raises(trio.BrokenResourceError): + with pytest.raises(trio.ClosedResourceError): await send_something(send_channel)