Fix Pubsub

This commit is contained in:
mhchia
2019-12-03 17:27:49 +08:00
parent bdbb7b2394
commit e9ab0646e3
7 changed files with 568 additions and 523 deletions

View File

@ -1,11 +1,13 @@
import asyncio
from abc import ABC, abstractmethod
import logging
import math
import time
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
KeysView,
List,
NamedTuple,
Tuple,
@ -13,8 +15,10 @@ from typing import (
cast,
)
from async_service import Service
import base58
from lru import LRU
import trio
from libp2p.exceptions import ParseError, ValidationError
from libp2p.host.host_interface import IHost
@ -53,24 +57,24 @@ class TopicValidator(NamedTuple):
is_async: bool
class Pubsub:
class BasePubsub(ABC):
pass
class Pubsub(BasePubsub, Service):
host: IHost
my_id: ID
router: "IPubsubRouter"
peer_queue: "asyncio.Queue[ID]"
dead_peer_queue: "asyncio.Queue[ID]"
protocols: List[TProtocol]
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
seen_messages: LRU
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
# TODO: Implement `trio.abc.Channel`?
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
peer_topics: Dict[str, List[ID]]
peers: Dict[ID, INetStream]
@ -80,10 +84,8 @@ class Pubsub:
# TODO: Be sure it is increased atomically everytime.
counter: int # uint64
_tasks: List["asyncio.Future[Any]"]
def __init__(
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -97,28 +99,26 @@ class Pubsub:
"""
self.host = host
self.router = router
self.my_id = my_id
# Attach this new Pubsub object to the router
self.router.attach(self)
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
# Only keep the receive channels in `Pubsub`.
# Therefore, we can only close from the receive side.
self.peer_receive_channel = peer_receive_channel
self.dead_peer_receive_channel = dead_peer_receive_channel
# Register a notifee
self.peer_queue = asyncio.Queue()
self.dead_peer_queue = asyncio.Queue()
self.host.get_network().register_notifee(
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
PubsubNotifee(peer_send_channel, dead_peer_send_channel)
)
# Register stream handlers for each pubsub router protocol to handle
# the pubsub streams opened on those protocols
self.protocols = self.router.get_protocols()
for protocol in self.protocols:
for protocol in router.protocols:
self.host.set_stream_handler(protocol, self.stream_handler)
# Use asyncio queues for proper context switching
self.incoming_msgs_from_peers = asyncio.Queue()
self.outgoing_messages = asyncio.Queue()
# keeps track of seen messages as LRU cache
if cache_size is None:
self.cache_size = 128
@ -129,7 +129,8 @@ class Pubsub:
# Map of topics we are subscribed to blocking queues
# for when the given topic receives a message
self.my_topics = {}
self.subscribed_topics_send = {}
self.subscribed_topics_receive = {}
# Map of topic to peers to keep track of what peers are subscribed to
self.peer_topics = {}
@ -142,16 +143,28 @@ class Pubsub:
self.counter = time.time_ns()
self._tasks = []
# Call handle peer to keep waiting for updates to peer queue
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
async def run(self) -> None:
self.manager.run_daemon_task(self.handle_peer_queue)
self.manager.run_daemon_task(self.handle_dead_peer_queue)
await self.manager.wait_finished()
@property
def my_id(self) -> ID:
return self.host.get_id()
@property
def protocols(self) -> Tuple[TProtocol, ...]:
return tuple(self.router.get_protocols())
@property
def topic_ids(self) -> KeysView[str]:
return self.subscribed_topics_receive.keys()
def get_hello_packet(self) -> rpc_pb2.RPC:
"""Generate subscription message with all topics we are subscribed to
only send hello packet if we have subscribed topics."""
packet = rpc_pb2.RPC()
for topic_id in self.my_topics:
for topic_id in self.topic_ids:
packet.subscriptions.extend(
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
)
@ -166,7 +179,7 @@ class Pubsub:
"""
peer_id = stream.muxed_conn.peer_id
while True:
while self.manager.is_running:
incoming: bytes = await read_varint_prefixed_bytes(stream)
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
rpc_incoming.ParseFromString(incoming)
@ -178,11 +191,7 @@ class Pubsub:
logger.debug(
"received `publish` message %s from peer %s", msg, peer_id
)
self._tasks.append(
asyncio.ensure_future(
self.push_msg(msg_forwarder=peer_id, msg=msg)
)
)
self.manager.run_task(self.push_msg, peer_id, msg)
if rpc_incoming.subscriptions:
# deal with RPC.subscriptions
@ -210,9 +219,6 @@ class Pubsub:
)
await self.router.handle_rpc(rpc_incoming, peer_id)
# Force context switch
await asyncio.sleep(0)
def set_topic_validator(
self, topic: str, validator: ValidatorFn, is_async_validator: bool
) -> None:
@ -285,7 +291,6 @@ class Pubsub:
logger.debug("Fail to add new peer %s: stream closed", peer_id)
del self.peers[peer_id]
return
# TODO: Check EOF of this stream.
# TODO: Check if the peer in black list.
try:
self.router.add_peer(peer_id, stream.get_protocol())
@ -311,23 +316,25 @@ class Pubsub:
async def handle_peer_queue(self) -> None:
"""
Continuously read from peer queue and each time a new peer is found,
Continuously read from peer channel and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol
TODO: Handle failure for when the peer does not support any of the
pubsub protocols we support
"""
while True:
peer_id: ID = await self.peer_queue.get()
# Add Peer
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
async with self.peer_receive_channel:
while self.manager.is_running:
peer_id: ID = await self.peer_receive_channel.receive()
# Add Peer
self.manager.run_task(self._handle_new_peer, peer_id)
async def handle_dead_peer_queue(self) -> None:
"""Continuously read from dead peer queue and close the stream between
"""Continuously read from dead peer channel and close the stream between
that peer and remove peer info from pubsub and pubsub router."""
while True:
peer_id: ID = await self.dead_peer_queue.get()
# Remove Peer
self._handle_dead_peer(peer_id)
async with self.dead_peer_receive_channel:
while self.manager.is_running:
peer_id: ID = await self.dead_peer_receive_channel.receive()
# Remove Peer
self._handle_dead_peer(peer_id)
def handle_subscription(
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
@ -361,13 +368,16 @@ class Pubsub:
# Check if this message has any topics that we are subscribed to
for topic in publish_message.topicIDs:
if topic in self.my_topics:
if topic in self.topic_ids:
# we are subscribed to a topic this message was sent for,
# so add message to the subscription output queue
# for each topic
await self.my_topics[topic].put(publish_message)
await self.subscribed_topics_send[topic].send(publish_message)
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]":
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
async def subscribe(
self, topic_id: str
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
"""
Subscribe ourself to a topic.
@ -377,11 +387,13 @@ class Pubsub:
logger.debug("subscribing to topic %s", topic_id)
# Already subscribed
if topic_id in self.my_topics:
return self.my_topics[topic_id]
if topic_id in self.topic_ids:
return self.subscribed_topics_receive[topic_id]
# Map topic_id to blocking queue
self.my_topics[topic_id] = asyncio.Queue()
# Map topic_id to a blocking channel
send_channel, receive_channel = trio.open_memory_channel(math.inf)
self.subscribed_topics_send[topic_id] = send_channel
self.subscribed_topics_receive[topic_id] = receive_channel
# Create subscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -395,8 +407,8 @@ class Pubsub:
# Tell router we are joining this topic
await self.router.join(topic_id)
# Return the asyncio queue for messages on this topic
return self.my_topics[topic_id]
# Return the trio channel for messages on this topic
return receive_channel
async def unsubscribe(self, topic_id: str) -> None:
"""
@ -408,10 +420,14 @@ class Pubsub:
logger.debug("unsubscribing from topic %s", topic_id)
# Return if we already unsubscribed from the topic
if topic_id not in self.my_topics:
if topic_id not in self.topic_ids:
return
# Remove topic_id from map if present
del self.my_topics[topic_id]
# Remove topic_id from the maps before yielding
send_channel = self.subscribed_topics_send[topic_id]
del self.subscribed_topics_send[topic_id]
del self.subscribed_topics_receive[topic_id]
# Only close the send side
await send_channel.aclose()
# Create unsubscribe message
packet: rpc_pb2.RPC = rpc_pb2.RPC()
@ -453,13 +469,13 @@ class Pubsub:
data=data,
topicIDs=[topic_id],
# Origin is ourself.
from_id=self.host.get_id().to_bytes(),
from_id=self.my_id.to_bytes(),
seqno=self._next_seqno(),
)
# TODO: Sign with our signing key
await self.push_msg(self.host.get_id(), msg)
await self.push_msg(self.my_id, msg)
logger.debug("successfully published message %s", msg)
@ -470,12 +486,12 @@ class Pubsub:
:param msg_forwarder: the peer who forward us the message.
:param msg: the message.
"""
sync_topic_validators = []
async_topic_validator_futures: List[Awaitable[bool]] = []
sync_topic_validators: List[SyncValidatorFn] = []
async_topic_validators: List[AsyncValidatorFn] = []
for topic_validator in self.get_msg_validators(msg):
if topic_validator.is_async:
async_topic_validator_futures.append(
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
async_topic_validators.append(
cast(AsyncValidatorFn, topic_validator.validator)
)
else:
sync_topic_validators.append(
@ -488,9 +504,20 @@ class Pubsub:
# TODO: Implement throttle on async validators
if len(async_topic_validator_futures) > 0:
results = await asyncio.gather(*async_topic_validator_futures)
if not all(results):
if len(async_topic_validators) > 0:
# TODO: Use a better pattern
final_result = True
async def run_async_validator(func: AsyncValidatorFn) -> None:
nonlocal final_result
result = await func(msg_forwarder, msg)
final_result = final_result and result
async with trio.open_nursery() as nursery:
for validator in async_topic_validators:
nursery.start_soon(run_async_validator, validator)
if not final_result:
raise ValidationError(f"Validation failed for msg={msg}")
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
@ -551,14 +578,4 @@ class Pubsub:
self.seen_messages[msg_id] = 1
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
if not self.my_topics:
return False
return any(topic in self.my_topics for topic in msg.topicIDs)
async def close(self) -> None:
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
return any(topic in self.topic_ids for topic in msg.topicIDs)