Add SubscriptionAPI

And `TrioSubscriptionAPI`, to make subscription io-agnostic.
This commit is contained in:
mhchia
2019-12-17 18:17:28 +08:00
parent fb0519129d
commit 47d10e186f
12 changed files with 158 additions and 36 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
@ -10,6 +10,11 @@ if TYPE_CHECKING:
from .pubsub import Pubsub # noqa: F401 from .pubsub import Pubsub # noqa: F401
# TODO: Add interface for Pubsub
class IPubsub(ABC):
pass
class IPubsubRouter(ABC): class IPubsubRouter(ABC):
@abstractmethod @abstractmethod
def get_protocols(self) -> List[TProtocol]: def get_protocols(self) -> List[TProtocol]:
@ -53,7 +58,6 @@ class IPubsubRouter(ABC):
:param rpc: rpc message :param rpc: rpc message
""" """
# FIXME: Should be changed to type 'peer.ID'
@abstractmethod @abstractmethod
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None: async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
""" """
@ -80,3 +84,15 @@ class IPubsubRouter(ABC):
:param topic: topic to leave :param topic: topic to leave
""" """
class ISubscriptionAPI(
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
):
@abstractmethod
async def cancel(self) -> None:
...
@abstractmethod
async def get(self) -> rpc_pb2.Message:
...

View File

@ -8,9 +8,9 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub import Pubsub from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/floodsub/1.0.0") PROTOCOL_ID = TProtocol("/floodsub/1.0.0")

View File

@ -12,11 +12,11 @@ from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .exceptions import NoPubsubAttached from .exceptions import NoPubsubAttached
from .mcache import MessageCache from .mcache import MessageCache
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub import Pubsub from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/meshsub/1.0.0") PROTOCOL_ID = TProtocol("/meshsub/1.0.0")

View File

