Merge branch 'master' into feature/porting-to-trio

This commit is contained in:
mhchia
2019-12-24 02:19:43 +08:00
112 changed files with 3868 additions and 1946 deletions

View File

@ -81,16 +81,20 @@ class FloodSub(IPubsubRouter):
:param pubsub_msg: pubsub message in protobuf.
"""
peers_gen = self._get_peers_to_send(
pubsub_msg.topicIDs,
msg_forwarder=msg_forwarder,
origin=ID(pubsub_msg.from_id),
peers_gen = set(
self._get_peers_to_send(
pubsub_msg.topicIDs,
msg_forwarder=msg_forwarder,
origin=ID(pubsub_msg.from_id),
)
)
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
if peer_id not in self.pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
@ -98,6 +102,7 @@ class FloodSub(IPubsubRouter):
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
except StreamClosed:
logger.debug("Fail to publish message to %s: stream closed", peer_id)
self.pubsub._handle_dead_peer(peer_id)
async def join(self, topic: str) -> None:
"""

View File

@ -1,7 +1,8 @@
from ast import literal_eval
from collections import defaultdict
import logging
import random
from typing import Any, Dict, Iterable, List, Sequence, Set
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple
from async_service import Service
import trio
@ -33,18 +34,18 @@ class GossipSub(IPubsubRouter, Service):
time_to_live: int
mesh: Dict[str, List[ID]]
fanout: Dict[str, List[ID]]
mesh: Dict[str, Set[ID]]
fanout: Dict[str, Set[ID]]
peers_to_protocol: Dict[ID, str]
# The protocol peer supports
peer_protocol: Dict[ID, TProtocol]
time_since_last_publish: Dict[str, int]
peers_gossipsub: List[ID]
peers_floodsub: List[ID]
# TODO: Add `time_since_last_publish`
# Create topic --> time since last publish map.
mcache: MessageCache
heartbeat_initial_delay: float
heartbeat_interval: int
def __init__(
@ -56,6 +57,7 @@ class GossipSub(IPubsubRouter, Service):
time_to_live: int,
gossip_window: int = 3,
gossip_history: int = 5,
heartbeat_initial_delay: float = 0.1,
heartbeat_interval: int = 120,
) -> None:
self.protocols = list(protocols)
@ -74,18 +76,13 @@ class GossipSub(IPubsubRouter, Service):
self.fanout = {}
# Create peer --> protocol mapping
self.peers_to_protocol = {}
# Create topic --> time since last publish map
self.time_since_last_publish = {}
self.peers_gossipsub = []
self.peers_floodsub = []
self.peer_protocol = {}
# Create message cache
self.mcache = MessageCache(gossip_window, gossip_history)
# Create heartbeat timer
self.heartbeat_initial_delay = heartbeat_initial_delay
self.heartbeat_interval = heartbeat_interval
async def run(self) -> None:
@ -122,18 +119,13 @@ class GossipSub(IPubsubRouter, Service):
"""
logger.debug("adding peer %s with protocol %s", peer_id, protocol_id)
if protocol_id == PROTOCOL_ID:
self.peers_gossipsub.append(peer_id)
elif protocol_id == floodsub.PROTOCOL_ID:
self.peers_floodsub.append(peer_id)
else:
if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID):
# We should never enter here. Becuase the `protocol_id` is registered by your pubsub
# instance in multistream-select, but it is not the protocol that gossipsub supports.
# In this case, probably we registered gossipsub to a wrong `protocol_id`
# in multistream-select, or wrong versions.
# TODO: Better handling
raise Exception(f"protocol is not supported: protocol_id={protocol_id}")
self.peers_to_protocol[peer_id] = protocol_id
raise ValueError(f"Protocol={protocol_id} is not supported.")
self.peer_protocol[peer_id] = protocol_id
def remove_peer(self, peer_id: ID) -> None:
"""
@ -143,13 +135,12 @@ class GossipSub(IPubsubRouter, Service):
"""
logger.debug("removing peer %s", peer_id)
if peer_id in self.peers_gossipsub:
self.peers_gossipsub.remove(peer_id)
elif peer_id in self.peers_floodsub:
self.peers_floodsub.remove(peer_id)
for topic in self.mesh:
self.mesh[topic].discard(peer_id)
for topic in self.fanout:
self.fanout[topic].discard(peer_id)
if peer_id in self.peers_to_protocol:
del self.peers_to_protocol[peer_id]
self.peer_protocol.pop(peer_id, None)
async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None:
"""
@ -189,6 +180,8 @@ class GossipSub(IPubsubRouter, Service):
logger.debug("publishing message %s", pubsub_msg)
for peer_id in peers_gen:
if peer_id not in self.pubsub.peers:
continue
stream = self.pubsub.peers[peer_id]
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
@ -215,36 +208,41 @@ class GossipSub(IPubsubRouter, Service):
continue
# floodsub peers
for peer_id in self.pubsub.peer_topics[topic]:
# FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
# This will improve the efficiency when searching for a peer's protocol id.
if peer_id in self.peers_floodsub:
send_to.add(peer_id)
floodsub_peers: Set[ID] = set(
peer_id
for peer_id in self.pubsub.peer_topics[topic]
if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
)
send_to.update(floodsub_peers)
# gossipsub peers
in_topic_gossipsub_peers: List[ID] = None
# TODO: Do we need to check `topic in self.pubsub.my_topics`?
gossipsub_peers: Set[ID] = set()
if topic in self.mesh:
in_topic_gossipsub_peers = self.mesh[topic]
gossipsub_peers = self.mesh[topic]
else:
# TODO(robzajac): Is topic DEFINITELY supposed to be in fanout if we are not
# subscribed?
# I assume there could be short periods between heartbeats where topic may not
# be but we should check that this path gets hit appropriately
if (topic not in self.fanout) or (len(self.fanout[topic]) == 0):
# If no peers in fanout, choose some peers from gossipsub peers in topic.
self.fanout[topic] = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree, []
)
in_topic_gossipsub_peers = self.fanout[topic]
for peer_id in in_topic_gossipsub_peers:
send_to.add(peer_id)
# When we publish to a topic that we have not subscribe to, we randomly pick
# `self.degree` number of peers who have subscribed to the topic and add them
# as our `fanout` peers.
topic_in_fanout: bool = topic in self.fanout
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
fanout_size = len(fanout_peers)
if not topic_in_fanout or (
topic_in_fanout and fanout_size < self.degree
):
if topic in self.pubsub.peer_topics:
# Combine fanout peers with selected peers
fanout_peers.update(
self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree - fanout_size, fanout_peers
)
)
self.fanout[topic] = fanout_peers
gossipsub_peers = fanout_peers
send_to.update(gossipsub_peers)
# Excludes `msg_forwarder` and `origin`
yield from send_to.difference([msg_forwarder, origin])
async def join(self, topic: str) -> None:
# Note: the comments here are the near-exact algorithm description from the spec
"""
Join notifies the router that we want to receive and forward messages
in a topic. It is invoked after the subscription announcement.
@ -256,10 +254,10 @@ class GossipSub(IPubsubRouter, Service):
if topic in self.mesh:
return
# Create mesh[topic] if it does not yet exist
self.mesh[topic] = []
self.mesh[topic] = set()
topic_in_fanout: bool = topic in self.fanout
fanout_peers: List[ID] = self.fanout[topic] if topic_in_fanout else []
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
fanout_size = len(fanout_peers)
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
# There are less than D peers (let this number be x)
@ -270,16 +268,14 @@ class GossipSub(IPubsubRouter, Service):
topic, self.degree - fanout_size, fanout_peers
)
# Combine fanout peers with selected peers
fanout_peers += selected_peers
fanout_peers.update(selected_peers)
# Add fanout peers to mesh and notifies them with a GRAFT(topic) control message.
for peer in fanout_peers:
if peer not in self.mesh[topic]:
self.mesh[topic].append(peer)
await self.emit_graft(topic, peer)
self.mesh[topic].add(peer)
await self.emit_graft(topic, peer)
if topic_in_fanout:
del self.fanout[topic]
self.fanout.pop(topic, None)
async def leave(self, topic: str) -> None:
# Note: the comments here are the near-exact algorithm description from the spec
@ -298,7 +294,75 @@ class GossipSub(IPubsubRouter, Service):
await self.emit_prune(topic, peer)
# Forget mesh[topic]
del self.mesh[topic]
self.mesh.pop(topic, None)
async def _emit_control_msgs(
self,
peers_to_graft: Dict[ID, List[str]],
peers_to_prune: Dict[ID, List[str]],
peers_to_gossip: Dict[ID, Dict[str, List[str]]],
) -> None:
graft_msgs: List[rpc_pb2.ControlGraft] = []
prune_msgs: List[rpc_pb2.ControlPrune] = []
ihave_msgs: List[rpc_pb2.ControlIHave] = []
# Starting with GRAFT messages
for peer, topics in peers_to_graft.items():
for topic in topics:
graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft(topicID=topic)
graft_msgs.append(graft_msg)
# If there are also PRUNE messages to send to this peer
if peer in peers_to_prune:
for topic in peers_to_prune[peer]:
prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune(
topicID=topic
)
prune_msgs.append(prune_msg)
del peers_to_prune[peer]
# If there are also IHAVE messages to send to this peer
if peer in peers_to_gossip:
for topic in peers_to_gossip[peer]:
ihave_msg: rpc_pb2.ControlIHave = rpc_pb2.ControlIHave(
messageIDs=peers_to_gossip[peer][topic], topicID=topic
)
ihave_msgs.append(ihave_msg)
del peers_to_gossip[peer]
control_msg = self.pack_control_msgs(ihave_msgs, graft_msgs, prune_msgs)
await self.emit_control_message(control_msg, peer)
# Next with PRUNE messages
for peer, topics in peers_to_prune.items():
prune_msgs = []
for topic in topics:
prune_msg = rpc_pb2.ControlPrune(topicID=topic)
prune_msgs.append(prune_msg)
# If there are also IHAVE messages to send to this peer
if peer in peers_to_gossip:
ihave_msgs = []
for topic in peers_to_gossip[peer]:
ihave_msg = rpc_pb2.ControlIHave(
messageIDs=peers_to_gossip[peer][topic], topicID=topic
)
ihave_msgs.append(ihave_msg)
del peers_to_gossip[peer]
control_msg = self.pack_control_msgs(ihave_msgs, None, prune_msgs)
await self.emit_control_message(control_msg, peer)
# Fianlly IHAVE messages
for peer in peers_to_gossip:
ihave_msgs = []
for topic in peers_to_gossip[peer]:
ihave_msg = rpc_pb2.ControlIHave(
messageIDs=peers_to_gossip[peer][topic], topicID=topic
)
ihave_msgs.append(ihave_msg)
control_msg = self.pack_control_msgs(ihave_msgs, None, None)
await self.emit_control_message(control_msg, peer)
# Heartbeat
async def heartbeat(self) -> None:
@ -308,16 +372,29 @@ class GossipSub(IPubsubRouter, Service):
Note: the heartbeats are called with awaits because each heartbeat depends on the
state changes in the preceding heartbeat
"""
# Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501
await trio.sleep(self.heartbeat_initial_delay)
while True:
# Maintain mesh and keep track of which peers to send GRAFT or PRUNE to
peers_to_graft, peers_to_prune = self.mesh_heartbeat()
# Maintain fanout
self.fanout_heartbeat()
# Get the peers to send IHAVE to
peers_to_gossip = self.gossip_heartbeat()
# Pack GRAFT, PRUNE and IHAVE for the same peer into one control message and send it
await self._emit_control_msgs(
peers_to_graft, peers_to_prune, peers_to_gossip
)
await self.mesh_heartbeat()
await self.fanout_heartbeat()
await self.gossip_heartbeat()
self.mcache.shift()
await trio.sleep(self.heartbeat_interval)
async def mesh_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec
def mesh_heartbeat(
self
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
for topic in self.mesh:
# Skip if no peers have subscribed to the topic
if topic not in self.pubsub.peer_topics:
@ -330,41 +407,43 @@ class GossipSub(IPubsubRouter, Service):
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic]
)
fanout_peers_not_in_mesh: List[ID] = [
peer for peer in selected_peers if peer not in self.mesh[topic]
]
for peer in fanout_peers_not_in_mesh:
for peer in selected_peers:
# Add peer to mesh[topic]
self.mesh[topic].append(peer)
self.mesh[topic].add(peer)
# Emit GRAFT(topic) control message to peer
await self.emit_graft(topic, peer)
peers_to_graft[peer].append(topic)
if num_mesh_peers_in_topic > self.degree_high:
# Select |mesh[topic]| - D peers from mesh[topic]
selected_peers = self.select_from_minus(
num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
num_mesh_peers_in_topic - self.degree, self.mesh[topic], set()
)
for peer in selected_peers:
# Remove peer from mesh[topic]
self.mesh[topic].remove(peer)
self.mesh[topic].discard(peer)
# Emit PRUNE(topic) control message to peer
await self.emit_prune(topic, peer)
peers_to_prune[peer].append(topic)
return peers_to_graft, peers_to_prune
async def fanout_heartbeat(self) -> None:
def fanout_heartbeat(self) -> None:
# Note: the comments here are the exact pseudocode from the spec
for topic in self.fanout:
# If time since last published > ttl
# TODO: there's no way time_since_last_publish gets set anywhere yet
if (
topic in self.time_since_last_publish
and self.time_since_last_publish[topic] > self.time_to_live
):
# Delete topic entry if it's not in `pubsub.peer_topics`
# or (TODO) if it's time-since-last-published > ttl
if topic not in self.pubsub.peer_topics:
# Remove topic from fanout
del self.fanout[topic]
del self.time_since_last_publish[topic]
else:
# Check if fanout peers are still in the topic and remove the ones that are not
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501
in_topic_fanout_peers = [
peer
for peer in self.fanout[topic]
if peer in self.pubsub.peer_topics[topic]
]
self.fanout[topic] = set(in_topic_fanout_peers)
num_fanout_peers_in_topic = len(self.fanout[topic])
# If |fanout[topic]| < D
@ -376,53 +455,43 @@ class GossipSub(IPubsubRouter, Service):
self.fanout[topic],
)
# Add the peers to fanout[topic]
self.fanout[topic].extend(selected_peers)
self.fanout[topic].update(selected_peers)
async def gossip_heartbeat(self) -> None:
def gossip_heartbeat(self) -> DefaultDict[ID, Dict[str, List[str]]]:
peers_to_gossip: DefaultDict[ID, Dict[str, List[str]]] = defaultdict(dict)
for topic in self.mesh:
msg_ids = self.mcache.window(topic)
if msg_ids:
# TODO: Make more efficient, possibly using a generator?
# Get all pubsub peers in a topic and only add them if they are gossipsub peers too
if topic in self.pubsub.peer_topics:
# Select D peers from peers.gossipsub[topic]
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree, []
topic, self.degree, self.mesh[topic]
)
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
for peer in peers_to_emit_ihave_to:
# TODO: this line is a monster, can hopefully be simplified
if (
topic not in self.mesh or (peer not in self.mesh[topic])
) and (
topic not in self.fanout or (peer not in self.fanout[topic])
):
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
await self.emit_ihave(topic, msg_id_strs, peer)
peers_to_gossip[peer][topic] = msg_id_strs
# TODO: Refactor and Dedup. This section is the roughly the same as the above.
# Do the same for fanout, for all topics not already hit in mesh
for topic in self.fanout:
if topic not in self.mesh:
msg_ids = self.mcache.window(topic)
if msg_ids:
# TODO: Make more efficient, possibly using a generator?
# Get all pubsub peers in topic and only add if they are gossipsub peers also
if topic in self.pubsub.peer_topics:
# Select D peers from peers.gossipsub[topic]
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree, []
)
for peer in peers_to_emit_ihave_to:
if peer not in self.fanout[topic]:
msg_id_strs = [str(msg) for msg in msg_ids]
await self.emit_ihave(topic, msg_id_strs, peer)
self.mcache.shift()
msg_ids = self.mcache.window(topic)
if msg_ids:
# Get all pubsub peers in topic and only add if they are gossipsub peers also
if topic in self.pubsub.peer_topics:
# Select D peers from peers.gossipsub[topic]
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
topic, self.degree, self.fanout[topic]
)
msg_id_strs = [str(msg) for msg in msg_ids]
for peer in peers_to_emit_ihave_to:
peers_to_gossip[peer][topic] = msg_id_strs
return peers_to_gossip
@staticmethod
def select_from_minus(
num_to_select: int, pool: Sequence[Any], minus: Sequence[Any]
num_to_select: int, pool: Iterable[Any], minus: Iterable[Any]
) -> List[Any]:
"""
Select at most num_to_select subset of elements from the set (pool - minus) randomly.
@ -441,7 +510,7 @@ class GossipSub(IPubsubRouter, Service):
# If num_to_select > size(selection_pool), then return selection_pool (which has the most
# possible elements s.t. the number of elements is less than num_to_select)
if num_to_select > len(selection_pool):
if num_to_select >= len(selection_pool):
return selection_pool
# Random selection
@ -450,16 +519,14 @@ class GossipSub(IPubsubRouter, Service):
return selection
def _get_in_topic_gossipsub_peers_from_minus(
self, topic: str, num_to_select: int, minus: Sequence[ID]
self, topic: str, num_to_select: int, minus: Iterable[ID]
) -> List[ID]:
gossipsub_peers_in_topic = [
gossipsub_peers_in_topic = set(
peer_id
for peer_id in self.pubsub.peer_topics[topic]
if peer_id in self.peers_gossipsub
]
return self.select_from_minus(
num_to_select, gossipsub_peers_in_topic, list(minus)
if self.peer_protocol[peer_id] == PROTOCOL_ID
)
return self.select_from_minus(num_to_select, gossipsub_peers_in_topic, minus)
# RPC handlers
@ -517,6 +584,12 @@ class GossipSub(IPubsubRouter, Service):
rpc_msg: bytes = packet.SerializeToString()
# 3) Get the stream to this peer
if sender_peer_id not in self.pubsub.peers:
logger.debug(
"Fail to responed to iwant request from %s: peer record not exist",
sender_peer_id,
)
return
peer_stream = self.pubsub.peers[sender_peer_id]
# 4) And write the packet to the stream
@ -537,7 +610,7 @@ class GossipSub(IPubsubRouter, Service):
# Add peer to mesh for topic
if topic in self.mesh:
if sender_peer_id not in self.mesh[topic]:
self.mesh[topic].append(sender_peer_id)
self.mesh[topic].add(sender_peer_id)
else:
# Respond with PRUNE if not subscribed to the topic
await self.emit_prune(topic, sender_peer_id)
@ -547,12 +620,27 @@ class GossipSub(IPubsubRouter, Service):
) -> None:
topic: str = prune_msg.topicID
# Remove peer from mesh for topic, if peer is in topic
if topic in self.mesh and sender_peer_id in self.mesh[topic]:
self.mesh[topic].remove(sender_peer_id)
# Remove peer from mesh for topic
if topic in self.mesh:
self.mesh[topic].discard(sender_peer_id)
# RPC emitters
def pack_control_msgs(
self,
ihave_msgs: List[rpc_pb2.ControlIHave],
graft_msgs: List[rpc_pb2.ControlGraft],
prune_msgs: List[rpc_pb2.ControlPrune],
) -> rpc_pb2.ControlMessage:
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
if ihave_msgs:
control_msg.ihave.extend(ihave_msgs)
if graft_msgs:
control_msg.graft.extend(graft_msgs)
if prune_msgs:
control_msg.prune.extend(prune_msgs)
return control_msg
async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: ID) -> None:
"""Emit ihave message, sent to to_peer, for topic and msg_ids."""
@ -608,6 +696,11 @@ class GossipSub(IPubsubRouter, Service):
rpc_msg: bytes = packet.SerializeToString()
# Get stream for peer from pubsub
if to_peer not in self.pubsub.peers:
logger.debug(
"Fail to emit control message to %s: peer record not exist", to_peer
)
return
peer_stream = self.pubsub.peers[to_peer]
# Write rpc to stream

View File

@ -96,8 +96,7 @@ class MessageCache:
last_entries: List[CacheEntry] = self.history[len(self.history) - 1]
for entry in last_entries:
if entry.mid in self.msgs:
del self.msgs[entry.mid]
self.msgs.pop(entry.mid)
i: int = len(self.history) - 2

View File

@ -9,6 +9,7 @@ from typing import (
KeysView,
List,
NamedTuple,
Set,
Tuple,
Union,
cast,
@ -19,6 +20,7 @@ import base58
from lru import LRU
import trio
from libp2p.crypto.keys import PrivateKey
from libp2p.exceptions import ParseError, ValidationError
from libp2p.host.host_interface import IHost
from libp2p.io.exceptions import IncompleteReadError
@ -33,7 +35,7 @@ from .abc import IPubsub, ISubscriptionAPI
from .pb import rpc_pb2
from .pubsub_notifee import PubsubNotifee
from .subscription import TrioSubscriptionAPI
from .validators import signature_validator
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
if TYPE_CHECKING:
from .abc import IPubsubRouter # noqa: F401
@ -73,16 +75,23 @@ class Pubsub(IPubsub, Service):
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
peer_topics: Dict[str, List[ID]]
peer_topics: Dict[str, Set[ID]]
peers: Dict[ID, INetStream]
topic_validators: Dict[str, TopicValidator]
# TODO: Be sure it is increased atomically everytime.
counter: int # uint64
# Indicate if we should enforce signature verification
strict_signing: bool
sign_key: PrivateKey
def __init__(
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
self,
host: IHost,
router: "IPubsubRouter",
cache_size: int = None,
strict_signing: bool = True,
) -> None:
"""
Construct a new Pubsub object, which is responsible for handling all
@ -126,6 +135,12 @@ class Pubsub(IPubsub, Service):
else:
self.cache_size = cache_size
self.strict_signing = strict_signing
if strict_signing:
self.sign_key = self.host.get_private_key()
else:
self.sign_key = None
self.seen_messages = LRU(self.cache_size)
# Map of topics we are subscribed to blocking queues
@ -142,7 +157,7 @@ class Pubsub(IPubsub, Service):
# Map of topic to topic validator
self.topic_validators = {}
self.counter = time.time_ns()
self.counter = int(time.time())
async def run(self) -> None:
self.manager.run_daemon_task(self.handle_peer_queue)
@ -239,8 +254,7 @@ class Pubsub(IPubsub, Service):
:param topic: the topic to remove validator from
"""
if topic in self.topic_validators:
del self.topic_validators[topic]
self.topic_validators.pop(topic, None)
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
"""
@ -282,24 +296,22 @@ class Pubsub(IPubsub, Service):
logger.debug("fail to add new peer %s, error %s", peer_id, error)
return
self.peers[peer_id] = stream
# Send hello packet
hello = self.get_hello_packet()
try:
await stream.write(encode_varint_prefixed(hello.SerializeToString()))
except StreamClosed:
logger.debug("Fail to add new peer %s: stream closed", peer_id)
del self.peers[peer_id]
return
# TODO: Check if the peer in black list.
try:
self.router.add_peer(peer_id, stream.get_protocol())
except Exception as error:
logger.debug("fail to add new peer %s, error %s", peer_id, error)
del self.peers[peer_id]
return
self.peers[peer_id] = stream
logger.debug("added new peer %s", peer_id)
def _handle_dead_peer(self, peer_id: ID) -> None:
@ -309,19 +321,16 @@ class Pubsub(IPubsub, Service):
for topic in self.peer_topics:
if peer_id in self.peer_topics[topic]:
self.peer_topics[topic].remove(peer_id)
self.peer_topics[topic].discard(peer_id)
self.router.remove_peer(peer_id)
logger.debug("removed dead peer %s", peer_id)
async def handle_peer_queue(self) -> None:
"""
Continuously read from peer channel and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol
TODO: Handle failure for when the peer does not support any of the
pubsub protocols we support
"""
"""Continuously read from peer queue and each time a new peer is found,
open a stream to the peer using a supported pubsub protocol pubsub
protocols we support."""
async with self.peer_receive_channel:
while self.manager.is_running:
peer_id: ID = await self.peer_receive_channel.receive()
@ -351,14 +360,14 @@ class Pubsub(IPubsub, Service):
"""
if sub_message.subscribe:
if sub_message.topicid not in self.peer_topics:
self.peer_topics[sub_message.topicid] = [origin_id]
self.peer_topics[sub_message.topicid] = set([origin_id])
elif origin_id not in self.peer_topics[sub_message.topicid]:
# Add peer to topic
self.peer_topics[sub_message.topicid].append(origin_id)
self.peer_topics[sub_message.topicid].add(origin_id)
else:
if sub_message.topicid in self.peer_topics:
if origin_id in self.peer_topics[sub_message.topicid]:
self.peer_topics[sub_message.topicid].remove(origin_id)
self.peer_topics[sub_message.topicid].discard(origin_id)
# FIXME(mhchia): Change the function name?
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:
@ -476,7 +485,13 @@ class Pubsub(IPubsub, Service):
seqno=self._next_seqno(),
)
# TODO: Sign with our signing key
if self.strict_signing:
priv_key = self.sign_key
signature = priv_key.sign(
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
)
msg.key = self.host.get_public_key().serialize()
msg.signature = signature
await self.push_msg(self.my_id, msg)
@ -536,18 +551,17 @@ class Pubsub(IPubsub, Service):
# TODO: Check if the `from` is in the blacklist. If yes, reject.
# TODO: Check if signing is required and if so signature should be attached.
# If the message is processed before, return(i.e., don't further process the message).
if self._is_msg_seen(msg):
return
# TODO: - Validate the message. If failed, reject it.
# Validate the signature of the message
# FIXME: `signature_validator` is currently a stub.
if not signature_validator(msg.key, msg.SerializeToString()):
logger.debug("Signature validation failed for msg: %s", msg)
return
# Check if signing is required and if so validate the signature
if self.strict_signing:
# Validate the signature of the message
if not signature_validator(msg):
logger.debug("Signature validation failed for msg: %s", msg)
return
# Validate the message with registered topic validators.
# If the validation failed, return(i.e., don't further process the message).
try:

View File

@ -1,10 +1,41 @@
# FIXME: Replace the type of `pubkey` with a custom type `Pubkey`
def signature_validator(pubkey: bytes, msg: bytes) -> bool:
import logging
from libp2p.crypto.serialization import deserialize_public_key
from libp2p.peer.id import ID
from .pb import rpc_pb2
logger = logging.getLogger("libp2p.pubsub")
PUBSUB_SIGNING_PREFIX = "libp2p-pubsub:"
def signature_validator(msg: rpc_pb2.Message) -> bool:
"""
Verify the message against the given public key.
:param pubkey: the public key which signs the message.
:param msg: the message signed.
"""
# TODO: Implement the signature validation
return True
# Check if signature is attached
if msg.signature == b"":
logger.debug("Reject because no signature attached for msg: %s", msg)
return False
# Validate if message sender matches message signer,
# i.e., check if `msg.key` matches `msg.from_id`
msg_pubkey = deserialize_public_key(msg.key)
if ID.from_pubkey(msg_pubkey) != msg.from_id:
logger.debug(
"Reject because signing key does not match sender ID for msg: %s", msg
)
return False
# First, construct the original payload that's signed by 'msg.key'
msg_without_key_sig = rpc_pb2.Message(
data=msg.data, topicIDs=msg.topicIDs, from_id=msg.from_id, seqno=msg.seqno
)
payload = PUBSUB_SIGNING_PREFIX.encode() + msg_without_key_sig.SerializeToString()
try:
return msg_pubkey.verify(payload, msg.signature)
except Exception:
return False