Add SubscriptionAPI

And `TrioSubscriptionAPI`, to make subscription io-agnostic.
This commit is contained in:
mhchia
2019-12-17 18:17:28 +08:00
parent fb0519129d
commit 47d10e186f
12 changed files with 158 additions and 36 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, AsyncContextManager, AsyncIterable, List
from libp2p.peer.id import ID
from libp2p.typing import TProtocol
@ -10,6 +10,11 @@ if TYPE_CHECKING:
from .pubsub import Pubsub # noqa: F401
# TODO: Add interface for Pubsub
class IPubsub(ABC):
pass
class IPubsubRouter(ABC):
@abstractmethod
def get_protocols(self) -> List[TProtocol]:
@ -53,7 +58,6 @@ class IPubsubRouter(ABC):
:param rpc: rpc message
"""
# FIXME: Should be changed to type 'peer.ID'
@abstractmethod
async def publish(self, msg_forwarder: ID, pubsub_msg: rpc_pb2.Message) -> None:
"""
@ -80,3 +84,15 @@ class IPubsubRouter(ABC):
:param topic: topic to leave
"""
class ISubscriptionAPI(
AsyncContextManager["ISubscriptionAPI"], AsyncIterable[rpc_pb2.Message]
):
@abstractmethod
async def cancel(self) -> None:
...
@abstractmethod
async def get(self) -> rpc_pb2.Message:
...

View File

@ -8,9 +8,9 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .pb import rpc_pb2
from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/floodsub/1.0.0")

View File

@ -12,11 +12,11 @@ from libp2p.pubsub import floodsub
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed
from .abc import IPubsubRouter
from .exceptions import NoPubsubAttached
from .mcache import MessageCache
from .pb import rpc_pb2
from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter
PROTOCOL_ID = TProtocol("/meshsub/1.0.0")

View File

@ -1,4 +1,3 @@
from abc import ABC
import logging
import math
import time
@ -30,12 +29,14 @@ from libp2p.peer.id import ID
from libp2p.typing import TProtocol
from libp2p.utils import encode_varint_prefixed, read_varint_prefixed_bytes
from .abc import IPubsub, ISubscriptionAPI
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
from .subscription import TrioSubscriptionAPI
from .validators import signature_validator
if TYPE_CHECKING:
from .pubsub_router_interface import IPubsubRouter # noqa: F401
from .abc import IPubsubRouter # noqa: F401
from typing import Any # noqa: F401
@ -57,11 +58,6 @@ class TopicValidator(NamedTuple):
is_async: bool
# TODO: Add interface for Pubsub
class IPubsub(ABC):
pass
class Pubsub(IPubsub, Service):
host: IHost
@ -75,7 +71,7 @@ class Pubsub(IPubsub, Service):
# TODO: Implement `trio.abc.Channel`?
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
peer_topics: Dict[str, List[ID]]
peers: Dict[ID, INetStream]
@ -380,10 +376,7 @@ class Pubsub(IPubsub, Service):
# for each topic
await self.subscribed_topics_send[topic].send(publish_message)
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
async def subscribe(
self, topic_id: str
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
async def subscribe(self, topic_id: str) -> ISubscriptionAPI:
"""
Subscribe ourself to a topic.
@ -396,14 +389,14 @@ class Pubsub(IPubsub, Service):
if topic_id in self.topic_ids:
return self.subscribed_topics_receive[topic_id]
# Map topic_id to a blocking channel
channels: Tuple[
"trio.MemorySendChannel[rpc_pb2.Message]",
"trio.MemoryReceiveChannel[rpc_pb2.Message]",
] = trio.open_memory_channel(math.inf)
send_channel, receive_channel = channels
subscription = TrioSubscriptionAPI(receive_channel)
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel
self.subscribed_topics_receive[topic_id] = subscription
# Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -417,8 +410,8 @@ class Pubsub(IPubsub, Service):
# Tell router we are joining this topic
await self.router.join(topic_id)
# Return the trio channel for messages on this topic
return receive_channel
# Return the subscription for messages on this topic
return subscription
async def unsubscribe(self, topic_id: str) -> None:
"""

View File

@ -0,0 +1,39 @@
from types import TracebackType
from typing import AsyncIterator, Optional, Type
import trio
from .abc import ISubscriptionAPI
from .pb import rpc_pb2
class BaseSubscriptionAPI(ISubscriptionAPI):
async def __aenter__(self) -> "BaseSubscriptionAPI":
await trio.hazmat.checkpoint()
return self
async def __aexit__(
self,
exc_type: "Optional[Type[BaseException]]",
exc_value: "Optional[BaseException]",
traceback: "Optional[TracebackType]",
) -> None:
await self.cancel()
class TrioSubscriptionAPI(BaseSubscriptionAPI):
receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
def __init__(
self, receive_channel: "trio.MemoryReceiveChannel[rpc_pb2.Message]"
) -> None:
self.receive_channel = receive_channel
async def cancel(self) -> None:
await self.receive_channel.aclose()
def __aiter__(self) -> AsyncIterator[rpc_pb2.Message]:
return self.receive_channel.__aiter__()
async def get(self) -> rpc_pb2.Message:
return await self.receive_channel.receive()

View File

@ -17,10 +17,10 @@ from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore
from libp2p.pubsub.abc import IPubsubRouter
from libp2p.pubsub.floodsub import FloodSub
from libp2p.pubsub.gossipsub import GossipSub
from libp2p.pubsub.pubsub import Pubsub
from libp2p.pubsub.pubsub_router_interface import IPubsubRouter
from libp2p.security.base_transport import BaseSecureTransport
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
import libp2p.security.secio.transport as secio

View File

@ -61,7 +61,7 @@ class DummyAccountNode(Service):
async def handle_incoming_msgs(self) -> None:
"""Handle all incoming messages on the CRYPTO_TOPIC from peers."""
while True:
incoming = await self.subscription.receive()
incoming = await self.subscription.get()
msg_comps = incoming.data.decode("utf-8").split(",")
if msg_comps[0] == "send":

View File

@ -250,7 +250,7 @@ async def perform_test_from_obj(obj, pubsub_factory) -> None:
# Look at each node in each topic
for node_id in topic_map[topic]:
# Get message from subscription queue
msg = await queues_map[node_id][topic].receive()
msg = await queues_map[node_id][topic].get()
assert data == msg.data
# Check the message origin
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id