@ -1,4 +1,3 @@
from abc import ABC
import logging import logging
import math import math
import time import time
@ -30,12 +29,14 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
from .abc import IPubsub, ISubscriptionAPI
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee from .pubsub_notifee import PubsubNotifee
from .subscription import TrioSubscriptionAPI
from .validators import signature_validator from .validators import signature_validator
if TYPE_CHECKING: if TYPE_CHECKING:
from .pubsub_router_interface import IPubsubRouter # noqa: F401 from .abc import IPubsubRouter # noqa: F401
from typing import Any # noqa: F401 from typing import Any # noqa: F401
@ -57,11 +58,6 @@ class TopicValidator(NamedTuple):
is_async: bool is_async: bool
# TODO: Add interface for Pubsub
class IPubsub(ABC):
pass
class Pubsub(IPubsub, Service): class Pubsub(IPubsub, Service):
host: IHost host: IHost
@ -75,7 +71,7 @@ class Pubsub(IPubsub, Service):
# TODO: Implement `trio.abc.Channel`? # TODO: Implement `trio.abc.Channel`?
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"] subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"] subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
peer_topics: Dict[str, List[ID]] peer_topics: Dict[str, List[ID]]
peers: Dict[ID, INetStream] peers: Dict[ID, INetStream]
@ -380,10 +376,7 @@ class Pubsub(IPubsub, Service):
# for each topic # for each topic
await self.subscribed_topics_send[topic].send(publish_message) await self.subscribed_topics_send[topic].send(publish_message)
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic? async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
async def subscribe(
self, topic_id: str
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
""" """
Subscribe ourself to a topic. Subscribe ourself to a topic.
@ -396,14 +389,14 @@ class Pubsub(IPubsub, Service):
if topic_id in self.topic_ids: if topic_id in self.topic_ids:
return self.subscribed_topics_receive[topic_id] return self.subscribed_topics_receive[topic_id]
# Map topic_id to a blocking channel
channels: Tuple[ channels: Tuple[
"trio.MemorySendChannel[rpc_pb2.Message]", "trio.MemorySendChannel[rpc_pb2.Message]",
"trio.MemoryReceiveChannel[rpc_pb2.Message]", "trio.MemoryReceiveChannel[rpc_pb2.Message]",
] = trio.open_memory_channel(math.inf) ] = trio.open_memory_channel(math.inf)
send_channel, receive_channel = channels send_channel, receive_channel = channels
subscription = TrioSubscriptionAPI(receive_channel)
self.subscribed_topics_send[topic_id] = send_channel self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel self.subscribed_topics_receive[topic_id] = subscription
# Create subscribe message # Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -417,8 +410,8 @@ class Pubsub(IPubsub, Service):
# Tell router we are joining this topic # Tell router we are joining this topic
await self.router.join(topic_id) await self.router.join(topic_id)
# Return the trio channel for messages on this topic # Return the subscription for messages on this topic
return receive_channel return subscription
async def unsubscribe(self, topic_id: str) -> None: async def unsubscribe(self, topic_id: str) -> None:
""" """

View File

@ -0,0 +1,39 @@
from types import TracebackType
from typing import AsyncIterator, Optional, Type
import trio
from .abc import ISubscriptionAPI
from .pb import rpc_pb2
class BaseSubscriptionAPI(ISubscriptionAPI):
async def __aenter__(self) -> "BaseSubscriptionAPI":
await trio.hazmat.checkpoint()
return self
async def __aexit__(
self,
exc_type: "Optional[Type[BaseException]]",
exc_value: "Optional[BaseException]",
traceback: "Optional[TracebackType]",
) -> None:
await self.cancel()
class TrioSubscriptionAPI(BaseSubscriptionAPI):
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
def __init__(
self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
) -> None:
self.receive_channel = receive_channel
async def cancel(self) -> None:
await self.receive_channel.aclose()
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
return self.receive_channel.__aiter__()
async def get(self) -> rpc_pb2.Message:
return await self.receive_channel.receive()

View File

@ -17,10 +17,10 @@ from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.abc import IPubsubRouter
from libp2p.pubsub.floodsub import FloodSub from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.pubsub_router_interface import IPubsubRouter
from libp2p.security.base_transport import BaseSecureTransport from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio import libp2p.security.secio.transport as secio

View File

@ -61,7 +61,7 @@ class DummyAccountNode(Service):
async def handle_incoming_msgs(self) -> None: async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers.""" """Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True: while True:
incoming = await self.subscription.receive() incoming = await self.subscription.get()
msg_comps = incoming.data.decode("utf-8").split(",") msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send": if msg_comps[0] == "send":

View File

@ -250,7 +250,7 @@ async def perform_test_from_obj(obj, pubsub_factory) -> None:
# Look at each node in each topic # Look at each node in each topic
for node_id in topic_map[topic]: for node_id in topic_map[topic]:
# Get message from subscription queue # Get message from subscription queue
msg = await queues_map[node_id][topic].receive() msg = await queues_map[node_id][topic].get()
assert data == msg.data assert data == msg.data
# Check the message origin # Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id

View File

@ -27,7 +27,7 @@ async def test_simple_two_nodes():
await pubsubs_fsub[0].publish(topic, data) await pubsubs_fsub[0].publish(topic, data)
res_b = await sub_b.receive() res_b = await sub_b.get()
# Check that the msg received by node_b is the same # Check that the msg received by node_b is the same
# as the message sent by node_a # as the message sent by node_a
@ -75,12 +75,9 @@ async def test_lru_cache_two_nodes(monkeypatch):
await trio.sleep(0.25) await trio.sleep(0.25)
for index in expected_received_indices: for index in expected_received_indices:
res_b = await sub_b.receive() res_b = await sub_b.get()
assert res_b.data == _make_testing_data(index) assert res_b.data == _make_testing_data(index)
with pytest.raises(trio.WouldBlock):
sub_b.receive_nowait()
@pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params) @pytest.mark.parametrize("test_case_obj", floodsub_protocol_pytest_params)
@pytest.mark.trio @pytest.mark.trio

View File

