mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'master' into feature/porting-to-trio
This commit is contained in:
@ -81,16 +81,20 @@ class FloodSub(IPubsubRouter):
|
||||
:param pubsub_msg: pubsub message in protobuf.
|
||||
"""
|
||||
|
||||
peers_gen = self._get_peers_to_send(
|
||||
pubsub_msg.topicIDs,
|
||||
msg_forwarder=msg_forwarder,
|
||||
origin=ID(pubsub_msg.from_id),
|
||||
peers_gen = set(
|
||||
self._get_peers_to_send(
|
||||
pubsub_msg.topicIDs,
|
||||
msg_forwarder=msg_forwarder,
|
||||
origin=ID(pubsub_msg.from_id),
|
||||
)
|
||||
)
|
||||
rpc_msg = rpc_pb2.RPC(publish=[pubsub_msg])
|
||||
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
for peer_id in peers_gen:
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
@ -98,6 +102,7 @@ class FloodSub(IPubsubRouter):
|
||||
await stream.write(encode_varint_prefixed(rpc_msg.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to publish message to %s: stream closed", peer_id)
|
||||
self.pubsub._handle_dead_peer(peer_id)
|
||||
|
||||
async def join(self, topic: str) -> None:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from ast import literal_eval
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, List, Sequence, Set
|
||||
from typing import Any, DefaultDict, Dict, Iterable, List, Sequence, Set, Tuple
|
||||
|
||||
from async_service import Service
|
||||
import trio
|
||||
@ -33,18 +34,18 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
time_to_live: int
|
||||
|
||||
mesh: Dict[str, List[ID]]
|
||||
fanout: Dict[str, List[ID]]
|
||||
mesh: Dict[str, Set[ID]]
|
||||
fanout: Dict[str, Set[ID]]
|
||||
|
||||
peers_to_protocol: Dict[ID, str]
|
||||
# The protocol peer supports
|
||||
peer_protocol: Dict[ID, TProtocol]
|
||||
|
||||
time_since_last_publish: Dict[str, int]
|
||||
|
||||
peers_gossipsub: List[ID]
|
||||
peers_floodsub: List[ID]
|
||||
# TODO: Add `time_since_last_publish`
|
||||
# Create topic --> time since last publish map.
|
||||
|
||||
mcache: MessageCache
|
||||
|
||||
heartbeat_initial_delay: float
|
||||
heartbeat_interval: int
|
||||
|
||||
def __init__(
|
||||
@ -56,6 +57,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
time_to_live: int,
|
||||
gossip_window: int = 3,
|
||||
gossip_history: int = 5,
|
||||
heartbeat_initial_delay: float = 0.1,
|
||||
heartbeat_interval: int = 120,
|
||||
) -> None:
|
||||
self.protocols = list(protocols)
|
||||
@ -74,18 +76,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
self.fanout = {}
|
||||
|
||||
# Create peer --> protocol mapping
|
||||
self.peers_to_protocol = {}
|
||||
|
||||
# Create topic --> time since last publish map
|
||||
self.time_since_last_publish = {}
|
||||
|
||||
self.peers_gossipsub = []
|
||||
self.peers_floodsub = []
|
||||
self.peer_protocol = {}
|
||||
|
||||
# Create message cache
|
||||
self.mcache = MessageCache(gossip_window, gossip_history)
|
||||
|
||||
# Create heartbeat timer
|
||||
self.heartbeat_initial_delay = heartbeat_initial_delay
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
|
||||
async def run(self) -> None:
|
||||
@ -122,18 +119,13 @@ class GossipSub(IPubsubRouter, Service):
|
||||
"""
|
||||
logger.debug("adding peer %s with protocol %s", peer_id, protocol_id)
|
||||
|
||||
if protocol_id == PROTOCOL_ID:
|
||||
self.peers_gossipsub.append(peer_id)
|
||||
elif protocol_id == floodsub.PROTOCOL_ID:
|
||||
self.peers_floodsub.append(peer_id)
|
||||
else:
|
||||
if protocol_id not in (PROTOCOL_ID, floodsub.PROTOCOL_ID):
|
||||
# We should never enter here. Becuase the `protocol_id` is registered by your pubsub
|
||||
# instance in multistream-select, but it is not the protocol that gossipsub supports.
|
||||
# In this case, probably we registered gossipsub to a wrong `protocol_id`
|
||||
# in multistream-select, or wrong versions.
|
||||
# TODO: Better handling
|
||||
raise Exception(f"protocol is not supported: protocol_id={protocol_id}")
|
||||
self.peers_to_protocol[peer_id] = protocol_id
|
||||
raise ValueError(f"Protocol={protocol_id} is not supported.")
|
||||
self.peer_protocol[peer_id] = protocol_id
|
||||
|
||||
def remove_peer(self, peer_id: ID) -> None:
|
||||
"""
|
||||
@ -143,13 +135,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
"""
|
||||
logger.debug("removing peer %s", peer_id)
|
||||
|
||||
if peer_id in self.peers_gossipsub:
|
||||
self.peers_gossipsub.remove(peer_id)
|
||||
elif peer_id in self.peers_floodsub:
|
||||
self.peers_floodsub.remove(peer_id)
|
||||
for topic in self.mesh:
|
||||
self.mesh[topic].discard(peer_id)
|
||||
for topic in self.fanout:
|
||||
self.fanout[topic].discard(peer_id)
|
||||
|
||||
if peer_id in self.peers_to_protocol:
|
||||
del self.peers_to_protocol[peer_id]
|
||||
self.peer_protocol.pop(peer_id, None)
|
||||
|
||||
async def handle_rpc(self, rpc: rpc_pb2.RPC, sender_peer_id: ID) -> None:
|
||||
"""
|
||||
@ -189,6 +180,8 @@ class GossipSub(IPubsubRouter, Service):
|
||||
logger.debug("publishing message %s", pubsub_msg)
|
||||
|
||||
for peer_id in peers_gen:
|
||||
if peer_id not in self.pubsub.peers:
|
||||
continue
|
||||
stream = self.pubsub.peers[peer_id]
|
||||
# FIXME: We should add a `WriteMsg` similar to write delimited messages.
|
||||
# Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/master/comm.go#L107
|
||||
@ -215,36 +208,41 @@ class GossipSub(IPubsubRouter, Service):
|
||||
continue
|
||||
|
||||
# floodsub peers
|
||||
for peer_id in self.pubsub.peer_topics[topic]:
|
||||
# FIXME: `gossipsub.peers_floodsub` can be changed to `gossipsub.peers` in go.
|
||||
# This will improve the efficiency when searching for a peer's protocol id.
|
||||
if peer_id in self.peers_floodsub:
|
||||
send_to.add(peer_id)
|
||||
floodsub_peers: Set[ID] = set(
|
||||
peer_id
|
||||
for peer_id in self.pubsub.peer_topics[topic]
|
||||
if self.peer_protocol[peer_id] == floodsub.PROTOCOL_ID
|
||||
)
|
||||
send_to.update(floodsub_peers)
|
||||
|
||||
# gossipsub peers
|
||||
in_topic_gossipsub_peers: List[ID] = None
|
||||
# TODO: Do we need to check `topic in self.pubsub.my_topics`?
|
||||
gossipsub_peers: Set[ID] = set()
|
||||
if topic in self.mesh:
|
||||
in_topic_gossipsub_peers = self.mesh[topic]
|
||||
gossipsub_peers = self.mesh[topic]
|
||||
else:
|
||||
# TODO(robzajac): Is topic DEFINITELY supposed to be in fanout if we are not
|
||||
# subscribed?
|
||||
# I assume there could be short periods between heartbeats where topic may not
|
||||
# be but we should check that this path gets hit appropriately
|
||||
|
||||
if (topic not in self.fanout) or (len(self.fanout[topic]) == 0):
|
||||
# If no peers in fanout, choose some peers from gossipsub peers in topic.
|
||||
self.fanout[topic] = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, []
|
||||
)
|
||||
in_topic_gossipsub_peers = self.fanout[topic]
|
||||
for peer_id in in_topic_gossipsub_peers:
|
||||
send_to.add(peer_id)
|
||||
# When we publish to a topic that we have not subscribe to, we randomly pick
|
||||
# `self.degree` number of peers who have subscribed to the topic and add them
|
||||
# as our `fanout` peers.
|
||||
topic_in_fanout: bool = topic in self.fanout
|
||||
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
|
||||
fanout_size = len(fanout_peers)
|
||||
if not topic_in_fanout or (
|
||||
topic_in_fanout and fanout_size < self.degree
|
||||
):
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Combine fanout peers with selected peers
|
||||
fanout_peers.update(
|
||||
self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree - fanout_size, fanout_peers
|
||||
)
|
||||
)
|
||||
self.fanout[topic] = fanout_peers
|
||||
gossipsub_peers = fanout_peers
|
||||
send_to.update(gossipsub_peers)
|
||||
# Excludes `msg_forwarder` and `origin`
|
||||
yield from send_to.difference([msg_forwarder, origin])
|
||||
|
||||
async def join(self, topic: str) -> None:
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
"""
|
||||
Join notifies the router that we want to receive and forward messages
|
||||
in a topic. It is invoked after the subscription announcement.
|
||||
@ -256,10 +254,10 @@ class GossipSub(IPubsubRouter, Service):
|
||||
if topic in self.mesh:
|
||||
return
|
||||
# Create mesh[topic] if it does not yet exist
|
||||
self.mesh[topic] = []
|
||||
self.mesh[topic] = set()
|
||||
|
||||
topic_in_fanout: bool = topic in self.fanout
|
||||
fanout_peers: List[ID] = self.fanout[topic] if topic_in_fanout else []
|
||||
fanout_peers: Set[ID] = self.fanout[topic] if topic_in_fanout else set()
|
||||
fanout_size = len(fanout_peers)
|
||||
if not topic_in_fanout or (topic_in_fanout and fanout_size < self.degree):
|
||||
# There are less than D peers (let this number be x)
|
||||
@ -270,16 +268,14 @@ class GossipSub(IPubsubRouter, Service):
|
||||
topic, self.degree - fanout_size, fanout_peers
|
||||
)
|
||||
# Combine fanout peers with selected peers
|
||||
fanout_peers += selected_peers
|
||||
fanout_peers.update(selected_peers)
|
||||
|
||||
# Add fanout peers to mesh and notifies them with a GRAFT(topic) control message.
|
||||
for peer in fanout_peers:
|
||||
if peer not in self.mesh[topic]:
|
||||
self.mesh[topic].append(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
self.mesh[topic].add(peer)
|
||||
await self.emit_graft(topic, peer)
|
||||
|
||||
if topic_in_fanout:
|
||||
del self.fanout[topic]
|
||||
self.fanout.pop(topic, None)
|
||||
|
||||
async def leave(self, topic: str) -> None:
|
||||
# Note: the comments here are the near-exact algorithm description from the spec
|
||||
@ -298,7 +294,75 @@ class GossipSub(IPubsubRouter, Service):
|
||||
await self.emit_prune(topic, peer)
|
||||
|
||||
# Forget mesh[topic]
|
||||
del self.mesh[topic]
|
||||
self.mesh.pop(topic, None)
|
||||
|
||||
async def _emit_control_msgs(
|
||||
self,
|
||||
peers_to_graft: Dict[ID, List[str]],
|
||||
peers_to_prune: Dict[ID, List[str]],
|
||||
peers_to_gossip: Dict[ID, Dict[str, List[str]]],
|
||||
) -> None:
|
||||
graft_msgs: List[rpc_pb2.ControlGraft] = []
|
||||
prune_msgs: List[rpc_pb2.ControlPrune] = []
|
||||
ihave_msgs: List[rpc_pb2.ControlIHave] = []
|
||||
# Starting with GRAFT messages
|
||||
for peer, topics in peers_to_graft.items():
|
||||
for topic in topics:
|
||||
graft_msg: rpc_pb2.ControlGraft = rpc_pb2.ControlGraft(topicID=topic)
|
||||
graft_msgs.append(graft_msg)
|
||||
|
||||
# If there are also PRUNE messages to send to this peer
|
||||
if peer in peers_to_prune:
|
||||
for topic in peers_to_prune[peer]:
|
||||
prune_msg: rpc_pb2.ControlPrune = rpc_pb2.ControlPrune(
|
||||
topicID=topic
|
||||
)
|
||||
prune_msgs.append(prune_msg)
|
||||
del peers_to_prune[peer]
|
||||
|
||||
# If there are also IHAVE messages to send to this peer
|
||||
if peer in peers_to_gossip:
|
||||
for topic in peers_to_gossip[peer]:
|
||||
ihave_msg: rpc_pb2.ControlIHave = rpc_pb2.ControlIHave(
|
||||
messageIDs=peers_to_gossip[peer][topic], topicID=topic
|
||||
)
|
||||
ihave_msgs.append(ihave_msg)
|
||||
del peers_to_gossip[peer]
|
||||
|
||||
control_msg = self.pack_control_msgs(ihave_msgs, graft_msgs, prune_msgs)
|
||||
await self.emit_control_message(control_msg, peer)
|
||||
|
||||
# Next with PRUNE messages
|
||||
for peer, topics in peers_to_prune.items():
|
||||
prune_msgs = []
|
||||
for topic in topics:
|
||||
prune_msg = rpc_pb2.ControlPrune(topicID=topic)
|
||||
prune_msgs.append(prune_msg)
|
||||
|
||||
# If there are also IHAVE messages to send to this peer
|
||||
if peer in peers_to_gossip:
|
||||
ihave_msgs = []
|
||||
for topic in peers_to_gossip[peer]:
|
||||
ihave_msg = rpc_pb2.ControlIHave(
|
||||
messageIDs=peers_to_gossip[peer][topic], topicID=topic
|
||||
)
|
||||
ihave_msgs.append(ihave_msg)
|
||||
del peers_to_gossip[peer]
|
||||
|
||||
control_msg = self.pack_control_msgs(ihave_msgs, None, prune_msgs)
|
||||
await self.emit_control_message(control_msg, peer)
|
||||
|
||||
# Fianlly IHAVE messages
|
||||
for peer in peers_to_gossip:
|
||||
ihave_msgs = []
|
||||
for topic in peers_to_gossip[peer]:
|
||||
ihave_msg = rpc_pb2.ControlIHave(
|
||||
messageIDs=peers_to_gossip[peer][topic], topicID=topic
|
||||
)
|
||||
ihave_msgs.append(ihave_msg)
|
||||
|
||||
control_msg = self.pack_control_msgs(ihave_msgs, None, None)
|
||||
await self.emit_control_message(control_msg, peer)
|
||||
|
||||
# Heartbeat
|
||||
async def heartbeat(self) -> None:
|
||||
@ -308,16 +372,29 @@ class GossipSub(IPubsubRouter, Service):
|
||||
Note: the heartbeats are called with awaits because each heartbeat depends on the
|
||||
state changes in the preceding heartbeat
|
||||
"""
|
||||
# Start after a delay. Ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L410 # Noqa: E501
|
||||
await trio.sleep(self.heartbeat_initial_delay)
|
||||
while True:
|
||||
# Maintain mesh and keep track of which peers to send GRAFT or PRUNE to
|
||||
peers_to_graft, peers_to_prune = self.mesh_heartbeat()
|
||||
# Maintain fanout
|
||||
self.fanout_heartbeat()
|
||||
# Get the peers to send IHAVE to
|
||||
peers_to_gossip = self.gossip_heartbeat()
|
||||
# Pack GRAFT, PRUNE and IHAVE for the same peer into one control message and send it
|
||||
await self._emit_control_msgs(
|
||||
peers_to_graft, peers_to_prune, peers_to_gossip
|
||||
)
|
||||
|
||||
await self.mesh_heartbeat()
|
||||
await self.fanout_heartbeat()
|
||||
await self.gossip_heartbeat()
|
||||
self.mcache.shift()
|
||||
|
||||
await trio.sleep(self.heartbeat_interval)
|
||||
|
||||
async def mesh_heartbeat(self) -> None:
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
def mesh_heartbeat(
|
||||
self
|
||||
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
|
||||
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
|
||||
for topic in self.mesh:
|
||||
# Skip if no peers have subscribed to the topic
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
@ -330,41 +407,43 @@ class GossipSub(IPubsubRouter, Service):
|
||||
topic, self.degree - num_mesh_peers_in_topic, self.mesh[topic]
|
||||
)
|
||||
|
||||
fanout_peers_not_in_mesh: List[ID] = [
|
||||
peer for peer in selected_peers if peer not in self.mesh[topic]
|
||||
]
|
||||
for peer in fanout_peers_not_in_mesh:
|
||||
for peer in selected_peers:
|
||||
# Add peer to mesh[topic]
|
||||
self.mesh[topic].append(peer)
|
||||
self.mesh[topic].add(peer)
|
||||
|
||||
# Emit GRAFT(topic) control message to peer
|
||||
await self.emit_graft(topic, peer)
|
||||
peers_to_graft[peer].append(topic)
|
||||
|
||||
if num_mesh_peers_in_topic > self.degree_high:
|
||||
# Select |mesh[topic]| - D peers from mesh[topic]
|
||||
selected_peers = self.select_from_minus(
|
||||
num_mesh_peers_in_topic - self.degree, self.mesh[topic], []
|
||||
num_mesh_peers_in_topic - self.degree, self.mesh[topic], set()
|
||||
)
|
||||
for peer in selected_peers:
|
||||
# Remove peer from mesh[topic]
|
||||
self.mesh[topic].remove(peer)
|
||||
self.mesh[topic].discard(peer)
|
||||
|
||||
# Emit PRUNE(topic) control message to peer
|
||||
await self.emit_prune(topic, peer)
|
||||
peers_to_prune[peer].append(topic)
|
||||
return peers_to_graft, peers_to_prune
|
||||
|
||||
async def fanout_heartbeat(self) -> None:
|
||||
def fanout_heartbeat(self) -> None:
|
||||
# Note: the comments here are the exact pseudocode from the spec
|
||||
for topic in self.fanout:
|
||||
# If time since last published > ttl
|
||||
# TODO: there's no way time_since_last_publish gets set anywhere yet
|
||||
if (
|
||||
topic in self.time_since_last_publish
|
||||
and self.time_since_last_publish[topic] > self.time_to_live
|
||||
):
|
||||
# Delete topic entry if it's not in `pubsub.peer_topics`
|
||||
# or (TODO) if it's time-since-last-published > ttl
|
||||
if topic not in self.pubsub.peer_topics:
|
||||
# Remove topic from fanout
|
||||
del self.fanout[topic]
|
||||
del self.time_since_last_publish[topic]
|
||||
else:
|
||||
# Check if fanout peers are still in the topic and remove the ones that are not
|
||||
# ref: https://github.com/libp2p/go-libp2p-pubsub/blob/01b9825fbee1848751d90a8469e3f5f43bac8466/gossipsub.go#L498-L504 # noqa: E501
|
||||
in_topic_fanout_peers = [
|
||||
peer
|
||||
for peer in self.fanout[topic]
|
||||
if peer in self.pubsub.peer_topics[topic]
|
||||
]
|
||||
self.fanout[topic] = set(in_topic_fanout_peers)
|
||||
num_fanout_peers_in_topic = len(self.fanout[topic])
|
||||
|
||||
# If |fanout[topic]| < D
|
||||
@ -376,53 +455,43 @@ class GossipSub(IPubsubRouter, Service):
|
||||
self.fanout[topic],
|
||||
)
|
||||
# Add the peers to fanout[topic]
|
||||
self.fanout[topic].extend(selected_peers)
|
||||
self.fanout[topic].update(selected_peers)
|
||||
|
||||
async def gossip_heartbeat(self) -> None:
|
||||
def gossip_heartbeat(self) -> DefaultDict[ID, Dict[str, List[str]]]:
|
||||
peers_to_gossip: DefaultDict[ID, Dict[str, List[str]]] = defaultdict(dict)
|
||||
for topic in self.mesh:
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# TODO: Make more efficient, possibly using a generator?
|
||||
# Get all pubsub peers in a topic and only add them if they are gossipsub peers too
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, []
|
||||
topic, self.degree, self.mesh[topic]
|
||||
)
|
||||
|
||||
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
# TODO: this line is a monster, can hopefully be simplified
|
||||
if (
|
||||
topic not in self.mesh or (peer not in self.mesh[topic])
|
||||
) and (
|
||||
topic not in self.fanout or (peer not in self.fanout[topic])
|
||||
):
|
||||
msg_id_strs = [str(msg_id) for msg_id in msg_ids]
|
||||
await self.emit_ihave(topic, msg_id_strs, peer)
|
||||
peers_to_gossip[peer][topic] = msg_id_strs
|
||||
|
||||
# TODO: Refactor and Dedup. This section is the roughly the same as the above.
|
||||
# Do the same for fanout, for all topics not already hit in mesh
|
||||
for topic in self.fanout:
|
||||
if topic not in self.mesh:
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# TODO: Make more efficient, possibly using a generator?
|
||||
# Get all pubsub peers in topic and only add if they are gossipsub peers also
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, []
|
||||
)
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
if peer not in self.fanout[topic]:
|
||||
msg_id_strs = [str(msg) for msg in msg_ids]
|
||||
await self.emit_ihave(topic, msg_id_strs, peer)
|
||||
|
||||
self.mcache.shift()
|
||||
msg_ids = self.mcache.window(topic)
|
||||
if msg_ids:
|
||||
# Get all pubsub peers in topic and only add if they are gossipsub peers also
|
||||
if topic in self.pubsub.peer_topics:
|
||||
# Select D peers from peers.gossipsub[topic]
|
||||
peers_to_emit_ihave_to = self._get_in_topic_gossipsub_peers_from_minus(
|
||||
topic, self.degree, self.fanout[topic]
|
||||
)
|
||||
msg_id_strs = [str(msg) for msg in msg_ids]
|
||||
for peer in peers_to_emit_ihave_to:
|
||||
peers_to_gossip[peer][topic] = msg_id_strs
|
||||
return peers_to_gossip
|
||||
|
||||
@staticmethod
|
||||
def select_from_minus(
|
||||
num_to_select: int, pool: Sequence[Any], minus: Sequence[Any]
|
||||
num_to_select: int, pool: Iterable[Any], minus: Iterable[Any]
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Select at most num_to_select subset of elements from the set (pool - minus) randomly.
|
||||
@ -441,7 +510,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
|
||||
# If num_to_select > size(selection_pool), then return selection_pool (which has the most
|
||||
# possible elements s.t. the number of elements is less than num_to_select)
|
||||
if num_to_select > len(selection_pool):
|
||||
if num_to_select >= len(selection_pool):
|
||||
return selection_pool
|
||||
|
||||
# Random selection
|
||||
@ -450,16 +519,14 @@ class GossipSub(IPubsubRouter, Service):
|
||||
return selection
|
||||
|
||||
def _get_in_topic_gossipsub_peers_from_minus(
|
||||
self, topic: str, num_to_select: int, minus: Sequence[ID]
|
||||
self, topic: str, num_to_select: int, minus: Iterable[ID]
|
||||
) -> List[ID]:
|
||||
gossipsub_peers_in_topic = [
|
||||
gossipsub_peers_in_topic = set(
|
||||
peer_id
|
||||
for peer_id in self.pubsub.peer_topics[topic]
|
||||
if peer_id in self.peers_gossipsub
|
||||
]
|
||||
return self.select_from_minus(
|
||||
num_to_select, gossipsub_peers_in_topic, list(minus)
|
||||
if self.peer_protocol[peer_id] == PROTOCOL_ID
|
||||
)
|
||||
return self.select_from_minus(num_to_select, gossipsub_peers_in_topic, minus)
|
||||
|
||||
# RPC handlers
|
||||
|
||||
@ -517,6 +584,12 @@ class GossipSub(IPubsubRouter, Service):
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# 3) Get the stream to this peer
|
||||
if sender_peer_id not in self.pubsub.peers:
|
||||
logger.debug(
|
||||
"Fail to responed to iwant request from %s: peer record not exist",
|
||||
sender_peer_id,
|
||||
)
|
||||
return
|
||||
peer_stream = self.pubsub.peers[sender_peer_id]
|
||||
|
||||
# 4) And write the packet to the stream
|
||||
@ -537,7 +610,7 @@ class GossipSub(IPubsubRouter, Service):
|
||||
# Add peer to mesh for topic
|
||||
if topic in self.mesh:
|
||||
if sender_peer_id not in self.mesh[topic]:
|
||||
self.mesh[topic].append(sender_peer_id)
|
||||
self.mesh[topic].add(sender_peer_id)
|
||||
else:
|
||||
# Respond with PRUNE if not subscribed to the topic
|
||||
await self.emit_prune(topic, sender_peer_id)
|
||||
@ -547,12 +620,27 @@ class GossipSub(IPubsubRouter, Service):
|
||||
) -> None:
|
||||
topic: str = prune_msg.topicID
|
||||
|
||||
# Remove peer from mesh for topic, if peer is in topic
|
||||
if topic in self.mesh and sender_peer_id in self.mesh[topic]:
|
||||
self.mesh[topic].remove(sender_peer_id)
|
||||
# Remove peer from mesh for topic
|
||||
if topic in self.mesh:
|
||||
self.mesh[topic].discard(sender_peer_id)
|
||||
|
||||
# RPC emitters
|
||||
|
||||
def pack_control_msgs(
|
||||
self,
|
||||
ihave_msgs: List[rpc_pb2.ControlIHave],
|
||||
graft_msgs: List[rpc_pb2.ControlGraft],
|
||||
prune_msgs: List[rpc_pb2.ControlPrune],
|
||||
) -> rpc_pb2.ControlMessage:
|
||||
control_msg: rpc_pb2.ControlMessage = rpc_pb2.ControlMessage()
|
||||
if ihave_msgs:
|
||||
control_msg.ihave.extend(ihave_msgs)
|
||||
if graft_msgs:
|
||||
control_msg.graft.extend(graft_msgs)
|
||||
if prune_msgs:
|
||||
control_msg.prune.extend(prune_msgs)
|
||||
return control_msg
|
||||
|
||||
async def emit_ihave(self, topic: str, msg_ids: Any, to_peer: ID) -> None:
|
||||
"""Emit ihave message, sent to to_peer, for topic and msg_ids."""
|
||||
|
||||
@ -608,6 +696,11 @@ class GossipSub(IPubsubRouter, Service):
|
||||
rpc_msg: bytes = packet.SerializeToString()
|
||||
|
||||
# Get stream for peer from pubsub
|
||||
if to_peer not in self.pubsub.peers:
|
||||
logger.debug(
|
||||
"Fail to emit control message to %s: peer record not exist", to_peer
|
||||
)
|
||||
return
|
||||
peer_stream = self.pubsub.peers[to_peer]
|
||||
|
||||
# Write rpc to stream
|
||||
|
||||
@ -96,8 +96,7 @@ class MessageCache:
|
||||
last_entries: List[CacheEntry] = self.history[len(self.history) - 1]
|
||||
|
||||
for entry in last_entries:
|
||||
if entry.mid in self.msgs:
|
||||
del self.msgs[entry.mid]
|
||||
self.msgs.pop(entry.mid)
|
||||
|
||||
i: int = len(self.history) - 2
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import (
|
||||
KeysView,
|
||||
List,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
@ -19,6 +20,7 @@ import base58
|
||||
from lru import LRU
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.keys import PrivateKey
|
||||
from libp2p.exceptions import ParseError, ValidationError
|
||||
from libp2p.host.host_interface import IHost
|
||||
from libp2p.io.exceptions import IncompleteReadError
|
||||
@ -33,7 +35,7 @@ from .abc import IPubsub, ISubscriptionAPI
|
||||
from .pb import rpc_pb2
|
||||
from .pubsub_notifee import PubsubNotifee
|
||||
from .subscription import TrioSubscriptionAPI
|
||||
from .validators import signature_validator
|
||||
from .validators import PUBSUB_SIGNING_PREFIX, signature_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abc import IPubsubRouter # noqa: F401
|
||||
@ -73,16 +75,23 @@ class Pubsub(IPubsub, Service):
|
||||
subscribed_topics_send: Dict[str, "trio.MemorySendChannel[rpc_pb2.Message]"]
|
||||
subscribed_topics_receive: Dict[str, "TrioSubscriptionAPI"]
|
||||
|
||||
peer_topics: Dict[str, List[ID]]
|
||||
peer_topics: Dict[str, Set[ID]]
|
||||
peers: Dict[ID, INetStream]
|
||||
|
||||
topic_validators: Dict[str, TopicValidator]
|
||||
|
||||
# TODO: Be sure it is increased atomically everytime.
|
||||
counter: int # uint64
|
||||
|
||||
# Indicate if we should enforce signature verification
|
||||
strict_signing: bool
|
||||
sign_key: PrivateKey
|
||||
|
||||
def __init__(
|
||||
self, host: IHost, router: "IPubsubRouter", cache_size: int = None
|
||||
self,
|
||||
host: IHost,
|
||||
router: "IPubsubRouter",
|
||||
cache_size: int = None,
|
||||
strict_signing: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Construct a new Pubsub object, which is responsible for handling all
|
||||
@ -126,6 +135,12 @@ class Pubsub(IPubsub, Service):
|
||||
else:
|
||||
self.cache_size = cache_size
|
||||
|
||||
self.strict_signing = strict_signing
|
||||
if strict_signing:
|
||||
self.sign_key = self.host.get_private_key()
|
||||
else:
|
||||
self.sign_key = None
|
||||
|
||||
self.seen_messages = LRU(self.cache_size)
|
||||
|
||||
# Map of topics we are subscribed to blocking queues
|
||||
@ -142,7 +157,7 @@ class Pubsub(IPubsub, Service):
|
||||
# Map of topic to topic validator
|
||||
self.topic_validators = {}
|
||||
|
||||
self.counter = time.time_ns()
|
||||
self.counter = int(time.time())
|
||||
|
||||
async def run(self) -> None:
|
||||
self.manager.run_daemon_task(self.handle_peer_queue)
|
||||
@ -239,8 +254,7 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
:param topic: the topic to remove validator from
|
||||
"""
|
||||
if topic in self.topic_validators:
|
||||
del self.topic_validators[topic]
|
||||
self.topic_validators.pop(topic, None)
|
||||
|
||||
def get_msg_validators(self, msg: rpc_pb2.Message) -> Tuple[TopicValidator, ...]:
|
||||
"""
|
||||
@ -282,24 +296,22 @@ class Pubsub(IPubsub, Service):
|
||||
logger.debug("fail to add new peer %s, error %s", peer_id, error)
|
||||
return
|
||||
|
||||
self.peers[peer_id] = stream
|
||||
|
||||
# Send hello packet
|
||||
hello = self.get_hello_packet()
|
||||
try:
|
||||
await stream.write(encode_varint_prefixed(hello.SerializeToString()))
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
# TODO: Check if the peer in black list.
|
||||
try:
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
except Exception as error:
|
||||
logger.debug("fail to add new peer %s, error %s", peer_id, error)
|
||||
del self.peers[peer_id]
|
||||
return
|
||||
|
||||
self.peers[peer_id] = stream
|
||||
|
||||
logger.debug("added new peer %s", peer_id)
|
||||
|
||||
def _handle_dead_peer(self, peer_id: ID) -> None:
|
||||
@ -309,19 +321,16 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
for topic in self.peer_topics:
|
||||
if peer_id in self.peer_topics[topic]:
|
||||
self.peer_topics[topic].remove(peer_id)
|
||||
self.peer_topics[topic].discard(peer_id)
|
||||
|
||||
self.router.remove_peer(peer_id)
|
||||
|
||||
logger.debug("removed dead peer %s", peer_id)
|
||||
|
||||
async def handle_peer_queue(self) -> None:
|
||||
"""
|
||||
Continuously read from peer channel and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol
|
||||
TODO: Handle failure for when the peer does not support any of the
|
||||
pubsub protocols we support
|
||||
"""
|
||||
"""Continuously read from peer queue and each time a new peer is found,
|
||||
open a stream to the peer using a supported pubsub protocol pubsub
|
||||
protocols we support."""
|
||||
async with self.peer_receive_channel:
|
||||
while self.manager.is_running:
|
||||
peer_id: ID = await self.peer_receive_channel.receive()
|
||||
@ -351,14 +360,14 @@ class Pubsub(IPubsub, Service):
|
||||
"""
|
||||
if sub_message.subscribe:
|
||||
if sub_message.topicid not in self.peer_topics:
|
||||
self.peer_topics[sub_message.topicid] = [origin_id]
|
||||
self.peer_topics[sub_message.topicid] = set([origin_id])
|
||||
elif origin_id not in self.peer_topics[sub_message.topicid]:
|
||||
# Add peer to topic
|
||||
self.peer_topics[sub_message.topicid].append(origin_id)
|
||||
self.peer_topics[sub_message.topicid].add(origin_id)
|
||||
else:
|
||||
if sub_message.topicid in self.peer_topics:
|
||||
if origin_id in self.peer_topics[sub_message.topicid]:
|
||||
self.peer_topics[sub_message.topicid].remove(origin_id)
|
||||
self.peer_topics[sub_message.topicid].discard(origin_id)
|
||||
|
||||
# FIXME(mhchia): Change the function name?
|
||||
async def handle_talk(self, publish_message: rpc_pb2.Message) -> None:
|
||||
@ -476,7 +485,13 @@ class Pubsub(IPubsub, Service):
|
||||
seqno=self._next_seqno(),
|
||||
)
|
||||
|
||||
# TODO: Sign with our signing key
|
||||
if self.strict_signing:
|
||||
priv_key = self.sign_key
|
||||
signature = priv_key.sign(
|
||||
PUBSUB_SIGNING_PREFIX.encode() + msg.SerializeToString()
|
||||
)
|
||||
msg.key = self.host.get_public_key().serialize()
|
||||
msg.signature = signature
|
||||
|
||||
await self.push_msg(self.my_id, msg)
|
||||
|
||||
@ -536,18 +551,17 @@ class Pubsub(IPubsub, Service):
|
||||
|
||||
# TODO: Check if the `from` is in the blacklist. If yes, reject.
|
||||
|
||||
# TODO: Check if signing is required and if so signature should be attached.
|
||||
|
||||
# If the message is processed before, return(i.e., don't further process the message).
|
||||
if self._is_msg_seen(msg):
|
||||
return
|
||||
|
||||
# TODO: - Validate the message. If failed, reject it.
|
||||
# Validate the signature of the message
|
||||
# FIXME: `signature_validator` is currently a stub.
|
||||
if not signature_validator(msg.key, msg.SerializeToString()):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
# Check if signing is required and if so validate the signature
|
||||
if self.strict_signing:
|
||||
# Validate the signature of the message
|
||||
if not signature_validator(msg):
|
||||
logger.debug("Signature validation failed for msg: %s", msg)
|
||||
return
|
||||
|
||||
# Validate the message with registered topic validators.
|
||||
# If the validation failed, return(i.e., don't further process the message).
|
||||
try:
|
||||
|
||||
@ -1,10 +1,41 @@
|
||||
# FIXME: Replace the type of `pubkey` with a custom type `Pubkey`
|
||||
def signature_validator(pubkey: bytes, msg: bytes) -> bool:
|
||||
import logging
|
||||
|
||||
from libp2p.crypto.serialization import deserialize_public_key
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
from .pb import rpc_pb2
|
||||
|
||||
logger = logging.getLogger("libp2p.pubsub")
|
||||
|
||||
PUBSUB_SIGNING_PREFIX = "libp2p-pubsub:"
|
||||
|
||||
|
||||
def signature_validator(msg: rpc_pb2.Message) -> bool:
|
||||
"""
|
||||
Verify the message against the given public key.
|
||||
|
||||
:param pubkey: the public key which signs the message.
|
||||
:param msg: the message signed.
|
||||
"""
|
||||
# TODO: Implement the signature validation
|
||||
return True
|
||||
# Check if signature is attached
|
||||
if msg.signature == b"":
|
||||
logger.debug("Reject because no signature attached for msg: %s", msg)
|
||||
return False
|
||||
|
||||
# Validate if message sender matches message signer,
|
||||
# i.e., check if `msg.key` matches `msg.from_id`
|
||||
msg_pubkey = deserialize_public_key(msg.key)
|
||||
if ID.from_pubkey(msg_pubkey) != msg.from_id:
|
||||
logger.debug(
|
||||
"Reject because signing key does not match sender ID for msg: %s", msg
|
||||
)
|
||||
return False
|
||||
# First, construct the original payload that's signed by 'msg.key'
|
||||
msg_without_key_sig = rpc_pb2.Message(
|
||||
data=msg.data, topicIDs=msg.topicIDs, from_id=msg.from_id, seqno=msg.seqno
|
||||
)
|
||||
payload = PUBSUB_SIGNING_PREFIX.encode() + msg_without_key_sig.SerializeToString()
|
||||
try:
|
||||
return msg_pubkey.verify(payload, msg.signature)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user