Add type hints to gossipsub.py

This commit is contained in:
NIC619
2019-07-24 16:29:14 +08:00
parent 8eb6a230ff
commit b920955db6

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
import random import random
from typing import ( from typing import (
Any,
Dict,
Iterable, Iterable,
List, List,
MutableSet, MutableSet,
@ -16,6 +18,7 @@ from libp2p.peer.id import (
from .mcache import MessageCache from .mcache import MessageCache
from .pb import rpc_pb2 from .pb import rpc_pb2
from .pubsub import Pubsub
from .pubsub_router_interface import IPubsubRouter from .pubsub_router_interface import IPubsubRouter
@ -24,11 +27,43 @@ class GossipSub(IPubsubRouter):
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
def __init__(self, protocols, degree, degree_low, degree_high, time_to_live, gossip_window=3, protocols: Sequence[str]
gossip_history=5, heartbeat_interval=120): pubsub: Pubsub
degree: int
degree_high: int
degree_low: int
time_to_live: int
# FIXME: Should be changed to `Dict[str, List[ID]]`
mesh: Dict[str, List[str]]
# FIXME: Should be changed to `Dict[str, List[ID]]`
fanout: Dict[str, List[str]]
time_since_last_publish: Dict[str, int]
#FIXME: Should be changed to List[ID]
peers_gossipsub: List[str]
#FIXME: Should be changed to List[ID]
peers_floodsub: List[str]
mcache: MessageCache
heartbeat_interval: int
def __init__(self,
protocols: Sequence[str],
degree: int,
degree_low: int,
degree_high: int,
time_to_live: int,
gossip_window: int=3,
gossip_history: int=5,
heartbeat_interval: int=120) -> None:
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
self.protocols = protocols self.protocols: List[str] = protocols
self.pubsub = None self.pubsub: Pubsub = None
# Store target degree, upper degree bound, and lower degree bound # Store target degree, upper degree bound, and lower degree bound
self.degree = degree self.degree = degree
@ -36,7 +71,7 @@ class GossipSub(IPubsubRouter):
self.degree_high = degree_high self.degree_high = degree_high
# Store time to live (for topics in fanout) # Store time to live (for topics in fanout)
self.time_to_live = time_to_live self.time_to_live: int = time_to_live
# Create topic --> list of peers mappings # Create topic --> list of peers mappings
self.mesh = {} self.mesh = {}
@ -56,13 +91,13 @@ class GossipSub(IPubsubRouter):
# Interface functions # Interface functions
def get_protocols(self): def get_protocols(self) -> List:
""" """
:return: the list of protocols supported by the router :return: the list of protocols supported by the router
""" """
return self.protocols return self.protocols
def attach(self, pubsub): def attach(self, pubsub: Pubsub) -> None:
""" """
Attach is invoked by the PubSub constructor to attach the router to a Attach is invoked by the PubSub constructor to attach the router to a
freshly initialized PubSub instance. freshly initialized PubSub instance.
@ -74,10 +109,11 @@ class GossipSub(IPubsubRouter):
# TODO: Start after delay # TODO: Start after delay
asyncio.ensure_future(self.heartbeat()) asyncio.ensure_future(self.heartbeat())
def add_peer(self, peer_id, protocol_id): def add_peer(self, peer_id: ID, protocol_id: str):
""" """
Notifies the router that a new peer has been connected Notifies the router that a new peer has been connected
:param peer_id: id of peer to add :param peer_id: id of peer to add
:param protocol_id: router protocol the peer speaks, e.g., floodsub, gossipsub
""" """
# Add peer to the correct peer list # Add peer to the correct peer list
@ -88,7 +124,7 @@ class GossipSub(IPubsubRouter):
elif peer_type == "flood": elif peer_type == "flood":
self.peers_floodsub.append(peer_id_str) self.peers_floodsub.append(peer_id_str)
def remove_peer(self, peer_id): def remove_peer(self, peer_id: ID) -> None:
""" """
Notifies the router that a peer has been disconnected Notifies the router that a peer has been disconnected
:param peer_id: id of peer to remove :param peer_id: id of peer to remove
@ -96,16 +132,18 @@ class GossipSub(IPubsubRouter):
peer_id_str = str(peer_id) peer_id_str = str(peer_id)
self.peers_to_protocol.remove(peer_id_str) self.peers_to_protocol.remove(peer_id_str)
async def handle_rpc(self, rpc, sender_peer_id): # FIXME: type of `sender_peer_id` should be changed to `ID`
async def handle_rpc(self, rpc: rpc_pb2.Message, sender_peer_id: str):
""" """
Invoked to process control messages in the RPC envelope. Invoked to process control messages in the RPC envelope.
It is invoked after subscriptions and payload messages have been processed It is invoked after subscriptions and payload messages have been processed
:param rpc: rpc message :param rpc: RPC message
:param sender_peer_id: id of the peer who sent the message
""" """
control_message = rpc.control control_message = rpc.control
sender_peer_id = str(sender_peer_id) sender_peer_id = str(sender_peer_id)
# Relay each rpc control to the appropriate handler # Relay each rpc control message to the appropriate handler
if control_message.ihave: if control_message.ihave:
for ihave in control_message.ihave: for ihave in control_message.ihave:
await self.handle_ihave(ihave, sender_peer_id) await self.handle_ihave(ihave, sender_peer_id)
@ -191,7 +229,7 @@ class GossipSub(IPubsubRouter):
# Excludes `msg_forwarder` and `origin` # Excludes `msg_forwarder` and `origin`
yield from send_to.difference([msg_forwarder, origin]) yield from send_to.difference([msg_forwarder, origin])
async def join(self, topic): async def join(self, topic: str) -> None:
# Note: the comments here are the near-exact algorithm description from the spec # Note: the comments here are the near-exact algorithm description from the spec
""" """
Join notifies the router that we want to receive and Join notifies the router that we want to receive and
@ -204,8 +242,9 @@ class GossipSub(IPubsubRouter):
# Create mesh[topic] if it does not yet exist # Create mesh[topic] if it does not yet exist
self.mesh[topic] = [] self.mesh[topic] = []
topic_in_fanout = topic in self.fanout topic_in_fanout: bool = topic in self.fanout
fanout_peers = self.fanout[topic] if topic_in_fanout else [] # FIXME: Should be changed to `List[ID]`
fanout_peers: List[str] = self.fanout[topic] if topic_in_fanout else []
fanout_size = len(fanout_peers) fanout_size = len(fanout_peers)
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree): if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
# There are less than D peers (let this number be x) # There are less than D peers (let this number be x)
@ -229,7 +268,7 @@ class GossipSub(IPubsubRouter):
if topic_in_fanout: if topic_in_fanout:
del self.fanout[topic] del self.fanout[topic]
async def leave(self, topic): async def leave(self, topic: str) -> None:
# Note: the comments here are the near-exact algorithm description from the spec # Note: the comments here are the near-exact algorithm description from the spec
""" """
Leave notifies the router that we are no longer interested in a topic. Leave notifies the router that we are no longer interested in a topic.
@ -247,7 +286,7 @@ class GossipSub(IPubsubRouter):
# Interface Helper Functions # Interface Helper Functions
@staticmethod @staticmethod
def get_peer_type(protocol_id): def get_peer_type(protocol_id: str) -> str:
# TODO: Do this in a better, more efficient way # TODO: Do this in a better, more efficient way
if "gossipsub" in protocol_id: if "gossipsub" in protocol_id:
return "gossip" return "gossip"
@ -255,7 +294,13 @@ class GossipSub(IPubsubRouter):
return "flood" return "flood"
return "unknown" return "unknown"
async def deliver_messages_to_peers(self, peers, msg_sender, origin_id, serialized_packet): # FIXME: type of `peers` should be changed to `List[ID]`
# FIXME: type of `msg_sender` and `origin_id` should be changed to `ID`
async def deliver_messages_to_peers(self,
peers: List[str],
msg_sender: str,
origin_id: str,
serialized_packet: bytes):
for peer_id_in_topic in peers: for peer_id_in_topic in peers:
# Forward to all peers that are not the # Forward to all peers that are not the
# message sender and are not the message origin # message sender and are not the message origin
@ -267,7 +312,7 @@ class GossipSub(IPubsubRouter):
await stream.write(serialized_packet) await stream.write(serialized_packet)
# Heartbeat # Heartbeat
async def heartbeat(self): async def heartbeat(self) -> None:
""" """
Call individual heartbeats. Call individual heartbeats.
Note: the heartbeats are called with awaits because each heartbeat depends on the Note: the heartbeats are called with awaits because each heartbeat depends on the
@ -281,7 +326,7 @@ class GossipSub(IPubsubRouter):
await asyncio.sleep(self.heartbeat_interval) await asyncio.sleep(self.heartbeat_interval)
async def mesh_heartbeat(self): async def mesh_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec # Note: the comments here are the exact pseudocode from the spec
for topic in self.mesh: for topic in self.mesh:
# Skip if no peers have subscribed to the topic # Skip if no peers have subscribed to the topic
@ -297,7 +342,8 @@ class GossipSub(IPubsubRouter):
self.mesh[topic], self.mesh[topic],
) )
fanout_peers_not_in_mesh = [ # FIXME: Should be changed to `List[ID]`
fanout_peers_not_in_mesh: List[str] = [
peer peer
for peer in selected_peers for peer in selected_peers
if peer not in self.mesh[topic] if peer not in self.mesh[topic]
@ -311,8 +357,12 @@ class GossipSub(IPubsubRouter):
if num_mesh_peers_in_topic > self.degree_high: if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic] # Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = GossipSub.select_from_minus(num_mesh_peers_in_topic - self.degree, # FIXME: Should be changed to `List[ID]`
self.mesh[topic], []) selected_peers: List[str] = GossipSub.select_from_minus(
num_mesh_peers_in_topic - self.degree,
self.mesh[topic],
[],
)
for peer in selected_peers: for peer in selected_peers:
# Remove peer from mesh[topic] # Remove peer from mesh[topic]
self.mesh[topic].remove(peer) self.mesh[topic].remove(peer)
@ -320,7 +370,7 @@ class GossipSub(IPubsubRouter):
# Emit PRUNE(topic) control message to peer # Emit PRUNE(topic) control message to peer
await self.emit_prune(topic, peer) await self.emit_prune(topic, peer)
async def fanout_heartbeat(self): async def fanout_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec # Note: the comments here are the exact pseudocode from the spec
for topic in self.fanout: for topic in self.fanout:
# If time since last published > ttl # If time since last published > ttl
@ -362,14 +412,14 @@ class GossipSub(IPubsubRouter):
# TODO: this line is a monster, can hopefully be simplified # TODO: this line is a monster, can hopefully be simplified
if (topic not in self.mesh or (peer not in self.mesh[topic]))\ if (topic not in self.mesh or (peer not in self.mesh[topic]))\
and (topic not in self.fanout or (peer not in self.fanout[topic])): and (topic not in self.fanout or (peer not in self.fanout[topic])):
msg_ids = [str(msg) for msg in msg_ids] msg_ids: List[str] = [str(msg) for msg in msg_ids]
await self.emit_ihave(topic, msg_ids, peer) await self.emit_ihave(topic, msg_ids, peer)
# TODO: Refactor and Dedup. This section is the roughly the same as the above. # TODO: Refactor and Dedup. This section is the roughly the same as the above.
# Do the same for fanout, for all topics not already hit in mesh # Do the same for fanout, for all topics not already hit in mesh
for topic in self.fanout: for topic in self.fanout:
if topic not in self.mesh: if topic not in self.mesh:
msg_ids = self.mcache.window(topic) msg_ids: List[str] = self.mcache.window(topic)
if msg_ids: if msg_ids:
# TODO: Make more efficient, possibly using a generator? # TODO: Make more efficient, possibly using a generator?
# Get all pubsub peers in topic and only add if they are gossipsub peers also # Get all pubsub peers in topic and only add if they are gossipsub peers also
@ -383,13 +433,13 @@ class GossipSub(IPubsubRouter):
for peer in peers_to_emit_ihave_to: for peer in peers_to_emit_ihave_to:
if peer not in self.mesh[topic] and peer not in self.fanout[topic]: if peer not in self.mesh[topic] and peer not in self.fanout[topic]:
msg_ids = [str(msg) for msg in msg_ids] msg_ids: List[str] = [str(msg) for msg in msg_ids]
await self.emit_ihave(topic, msg_ids, peer) await self.emit_ihave(topic, msg_ids, peer)
self.mcache.shift() self.mcache.shift()
@staticmethod @staticmethod
def select_from_minus(num_to_select, pool, minus): def select_from_minus(num_to_select: int, pool: Sequence[Any], minus: Sequence[Any]) -> List[Any]:
""" """
Select at most num_to_select subset of elements from the set (pool - minus) randomly. Select at most num_to_select subset of elements from the set (pool - minus) randomly.
:param num_to_select: number of elements to randomly select :param num_to_select: number of elements to randomly select
@ -400,10 +450,10 @@ class GossipSub(IPubsubRouter):
# Create selection pool, which is selection_pool = pool - minus # Create selection pool, which is selection_pool = pool - minus
if minus: if minus:
# Create a new selection pool by removing elements of minus # Create a new selection pool by removing elements of minus
selection_pool = [x for x in pool if x not in minus] selection_pool: List[Any] = [x for x in pool if x not in minus]
else: else:
# Don't create a new selection_pool if we are not subbing anything # Don't create a new selection_pool if we are not subbing anything
selection_pool = pool selection_pool: List[Any] = pool
# If num_to_select > size(selection_pool), then return selection_pool (which has the most # If num_to_select > size(selection_pool), then return selection_pool (which has the most
# possible elements s.t. the number of elements is less than num_to_select) # possible elements s.t. the number of elements is less than num_to_select)
@ -411,7 +461,7 @@ class GossipSub(IPubsubRouter):
return selection_pool return selection_pool
# Random selection # Random selection
selection = random.sample(selection_pool, num_to_select) selection: List[Any] = random.sample(selection_pool, num_to_select)
return selection return selection
@ -433,7 +483,7 @@ class GossipSub(IPubsubRouter):
# RPC handlers # RPC handlers
async def handle_ihave(self, ihave_msg, sender_peer_id): async def handle_ihave(self, ihave_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
""" """
Checks the seen set and requests unknown messages with an IWANT message. Checks the seen set and requests unknown messages with an IWANT message.
""" """
@ -442,29 +492,36 @@ class GossipSub(IPubsubRouter):
from_id_str = sender_peer_id from_id_str = sender_peer_id
# Get list of all seen (seqnos, from) from the (seqno, from) tuples in seen_messages cache # Get list of all seen (seqnos, from) from the (seqno, from) tuples in seen_messages cache
seen_seqnos_and_peers = [seqno_and_from seen_seqnos_and_peers = [
for seqno_and_from in self.pubsub.seen_messages.keys()] seqno_and_from
for seqno_and_from in self.pubsub.seen_messages.keys()
]
# Add all unknown message ids (ids that appear in ihave_msg but not in seen_seqnos) to list # Add all unknown message ids (ids that appear in ihave_msg but not in seen_seqnos) to list
# of messages we want to request # of messages we want to request
msg_ids_wanted = [msg_id for msg_id in ihave_msg.messageIDs # FIXME: Update type of message ID
if literal_eval(msg_id) not in seen_seqnos_and_peers] msg_ids_wanted = [
msg_id
for msg_id in ihave_msg.messageIDs
if literal_eval(msg_id) not in seen_seqnos_and_peers
]
# Request messages with IWANT message # Request messages with IWANT message
if msg_ids_wanted: if msg_ids_wanted:
await self.emit_iwant(msg_ids_wanted, from_id_str) await self.emit_iwant(msg_ids_wanted, from_id_str)
async def handle_iwant(self, iwant_msg, sender_peer_id): async def handle_iwant(self, iwant_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
""" """
Forwards all request messages that are present in mcache to the requesting peer. Forwards all request messages that are present in mcache to the requesting peer.
""" """
from_id_str = sender_peer_id from_id_str = sender_peer_id
msg_ids = [literal_eval(msg) for msg in iwant_msg.messageIDs] # FIXME: Update type of message ID
msgs_to_forward = [] msg_ids: List[Any] = [literal_eval(msg) for msg in iwant_msg.messageIDs]
msgs_to_forward: List = []
for msg_id_iwant in msg_ids: for msg_id_iwant in msg_ids:
# Check if the wanted message ID is present in mcache # Check if the wanted message ID is present in mcache
msg = self.mcache.get(msg_id_iwant) msg: rpc_pb2.Message = self.mcache.get(msg_id_iwant)
# Cache hit # Cache hit
if msg: if msg:
@ -476,12 +533,12 @@ class GossipSub(IPubsubRouter):
# because then the message will forwarded to peers in the topics contained in the messages. # because then the message will forwarded to peers in the topics contained in the messages.
# We should # We should
# 1) Package these messages into a single packet # 1) Package these messages into a single packet
packet = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.publish.extend(msgs_to_forward) packet.publish.extend(msgs_to_forward)
# 2) Serialize that packet # 2) Serialize that packet
rpc_msg = packet.SerializeToString() rpc_msg: bytes = packet.SerializeToString()
# 3) Get the stream to this peer # 3) Get the stream to this peer
# TODO: Should we pass in from_id or from_id_str here? # TODO: Should we pass in from_id or from_id_str here?
@ -490,8 +547,8 @@ class GossipSub(IPubsubRouter):
# 4) And write the packet to the stream # 4) And write the packet to the stream
await peer_stream.write(rpc_msg) await peer_stream.write(rpc_msg)
async def handle_graft(self, graft_msg, sender_peer_id): async def handle_graft(self, graft_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
topic = graft_msg.topicID topic: str = graft_msg.topicID
from_id_str = sender_peer_id from_id_str = sender_peer_id
@ -503,8 +560,8 @@ class GossipSub(IPubsubRouter):
# Respond with PRUNE if not subscribed to the topic # Respond with PRUNE if not subscribed to the topic
await self.emit_prune(topic, sender_peer_id) await self.emit_prune(topic, sender_peer_id)
async def handle_prune(self, prune_msg, sender_peer_id): async def handle_prune(self, prune_msg: rpc_pb2.Message, sender_peer_id: str) -> None:
topic = prune_msg.topicID topic: str = prune_msg.topicID
from_id_str = sender_peer_id from_id_str = sender_peer_id
@ -514,65 +571,65 @@ class GossipSub(IPubsubRouter):
# RPC emitters # RPC emitters
async def emit_ihave(self, topic, msg_ids, to_peer): async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: str) -> None:
""" """
Emit ihave message, sent to to_peer, for topic and msg_ids Emit ihave message, sent to to_peer, for topic and msg_ids
""" """
ihave_msg = rpc_pb2.ControlIHave() ihave_msg: rpc_pb2.ControlIHave = rpc_pb2.ControlIHave()
ihave_msg.messageIDs.extend(msg_ids) ihave_msg.messageIDs.extend(msg_ids)
ihave_msg.topicID = topic ihave_msg.topicID = topic
control_msg = rpc_pb2.ControlMessage() control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.ihave.extend([ihave_msg]) control_msg.ihave.extend([ihave_msg])
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_iwant(self, msg_ids, to_peer): async def emit_iwant(self, msg_ids: Any, to_peer: str) -> None:
""" """
Emit iwant message, sent to to_peer, for msg_ids Emit iwant message, sent to to_peer, for msg_ids
""" """
iwant_msg = rpc_pb2.ControlIWant() iwant_msg: rpc_pb2.ControlIWant = rpc_pb2.ControlIWant()
iwant_msg.messageIDs.extend(msg_ids) iwant_msg.messageIDs.extend(msg_ids)
control_msg = rpc_pb2.ControlMessage() control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.iwant.extend([iwant_msg]) control_msg.iwant.extend([iwant_msg])
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_graft(self, topic, to_peer): async def emit_graft(self, topic: str, to_peer: str) -> None:
""" """
Emit graft message, sent to to_peer, for topic Emit graft message, sent to to_peer, for topic
""" """
graft_msg = rpc_pb2.ControlGraft() graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft()
graft_msg.topicID = topic graft_msg.topicID = topic
control_msg = rpc_pb2.ControlMessage() control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.graft.extend([graft_msg]) control_msg.graft.extend([graft_msg])
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_prune(self, topic, to_peer): async def emit_prune(self, topic: str, to_peer: str) -> None:
""" """
Emit graft message, sent to to_peer, for topic Emit graft message, sent to to_peer, for topic
""" """
prune_msg = rpc_pb2.ControlPrune() prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune()
prune_msg.topicID = topic prune_msg.topicID = topic
control_msg = rpc_pb2.ControlMessage() control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
control_msg.prune.extend([prune_msg]) control_msg.prune.extend([prune_msg])
await self.emit_control_message(control_msg, to_peer) await self.emit_control_message(control_msg, to_peer)
async def emit_control_message(self, control_msg, to_peer): async def emit_control_message(self, control_msg: rpc_pb2.ControlMessage, to_peer: str) -> None:
# Add control message to packet # Add control message to packet
packet = rpc_pb2.RPC() packet: rpc_pb2.RPC = rpc_pb2.RPC()
packet.control.CopyFrom(control_msg) packet.control.CopyFrom(control_msg)
rpc_msg = packet.SerializeToString() rpc_msg: bytes = packet.SerializeToString()
# Get stream for peer from pubsub # Get stream for peer from pubsub
peer_stream = self.pubsub.peers[to_peer] peer_stream = self.pubsub.peers[to_peer]