@ -196,7 +196,7 @@ async def test_dense():
await trio.sleep(0.5) await trio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.receive() msg = await queue.get()
assert msg.data == msg_content assert msg.data == msg_content
@ -229,7 +229,7 @@ async def test_fanout():
await trio.sleep(0.5) await trio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for sub in subs: for sub in subs:
msg = await sub.receive() msg = await sub.get()
assert msg.data == msg_content assert msg.data == msg_content
# Subscribe message origin # Subscribe message origin
@ -248,7 +248,7 @@ async def test_fanout():
await trio.sleep(0.5) await trio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for sub in subs: for sub in subs:
msg = await sub.receive() msg = await sub.get()
assert msg.data == msg_content assert msg.data == msg_content
@ -287,7 +287,7 @@ async def test_fanout_maintenance():
await trio.sleep(0.5) await trio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.receive() msg = await queue.get()
assert msg.data == msg_content assert msg.data == msg_content
for sub in pubsubs_gsub: for sub in pubsubs_gsub:
@ -319,7 +319,7 @@ async def test_fanout_maintenance():
await trio.sleep(0.5) await trio.sleep(0.5)
# Assert that all blocking queues receive the message # Assert that all blocking queues receive the message
for queue in queues: for queue in queues:
msg = await queue.receive() msg = await queue.get()
assert msg.data == msg_content assert msg.data == msg_content
@ -346,5 +346,5 @@ async def test_gossip_propagation():
await trio.sleep(2) await trio.sleep(2)
# should be able to read message # should be able to read message
msg = await queue_1.receive() msg = await queue_1.get()
assert msg.data == msg_content assert msg.data == msg_content

View File

@ -384,7 +384,7 @@ async def test_handle_talk():
len(pubsubs_fsub[0].topic_ids) == 1 len(pubsubs_fsub[0].topic_ids) == 1
and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC] and sub == pubsubs_fsub[0].subscribed_topics_receive[TESTING_TOPIC]
) )
assert (await sub.receive()) == msg_0 assert (await sub.get()) == msg_0
@pytest.mark.trio @pytest.mark.trio
@ -486,7 +486,7 @@ async def test_push_msg(monkeypatch):
with trio.fail_after(0.1): with trio.fail_after(0.1):
await event.wait() await event.wait()
# Test: Subscribers are notified when `push_msg` new messages. # Test: Subscribers are notified when `push_msg` new messages.
assert (await sub.receive()) == msg_1 assert (await sub.get()) == msg_1
with mock_router_publish() as event: with mock_router_publish() as event:
# Test: add a topic validator and `push_msg` the message that # Test: add a topic validator and `push_msg` the message that

View File

@ -0,0 +1,77 @@
import math
import pytest
import trio
from libp2p.pubsub.pb import rpc_pb2
from libp2p.pubsub.subscription import TrioSubscriptionAPI
GET_TIMEOUT = 0.001
def make_trio_subscription():
send_channel, receive_channel = trio.open_memory_channel(math.inf)
return send_channel, TrioSubscriptionAPI(receive_channel)
def make_pubsub_msg():
return rpc_pb2.Message()
async def send_something(send_channel):
msg = make_pubsub_msg()
await send_channel.send(msg)
return msg
@pytest.mark.trio
async def test_trio_subscription_get():
send_channel, sub = make_trio_subscription()
data_0 = await send_something(send_channel)
data_1 = await send_something(send_channel)
assert data_0 == await sub.get()
assert data_1 == await sub.get()
# No more message
with pytest.raises(trio.TooSlowError):
with trio.fail_after(GET_TIMEOUT):
await sub.get()
@pytest.mark.trio
async def test_trio_subscription_iter():
send_channel, sub = make_trio_subscription()
received_data = []
async def iter_subscriptions(subscription):
async for data in sub:
received_data.append(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(iter_subscriptions, sub)
await send_something(send_channel)
await send_something(send_channel)
await send_channel.aclose()
assert len(received_data) == 2
@pytest.mark.trio
async def test_trio_subscription_cancel():
send_channel, sub = make_trio_subscription()
await sub.cancel()
# Test: If the subscription is cancelled, `send_channel` should be broken.
with pytest.raises(trio.BrokenResourceError):
await send_something(send_channel)
# Test: No side effect when cancelled twice.
await sub.cancel()
@pytest.mark.trio
async def test_trio_subscription_async_context_manager():
send_channel, sub = make_trio_subscription()
async with sub:
# Test: `sub` is not cancelled yet, so `send_something` works fine.
await send_something(send_channel)
# Test: `sub` is cancelled, `send_something` fails
with pytest.raises(trio.BrokenResourceError):
await send_something(send_channel)