mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-08 06:00:53 +00:00
Fix Pubsub
This commit is contained in:
@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
KeysView,
|
||||
List,
|
||||
NamedTuple,
|
||||
Tuple,
|
||||
@ -13,8 +15,10 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from async_service import Service
|
||||
import base58
|
||||
from lru import LRU
|
||||
import trio
|
||||
|
||||
from libp2p.exceptions import ParseError, ValidationError
|
||||
from libp2p.host.host_interface import IHost
|
||||
@ -53,24 +57,24 @@ class TopicValidator(NamedTuple):
|
||||
is_async: bool
|
||||
|
||||
|
||||
class Pubsub:
|
||||
class BasePubsub(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class Pubsub(BasePubsub, Service):
|
||||
|
||||
host: IHost
|
||||
my_id: ID
|
||||
|
||||
router: "IPubsubRouter"
|
||||
|
||||
peer_queue: "asyncio.Queue[ID]"
|
||||
dead_peer_queue: "asyncio.Queue[ID]"
|
||||
|
||||
protocols: List[TProtocol]
|
||||
|
||||
incoming_msgs_from_peers: "asyncio.Queue[rpc_pb2.Message]"
|
||||
outgoing_messages: "asyncio.Queue[rpc_pb2.Message]"
|
||||
peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||
dead_peer_receive_channel: "trio.MemoryReceiveChannel[ID]"
|
||||
|
||||
seen_messages: LRU
|
||||
|
||||
my_topics: Dict[str, "asyncio.Queue[rpc_pb2.Message]"]
|
||||
# TODO: Implement `trio.abc.Channel`?
|
||||
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "trio.MemoryReceiveChannel[rpc_pb2.Message]"]
|
||||
|
||||
peer_topics: Dict[str, List[ID]]
|
||||
peers: Dict[ID, INetStream]
|
||||
@ -80,10 +84,8 @@ class Pubsub:
|
||||
# TODO: Be sure it is increased atomically everytime.
|
||||
counter: int # uint64
|
||||
|
||||
_tasks: List["asyncio.Future[Any]"]
|
||||
|
||||
def __init__(
|
||||
self, host: IHost, router: "IPubsubRouter", my_id: ID, cache_size: int = None
|
||||
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -97,28 +99,26 @@ class Pubsub:
|
||||
"""
|
||||
self.host = host
|
||||
self.router = router
|
||||
self.my_id = my_id
|
||||
|
||||
# Attach this new Pubsub object to the router
|
||||
self.router.attach(self)
|
||||
|
||||
peer_send_channel, peer_receive_channel = trio.open_memory_channel(0)
|
||||
dead_peer_send_channel, dead_peer_receive_channel = trio.open_memory_channel(0)
|
||||
# Only keep the receive channels in `Pubsub`.
|
||||
# Therefore, we can only close from the receive side.
|
||||
self.peer_receive_channel = peer_receive_channel
|
||||
self.dead_peer_receive_channel = dead_peer_receive_channel
|
||||
# Register a notifee
|
||||
self.peer_queue = asyncio.Queue()
|
||||
self.dead_peer_queue = asyncio.Queue()
|
||||
self.host.get_network().register_notifee(
|
||||
PubsubNotifee(self.peer_queue, self.dead_peer_queue)
|
||||
PubsubNotifee(peer_send_channel, dead_peer_send_channel)
|
||||
)
|
||||
|
||||
# Register stream handlers for each pubsub router protocol to handle
|
||||
# the pubsub streams opened on those protocols
|
||||
self.protocols = self.router.get_protocols()
|
||||
for protocol in self.protocols:
|
||||
for protocol in router.protocols:
|
||||
self.host.set_stream_handler(protocol, self.stream_handler)
|
||||
|
||||
# Use asyncio queues for proper context switching
|
||||
self.incoming_msgs_from_peers = asyncio.Queue()
|
||||
self.outgoing_messages = asyncio.Queue()
|
||||
|
||||
# keeps track of seen messages as LRU cache
|
||||
if cache_size is None:
|
||||
self.cache_size = 128
|
||||
@ -129,7 +129,8 @@ class Pubsub:
|
||||
|
||||
# Map of topics we are subscribed to blocking queues
|
||||
# for when the given topic receives a message
|
||||
self.my_topics = {}
|
||||
self.subscribed_topics_send = {}
|
||||
self.subscribed_topics_receive = {}
|
||||
|
||||
# Map of topic to peers to keep track of what peers are subscribed to
|
||||
self.peer_topics = {}
|
||||
@ -142,16 +143,28 @@ class Pubsub:
|
||||
|
||||
self.counter = time.time_ns()
|
||||
|
||||
self._tasks = []
|
||||
# Call handle peer to keep waiting for updates to peer queue
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_peer_queue()))
|
||||
self._tasks.append(asyncio.ensure_future(self.handle_dead_peer_queue()))
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
self.manager.run_daemon_task(self.handle_dead_peer_queue)
|
||||
await self.manager.wait_finished()
|
||||
|
||||
@property
|
||||
def my_id(self) -> ID:
|
||||
return self.host.get_id()
|
||||
|
||||
@property
|
||||
def protocols(self) -> Tuple[TProtocol, ...]:
|
||||
return tuple(self.router.get_protocols())
|
||||
|
||||
@property
|
||||
def topic_ids(self) -> KeysView[str]:
|
||||
return self.subscribed_topics_receive.keys()
|
||||
|
||||
def get_hello_packet(self) -> rpc_pb2.RPC:
|
||||
"""Generate subscription message with all topics we are subscribed to
|
||||
only send hello packet if we have subscribed topics."""
|
||||
packet = rpc_pb2.RPC()
|
||||
for topic_id in self.my_topics:
|
||||
for topic_id in self.topic_ids:
|
||||
packet.subscriptions.extend(
|
||||
[rpc_pb2.RPC.SubOpts(subscribe=True, topicid=topic_id)]
|
||||
)
|
||||
@ -166,7 +179,7 @@ class Pubsub:
|
||||
"""
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
|
||||
while True:
|
||||
while self.manager.is_running:
|
||||
incoming: bytes = await read_varint_prefixed_bytes(stream)
|
||||
rpc_incoming: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
rpc_incoming.ParseFromString(incoming)
|
||||
@ -178,11 +191,7 @@ class Pubsub:
|
||||
logger.debug(
|
||||
"received `publish` message %s from peer %s", msg, peer_id
|
||||
)
|
||||
self._tasks.append(
|
||||
asyncio.ensure_future(
|
||||
self.push_msg(msg_forwarder=peer_id, msg=msg)
|
||||
)
|
||||
)
|
||||
self.manager.run_task(self.push_msg, peer_id, msg)
|
||||
|
||||
if rpc_incoming.subscriptions:
|
||||
# deal with RPC.subscriptions
|
||||
@ -210,9 +219,6 @@ class Pubsub:
|
||||
)
|
||||
await self.router.handle_rpc(rpc_incoming, peer_id)
|
||||
|
||||
# Force context switch
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def set_topic_validator(
|
||||
self, topic: str, validator: ValidatorFn, is_async_validator: bool
|
||||
) -> None:
|
||||
@ -285,7 +291,6 @@ class Pubsub:
|
||||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
# TODO: Check EOF of this stream.
|
||||
# TODO: Check if the peer in black list.
|
||||
try:
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
@ -311,23 +316,25 @@ class Pubsub:
|
||||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
Continuously read from peer queue and each time a new peer is found,
|
||||
Continuously read from peer channel and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol
|
||||
TODO: Handle failure for when the peer does not support any of the
|
||||
pubsub protocols we support
|
||||
"""
|
||||
while True:
|
||||
peer_id: ID = await self.peer_queue.get()
|
||||
# Add Peer
|
||||
self._tasks.append(asyncio.ensure_future(self._handle_new_peer(peer_id)))
|
||||
async with self.peer_receive_channel:
|
||||
while self.manager.is_running:
|
||||
peer_id: ID = await self.peer_receive_channel.receive()
|
||||
# Add Peer
|
||||
self.manager.run_task(self._handle_new_peer, peer_id)
|
||||
|
||||
async def handle_dead_peer_queue(self) -> None:
|
||||
"""Continuously read from dead peer queue and close the stream between
|
||||
"""Continuously read from dead peer channel and close the stream between
|
||||
that peer and remove peer info from pubsub and pubsub router."""
|
||||
while True:
|
||||
peer_id: ID = await self.dead_peer_queue.get()
|
||||
# Remove Peer
|
||||
self._handle_dead_peer(peer_id)
|
||||
async with self.dead_peer_receive_channel:
|
||||
while self.manager.is_running:
|
||||
peer_id: ID = await self.dead_peer_receive_channel.receive()
|
||||
# Remove Peer
|
||||
self._handle_dead_peer(peer_id)
|
||||
|
||||
def handle_subscription(
|
||||
self, origin_id: ID, sub_message: rpc_pb2.RPC.SubOpts
|
||||
@ -361,13 +368,16 @@ class Pubsub:
|
||||
|
||||
# Check if this message has any topics that we are subscribed to
|
||||
for topic in publish_message.topicIDs:
|
||||
if topic in self.my_topics:
|
||||
if topic in self.topic_ids:
|
||||
# we are subscribed to a topic this message was sent for,
|
||||
# so add message to the subscription output queue
|
||||
# for each topic
|
||||
await self.my_topics[topic].put(publish_message)
|
||||
await self.subscribed_topics_send[topic].send(publish_message)
|
||||
|
||||
async def subscribe(self, topic_id: str) -> "asyncio.Queue[rpc_pb2.Message]":
|
||||
# TODO: Change to return an `AsyncIterable` to be I/O-agnostic?
|
||||
async def subscribe(
|
||||
self, topic_id: str
|
||||
) -> "trio.MemoryReceiveChannel[rpc_pb2.Message]":
|
||||
"""
|
||||
Subscribe ourself to a topic.
|
||||
|
||||
@ -377,11 +387,13 @@ class Pubsub:
|
||||
logger.debug("subscribing to topic %s", topic_id)
|
||||
|
||||
# Already subscribed
|
||||
if topic_id in self.my_topics:
|
||||
return self.my_topics[topic_id]
|
||||
if topic_id in self.topic_ids:
|
||||
return self.subscribed_topics_receive[topic_id]
|
||||
|
||||
# Map topic_id to blocking queue
|
||||
self.my_topics[topic_id] = asyncio.Queue()
|
||||
# Map topic_id to a blocking channel
|
||||
send_channel, receive_channel = trio.open_memory_channel(math.inf)
|
||||
self.subscribed_topics_send[topic_id] = send_channel
|
||||
self.subscribed_topics_receive[topic_id] = receive_channel
|
||||
|
||||
# Create subscribe message
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
@ -395,8 +407,8 @@ class Pubsub:
|
||||
# Tell router we are joining this topic
|
||||
await self.router.join(topic_id)
|
||||
|
||||
# Return the asyncio queue for messages on this topic
|
||||
return self.my_topics[topic_id]
|
||||
# Return the trio channel for messages on this topic
|
||||
return receive_channel
|
||||
|
||||
async def unsubscribe(self, topic_id: str) -> None:
|
||||
"""
|
||||
@ -408,10 +420,14 @@ class Pubsub:
|
||||
logger.debug("unsubscribing from topic %s", topic_id)
|
||||
|
||||
# Return if we already unsubscribed from the topic
|
||||
if topic_id not in self.my_topics:
|
||||
if topic_id not in self.topic_ids:
|
||||
return
|
||||
# Remove topic_id from map if present
|
||||
del self.my_topics[topic_id]
|
||||
# Remove topic_id from the maps before yielding
|
||||
send_channel = self.subscribed_topics_send[topic_id]
|
||||
del self.subscribed_topics_send[topic_id]
|
||||
del self.subscribed_topics_receive[topic_id]
|
||||
# Only close the send side
|
||||
await send_channel.aclose()
|
||||
|
||||
# Create unsubscribe message
|
||||
packet: rpc_pb2.RPC = rpc_pb2.RPC()
|
||||
@ -453,13 +469,13 @@ class Pubsub:
|
||||
data=data,
|
||||
topicIDs=[topic_id],
|
||||
# Origin is ourself.
|
||||
from_id=self.host.get_id().to_bytes(),
|
||||
from_id=self.my_id.to_bytes(),
|
||||
seqno=self._next_seqno(),
|
||||
)
|
||||
|
||||
# TODO: Sign with our signing key
|
||||
|
||||
await self.push_msg(self.host.get_id(), msg)
|
||||
await self.push_msg(self.my_id, msg)
|
||||
|
||||
logger.debug("successfully published message %s", msg)
|
||||
|
||||
@ -470,12 +486,12 @@ class Pubsub:
|
||||
:param msg_forwarder: the peer who forward us the message.
|
||||
:param msg: the message.
|
||||
"""
|
||||
sync_topic_validators = []
|
||||
async_topic_validator_futures: List[Awaitable[bool]] = []
|
||||
sync_topic_validators: List[SyncValidatorFn] = []
|
||||
async_topic_validators: List[AsyncValidatorFn] = []
|
||||
for topic_validator in self.get_msg_validators(msg):
|
||||
if topic_validator.is_async:
|
||||
async_topic_validator_futures.append(
|
||||
cast(Awaitable[bool], topic_validator.validator(msg_forwarder, msg))
|
||||
async_topic_validators.append(
|
||||
cast(AsyncValidatorFn, topic_validator.validator)
|
||||
)
|
||||
else:
|
||||
sync_topic_validators.append(
|
||||
@ -488,9 +504,20 @@ class Pubsub:
|
||||
|
||||
# TODO: Implement throttle on async validators
|
||||
|
||||
if len(async_topic_validator_futures) > 0:
|
||||
results = await asyncio.gather(*async_topic_validator_futures)
|
||||
if not all(results):
|
||||
if len(async_topic_validators) > 0:
|
||||
# TODO: Use a better pattern
|
||||
final_result = True
|
||||
|
||||
async def run_async_validator(func: AsyncValidatorFn) -> None:
|
||||
nonlocal final_result
|
||||
result = await func(msg_forwarder, msg)
|
||||
final_result = final_result and result
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for validator in async_topic_validators:
|
||||
nursery.start_soon(run_async_validator, validator)
|
||||
|
||||
if not final_result:
|
||||
raise ValidationError(f"Validation failed for msg={msg}")
|
||||
|
||||
async def push_msg(self, msg_forwarder: ID, msg: rpc_pb2.Message) -> None:
|
||||
@ -551,14 +578,4 @@ class Pubsub:
|
||||
self.seen_messages[msg_id] = 1
|
||||
|
||||
def _is_subscribed_to_msg(self, msg: rpc_pb2.Message) -> bool:
|
||||
if not self.my_topics:
|
||||
return False
|
||||
return any(topic in self.my_topics for topic in msg.topicIDs)
|
||||
|
||||
async def close(self) -> None:
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
return any(topic in self.topic_ids for topic in msg.topicIDs)
|
||||
|
||||
Reference in New Issue
Block a user